Graph Classification: From Node Embeddings to Graph Embeddings
Published:
The Graph Classification Pipeline
Given a dataset of graphs {(Gโ, yโ), โฆ, (Gโ, yโ)}, the goal is to learn a function f: G โ y. Unlike node classification (predict per-node label) or link prediction (predict edge existence), graph classification must process entire graphs of varying sizes.
The standard pipeline:
Input graph G = (V, E, X)
โ
[Message Passing: K layers]
โ
Node embeddings {h^(K)_v : v โ V}
โ
[Readout: global pooling]
โ
Graph embedding h_G โ โ^d
โ
[MLP classifier]
โ
Prediction ลท
Message Passing for Graph Classification
The message passing stage is the same as for node-level tasks. The only difference: we do not use the final node embeddings directly โ we aggregate them.
JK-Net readout: rather than using only the last-layer embeddings, JK-Net concatenates all intermediate embeddings before pooling:
This is particularly useful for graph classification: different nodes may require different receptive field sizes, and combining all layers ensures no scale is lost.
The GIN Recipe for Graph Classification
GIN (Graph Isomorphism Network) achieves 1-WL expressiveness. For graph classification:
- K layers of GIN message passing (sum aggregation + injective MLP)
- Sum readout over all layer outputs:
This double-sum ensures both layer-wise and node-wise information is captured.
- MLP classifier on h_G โ ลท
The combination of sum aggregation (injective over multisets) + sum readout (preserves count information) + MLP (universal approximator) achieves the maximum expressiveness of any MPNN.
Benchmarks and Datasets
TUDatasets (standard graph classification benchmarks):
- MUTAG (188 graphs, 2 classes): mutagenic aromatic compounds
- PROTEINS (1113 graphs, 2 classes): enzyme vs non-enzyme proteins
- IMDB-B (1000 graphs, 2 classes): movie collaboration graphs
- REDDIT-B (2000 graphs, 2 classes): discussion thread graphs
- COLLAB (5000 graphs, 3 classes): collaboration networks
Note: these benchmarks have been criticised for high variance and potential data leakage. OGB (Open Graph Benchmark) provides more rigorous benchmarks.
OGB graph classification benchmarks:
- ogbg-molhiv: HIV activity prediction (41,127 molecules)
- ogbg-molpcba: molecular property prediction (437,929 molecules)
- ogbg-ppa: protein function prediction (158,100 protein interaction graphs)
Baseline vs State-of-the-Art Performance
On MUTAG and similar small datasets, the performance hierarchy is roughly:
GCN + mean pooling: ~73%
GCN + sum pooling: ~80%
GIN + sum pooling: ~89%
DiffPool: ~87%
Set2Set + MPNN: ~91%
Graph Transformers: ~92%+
(Illustrative; exact numbers vary by split and implementation.)
Common Failure Modes
Readout bottleneck: using mean pooling with a powerful GNN loses count information โ two graphs with different sizes but proportionally identical node distributions get the same embedding.
Depth collapse: adding too many message passing layers โ oversmoothing โ all node embeddings identical โ graph embeddings identical regardless of structure.
Benchmark overfitting: TUDataset benchmarks are small and high-variance. Performance differences < 2% should not be interpreted as meaningful without statistical testing.
End-to-End Training
The entire pipeline (GNN + readout + MLP) is trained end-to-end with a single loss (cross-entropy for classification, MSE for regression). The readout step is differentiable for all standard choices (sum/mean/max are differentiable; attention readout is differentiable; DiffPool is differentiable via soft assignment; TopKPool is approximately differentiable via score gating).
Summary
| Design choice | Recommendation |
|---|---|
| Message passing | GIN (most expressive MPNN) |
| Readout | Sum (most expressive), or attention for task-adaptive |
| Hierarchical pooling | DiffPool (small graphs), TopKPool/SAGPool (large graphs) |
| MLP depth | 2-3 layers with batch norm |
| Layer combination | JK-Net style concatenation before readout |
Graph classification ties together all the concepts in the pooling section: the choice of message passing determines per-node expressiveness; the readout determines what graph-level information is preserved; the MLP maps the graph summary to the prediction. Getting all three right is what separates random-chance performance from state-of-the-art.
References
- Xu, K., Hu, W., Leskovec, J., & Jegelka, S. (2019). How Powerful are Graph Neural Networks?. ICLR 2019 (GIN โ most expressive MPNN for graph classification).
- Xu, K., Li, C., Tian, Y., Sonobe, T., Kawarabayashi, K., & Jegelka, S. (2018). Representation Learning on Graphs with Jumping Knowledge Networks. ICML 2018 (JK-Net readout).
- Hu, W., Fey, M., Zitnik, M., Dong, Y., Ren, H., Liu, B., Catasta, M., & Leskovec, J. (2020). Open Graph Benchmark: Datasets for Machine Learning on Graphs. NeurIPS 2020 (OGB benchmarks).
