exp#

Experiment framework built on PyTorch Lightning, tyro, and Optuna.

Configuration#

Typed configuration dataclasses for the NSD benchmark runner.

Consumed by tyro.cli in exp/run.py. Per-dataset preset defaults live in exp/registries/presets.py and are injected via tyro.cli(Config, default=...), so every field remains overridable from the command line.

class ModelType[source]#

Bases: StrEnum

NSD = 'nsd'#
__new__(value)#
class ModelVariant[source]#

Bases: StrEnum

DIAGONAL = 'diagonal'#
GENERAL = 'general'#
ORTHOGONAL = 'orthogonal'#
GENERAL_ATTENTION = 'general_attention'#
ORTHOGONAL_ATTENTION = 'orthogonal_attention'#
LOW_RANK = 'low_rank'#
__new__(value)#
class DatasetConfig[source]#

Bases: object

Dataset identity and storage location.

Parameters:
  • name (str, default: 'cora')

  • root (str, default: <factory>)

name: str = 'cora'#
root: str#
__init__(name='cora', root=<factory>)#
Parameters:
  • name (str, default: 'cora')

  • root (str, default: <factory>)

Return type:

None

class ModelConfig[source]#

Bases: object

NSD architecture hyperparameters.

Parameters:
  • type (ModelType, default: <ModelType.NSD: 'nsd'>)

  • variant (Literal['diagonal', 'general', 'orthogonal', 'low_rank', 'general_attention', 'orthogonal_attention'], default: 'general')

  • stalk_dim (int, default: 4)

  • hidden_dim (int, default: 16)

  • num_layers (int, default: 2)

  • alpha (float, default: 1.0)

  • rank (int, default: 1)

  • orth_strategy (Literal['cayley', 'fasth'], default: 'cayley')

  • normalize_output (bool, default: True)

  • jknet (bool, default: False)

  • num_heads (int, default: 1)

  • leaky_relu_slope (float, default: 0.2)

  • clamp_val (float, default: 10.0)

type: ModelType = 'nsd'#
variant: Literal['diagonal', 'general', 'orthogonal', 'low_rank', 'general_attention', 'orthogonal_attention'] = 'general'#
stalk_dim: int = 4#
hidden_dim: int = 16#
num_layers: int = 2#
alpha: float = 1.0#
rank: int = 1#
orth_strategy: Literal['cayley', 'fasth'] = 'cayley'#
normalize_output: bool = True#
jknet: bool = False#
num_heads: int = 1#
leaky_relu_slope: float = 0.2#
clamp_val: float = 10.0#
__init__(type=ModelType.NSD, variant='general', stalk_dim=4, hidden_dim=16, num_layers=2, alpha=1.0, rank=1, orth_strategy='cayley', normalize_output=True, jknet=False, num_heads=1, leaky_relu_slope=0.2, clamp_val=10.0)#
Parameters:
  • type (ModelType, default: <ModelType.NSD: 'nsd'>)

  • variant (Literal['diagonal', 'general', 'orthogonal', 'low_rank', 'general_attention', 'orthogonal_attention'], default: 'general')

  • stalk_dim (int, default: 4)

  • hidden_dim (int, default: 16)

  • num_layers (int, default: 2)

  • alpha (float, default: 1.0)

  • rank (int, default: 1)

  • orth_strategy (Literal['cayley', 'fasth'], default: 'cayley')

  • normalize_output (bool, default: True)

  • jknet (bool, default: False)

  • num_heads (int, default: 1)

  • leaky_relu_slope (float, default: 0.2)

  • clamp_val (float, default: 10.0)

Return type:

None

class RegConfig[source]#

Bases: object

Regularisation hyperparameters.

Parameters:
  • input_dropout (float, default: 0.0)

  • dropout (float, default: 0.0)

input_dropout: float = 0.0#
dropout: float = 0.0#
__init__(input_dropout=0.0, dropout=0.0)#
Parameters:
  • input_dropout (float, default: 0.0)

  • dropout (float, default: 0.0)

Return type:

None

class OptimConfig[source]#

Bases: object

Optimisation schedule.

