# Copyright (c) 2026 "Sheaf Neural Networks as Message Passing"
# Authors: Alessio Borgi, Gabriele Onorato, Luke Braithwaite,
# Mario Severino, Emanuele Mule, Dario Loi,
# Francesco Restuccia, Fabrizio Silvestri, Pietro Liò
"""Preset registry: per-dataset best-known configurations."""
from __future__ import annotations
from typing import Literal
from exp.config import (
Config,
CVConfig,
DatasetConfig,
HardwareConfig,
ModelConfig,
ModelType,
OptimConfig,
RegConfig,
WandBConfig,
)
from exp.registries.base import Registry
[docs]
class PresetRegistry(Registry[str, Config]):
"""Registry for named experiment presets."""
[docs]
def get_or_default(self, name: str | None) -> Config:
"""Return the named preset, or a bare ``Config()`` if name is None."""
if name is None:
return Config()
return self.get(name)
_VariantLiteral = Literal[
"diagonal",
"general",
"orthogonal",
"low_rank",
"general_attention",
"orthogonal_attention",
]
_StopStrategyLiteral = Literal["loss", "acc"]
[docs]
def generate_config(
name: str,
*,
variant: _VariantLiteral,
stalk_dim: int,
hidden_dim: int,
num_layers: int,
input_dropout: float,
dropout: float,
lr: float,
weight_decay: float,
alpha: float = 1.0,
stop_strategy: _StopStrategyLiteral = "loss",
model_type: ModelType = ModelType.NSD,
num_heads: int = 1,
leaky_relu_slope: float = 0.2,
clamp_val: float = 10.0,
normalize_output: bool = True,
jknet: bool = False,
) -> Config:
return Config(
dataset=DatasetConfig(name=name),
model=ModelConfig(
type=model_type,
variant=variant,
stalk_dim=stalk_dim,
hidden_dim=hidden_dim,
num_layers=num_layers,
alpha=alpha,
num_heads=num_heads,
leaky_relu_slope=leaky_relu_slope,
clamp_val=clamp_val,
normalize_output=normalize_output,
jknet=jknet,
),
reg=RegConfig(input_dropout=input_dropout, dropout=dropout),
optim=OptimConfig(
lr=lr,
weight_decay=weight_decay,
stop_strategy=stop_strategy,
),
cv=CVConfig(),
hardware=HardwareConfig(),
wandb=WandBConfig(),
)
_PRESETS: dict[str, Config] = {
# ------------------------------------------------------------------
# NSD presets (--preset <dataset>)
# ------------------------------------------------------------------
"cora": generate_config(
"cora",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer": generate_config(
"citeseer",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon": generate_config(
"chameleon",
variant="orthogonal",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel": generate_config(
"squirrel",
variant="orthogonal",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered": generate_config(
"chameleon_filtered",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered": generate_config(
"squirrel_filtered",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell": generate_config(
"cornell",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas": generate_config(
"texas",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film": generate_config(
"film",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings": generate_config(
"amazon_ratings",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper": generate_config(
"minesweeper",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions": generate_config(
"questions",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire": generate_config(
"roman_empire",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers": generate_config(
"tolokers",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# NSD — Diagonal (--preset <dataset>_nsd_diagonal)
# ------------------------------------------------------------------
"cora_nsd_diagonal": generate_config(
"cora",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer_nsd_diagonal": generate_config(
"citeseer",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon_nsd_diagonal": generate_config(
"chameleon",
variant="diagonal",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_nsd_diagonal": generate_config(
"squirrel",
variant="diagonal",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered_nsd_diagonal": generate_config(
"chameleon_filtered",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered_nsd_diagonal": generate_config(
"squirrel_filtered",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell_nsd_diagonal": generate_config(
"cornell",
variant="diagonal",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas_nsd_diagonal": generate_config(
"texas",
variant="diagonal",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film_nsd_diagonal": generate_config(
"film",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings_nsd_diagonal": generate_config(
"amazon_ratings",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper_nsd_diagonal": generate_config(
"minesweeper",
variant="diagonal",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions_nsd_diagonal": generate_config(
"questions",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire_nsd_diagonal": generate_config(
"roman_empire",
variant="diagonal",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers_nsd_diagonal": generate_config(
"tolokers",
variant="diagonal",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# NSD — General (--preset <dataset>_nsd_general)
# ------------------------------------------------------------------
"cora_nsd_general": generate_config(
"cora",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer_nsd_general": generate_config(
"citeseer",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon_nsd_general": generate_config(
"chameleon",
variant="general",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_nsd_general": generate_config(
"squirrel",
variant="general",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered_nsd_general": generate_config(
"chameleon_filtered",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered_nsd_general": generate_config(
"squirrel_filtered",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell_nsd_general": generate_config(
"cornell",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas_nsd_general": generate_config(
"texas",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film_nsd_general": generate_config(
"film",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings_nsd_general": generate_config(
"amazon_ratings",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper_nsd_general": generate_config(
"minesweeper",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions_nsd_general": generate_config(
"questions",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire_nsd_general": generate_config(
"roman_empire",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers_nsd_general": generate_config(
"tolokers",
variant="general",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# NSD — Orthogonal (--preset <dataset>_nsd_orthogonal)
# ------------------------------------------------------------------
"cora_nsd_orthogonal": generate_config(
"cora",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer_nsd_orthogonal": generate_config(
"citeseer",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon_nsd_orthogonal": generate_config(
"chameleon",
variant="orthogonal",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_nsd_orthogonal": generate_config(
"squirrel",
variant="orthogonal",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered_nsd_orthogonal": generate_config(
"chameleon_filtered",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered_nsd_orthogonal": generate_config(
"squirrel_filtered",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell_nsd_orthogonal": generate_config(
"cornell",
variant="orthogonal",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas_nsd_orthogonal": generate_config(
"texas",
variant="orthogonal",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film_nsd_orthogonal": generate_config(
"film",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings_nsd_orthogonal": generate_config(
"amazon_ratings",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper_nsd_orthogonal": generate_config(
"minesweeper",
variant="orthogonal",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions_nsd_orthogonal": generate_config(
"questions",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire_nsd_orthogonal": generate_config(
"roman_empire",
variant="orthogonal",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers_nsd_orthogonal": generate_config(
"tolokers",
variant="orthogonal",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# NSD — General Attention (--preset <dataset>_nsd_general_attention)
# ------------------------------------------------------------------
"cora_nsd_general_attention": generate_config(
"cora",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer_nsd_general_attention": generate_config(
"citeseer",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon_nsd_general_attention": generate_config(
"chameleon",
variant="general_attention",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_nsd_general_attention": generate_config(
"squirrel",
variant="general_attention",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered_nsd_general_attention": generate_config(
"chameleon_filtered",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered_nsd_general_attention": generate_config(
"squirrel_filtered",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell_nsd_general_attention": generate_config(
"cornell",
variant="general_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas_nsd_general_attention": generate_config(
"texas",
variant="general_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film_nsd_general_attention": generate_config(
"film",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings_nsd_general_attention": generate_config(
"amazon_ratings",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper_nsd_general_attention": generate_config(
"minesweeper",
variant="general_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions_nsd_general_attention": generate_config(
"questions",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire_nsd_general_attention": generate_config(
"roman_empire",
variant="general_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers_nsd_general_attention": generate_config(
"tolokers",
variant="general_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# NSD — Orthogonal Attention (--preset <dataset>_nsd_orthogonal_attention)
# ------------------------------------------------------------------
"cora_nsd_orthogonal_attention": generate_config(
"cora",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer_nsd_orthogonal_attention": generate_config(
"citeseer",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon_nsd_orthogonal_attention": generate_config(
"chameleon",
variant="orthogonal_attention",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_nsd_orthogonal_attention": generate_config(
"squirrel",
variant="orthogonal_attention",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered_nsd_orthogonal_attention": generate_config(
"chameleon_filtered",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered_nsd_orthogonal_attention": generate_config(
"squirrel_filtered",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell_nsd_orthogonal_attention": generate_config(
"cornell",
variant="orthogonal_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas_nsd_orthogonal_attention": generate_config(
"texas",
variant="orthogonal_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film_nsd_orthogonal_attention": generate_config(
"film",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings_nsd_orthogonal_attention": generate_config(
"amazon_ratings",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper_nsd_orthogonal_attention": generate_config(
"minesweeper",
variant="orthogonal_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions_nsd_orthogonal_attention": generate_config(
"questions",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire_nsd_orthogonal_attention": generate_config(
"roman_empire",
variant="orthogonal_attention",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers_nsd_orthogonal_attention": generate_config(
"tolokers",
variant="orthogonal_attention",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# NSD — Low-Rank (--preset <dataset>_nsd_low_rank)
# ------------------------------------------------------------------
"cora_nsd_low_rank": generate_config(
"cora",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"citeseer_nsd_low_rank": generate_config(
"citeseer",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"chameleon_nsd_low_rank": generate_config(
"chameleon",
variant="low_rank",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_nsd_low_rank": generate_config(
"squirrel",
variant="low_rank",
stalk_dim=5,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"chameleon_filtered_nsd_low_rank": generate_config(
"chameleon_filtered",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"squirrel_filtered_nsd_low_rank": generate_config(
"squirrel_filtered",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=1e-3,
),
"cornell_nsd_low_rank": generate_config(
"cornell",
variant="low_rank",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"texas_nsd_low_rank": generate_config(
"texas",
variant="low_rank",
stalk_dim=3,
hidden_dim=16,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
stop_strategy="acc",
),
"film_nsd_low_rank": generate_config(
"film",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.005,
weight_decay=5e-4,
),
"amazon_ratings_nsd_low_rank": generate_config(
"amazon_ratings",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"minesweeper_nsd_low_rank": generate_config(
"minesweeper",
variant="low_rank",
stalk_dim=3,
hidden_dim=16,
num_layers=4,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"questions_nsd_low_rank": generate_config(
"questions",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.005,
weight_decay=5e-4,
),
"roman_empire_nsd_low_rank": generate_config(
"roman_empire",
variant="low_rank",
stalk_dim=4,
hidden_dim=32,
num_layers=4,
alpha=1.0,
input_dropout=0.5,
dropout=0.5,
lr=0.01,
weight_decay=5e-4,
),
"tolokers_nsd_low_rank": generate_config(
"tolokers",
variant="low_rank",
stalk_dim=3,
hidden_dim=16,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# HGB (--preset <dataset>_hgb)
# ------------------------------------------------------------------
"imdb_hgb": generate_config(
"imdb",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"dblp_hgb": generate_config(
"dblp",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"acm_hgb": generate_config(
"acm",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
"freebase_hgb": generate_config(
"freebase",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.5,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# Graph Classification (--preset <dataset>_gc)
# ------------------------------------------------------------------
"mutag_gc": generate_config(
"mutag",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"proteins_gc": generate_config(
"proteins",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"enzymes_gc": generate_config(
"enzymes",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"nci1_gc": generate_config(
"nci1",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"nci109_gc": generate_config(
"nci109",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"ptcmr_gc": generate_config(
"ptcmr",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"dd_gc": generate_config(
"dd",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"collab_gc": generate_config(
"collab",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"imdb_b_gc": generate_config(
"imdb_b",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"imdb_m_gc": generate_config(
"imdb_m",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"reddit_b_gc": generate_config(
"reddit_b",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
"reddit_5k_gc": generate_config(
"reddit_5k",
variant="general",
stalk_dim=4,
hidden_dim=32,
num_layers=3,
alpha=1.0,
input_dropout=0.2,
dropout=0.2,
lr=0.01,
weight_decay=5e-4,
),
# ------------------------------------------------------------------
# Link Prediction (--preset <dataset>_lp)
# ------------------------------------------------------------------
"lastfm_lp": generate_config(
"lastfm",
variant="general",
stalk_dim=2,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
jknet=True,
normalize_output=False,
),
"movielens_lp": generate_config(
"movielens",
variant="general",
stalk_dim=2,
hidden_dim=32,
num_layers=2,
alpha=1.0,
input_dropout=0.0,
dropout=0.0,
lr=0.01,
weight_decay=5e-4,
jknet=True,
normalize_output=False,
),
}
# ---------------------------------------------------------------------------
# Registry instance – populated from _PRESETS at import time
# ---------------------------------------------------------------------------
preset_registry: PresetRegistry = PresetRegistry()
for _name, _cfg in _PRESETS.items():
preset_registry.register(_name, _cfg)
del _name, _cfg