Source code for sheaf_mpnn.nsd.nsd_layers

# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
#   Mario Severino, Emanuele Mule, Dario Loi,
#   Francesco Restuccia, Fabrizio Silvestri, Pietro Liò

from abc import abstractmethod
from typing import Literal

import torch
from torch import nn
from torch_geometric.utils import add_self_loops

from sheaf_mpnn.base_conv import BaseSheafConv
from sheaf_mpnn.utils import (
    apply_diagonal_norm,
    apply_general_norm,
    apply_low_rank_norm,
    apply_orthogonal_norm,
    attention_cayley,
    cayley,
    householder,
)


class BaseNSDConv(BaseSheafConv):
    """Base class for NSD convolutions (MPSNN formulation).

    The class computes the full sheaf Laplacian action inside message()
    without a separate scatter for the diagonal term.  Then, each of the
    subclasses implement get_map_products(), which precomputes the composed
    restriction-map products self_map = restriction_maps_dst^T restriction_maps_dst
    and cross_map = restriction_maps_dst^T restriction_maps_src per edge before the
    message loop, reducing message() to two matmuls instead of three.
    """

    def __init__(
        self,
        stalk_dim: int,
        in_channels: int,
        hidden_dim: int,
        alpha: float = 1.0,
        context_dim: int | None = None,
        add_self_loops: bool = True,
    ):
        """Initializes the shared NSD convolution parameters.

        Args:
            stalk_dim: Stalk dimension. Each node state handled by the layer has shape
                ``[stalk_dim, in_channels]``.
            in_channels: Feature dimension inside each stalk channel (f).
            hidden_dim: Hidden width of the restriction-map generator MLP.
            alpha: Initial residual diffusion step size.
            context_dim: Width of each node context vector ``x_feat``.
            add_self_loops: Whether to add self-loops for degree normalization.
        """
        super().__init__(
            stalk_dim, in_channels, hidden_dim, context_dim, add_self_loops
        )
        self.alpha = nn.Parameter(torch.tensor(alpha))

    def forward(
        self,
        x_feat: torch.Tensor,
        x_stalk: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> torch.Tensor:
        """Applies one NSD diffusion step to lifted node features.

        Args:
            x_feat (torch.Tensor): Node context features [num_nodes, context_dim].
            x_stalk (torch.Tensor): Lifted node features [num_nodes, d, in_channels].
            edge_index (torch.Tensor): Graph connectivity [2, num_edges].

        Returns:
            torch.Tensor: Updated stalk features [num_nodes, d, in_channels].
        """
        z = self._apply_stalk_transform(x_stalk)
        num_nodes = x_stalk.size(0)

        if self.add_self_loops:
            edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

        src_idx, dst_idx = edge_index

        self_map, cross_map = self.get_map_products(x_feat, edge_index)
        norm_self, norm_cross = self._apply_norm(
            self_map, cross_map, edge_index, num_nodes
        )

        z_src, z_dst = z[src_idx], z[dst_idx]
        laplacian_out = self.propagate(  # ty: ignore[missing-argument]
            edge_index,
            z_dst=z_dst,
            z_src=z_src,
            self_map=norm_self,
            cross_map=norm_cross,
            size=(num_nodes, num_nodes),
        )

        return x_stalk - self.alpha * self.sigma(laplacian_out)

    @abstractmethod
    def get_map_products(
        self, x_feat: torch.Tensor, edge_index: torch.Tensor
    ) -> tuple[torch.Tensor | None, torch.Tensor]:
        """Precompute self_map and cross_map restriction-map products per edge."""

    def _bidirectional_input(
        self, x_feat: torch.Tensor, edge_index: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Run self.map_generator on both edge orientations in one forward pass."""
        src_idx, dst_idx = edge_index
        x_dst, x_src = x_feat[dst_idx], x_feat[src_idx]
        inp = torch.cat(
            [
                torch.cat([x_dst, x_src], dim=-1),
                torch.cat([x_src, x_dst], dim=-1),
            ],
            dim=0,
        )
        raw_dst, raw_src = self.map_generator(inp).chunk(2, dim=0)
        return raw_dst, raw_src


[docs] class DiagonalNSDConv(BaseNSDConv): """Diagonal NSD convolution layer."""
[docs] def __init__( self, stalk_dim: int, in_channels: int, hidden_dim: int, alpha: float = 1.0, context_dim: int | None = None, add_self_loops: bool = True, ): super().__init__( stalk_dim, in_channels, hidden_dim, alpha, context_dim, add_self_loops ) self.map_generator = nn.Sequential( nn.Linear(2 * self.context_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, stalk_dim), ) self.reset_parameters()
def _apply_norm(self, self_map, cross_map, edge_index, num_nodes): return apply_diagonal_norm( self_map, cross_map, edge_index, num_nodes, )
[docs] def get_map_products(self, x_feat, edge_index): w_dst, w_src = self._bidirectional_input(x_feat, edge_index) # Return element-wise products [E, d] instead of [E, d, d] diagonal matrices. return w_dst**2, w_dst * w_src
[docs] def message( self, z_dst: torch.Tensor, z_src: torch.Tensor, self_map: torch.Tensor, cross_map: torch.Tensor, ) -> torch.Tensor: # self_map = w_dst^2 [E, d], cross_map = w_dst * w_src [E, d]. return self_map[:, :, None] * z_dst - cross_map[:, :, None] * z_src
[docs] class GeneralNSDConv(BaseNSDConv): """Generalized NSD convolution layer."""
[docs] def __init__( self, stalk_dim: int, in_channels: int, hidden_dim: int, alpha: float = 1.0, context_dim: int | None = None, add_self_loops: bool = True, use_attention: bool = False, ): super().__init__( stalk_dim, in_channels, hidden_dim, alpha, context_dim, add_self_loops ) self.use_attention = use_attention self.map_generator = nn.Sequential( nn.Linear(2 * self.context_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, stalk_dim * stalk_dim), ) self.reset_parameters()
def _apply_norm(self, self_map, cross_map, edge_index, num_nodes): return apply_general_norm( self_map, cross_map, edge_index, num_nodes, self.stalk_dim, self.training, )
[docs] def get_map_products(self, x_feat, edge_index): phi_dst, phi_src = ( w.view(-1, self.stalk_dim, self.stalk_dim) for w in self._bidirectional_input(x_feat, edge_index) ) if self.use_attention: eye = torch.eye( self.stalk_dim, device=x_feat.device, dtype=x_feat.dtype ).unsqueeze(0) # softmax produces row-stochastic phi; subtracting from I gives # zero-row-sum maps whose Gram products have a sheaf-Laplacian structure. phi_dst = eye - torch.softmax(phi_dst, dim=-1) phi_src = eye - torch.softmax(phi_src, dim=-1) self_map = torch.matmul(phi_dst.transpose(-2, -1), phi_dst) # [E, d, d] cross_map = torch.matmul(phi_dst.transpose(-2, -1), phi_src) # [E, d, d] return self_map, cross_map
[docs] class OrthogonalNSDConv(BaseNSDConv): """Orthogonal NSD convolution layer."""
[docs] def __init__( self, stalk_dim: int, in_channels: int, hidden_dim: int, alpha: float = 1.0, context_dim: int | None = None, add_self_loops: bool = True, clamp_val: float = 10.0, use_attention: bool = False, orth_strategy: Literal["cayley", "fasth"] = "cayley", ): """Initializes an orthogonal NSD convolution layer. The map generator outputs parameters for either: (1) entries of a skew-symmetric matrix (cayley), (2) Householder vectors (fasth), or (3) attention-based mappings. All parameterisations produce orthogonal ``d x d`` restriction maps. Args: stalk_dim (int): Stalk dimension and orthogonal restriction-map matrix size. in_channels (int): Feature dimension inside each stalk channel. hidden_dim (int): Hidden width of the restriction-map generator MLP. alpha (float, optional): Initial learnable diffusion step size. Defaults to 1.0. context_dim (int, optional): Width of ``x_feat``. Defaults to ``d * in_channels`` when omitted. add_self_loops (bool, optional): If ``True``, self-loops augment the degree used for normalization. Defaults to ``True``. clamp_val (float, optional): Maximum absolute value for clamping Cayley-transform parameters. Defaults to 10.0. use_attention (bool, optional): If ``True``, uses the attention-based Cayley initialization from main. Defaults to ``False``. orth_strategy (str, optional): "cayley" or "fasth". Defaults to "cayley". """ if orth_strategy not in {"cayley", "fasth"}: raise ValueError( f"orth_strategy must be 'cayley' or 'fasth', got {orth_strategy!r}" ) super().__init__( stalk_dim, in_channels, hidden_dim, alpha, context_dim, add_self_loops ) self.clamp_val = clamp_val self.use_attention = use_attention self.orth_strategy = orth_strategy if use_attention or orth_strategy == "fasth": num_params = stalk_dim * stalk_dim else: num_params = (stalk_dim * (stalk_dim - 1)) // 2 self.map_generator = nn.Sequential( nn.Linear(2 * self.context_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, num_params), ) self.reset_parameters()
def _apply_norm(self, self_map, cross_map, edge_index, num_nodes): return apply_orthogonal_norm(cross_map, edge_index, num_nodes)
[docs] def get_map_products(self, x_feat, edge_index): params_dst, params_src = self._bidirectional_input(x_feat, edge_index) if self.use_attention: W_dst = attention_cayley( params_dst, self.stalk_dim, x_feat.device, x_feat.dtype ) W_src = attention_cayley( params_src, self.stalk_dim, x_feat.device, x_feat.dtype ) elif self.orth_strategy == "fasth": W_dst = householder(params_dst, self.stalk_dim) W_src = householder(params_src, self.stalk_dim) else: W_dst = cayley(params_dst, self.stalk_dim, self.clamp_val) W_src = cayley(params_src, self.stalk_dim, self.clamp_val) # W_dst^T W_dst = I (orthogonality), so self_map = I; only cross_map is needed. cross_map = torch.matmul(W_dst.transpose(-2, -1), W_src) # [E, d, d] return None, cross_map
[docs] def message( self, z_dst: torch.Tensor, z_src: torch.Tensor, self_map: torch.Tensor, cross_map: torch.Tensor, ) -> torch.Tensor: # self_map = D^{-1}_dst [E,1,1]; cross_map = D^{-1/2}_dst W^T W D^{-1/2}_src. return self_map * z_dst - torch.matmul(cross_map, z_src)
class LowRankNSDConv(BaseNSDConv): """Low-rank NSD convolution layer. Parameterizes each restriction map as F = A @ B^T where A, B ∈ R^{dxr}, bounding the effective rank of each map to at most ``rank``. This gives cheaper parameterization than ``GeneralNSDConv`` (2*d*r vs d*d params per map) while keeping the full d-dimensional stalk. The Laplacian products reduce to: restriction_maps_dst^T restriction_maps_dst = right_dst (left_dst^T left_dst) right_dst^T [d, d] restriction_maps_dst^T restriction_maps_src = right_dst (left_dst^T left_src) right_src^T [d, d] """ def __init__( self, stalk_dim: int, in_channels: int, hidden_dim: int, alpha: float = 1.0, context_dim: int | None = None, add_self_loops: bool = True, rank: int = 1, ): """Initializes a low-rank NSD convolution layer. Args: stalk_dim: Stalk dimension. in_channels: Feature dimension inside each stalk channel. hidden_dim: Hidden width of the restriction-map generator MLP. alpha: Initial learnable diffusion step size. Defaults to 1.0. context_dim: Width of ``x_feat``. Defaults to ``d * in_channels``. add_self_loops: Whether to add self-loops for degree normalization. rank: Rank of each restriction map F = A @ B^T. Must be positive. """ super().__init__( stalk_dim, in_channels, hidden_dim, alpha, context_dim, add_self_loops ) if rank <= 0: raise ValueError("rank must be positive") self.rank = rank self.map_generator = nn.Sequential( nn.Linear(2 * self.context_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 2 * stalk_dim * rank), ) self.reset_parameters() def _apply_norm(self, self_map, cross_map, edge_index, num_nodes): return apply_low_rank_norm( self_map, cross_map, edge_index, num_nodes, self.stalk_dim, self.training, ) def get_map_products(self, x_feat, edge_index): raw_dst, raw_src = self._bidirectional_input(x_feat, edge_index) # Reshape into [E, d, 2r], then split into left [E, d, r] and right [E, d, r]. # Restriction map F = left @ right^T, so F^T = right @ left^T. left_dst, right_dst = raw_dst.view(-1, self.stalk_dim, 2 * self.rank).chunk( 2, dim=-1 ) left_src, right_src = raw_src.view(-1, self.stalk_dim, 2 * self.rank).chunk( 2, dim=-1 ) gram_dst_dst = torch.matmul(left_dst.transpose(-2, -1), left_dst) # [E, r, r] gram_dst_src = torch.matmul(left_dst.transpose(-2, -1), left_src) # [E, r, r] self_map = torch.matmul( right_dst, torch.matmul(gram_dst_dst, right_dst.transpose(-2, -1)) ) cross_map = torch.matmul( right_dst, torch.matmul(gram_dst_src, right_src.transpose(-2, -1)) ) return self_map, cross_map __all__ = [ "BaseNSDConv", "DiagonalNSDConv", "GeneralNSDConv", "OrthogonalNSDConv", "LowRankNSDConv", ]