Source code for exp.config

# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
#   Mario Severino, Emanuele Mule, Dario Loi,
#   Francesco Restuccia, Fabrizio Silvestri, Pietro Liò

"""Typed configuration dataclasses for the NSD benchmark runner.

Consumed by ``tyro.cli`` in ``exp/run.py``.  Per-dataset preset defaults live
in ``exp/registries/presets.py`` and are injected via ``tyro.cli(Config, default=...)``,
so every field remains overridable from the command line.
"""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from enum import StrEnum, auto
from typing import Literal

_DATA_DIR: str = os.path.normpath(
    os.path.join(os.path.dirname(__file__), "..", "..", "exp", "data")
)


[docs] class ModelType(StrEnum): NSD = auto()
[docs] class ModelVariant(StrEnum): DIAGONAL = auto() GENERAL = auto() ORTHOGONAL = auto() GENERAL_ATTENTION = auto() ORTHOGONAL_ATTENTION = auto() LOW_RANK = auto()
[docs] @dataclass class DatasetConfig: """Dataset identity and storage location.""" name: str = "cora" root: str = field(default_factory=lambda: _DATA_DIR)
[docs] @dataclass class ModelConfig: """NSD architecture hyperparameters.""" type: ModelType = ModelType.NSD variant: Literal[ "diagonal", "general", "orthogonal", "low_rank", "general_attention", "orthogonal_attention", ] = "general" stalk_dim: int = 4 hidden_dim: int = 16 num_layers: int = 2 alpha: float = 1.0 rank: int = 1 orth_strategy: Literal["cayley", "fasth"] = "cayley" normalize_output: bool = True jknet: bool = False # Unused for NSD, kept so SweepConfig can reference ModelConfig fields generically num_heads: int = 1 leaky_relu_slope: float = 0.2 clamp_val: float = 10.0
[docs] @dataclass class RegConfig: """Regularisation hyperparameters.""" input_dropout: float = 0.0 dropout: float = 0.0
[docs] @dataclass class OptimConfig: """Optimisation schedule.""" lr: float = 0.01 weight_decay: float = 5e-4 epochs: int = 1000 early_stopping: int = 200 stop_strategy: Literal["loss", "acc"] = "loss" batch_size: int = 1
[docs] @dataclass class CVConfig: """Cross-validation setup.""" folds: int = 10 seed: int = 42 min_acc: float = 0.0
[docs] @dataclass class HardwareConfig: """Hardware selection and data loading performance.""" cuda: int = 0 num_workers: int = 0 pin_memory: bool = True persistent_workers: bool = False
[docs] @dataclass class WandBConfig: """Weights & Biases integration.""" enabled: bool = False entity: str | None = None project: str | None = None
[docs] @dataclass class Config: """Root configuration for a benchmark run.""" dataset: DatasetConfig = field(default_factory=DatasetConfig) model: ModelConfig = field(default_factory=ModelConfig) reg: RegConfig = field(default_factory=RegConfig) optim: OptimConfig = field(default_factory=OptimConfig) cv: CVConfig = field(default_factory=CVConfig) hardware: HardwareConfig = field(default_factory=HardwareConfig) wandb: WandBConfig = field(default_factory=WandBConfig)