Source code for exp.gen_splits

# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
#   Mario Severino, Emanuele Mule, Dario Loi,
#   Francesco Restuccia, Fabrizio Silvestri, Pietro Liò

"""Generate or download pre-computed 60/20/20 train/val/test splits.

Two sources are supported:

canonical (default)
    Downloads the official Geom-GCN (Pei et al. 2020) splits from GitHub.
    These are the same splits used in the original NSD paper and most
    heterophily benchmarks. 10 folds per dataset.

generate
    Creates 10 stratified splits locally using StratifiedShuffleSplit.
    Useful as a fallback when offline or for datasets not covered by
    the canonical repository.

Files are saved to exp/splits/ as:
    exp/splits/{name}_split_0.6_0.2_{fold}.npz

Each .npz contains three 1-D boolean arrays of length N:
    train_mask, val_mask, test_mask

Usage
-----
    # Via the unified CLI (recommended):
    sheaf splits
    sheaf splits --datasets cora citeseer texas
    sheaf splits --source generate
    sheaf splits --root /data/pyg --splits-dir /data/splits
    sheaf splits --folds 5
    sheaf splits --overwrite

    # Direct module invocation:
    python -m exp.gen_splits
"""

from __future__ import annotations

import os
import urllib.request
from dataclasses import dataclass, field
from typing import Literal

import numpy as np
import torch
import tyro
from rich.console import Console
from sklearn.model_selection import StratifiedShuffleSplit

from exp.data import load_dataset
from exp.registries import dataset_registry

_console = Console()

_DEFAULT_SPLITS_DIR = os.path.join(os.path.dirname(__file__), "splits")
_N_FOLDS = 10
_TRAIN_RATIO = 0.6
_VAL_RATIO = 0.2

_GEOM_GCN_BASE = "https://github.com/graphdml-uiuc-jlu/geom-gcn/raw/master/splits"


def _npz_split_datasets() -> frozenset[str]:
    return frozenset(
        name
        for name in dataset_registry.list_keys()
        if dataset_registry.get(name).split_type == "npz_file"
    )


def _canonical_url(name: str, fold: int) -> str:
    return f"{_GEOM_GCN_BASE}/{name}_split_0.6_0.2_{fold}.npz"


[docs] @dataclass class SplitsConfig: """Configuration for downloading or generating dataset splits.""" datasets: list[str] = field(default_factory=lambda: sorted(_npz_split_datasets())) source: Literal["canonical", "generate"] = "canonical" root: str = "exp/data" splits_dir: str = _DEFAULT_SPLITS_DIR folds: int = _N_FOLDS overwrite: bool = False
# --------------------------------------------------------------------------- # Core logic (testable without CLI) # ---------------------------------------------------------------------------
[docs] def download_canonical_splits( name: str, splits_dir: str = _DEFAULT_SPLITS_DIR, n_folds: int = _N_FOLDS, overwrite: bool = False, ) -> None: """Download the official Geom-GCN splits for *name* from GitHub. Skips any fold whose file already exists unless *overwrite* is True. """ _console.print(f"[bold][{name}][/bold] Downloading canonical Geom-GCN splits...") os.makedirs(splits_dir, exist_ok=True) for fold in range(n_folds): out_path = os.path.join(splits_dir, f"{name}_split_0.6_0.2_{fold}.npz") if os.path.exists(out_path) and not overwrite: _console.print(f" fold {fold:2d}: already exists, skipping.") continue url = _canonical_url(name, fold) try: urllib.request.urlretrieve(url, out_path) arr = np.load(out_path) train_n = int(arr["train_mask"].sum()) val_n = int(arr["val_mask"].sum()) test_n = int(arr["test_mask"].sum()) _console.print( f" fold {fold:2d}: " f"train={train_n:4d} val={val_n:4d} test={test_n:4d} " f"[dim]<- {url}[/dim]" ) except Exception as exc: if os.path.exists(out_path): os.remove(out_path) raise RuntimeError( f"Failed to download {url}: {exc}\n" "Tip: run with --source generate to create splits locally." ) from exc _console.print(f"[bold][{name}][/bold] Done.\n")
[docs] def generate_splits( name: str, root: str = "exp/data", splits_dir: str = _DEFAULT_SPLITS_DIR, n_folds: int = _N_FOLDS, overwrite: bool = False, ) -> None: """Generate and save n_folds stratified 60/20/20 splits for *name*. Skips any fold whose file already exists unless *overwrite* is True. """ _console.print(f"[bold][{name}][/bold] Loading dataset...") data, info = load_dataset(name, root=root) assert isinstance(data.y, torch.Tensor) labels = data.y.numpy() n = int(labels.shape[0]) os.makedirs(splits_dir, exist_ok=True) for fold in range(n_folds): out_path = os.path.join(splits_dir, f"{name}_split_0.6_0.2_{fold}.npz") if os.path.exists(out_path) and not overwrite: _console.print(f" fold {fold:2d}: already exists, skipping.") continue sss1 = StratifiedShuffleSplit( n_splits=1, train_size=_TRAIN_RATIO, random_state=fold ) train_idx, rest_idx = next(sss1.split(np.zeros(n), labels)) val_of_rest = _VAL_RATIO / (1.0 - _TRAIN_RATIO) sss2 = StratifiedShuffleSplit( n_splits=1, train_size=val_of_rest, random_state=fold ) val_local, test_local = next( sss2.split(np.zeros(len(rest_idx)), labels[rest_idx]) ) val_idx = rest_idx[val_local] test_idx = rest_idx[test_local] train_mask = np.zeros(n, dtype=bool) val_mask = np.zeros(n, dtype=bool) test_mask = np.zeros(n, dtype=bool) train_mask[train_idx] = True val_mask[val_idx] = True test_mask[test_idx] = True np.savez( out_path, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask ) _console.print( f" fold {fold:2d}: " f"train={train_mask.sum():4d} val={val_mask.sum():4d} " f"test={test_mask.sum():4d} [dim]-> {out_path}[/dim]" ) _console.print(f"[bold][{name}][/bold] Done.\n")
[docs] def splits(cfg: SplitsConfig) -> None: """Process splits for all requested datasets according to *cfg*.""" os.makedirs(cfg.splits_dir, exist_ok=True) valid = _npz_split_datasets() for name in cfg.datasets: if name not in valid: _console.print( f"[yellow]Warning:[/yellow] '{name}' does not use NPZ splits" " — skipping." ) continue if cfg.source == "canonical": download_canonical_splits( name, splits_dir=cfg.splits_dir, n_folds=cfg.folds, overwrite=cfg.overwrite, ) else: generate_splits( name, root=cfg.root, splits_dir=cfg.splits_dir, n_folds=cfg.folds, overwrite=cfg.overwrite, ) _console.print("[green]All requested splits processed.[/green]")
# --------------------------------------------------------------------------- # CLI entry point # ---------------------------------------------------------------------------
[docs] def main() -> None: """Entry point for ``python -m exp.gen_splits``.""" cfg = tyro.cli(SplitsConfig) splits(cfg)
if __name__ == "__main__": main()