exp#
Experiment framework built on PyTorch Lightning, tyro, and Optuna.
Configuration#
Typed configuration dataclasses for the NSD benchmark runner.
Consumed by tyro.cli in exp/run.py. Per-dataset preset defaults live
in exp/registries/presets.py and are injected via tyro.cli(Config, default=...),
so every field remains overridable from the command line.
- class ModelVariant[source]#
Bases:
StrEnum- DIAGONAL = 'diagonal'#
- GENERAL = 'general'#
- ORTHOGONAL = 'orthogonal'#
- GENERAL_ATTENTION = 'general_attention'#
- ORTHOGONAL_ATTENTION = 'orthogonal_attention'#
- LOW_RANK = 'low_rank'#
- __new__(value)#
- class ModelConfig[source]#
Bases:
objectNSD architecture hyperparameters.
- Parameters:
type (
ModelType, default:<ModelType.NSD: 'nsd'>)variant (
Literal['diagonal','general','orthogonal','low_rank','general_attention','orthogonal_attention'], default:'general')stalk_dim (
int, default:4)hidden_dim (
int, default:16)num_layers (
int, default:2)alpha (
float, default:1.0)rank (
int, default:1)orth_strategy (
Literal['cayley','fasth'], default:'cayley')normalize_output (
bool, default:True)jknet (
bool, default:False)num_heads (
int, default:1)leaky_relu_slope (
float, default:0.2)clamp_val (
float, default:10.0)
- variant: Literal['diagonal', 'general', 'orthogonal', 'low_rank', 'general_attention', 'orthogonal_attention'] = 'general'#
- __init__(type=ModelType.NSD, variant='general', stalk_dim=4, hidden_dim=16, num_layers=2, alpha=1.0, rank=1, orth_strategy='cayley', normalize_output=True, jknet=False, num_heads=1, leaky_relu_slope=0.2, clamp_val=10.0)#
- Parameters:
type (
ModelType, default:<ModelType.NSD: 'nsd'>)variant (
Literal['diagonal','general','orthogonal','low_rank','general_attention','orthogonal_attention'], default:'general')stalk_dim (
int, default:4)hidden_dim (
int, default:16)num_layers (
int, default:2)alpha (
float, default:1.0)rank (
int, default:1)orth_strategy (
Literal['cayley','fasth'], default:'cayley')normalize_output (
bool, default:True)jknet (
bool, default:False)num_heads (
int, default:1)leaky_relu_slope (
float, default:0.2)clamp_val (
float, default:10.0)
- Return type:
None
- class OptimConfig[source]#
Bases:
objectOptimisation schedule.
- Parameters:
- __init__(lr=0.01, weight_decay=0.0005, epochs=1000, early_stopping=200, stop_strategy='loss', batch_size=1)#
- class HardwareConfig[source]#
Bases:
objectHardware selection and data loading performance.
- Parameters:
- class Config[source]#
Bases:
objectRoot configuration for a benchmark run.
- Parameters:
dataset (
DatasetConfig, default:<factory>)model (
ModelConfig, default:<factory>)reg (
RegConfig, default:<factory>)optim (
OptimConfig, default:<factory>)cv (
CVConfig, default:<factory>)hardware (
HardwareConfig, default:<factory>)wandb (
WandBConfig, default:<factory>)
- dataset: DatasetConfig#
- model: ModelConfig#
- optim: OptimConfig#
- hardware: HardwareConfig#
- wandb: WandBConfig#
- __init__(dataset=<factory>, model=<factory>, reg=<factory>, optim=<factory>, cv=<factory>, hardware=<factory>, wandb=<factory>)#
- Parameters:
dataset (
DatasetConfig, default:<factory>)model (
ModelConfig, default:<factory>)reg (
RegConfig, default:<factory>)optim (
OptimConfig, default:<factory>)cv (
CVConfig, default:<factory>)hardware (
HardwareConfig, default:<factory>)wandb (
WandBConfig, default:<factory>)
- Return type:
None
Generic registry base class.
- class Registry[source]#
Bases:
GenericA typed key-value store that raises on duplicate keys and missing lookups.
Preset registry: per-dataset best-known configurations.
- generate_config(name, *, variant, stalk_dim, hidden_dim, num_layers, input_dropout, dropout, lr, weight_decay, alpha=1.0, stop_strategy='loss', model_type=ModelType.NSD, num_heads=1, leaky_relu_slope=0.2, clamp_val=10.0, normalize_output=True, jknet=False)[source]#
- Parameters:
name (
str)variant (
Literal['diagonal','general','orthogonal','low_rank','general_attention','orthogonal_attention'])stalk_dim (
int)hidden_dim (
int)num_layers (
int)input_dropout (
float)dropout (
float)lr (
float)weight_decay (
float)alpha (
float, default:1.0)stop_strategy (
Literal['loss','acc'], default:"loss")model_type (
ModelType, default:ModelType.NSD)num_heads (
int, default:1)leaky_relu_slope (
float, default:0.2)clamp_val (
float, default:10.0)normalize_output (
bool, default:True)jknet (
bool, default:False)
- Return type:
Data#
Dataset loading and Lightning DataModule for the NSD benchmark suite.
Supported datasets#
Homophilic : cora, citeseer Heterophilic: chameleon, squirrel, chameleon_filtered, squirrel_filtered,
cornell, texas, film
- Heterophilous (Platonov et al. 2023):
amazon_ratings, minesweeper, questions, roman_empire, tolokers
All datasets are downloaded automatically to root on first use.
Filtered Wikipedia datasets are fetched from the yandex-research GitHub
release if not already cached locally.
- class DatasetInfo[source]#
Bases:
objectLightweight descriptor returned alongside a loaded Data object.
- Parameters:
- class FilteredWikipediaDataset[source]#
Bases:
InMemoryDatasetFiltered chameleon / squirrel networks (Platonov et al., 2023).
The raw
.npzfile embeds 10 pre-computed train/val/test splits as boolean mask matrices of shape(10, N). After processing they are stored as(N, 10)tensors so they match the convention used by other PyG heterophilous datasets.- property raw_file_names: list[str]#
The name of the files in the
self.raw_dirfolder that must be present in order to skip downloading.
- load_dataset(name, root=_DATA_DIR)[source]#
Load a benchmark dataset, downloading it automatically if needed.
- Parameters:
- Returns:
A single
torch_geometric.data.Dataobject. info: ADatasetInfowith metadata about the dataset.- Return type:
- class SheafDataModule[source]#
Bases:
LightningDataModuleFull-graph DataModule for transductive node classification.
All three dataloaders return the same graph; the model selects nodes via
train_mask/val_mask/test_maskinside each step.- Parameters:
- __init__(name, root=_DATA_DIR, fold=0, batch_size=1, num_workers=0, pin_memory=False, persistent_workers=False)[source]#
- prepare_data_per_node#
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices#
If True, dataloader with zero length within local rank is allowed. Default value is False.
- property info: DatasetInfo#
- setup(stage=None)[source]#
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (
str|None, default:None) – either'fit','validate','test', or'predict'- Return type:
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- train_dataloader()[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set
~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochsto a positive integer.For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return type:
- val_dataloader()[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set
~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochsto a positive integer.It’s recommended that all data downloads and preparation happen in
prepare_data().prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.- Return type:
- test_dataloader()[source]#
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.- Return type:
Train / val / test split management for 10-fold cross-validation.
Two split strategies are used depending on the dataset:
npz_filePre-computed 60 / 20 / 20 splits stored as
exp/splits/<dataset>_split_0.6_0.2_<fold>.npz. Each file has keystrain_mask,val_mask,test_mask- 1-D boolean arrays of length N. Used for: cora, citeseer, chameleon, squirrel, cornell, texas, film.pyg_maskFold column extracted from the multi-column masks that PyG attaches to HeterophilousGraphDataset, FilteredWikipediaDataset, etc. The masks are
(N, num_splits)boolean tensors;apply_splitselects column fold. Used for: amazon_ratings, minesweeper, questions, roman_empire, tolokers, chameleon_filtered, squirrel_filtered.
The public function apply_split() dispatches automatically based on
DatasetInfo.split_type.
- apply_split(data, info, fold)[source]#
Return a clone of data with 1-D boolean masks for the given fold.
- Parameters:
data (
Data) – The full graph dataset (masks may be multi-column at this point).info (
DatasetInfo) – Metadata returned byload_dataset().fold (
int) – Zero-based fold index in[0, info.num_splits).
- Return type:
- Returns:
A cloned
Dataobject whosetrain_mask,val_mask, andtest_maskare 1-D boolean tensors of length N.
Generate or download pre-computed 60/20/20 train/val/test splits.
Two sources are supported:
- canonical (default)
Downloads the official Geom-GCN (Pei et al. 2020) splits from GitHub. These are the same splits used in the original NSD paper and most heterophily benchmarks. 10 folds per dataset.
- generate
Creates 10 stratified splits locally using StratifiedShuffleSplit. Useful as a fallback when offline or for datasets not covered by the canonical repository.
- Files are saved to exp/splits/ as:
exp/splits/{name}_split_0.6_0.2_{fold}.npz
- Each .npz contains three 1-D boolean arrays of length N:
train_mask, val_mask, test_mask
Usage#
# Via the unified CLI (recommended): sheaf splits sheaf splits –datasets cora citeseer texas sheaf splits –source generate sheaf splits –root /data/pyg –splits-dir /data/splits sheaf splits –folds 5 sheaf splits –overwrite
# Direct module invocation: python -m exp.gen_splits
- class SplitsConfig[source]#
Bases:
objectConfiguration for downloading or generating dataset splits.
- Parameters:
- __init__(datasets=<factory>, source='canonical', root='exp/data', splits_dir='/home/runner/work/sheaf-mpnn/sheaf-mpnn/src/exp/splits', folds=10, overwrite=False)#
- Parameters:
- Return type:
None
- download_canonical_splits(name, splits_dir=_DEFAULT_SPLITS_DIR, n_folds=_N_FOLDS, overwrite=False)[source]#
Download the official Geom-GCN splits for name from GitHub.
Skips any fold whose file already exists unless overwrite is True.
- generate_splits(name, root='exp/data', splits_dir=_DEFAULT_SPLITS_DIR, n_folds=_N_FOLDS, overwrite=False)[source]#
Generate and save n_folds stratified 60/20/20 splits for name.
Skips any fold whose file already exists unless overwrite is True.
- splits(cfg)[source]#
Process splits for all requested datasets according to cfg.
- Parameters:
cfg (
SplitsConfig)- Return type:
Training#
PyTorch Lightning module wrapping Sheaf models.
This module provides the SheafLightningModule, which handles the training loop, evaluation metrics, and optimizer configuration for all sheaf-based models. It is designed to be compatible with both transductive (node classification) and inductive (graph classification) tasks.
- class SheafLightningModule[source]#
Bases:
LightningModuleWraps Sheaf models with Lightning training / evaluation logic.
This class serves as the interface between the raw model and the PyTorch Lightning Trainer. It handles loss calculation, metric tracking (ACC/AUC), and hardware-agnostic execution.
- Parameters:
cfg (
Config) – Global configuration object containing model and optimization params.info (
DatasetInfo) – Metadata about the dataset (num_features, num_classes, metric, etc.).
- __init__(cfg, info)[source]#
- Parameters:
cfg (
Config)info (
DatasetInfo)
- Return type:
None
Neural Sheaf Diffusion - 10-fold cross-validation experiment runner.
Usage#
# Via the unified CLI (recommended): sheaf run –preset cora sheaf run –preset texas –model.num_layers 3
# Direct module invocation: python -m exp.run –preset cora
YAML-driven hyperparameter sweep using Optuna and the model registry.
Usage:
# Via the unified CLI (recommended):
sheaf sweep --yaml-path nsd_cora.yaml --preset cora
# Direct module invocation:
python -m exp.sweeps.sweep --yaml-path nsd_cora.yaml --preset cora
# Without a preset (uses config defaults):
sheaf sweep --yaml-path nsd_cora.yaml
Example YAML:
model: nsd
dataset:
name: texas
root: exp/data
search_space:
variant:
type: categorical
choices: [diagonal, general, orthogonal]
stalk_dim:
type: int
low: 2
high: 8
lr:
type: float
low: 0.0001
high: 0.1
log: true
config:
n_trials: 100
study_name: nsd-texas
storage: sqlite:///sweep.db