Source code for exp.run

# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
#   Mario Severino, Emanuele Mule, Dario Loi,
#   Francesco Restuccia, Fabrizio Silvestri, Pietro Liò

"""Neural Sheaf Diffusion - 10-fold cross-validation experiment runner.

Usage
-----
    # Via the unified CLI (recommended):
    sheaf run --preset cora
    sheaf run --preset texas --model.num_layers 3

    # Direct module invocation:
    python -m exp.run --preset cora
"""

from __future__ import annotations

import dataclasses
import logging
import random
import sys
import tempfile
import warnings
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from lightning.pytorch.loggers import WandbLogger as _WandbLogger

import numpy as np
import torch
import tyro
from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.utilities.model_summary import ModelSummary
from rich.console import Console
from rich.panel import Panel
from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn
from rich.table import Table

from exp.config import Config
from exp.data import DatasetInfo, SheafDataModule
from exp.module import SheafLightningModule
from exp.registries.presets import preset_registry
from sheaf_mpnn.utils import setup_torch

log = logging.getLogger(__name__)
_console = Console()

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)  # noqa: NPY002
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def _parse_config() -> Config:
    """Handle ``--preset <name>`` before handing the rest to tyro.

    The preset is stripped from ``sys.argv`` and used to set the ``default``
    argument of ``tyro.cli``, so every field can still be overridden.
    """
    argv = sys.argv[1:]
    preset_name: str | None = None
    clean: list[str] = []

    i = 0
    while i < len(argv):
        a = argv[i]
        if a == "--preset":
            if i + 1 < len(argv):
                preset_name = argv[i + 1]
                i += 2
            else:
                i += 1
        elif a.startswith("--preset="):
            preset_name = a.split("=", 1)[1]
            i += 1
        else:
            clean.append(a)
            i += 1

    sys.argv = [sys.argv[0]] + clean
    if preset_name is not None:
        if preset_name not in preset_registry:
            known = sorted(preset_registry.list_keys())
            raise SystemExit(
                f"Unknown preset {preset_name!r}. "
                f"Run with --help to see available presets.\n"
                f"Known: {', '.join(known)}"
            )
        default: Config | None = preset_registry.get(preset_name)
    else:
        default = None
    result = tyro.cli(Config, default=default)
    assert result is not None
    return result


# ---------------------------------------------------------------------------
# Per-fold training
# ---------------------------------------------------------------------------


def _make_logger(
    cfg: Config, info: DatasetInfo, fold: int, run_name: str
) -> _WandbLogger | bool:
    if not cfg.wandb.enabled:
        return False
    from lightning.pytorch.loggers import WandbLogger  # lazy - wandb is optional

    return WandbLogger(
        project=cfg.wandb.project or f"nsd-{info.name}",
        entity=cfg.wandb.entity,
        name=f"{run_name}-fold{fold}",
        group=run_name,
        config=dataclasses.asdict(cfg),
    )


def _model_label(cfg: Config) -> str:
    """Short identifier for logging and W&B run names."""
    return f"{cfg.model.type}-{cfg.model.variant}"


def _run_fold(
    cfg: Config,
    info: DatasetInfo,
    fold: int,
    monitor: str,
    ckpt_mode: str,
) -> float:
    """Train and evaluate one cross-validation fold; return the test metric."""
    run_name = (
        f"{_model_label(cfg)}-d{cfg.model.stalk_dim}"
        f"-h{cfg.model.hidden_dim}-L{cfg.model.num_layers}"
    )
    dm = SheafDataModule(
        cfg.dataset.name,
        root=cfg.dataset.root,
        fold=fold,
        batch_size=cfg.optim.batch_size,
        num_workers=cfg.hardware.num_workers,
        pin_memory=cfg.hardware.pin_memory,
        persistent_workers=cfg.hardware.persistent_workers,
    )
    module = SheafLightningModule(cfg, info)
    logger = _make_logger(cfg, info, fold, run_name)

    with tempfile.TemporaryDirectory() as ckpt_dir:
        ckpt_cb = ModelCheckpoint(
            dirpath=ckpt_dir,
            monitor=monitor,
            mode=ckpt_mode,
            save_top_k=1,
            filename="best",
        )

        trainer = Trainer(
            max_epochs=cfg.optim.epochs,
            callbacks=[
                EarlyStopping(
                    monitor=monitor,
                    patience=cfg.optim.early_stopping,
                    mode=ckpt_mode,
                ),
                ckpt_cb,
            ],
            logger=logger,
            accelerator="gpu" if torch.cuda.is_available() else "cpu",
            devices=[cfg.hardware.cuda] if torch.cuda.is_available() else "auto",
            enable_progress_bar=False,
            enable_model_summary=False,
            log_every_n_steps=1,
        )

        trainer.fit(module, dm)
        [test_res] = trainer.test(module, dm, ckpt_path="best", verbose=False)

    return float(test_res.get(f"test_{info.metric}", 0.0))