Parameters:
  • lr (float, default: 0.01)

  • weight_decay (float, default: 0.0005)

  • epochs (int, default: 1000)

  • early_stopping (int, default: 200)

  • stop_strategy (Literal['loss', 'acc'], default: 'loss')

  • batch_size (int, default: 1)

lr: float = 0.01#
weight_decay: float = 0.0005#
epochs: int = 1000#
early_stopping: int = 200#
stop_strategy: Literal['loss', 'acc'] = 'loss'#
batch_size: int = 1#
__init__(lr=0.01, weight_decay=0.0005, epochs=1000, early_stopping=200, stop_strategy='loss', batch_size=1)#
Parameters:
  • lr (float, default: 0.01)

  • weight_decay (float, default: 0.0005)

  • epochs (int, default: 1000)

  • early_stopping (int, default: 200)

  • stop_strategy (Literal['loss', 'acc'], default: 'loss')

  • batch_size (int, default: 1)

Return type:

None

class CVConfig[source]#

Bases: object

Cross-validation setup.

Parameters:
  • folds (int, default: 10)

  • seed (int, default: 42)

  • min_acc (float, default: 0.0)

folds: int = 10#
seed: int = 42#
min_acc: float = 0.0#
__init__(folds=10, seed=42, min_acc=0.0)#
Parameters:
  • folds (int, default: 10)

  • seed (int, default: 42)

  • min_acc (float, default: 0.0)

Return type:

None

class HardwareConfig[source]#

Bases: object

Hardware selection and data loading performance.

Parameters:
  • cuda (int, default: 0)

  • num_workers (int, default: 0)

  • pin_memory (bool, default: True)

  • persistent_workers (bool, default: False)

cuda: int = 0#
num_workers: int = 0#
pin_memory: bool = True#
persistent_workers: bool = False#
__init__(cuda=0, num_workers=0, pin_memory=True, persistent_workers=False)#
Parameters:
  • cuda (int, default: 0)

  • num_workers (int, default: 0)

  • pin_memory (bool, default: True)

  • persistent_workers (bool, default: False)

Return type:

None

class WandBConfig[source]#

Bases: object

Weights & Biases integration.

Parameters:
  • enabled (bool, default: False)

  • entity (str | None, default: None)

  • project (str | None, default: None)

enabled: bool = False#
entity: str | None = None#
project: str | None = None#
__init__(enabled=False, entity=None, project=None)#
Parameters:
  • enabled (bool, default: False)

  • entity (str | None, default: None)

  • project (str | None, default: None)

Return type:

None

class Config[source]#

Bases: object

Root configuration for a benchmark run.

Parameters:
dataset: DatasetConfig#
model: ModelConfig#
reg: RegConfig#
optim: OptimConfig#
cv: CVConfig#
hardware: HardwareConfig#
wandb: WandBConfig#
__init__(dataset=<factory>, model=<factory>, reg=<factory>, optim=<factory>, cv=<factory>, hardware=<factory>, wandb=<factory>)#
Parameters:
Return type:

None

Generic registry base class.

class Registry[source]#

Bases: Generic

A typed key-value store that raises on duplicate keys and missing lookups.

__init__()[source]#
Return type:

None

register(key, value)[source]#
Parameters:
Return type:

None

get(key)[source]#
Parameters:

key (TypeVar(Key))

Return type:

TypeVar(Value)

list_keys()[source]#
Return type:

list[TypeVar(Key)]

Preset registry: per-dataset best-known configurations.

class PresetRegistry[source]#

Bases: Registry[str, Config]

Registry for named experiment presets.

get_or_default(name)[source]#

Return the named preset, or a bare Config() if name is None.

Parameters:

name (str | None)

Return type:

Config

