Source code for sheaf_mpnn.base_conv

# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
#   Mario Severino, Emanuele Mule, Dario Loi,
#   Francesco Restuccia, Fabrizio Silvestri, Pietro Liò

from abc import abstractmethod

import torch
from torch import nn
from torch_geometric.nn import MessagePassing


[docs] class BaseSheafConv(MessagePassing): """Shared base for all NSD and NSP sheaf convolution layers. Factors out the parameterization and utilities that are identical across every variant in both model families: * ``W1`` / ``W2`` -- bilateral stalk transforms (left d*d, right f*f). * ``sigma`` -- activation function (Tanh). * ``reset_parameters()`` -- Xavier init for W1, W2, and any map_generator. * ``_apply_stalk_transform(x)`` -- computes ``W1 @ x @ W2``. * ``_apply_norm(...)`` -- abstract; each concrete subclass delegates to the appropriate ``apply_*_norm`` function from ``sheaf_mpnn.utils``. Concrete subclasses must implement: ``get_map_products(x_feat, edge_index) -> (self_map, cross_map)`` ``_apply_norm(self_map, cross_map, edge_index, num_nodes)`` ``forward(x_feat, x_stalk, edge_index) -> updated stalk`` ``message(...)`` """
[docs] def __init__( self, stalk_dim: int, in_channels: int, hidden_dim: int, context_dim: int | None = None, add_self_loops: bool = True, ): super().__init__(aggr="add", node_dim=0) self.stalk_dim = stalk_dim self.in_channels = in_channels # 'f' (feature dimension per stalk entry) self.context_dim = ( context_dim if context_dim is not None else (stalk_dim * in_channels) ) self.add_self_loops = add_self_loops self.W1 = nn.Parameter(torch.empty(stalk_dim, stalk_dim)) self.W2 = nn.Parameter(torch.empty(in_channels, in_channels)) # [f, f] self.sigma = nn.Tanh()
[docs] def reset_parameters(self): nn.init.xavier_uniform_(self.W1) nn.init.xavier_uniform_(self.W2) # Covers both the singular map_generator (NSD) and the # map_generators ModuleList (SheafAttnConv). generators: list[nn.Sequential] = [] if hasattr(self, "map_generator"): gen = self.map_generator if isinstance(gen, nn.Sequential): generators.append(gen) if hasattr(self, "map_generators"): map_gens = self.map_generators if isinstance(map_gens, nn.ModuleList): for gen in map_gens: if isinstance(gen, nn.Sequential): generators.append(gen) for gen in generators: for m in gen: if isinstance(m, nn.Linear): # gain=0.01: warm-start near-zero so the Laplacian is off at init. nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0.0)
def _apply_stalk_transform(self, x): """Applies bilateral stalk transform: W1 @ x @ W2.""" return torch.matmul(torch.matmul(self.W1, x), self.W2) @abstractmethod def _apply_norm(self, self_map, cross_map, edge_index, num_nodes): """Normalizes restriction-map products by the sheaf degree matrix D^{-1/2}. Each concrete subclass calls the matching ``apply_*_norm`` utility from ``sheaf_mpnn.utils``: Diagonal → apply_diagonal_norm(self_map, cross_map, edge_index, num_nodes) Orthogonal → apply_orthogonal_norm(cross_map, edge_index, num_nodes) Low-rank → apply_low_rank_norm(self_map, cross_map, ..., stalk_dim, ...) General → apply_general_norm(self_map, cross_map, ..., stalk_dim, ...) Returns: (norm_self, norm_cross): Normalized products ready for message(). """ raise NotImplementedError("Subclasses must implement _apply_norm.")
[docs] def message( # ty: ignore[invalid-method-override] self, z_dst, z_src, self_map, cross_map ): """Builds per-edge sheaf Laplacian messages. Args: z_dst: Destination-node transformed stalks [E, d, f]. z_src: Source-node transformed stalks [E, d, f]. self_map: Normalized F_dst^T F_dst per edge [E, d, d]. cross_map: Normalized F_dst^T F_src per edge [E, d, d]. Returns: torch.Tensor: Per-edge messages [E, d, f]. """ return torch.matmul(self_map, z_dst) - torch.matmul(cross_map, z_src)
__all__ = ["BaseSheafConv"]