# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
# Mario Severino, Emanuele Mule, Dario Loi,
# Francesco Restuccia, Fabrizio Silvestri, Pietro Liò
"""Dataset loading and Lightning DataModule for the NSD benchmark suite.
Supported datasets
------------------
Homophilic : cora, citeseer
Heterophilic: chameleon, squirrel, chameleon_filtered, squirrel_filtered,
cornell, texas, film
Heterophilous (Platonov et al. 2023):
amazon_ratings, minesweeper, questions, roman_empire, tolokers
All datasets are downloaded automatically to ``root`` on first use.
Filtered Wikipedia datasets are fetched from the yandex-research GitHub
release if not already cached locally.
"""
from __future__ import annotations
import os
import urllib.request
from dataclasses import dataclass
from typing import cast
import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.datasets import (
Actor,
HeterophilousGraphDataset,
Planetoid,
WebKB,
WikipediaNetwork,
)
from torch_geometric.utils import coalesce, to_undirected
_DATA_DIR: str = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "exp", "data")
)
# ---------------------------------------------------------------------------
# Dataset metadata
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class DatasetInfo:
"""Lightweight descriptor returned alongside a loaded Data object."""
name: str
num_features: int
num_classes: int
num_splits: int # number of available train/val/test folds
metric: str # "acc" | "roc_auc" -- depends on dataset evaluation protocol
split_type: str # "npz_file" | "pyg_mask"
# Datasets evaluated with ROC-AUC (binary node classification).
ROC_AUC_DATASETS: frozenset[str] = frozenset({"minesweeper", "tolokers", "questions"})
# Canonical dataset name (loader_kind, pyg_key).
_LOADER: dict[str, tuple[str, str]] = {
"cora": ("planetoid", "Cora"),
"citeseer": ("planetoid", "CiteSeer"),
"chameleon": ("wiki", "chameleon"),
"squirrel": ("wiki", "squirrel"),
"chameleon_filtered": ("filtered_wiki", "chameleon_filtered"),
"squirrel_filtered": ("filtered_wiki", "squirrel_filtered"),
"cornell": ("webkb", "Cornell"),
"texas": ("webkb", "Texas"),
"film": ("actor", ""),
"amazon_ratings": ("heterophilous", "amazon-ratings"),
"minesweeper": ("heterophilous", "minesweeper"),
"questions": ("heterophilous", "questions"),
"roman_empire": ("heterophilous", "roman-empire"),
"tolokers": ("heterophilous", "tolokers"),
}
# "fil" accepted as an alias for "film".
_ALIASES: dict[str, str] = {"fil": "film"}
# Datasets whose splits live in exp/splits/*.npz (60/20/20, 10 folds).
NPZ_SPLIT_DATASETS: frozenset[str] = frozenset(
{
"cora",
"citeseer",
"chameleon",
"squirrel",
"cornell",
"texas",
"film",
}
)
def _canonical(name: str) -> str:
name = name.lower().replace("-", "_").strip()
return _ALIASES.get(name, name)
# ---------------------------------------------------------------------------
# Custom dataset for Filtered Wikipedia networks (Platonov et al., 2023).
# ---------------------------------------------------------------------------
_FILTERED_WIKI_URLS: dict[str, str] = {
"chameleon_filtered": (
"https://github.com/yandex-research/heterophilous-graphs"
"/raw/main/data/chameleon_filtered.npz"
),
"squirrel_filtered": (
"https://github.com/yandex-research/heterophilous-graphs"
"/raw/main/data/squirrel_filtered.npz"
),
}
[docs]
class FilteredWikipediaDataset(InMemoryDataset):
"""Filtered chameleon / squirrel networks (Platonov et al., 2023).
The raw ``.npz`` file embeds 10 pre-computed train/val/test splits as
boolean mask matrices of shape ``(10, N)``. After processing they are
stored as ``(N, 10)`` tensors so they match the convention used by other
PyG heterophilous datasets.
"""
[docs]
def __init__(self, root: str, name: str) -> None:
assert name in _FILTERED_WIKI_URLS, f"Unknown filtered dataset: {name!r}"
self._ds_name = name
super().__init__(root)
self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)
@property
def raw_file_names(self) -> list[str]:
return [f"{self._ds_name}.npz"]
@property
def processed_file_names(self) -> list[str]:
return [f"{self._ds_name}.pt"]
[docs]
def download(self) -> None:
url = _FILTERED_WIKI_URLS[self._ds_name]
dst = os.path.join(self.raw_dir, f"{self._ds_name}.npz")
print(f"Downloading {self._ds_name} from\n {url}")
urllib.request.urlretrieve(url, dst)
[docs]
def process(self) -> None:
npz = np.load(os.path.join(self.raw_dir, f"{self._ds_name}.npz"))
x = torch.from_numpy(npz["node_features"]).float()
y = torch.from_numpy(npz["node_labels"]).long()
edges = torch.from_numpy(npz["edges"]).t().contiguous() # (2, E)
edge_index = to_undirected(edges)
edge_index = coalesce(edge_index, num_nodes=x.size(0))
# raw masks: (10, N) bool -> (N, 10) for consistency with PyG convention
train_mask = torch.from_numpy(npz["train_masks"]).T.bool()
val_mask = torch.from_numpy(npz["val_masks"]).T.bool()
test_mask = torch.from_numpy(npz["test_masks"]).T.bool()
data = Data(
x=x,
edge_index=edge_index,
y=y,
train_mask=train_mask,
val_mask=val_mask,
test_mask=test_mask,
)
torch.save(self.collate([data]), self.processed_paths[0])
[docs]
def load_dataset(name: str, root: str = _DATA_DIR) -> tuple[Data, DatasetInfo]:
"""Load a benchmark dataset, downloading it automatically if needed.
Args:
name: Dataset identifier. One of the 14 supported names.
root: Base directory used by PyG for download caching.
Returns:
data: A single ``torch_geometric.data.Data`` object.
info: A :class:`DatasetInfo` with metadata about the dataset.
"""
name = _canonical(name)
if name not in _LOADER:
raise ValueError(
f"Unknown dataset {name!r}. Supported: {sorted(_LOADER.keys())}"
)
kind, pykey = _LOADER[name]
metric = "roc_auc" if name in ROC_AUC_DATASETS else "acc"
ds_root = os.path.join(root, name)
if kind == "planetoid":
ds = Planetoid(ds_root, pykey, split="public")
data = cast(Data, ds[0])
# Planetoid ships a single fixed split; 10-fold .npz files are used instead.
num_splits = 10
elif kind == "wiki":
ds = WikipediaNetwork(ds_root, pykey, geom_gcn_preprocess=True)
data = cast(Data, ds[0])
num_splits = int(data.train_mask.size(1)) if data.train_mask.dim() > 1 else 1
elif kind == "webkb":
ds = WebKB(ds_root, pykey)
data = cast(Data, ds[0])
num_splits = 10 # 10-fold .npz files cover these datasets.
elif kind == "actor":
ds = Actor(ds_root)
data = cast(Data, ds[0])
num_splits = 10 # 10-fold .npz files cover film.
elif kind == "heterophilous":
ds = HeterophilousGraphDataset(ds_root, pykey)
data = cast(Data, ds[0])
num_splits = int(data.train_mask.size(1)) if data.train_mask.dim() > 1 else 10
elif kind == "filtered_wiki":
ds = FilteredWikipediaDataset(ds_root, name)
data = cast(Data, ds[0])
num_splits = int(data.train_mask.size(1)) if data.train_mask.dim() > 1 else 10
else:
raise AssertionError(f"Unhandled loader kind: {kind!r}")
assert data.x is not None
assert isinstance(data.y, torch.Tensor)
num_features = int(data.x.size(1))
num_classes = int(data.y.max().item()) + 1
split_type = "npz_file" if name in NPZ_SPLIT_DATASETS else "pyg_mask"
info = DatasetInfo(
name=name,
num_features=num_features,
num_classes=num_classes,
num_splits=num_splits,
metric=metric,
split_type=split_type,
)
return data, info
# ---------------------------------------------------------------------------
# Lightning DataModule
# ---------------------------------------------------------------------------
try:
from lightning import LightningDataModule
from torch_geometric.loader import DataLoader as _PyGLoader
[docs]
class SheafDataModule(LightningDataModule):
"""Full-graph DataModule for transductive node classification.
All three dataloaders return the same graph; the model selects nodes
via ``train_mask`` / ``val_mask`` / ``test_mask`` inside each step.
"""
[docs]
def __init__(
self,
name: str,
root: str = _DATA_DIR,
fold: int = 0,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
persistent_workers: bool = False,
) -> None:
super().__init__()
self._name = name
self._root = root
self._fold = fold
self._batch_size = batch_size
self._num_workers = num_workers
self._pin_memory = pin_memory
self._persistent_workers = persistent_workers
self._data: Data | None = None
self._info: DatasetInfo | None = None
self._split: Data | None = None
@property
def info(self) -> DatasetInfo:
if self._info is None:
raise RuntimeError("Call setup() before accessing .info")
return self._info
@property
def num_nodes(self) -> int:
"""Number of nodes in the loaded graph (requires setup() to have run)."""
if self._data is None:
raise RuntimeError("Call setup() before accessing .num_nodes")
assert self._data.num_nodes is not None
return int(self._data.num_nodes)
@property
def num_edges(self) -> int:
"""Number of undirected edges (requires setup() to have run)."""
if self._data is None:
raise RuntimeError("Call setup() before accessing .num_edges")
return int(self._data.num_edges) // 2
@property
def homophily(self) -> float:
"""Edge homophily h in [0,1]: fraction of same-class edges."""
from torch_geometric.utils import homophily as _h
if self._data is None:
raise RuntimeError("Call setup() before accessing .homophily")
assert isinstance(self._data.y, torch.Tensor)
assert self._data.edge_index is not None
return float(_h(self._data.edge_index, self._data.y, method="edge"))
@property
def split_sizes(self) -> tuple[int, int, int]:
"""(train, val, test) node counts for the current fold."""
if self._split is None:
raise RuntimeError("Call setup() before accessing .split_sizes")
return (
int(self._split.train_mask.sum()),
int(self._split.val_mask.sum()),
int(self._split.test_mask.sum()),
)
[docs]
def setup(self, stage: str | None = None) -> None:
if self._data is None:
self._data, self._info = load_dataset(self._name, root=self._root)
# Lazy import avoids the circular dependency: splits.py -> data.py
from exp.splits import apply_split # noqa: PLC0415
assert self._info is not None
self._split = apply_split(self._data, self._info, self._fold)
def _loader(self) -> _PyGLoader:
assert self._split is not None, "Call setup() first"
return _PyGLoader(
[self._split],
batch_size=self._batch_size,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers and self._num_workers > 0,
)
[docs]
def train_dataloader(self) -> _PyGLoader:
return self._loader()
[docs]
def val_dataloader(self) -> _PyGLoader:
return self._loader()
[docs]
def test_dataloader(self) -> _PyGLoader:
return self._loader()
except ImportError:
pass # lightning is optional; SheafDataModule is unavailable without it