generate_config(name, *, variant, stalk_dim, hidden_dim, num_layers, input_dropout, dropout, lr, weight_decay, alpha=1.0, stop_strategy='loss', model_type=ModelType.NSD, num_heads=1, leaky_relu_slope=0.2, clamp_val=10.0, normalize_output=True, jknet=False)[source]#
Parameters:
  • name (str)

  • variant (Literal['diagonal', 'general', 'orthogonal', 'low_rank', 'general_attention', 'orthogonal_attention'])

  • stalk_dim (int)

  • hidden_dim (int)

  • num_layers (int)

  • input_dropout (float)

  • dropout (float)

  • lr (float)

  • weight_decay (float)

  • alpha (float, default: 1.0)

  • stop_strategy (Literal['loss', 'acc'], default: "loss")

  • model_type (ModelType, default: ModelType.NSD)

  • num_heads (int, default: 1)

  • leaky_relu_slope (float, default: 0.2)

  • clamp_val (float, default: 10.0)

  • normalize_output (bool, default: True)

  • jknet (bool, default: False)

Return type:

Config

Data#

Dataset loading and Lightning DataModule for the NSD benchmark suite.

Supported datasets#

Homophilic : cora, citeseer Heterophilic: chameleon, squirrel, chameleon_filtered, squirrel_filtered,

cornell, texas, film

Heterophilous (Platonov et al. 2023):

amazon_ratings, minesweeper, questions, roman_empire, tolokers

All datasets are downloaded automatically to root on first use. Filtered Wikipedia datasets are fetched from the yandex-research GitHub release if not already cached locally.

class DatasetInfo[source]#

Bases: object

Lightweight descriptor returned alongside a loaded Data object.

Parameters:
  • name (str)

  • num_features (int)

  • num_classes (int)

  • num_splits (int)

  • metric (str)

  • split_type (str)

name: str#
num_features: int#
num_classes: int#
num_splits: int#
metric: str#
split_type: str#
__init__(name, num_features, num_classes, num_splits, metric, split_type)#
Parameters:
  • name (str)

  • num_features (int)

  • num_classes (int)

  • num_splits (int)

  • metric (str)

  • split_type (str)

Return type:

None

class FilteredWikipediaDataset[source]#

Bases: InMemoryDataset

Filtered chameleon / squirrel networks (Platonov et al., 2023).

The raw .npz file embeds 10 pre-computed train/val/test splits as boolean mask matrices of shape (10, N). After processing they are stored as (N, 10) tensors so they match the convention used by other PyG heterophilous datasets.

Parameters:
__init__(root, name)[source]#
Parameters:
Return type:

None

property raw_file_names: list[str]#

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

property processed_file_names: list[str]#

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

download()[source]#

Downloads the dataset to the self.raw_dir folder.

Return type:

None

process()[source]#

Processes the dataset to the self.processed_dir folder.

Return type:

None

load_dataset(name, root=_DATA_DIR)[source]#

Load a benchmark dataset, downloading it automatically if needed.

Parameters:
  • name (str) – Dataset identifier. One of the 14 supported names.

  • root (str, default: _DATA_DIR) – Base directory used by PyG for download caching.

Returns:

A single torch_geometric.data.Data object. info: A DatasetInfo with metadata about the dataset.

Return type:

tuple[Data, DatasetInfo]

class SheafDataModule[source]#

Bases: LightningDataModule

Full-graph DataModule for transductive node classification.

All three dataloaders return the same graph; the model selects nodes via train_mask / val_mask / test_mask inside each step.

Parameters:
  • name (str)

  • root (str, default: '/home/runner/work/sheaf-mpnn/sheaf-mpnn/exp/data')

  • fold (int, default: 0)

  • batch_size (int, default: 1)

  • num_workers (int, default: 0)

  • pin_memory (bool, default: False)

  • persistent_workers (bool, default: False)

__init__(name, root=_DATA_DIR, fold=0, batch_size=1, num_workers=0, pin_memory=False, persistent_workers=False)[source]#
prepare_data_per_node#

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices#

If True, dataloader with zero length within local rank is allowed. Default value is False.

Parameters:
  • name (str)

  • root (str, default: _DATA_DIR)

  • fold (int, default: 0)

  • batch_size (int, default: 1)

  • num_workers (int, default: 0)

  • pin_memory (bool, default: False)

  • persistent_workers (bool, default: False)

Return type:

None

property info: DatasetInfo#
property num_nodes: int#

Number of nodes in the loaded graph (requires setup() to have run).

property num_edges: int#

Number of undirected edges (requires setup() to have run).