# ---------------------------------------------------------------------------
# Rich output helpers
# ---------------------------------------------------------------------------


def _display_startup(
    cfg: Config,
    info: DatasetInfo,
    n_train: int,
    n_val: int,
    n_test: int,
    avg_deg: float,
    n_folds: int,
) -> None:
    model_line = (
        f"{cfg.model.type.upper()}-{cfg.model.variant}  "
        f"d={cfg.model.stalk_dim}  hidden={cfg.model.hidden_dim}  "
        f"layers={cfg.model.num_layers}  alpha={cfg.model.alpha}  "
        f"dropout={cfg.reg.dropout}"
    )
    summary = str(ModelSummary(SheafLightningModule(cfg, info), max_depth=1))
    content = (
        f"[bold]Dataset[/bold]  {info.name}  "
        f"N={info.num_features:,}  E={info.num_features:,}  avg_deg={avg_deg:.1f}\n"
        f"         F={info.num_features}  C={info.num_classes}  metric={info.metric}\n"
        f"[bold]Split[/bold]    train={n_train:,}  val={n_val:,}  test={n_test:,}  "
        f"({n_folds}-fold CV)\n"
        f"[bold]Model[/bold]    {model_line}\n\n"
        f"{summary}"
    )
    _console.print(Panel(content, title="NSD Cross-Validation"))


def _display_results(cfg: Config, info: DatasetInfo, results: list[float]) -> None:
    scale = 100.0 if info.metric == "acc" else 1.0
    suffix = "%" if info.metric == "acc" else ""

    table = Table(title=f"{_model_label(cfg)} on {info.name} [{len(results)} folds]")
    table.add_column("Fold", justify="right", style="cyan")
    table.add_column(f"Test {info.metric}", justify="right")

    for fold, metric in enumerate(results):
        table.add_row(str(fold), f"{metric * scale:.4f}{suffix}")

    arr = np.array(results)
    table.add_section()
    table.add_row(
        "Mean ± Std",
        f"[bold]{arr.mean() * scale:.2f} ± {arr.std() * scale:.2f}{suffix}[/bold]",
    )
    _console.print(table)


# ---------------------------------------------------------------------------
# Main logic (testable without CLI)
# ---------------------------------------------------------------------------


[docs] def run(cfg: Config) -> list[float]: """Run 10-fold cross-validation; return per-fold test metrics.""" _silence_third_party() setup_torch(precision="high", seed=cfg.cv.seed) dm_meta = SheafDataModule(cfg.dataset.name, root=cfg.dataset.root, fold=0) dm_meta.setup() info = dm_meta.info n_folds = min(cfg.cv.folds, info.num_splits) n_train, n_val, n_test = dm_meta.split_sizes avg_deg = (dm_meta.num_edges * 2) / dm_meta.num_nodes _display_startup(cfg, info, n_train, n_val, n_test, avg_deg, n_folds) monitor = "val_loss" if cfg.optim.stop_strategy == "loss" else f"val_{info.metric}" ckpt_mode = "min" if cfg.optim.stop_strategy == "loss" else "max" results: list[float] = [] with Progress( TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeElapsedColumn(), console=_console, transient=True, ) as progress: task = progress.add_task("Running folds...", total=n_folds) for fold in range(n_folds): _set_seed(cfg.cv.seed + fold) test_metric = _run_fold(cfg, info, fold, monitor, ckpt_mode) results.append(test_metric) progress.update( task, advance=1, description=f"fold {fold:2d} {info.metric}={test_metric:.4f}", ) if fold == 0 and info.metric == "acc" and test_metric < cfg.cv.min_acc: _console.print( f"[yellow]fold 0 test {info.metric}={test_metric:.4f} " f"< min_acc={cfg.cv.min_acc:.4f}; aborting run.[/yellow]" ) break _display_results(cfg, info, results) return results
# --------------------------------------------------------------------------- # CLI entry point # --------------------------------------------------------------------------- def _silence_third_party() -> None: logging.getLogger("lightning.pytorch").setLevel(logging.WARNING) warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning, module="lightning.*") warnings.filterwarnings("ignore", category=UserWarning, module="torch_geometric.*") warnings.filterwarnings("ignore", message=".*LeafSpec.*")
[docs] def main() -> None: """Entry point for ``python -m exp.run``.""" logging.basicConfig(level=logging.INFO, format="%(message)s") cfg = _parse_config() run(cfg)
if __name__ == "__main__": main()