Source code for sheaf_mpnn.nsd.nsd_model

# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
#   Mario Severino, Emanuele Mule, Dario Loi,
#   Francesco Restuccia, Fabrizio Silvestri, Pietro Liò

from enum import Enum, auto
from typing import Any, Literal

import torch.nn.functional as F
from torch import nn
from torch_geometric.nn.models import JumpingKnowledge

from sheaf_mpnn.nsd.nsd_layers import (
    BaseNSDConv,
    DiagonalNSDConv,
    GeneralNSDConv,
    LowRankNSDConv,
    OrthogonalNSDConv,
)


[docs] class NSDVariant(Enum): DIAGONAL = auto() GENERAL = auto() ORTHOGONAL = auto() GENERAL_ATTENTION = auto() ORTHOGONAL_ATTENTION = auto() LOW_RANK = auto() @property def layer_class(self): mapping = { NSDVariant.DIAGONAL: DiagonalNSDConv, NSDVariant.GENERAL: GeneralNSDConv, NSDVariant.ORTHOGONAL: OrthogonalNSDConv, NSDVariant.LOW_RANK: LowRankNSDConv, NSDVariant.GENERAL_ATTENTION: GeneralNSDConv, NSDVariant.ORTHOGONAL_ATTENTION: OrthogonalNSDConv, } return mapping[self] @property def layer_kwargs(self) -> dict[str, Any]: if self in {NSDVariant.DIAGONAL, NSDVariant.LOW_RANK}: return {} if self in {NSDVariant.GENERAL_ATTENTION, NSDVariant.ORTHOGONAL_ATTENTION}: return {"use_attention": True} return {"use_attention": False}
[docs] def build_kwargs( self, orth_strategy: Literal["cayley", "fasth"] = "cayley", rank: int = 1, ) -> dict[str, Any]: """Build the full layer keyword-argument dict for this variant.""" if self == NSDVariant.DIAGONAL: return {} if self == NSDVariant.LOW_RANK: return {"rank": rank} if self == NSDVariant.GENERAL: return {"use_attention": False} if self == NSDVariant.GENERAL_ATTENTION: return {"use_attention": True} if self == NSDVariant.ORTHOGONAL: return {"use_attention": False, "orth_strategy": orth_strategy} # ORTHOGONAL_ATTENTION return {"use_attention": True, "orth_strategy": orth_strategy}
[docs] class NSDModel(nn.Module): """End-to-end Neural Sheaf Diffusion (NSD) model. The wrapper lifts raw node features into stalk features, applies a stack of NSD convolution layers, and decodes the flattened stalk representation back to the requested output dimension. """
[docs] def __init__( self, in_channels: int, out_channels: int, stalk_dim: int = 4, hidden_dim: int = 16, num_layers: int = 2, variant: NSDVariant = NSDVariant.GENERAL, alpha: float = 1.0, add_self_loops: bool = True, orth_strategy: str = "cayley", rank: int = 1, input_dropout: float = 0.0, dropout: float = 0.0, normalize_output: bool = True, jknet: bool = False, ): """Initializes an NSD model for node-level prediction. Args: in_channels (int): Number of raw input features per node. out_channels (int): Number of output channels per node (e.g. num classes). stalk_dim (int, optional): Stalk dimension. Each node is represented internally as a matrix with shape ``[stalk_dim, hidden_dim]``. hidden_dim (int, optional): Feature dimension inside each stalk channel. The encoded node state has size ``d * hidden_dim``. num_layers (int, optional): Number of NSD convolution layers. Must be positive. variant (NSDVariant, optional): Restriction-map family. ``DIAGONAL`` is cheapest, ``GENERAL`` is most expressive, ``ORTHOGONAL`` uses orthogonal maps (via Cayley or Householder parameterisation). ``GENERAL_ATTENTION`` and ``ORTHOGONAL_ATTENTION`` use an attention-based map initialisation. alpha (float, optional): Initial learnable diffusion step size per layer. add_self_loops (bool, optional): If ``True``, self-loops are added to the graph before computing degree normalization in each layer. Defaults to ``True``. orth_strategy (str, optional): Orthogonality strategy for the ``ORTHOGONAL`` variant: "cayley" or "fasth". Defaults to "cayley". rank (int, optional): Rank of each restriction map for the ``LOW_RANK`` variant. Must be positive. Ignored for other variants. Defaults to 1. input_dropout (float, optional): Dropout probability applied to raw input features before encoding. Defaults to 0.0. dropout (float, optional): Dropout probability applied to stalk features between layers. Defaults to 0.0. normalize_output (bool, optional): If ``True``, L2-normalise the representation before the decoder (Lv et al., 2021). If ``jknet`` is ``True``, each layer's output is also normalised before concatenation. Defaults to ``True``. jknet (bool, optional): If ``True``, collect hidden states from every layer and concatenate them before the decoder (Xu et al., 2018). Normalization is controlled by ``normalize_output``. Intended for link prediction. Defaults to ``False``. """ super().__init__() if stalk_dim <= 0: raise ValueError("stalk_dim must be positive") if hidden_dim <= 0: raise ValueError("hidden_dim must be positive") if num_layers <= 0: raise ValueError("must have at least one NSD layer") self.stalk_dim = stalk_dim self.hidden_dim = hidden_dim self.out_channels = out_channels self.num_layers = num_layers self.rank = rank self.normalize_output = normalize_output self.jknet = jknet context_dim = stalk_dim * hidden_dim layer_class = variant.layer_class self.input_dropout_layer = nn.Dropout(p=input_dropout) self.dropout_layer = nn.Dropout(p=dropout) self.encoder = nn.Linear(in_channels, context_dim) extra_kwargs = variant.build_kwargs(orth_strategy=orth_strategy, rank=rank) self.layers = nn.ModuleList( [ layer_class( stalk_dim=stalk_dim, in_channels=hidden_dim, # 'f' for W2 [f x f] hidden_dim=hidden_dim, context_dim=context_dim, # 'd*f' for MLP input [2*df x hidden] alpha=alpha, add_self_loops=add_self_loops, **extra_kwargs, ) for _ in range(num_layers) ] ) # JKNet expands decoder input to L context_dim (Xu et al., 2018). decoder_in = (num_layers if jknet else 1) * context_dim self.decoder = nn.Linear(decoder_in, out_channels) if jknet: self.jk = JumpingKnowledge( mode="cat", channels=context_dim, num_layers=num_layers )
[docs] def reset_parameters(self): self.encoder.reset_parameters() for layer in self.layers: assert isinstance(layer, BaseNSDConv) layer.reset_parameters() self.decoder.reset_parameters()
[docs] def forward(self, x, edge_index): """Runs the NSD encoder, diffusion layers, and decoder. Args: x (torch.Tensor): Raw node features with shape ``[num_nodes, in_channels]``. edge_index (torch.Tensor): Graph connectivity in COO format with shape ``[2, num_edges]``. Returns: torch.Tensor: Node outputs with shape ``[num_nodes, out_channels]``. """ if x.dim() != 2 or x.size(1) != self.encoder.in_features: raise ValueError( f"x must be [num_nodes, {self.encoder.in_features}], \ got {tuple(x.shape)}" ) if edge_index.shape[0] != 2: raise ValueError( f"edge_index must have shape [2, num_edges], got \ {tuple(edge_index.shape)}" ) # Lift raw features to stalk space: [N, in_channels] -> [N, d, f]. x_stalk = self.encoder(self.input_dropout_layer(x)).view( -1, self.stalk_dim, self.hidden_dim ) layer_reps = [] for layer in self.layers: # Flatten stalk to [N, d*f] as context for restriction-map generation. x_feat = self.dropout_layer(x_stalk.reshape(x_stalk.size(0), -1)) x_stalk = layer(x_feat, x_stalk, edge_index) if self.jknet: h = x_stalk.reshape(x_stalk.size(0), -1) if self.normalize_output: h = F.normalize(h, p=2, dim=-1) layer_reps.append(h) if self.jknet: x_out = self.jk(layer_reps) else: x_out = x_stalk.reshape(x_stalk.size(0), -1) if self.normalize_output: # L2-normalise final representation — Lv et al. (2021). x_out = F.normalize(x_out, p=2, dim=-1) return self.decoder(x_out)
__all__ = ["NSDVariant", "NSDModel"]