property homophily: float#

fraction of same-class edges.

Type:

Edge homophily h in [0,1]

property split_sizes: tuple[int, int, int]#

(train, val, test) node counts for the current fold.

setup(stage=None)[source]#

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str | None, default: None) – either 'fit', 'validate', 'test', or 'predict'

Return type:

None

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
train_dataloader()[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set ~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return type:

DataLoader

val_dataloader()[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set ~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Return type:

DataLoader

test_dataloader()[source]#

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Return type:

DataLoader

Train / val / test split management for 10-fold cross-validation.

Two split strategies are used depending on the dataset:

npz_file

Pre-computed 60 / 20 / 20 splits stored as exp/splits/<dataset>_split_0.6_0.2_<fold>.npz. Each file has keys train_mask, val_mask, test_mask - 1-D boolean arrays of length N. Used for: cora, citeseer, chameleon, squirrel, cornell, texas, film.

pyg_mask

Fold column extracted from the multi-column masks that PyG attaches to HeterophilousGraphDataset, FilteredWikipediaDataset, etc. The masks are (N, num_splits) boolean tensors; apply_split selects column fold. Used for: amazon_ratings, minesweeper, questions, roman_empire, tolokers, chameleon_filtered, squirrel_filtered.

The public function apply_split() dispatches automatically based on DatasetInfo.split_type.

apply_split(data, info, fold)[source]#

Return a clone of data with 1-D boolean masks for the given fold.

Parameters:
  • data (Data) – The full graph dataset (masks may be multi-column at this point).

  • info (DatasetInfo) – Metadata returned by load_dataset().

  • fold (int) – Zero-based fold index in [0, info.num_splits).

Return type:

Data

Returns:

A cloned Data object whose train_mask, val_mask, and test_mask are 1-D boolean tensors of length N.

Generate or download pre-computed 60/20/20 train/val/test splits.

Two sources are supported:

canonical (default)

Downloads the official Geom-GCN (Pei et al. 2020) splits from GitHub. These are the same splits used in the original NSD paper and most heterophily benchmarks. 10 folds per dataset.

generate

Creates 10 stratified splits locally using StratifiedShuffleSplit. Useful as a fallback when offline or for datasets not covered by the canonical repository.

Files are saved to exp/splits/ as:

exp/splits/{name}_split_0.6_0.2_{fold}.npz

Each .npz contains three 1-D boolean arrays of length N:

train_mask, val_mask, test_mask

Usage#

# Via the unified CLI (recommended): sheaf splits sheaf splits –datasets cora citeseer texas sheaf splits –source generate sheaf splits –root /data/pyg –splits-dir /data/splits sheaf splits –folds 5 sheaf splits –overwrite

# Direct module invocation: python -m exp.gen_splits

class SplitsConfig[source]#

Bases: object

Configuration for downloading or generating dataset splits.

Parameters:
  • datasets (list[str], default: <factory>)

  • source (Literal['canonical', 'generate'], default: 'canonical')

  • root (str, default: 'exp/data')

  • splits_dir (str, default: '/home/runner/work/sheaf-mpnn/sheaf-mpnn/src/exp/splits')

  • folds (int, default: 10)

  • overwrite (bool, default: False)

datasets: list[str]#
source: Literal['canonical', 'generate'] = 'canonical'#
root: str = 'exp/data'#
splits_dir: str = '/home/runner/work/sheaf-mpnn/sheaf-mpnn/src/exp/splits'#
folds: int = 10#
overwrite: bool = False#
__init__(datasets=<factory>, source='canonical', root='exp/data', splits_dir='/home/runner/work/sheaf-mpnn/sheaf-mpnn/src/exp/splits', folds=10, overwrite=False)#
Parameters:
  • datasets (list[str], default: <factory>)

  • source (Literal['canonical', 'generate'], default: 'canonical')

  • root (str, default: 'exp/data')

  • splits_dir (str, default: '/home/runner/work/sheaf-mpnn/sheaf-mpnn/src/exp/splits')

  • folds (int, default: 10)

  • overwrite (bool, default: False)

Return type:

None

download_canonical_splits(name, splits_dir=_DEFAULT_SPLITS_DIR, n_folds=_N_FOLDS, overwrite=False)[source]#

Download the official Geom-GCN splits for name from GitHub.

Skips any fold whose file already exists unless overwrite is True.

Parameters:
  • name (str)

  • splits_dir (str, default: _DEFAULT_SPLITS_DIR)

  • n_folds (int, default: _N_FOLDS)

  • overwrite (bool, default: False)

Return type:

None

generate_splits(name, root='exp/data', splits_dir=_DEFAULT_SPLITS_DIR, n_folds=_N_FOLDS, overwrite=False)[source]#

Generate and save n_folds stratified 60/20/20 splits for name.

Skips any fold whose file already exists unless overwrite is True.

Parameters:
  • name (str)

  • root (str, default: "exp/data")

  • splits_dir (str, default: _DEFAULT_SPLITS_DIR)

  • n_folds (int, default: _N_FOLDS)

  • overwrite (bool, default: False)

Return type:

None

splits(cfg)[source]#

Process splits for all requested datasets according to cfg.

Parameters:

cfg (SplitsConfig)

Return type:

None

main()[source]#

Entry point for python -m exp.gen_splits.

Return type:

None

Training#

PyTorch Lightning module wrapping Sheaf models.

This module provides the SheafLightningModule, which handles the training loop, evaluation metrics, and optimizer configuration for all sheaf-based models. It is designed to be compatible with both transductive (node classification) and inductive (graph classification) tasks.

class SheafLightningModule[source]#

Bases: LightningModule

Wraps Sheaf models with Lightning training / evaluation logic.

This class serves as the interface between the raw model and the PyTorch Lightning Trainer. It handles loss calculation, metric tracking (ACC/AUC), and hardware-agnostic execution.

Parameters:
  • cfg (Config) – Global configuration object containing model and optimization params.

  • info (DatasetInfo) – Metadata about the dataset (num_features, num_classes, metric, etc.).

__init__(cfg, info)[source]#
Parameters:
Return type:

None

training_step(batch, batch_idx)[source]#

Standard training step: forward pass + cross-entropy loss.

validation_step(batch, batch_idx)[source]#

Validation step using the shared evaluation logic.

Return type:

None

test_step(batch, batch_idx)[source]#

Test step using the shared evaluation logic.

Return type:

None

configure_optimizers()[source]#

Setup Adam optimizer with parameters from the config.

Neural Sheaf Diffusion - 10-fold cross-validation experiment runner.

Usage#

# Via the unified CLI (recommended): sheaf run –preset cora sheaf run –preset texas –model.num_layers 3

# Direct module invocation: python -m exp.run –preset cora

run(cfg)[source]#

Run 10-fold cross-validation; return per-fold test metrics.

Parameters:

cfg (Config)

Return type:

list[float]

main()[source]#

Entry point for python -m exp.run.

Return type:

None

YAML-driven hyperparameter sweep using Optuna and the model registry.

Usage:

# Via the unified CLI (recommended):
sheaf sweep --yaml-path nsd_cora.yaml --preset cora

# Direct module invocation:
python -m exp.sweeps.sweep --yaml-path nsd_cora.yaml --preset cora

# Without a preset (uses config defaults):
sheaf sweep --yaml-path nsd_cora.yaml

Example YAML:

model: nsd
dataset:
  name: texas
  root: exp/data
search_space:
  variant:
    type: categorical
    choices: [diagonal, general, orthogonal]
  stalk_dim:
    type: int
    low: 2
    high: 8
  lr:
    type: float
    low: 0.0001
    high: 0.1
    log: true
config:
  n_trials: 100
  study_name: nsd-texas
  storage: sqlite:///sweep.db
sweep(yaml_path, preset=None)[source]#

Run a YAML-driven Optuna hyperparameter sweep.

Parameters:
Return type:

None

main(yaml_path, preset=None)[source]#

Entry point for python -m exp.sweeps.sweep.

Parameters:
  • yaml_path (Path) – Path to a YAML file describing the model, search space, and Optuna config. See module docstring for the expected format.

  • preset (str | None, default: None) – Named preset to use as base config.

Return type:

None