Graph Classification: From Node Embeddings to Graph Embeddings

5 minute read

Published:

TL;DR: A graph classifier has three stages: (1) message passing to build node embeddings; (2) readout to collapse node embeddings into a graph embedding; (3) an MLP to predict from the graph embedding. The expressiveness bottleneck is usually the readout step, not the message passing. Choosing sum readout + GIN + MLP achieves 1-WL expressiveness for graph-level tasks.

The Graph Classification Pipeline

Given a dataset of graphs {(Gโ‚, yโ‚), โ€ฆ, (Gโ‚™, yโ‚™)}, the goal is to learn a function f: G โ†’ y. Unlike node classification (predict per-node label) or link prediction (predict edge existence), graph classification must process entire graphs of varying sizes.

The standard pipeline:

Input graph G = (V, E, X)
         โ†“
[Message Passing: K layers]
         โ†“
Node embeddings {h^(K)_v : v โˆˆ V}
         โ†“
[Readout: global pooling]
         โ†“
Graph embedding h_G โˆˆ โ„^d
         โ†“
[MLP classifier]
         โ†“
Prediction ลท

Message Passing for Graph Classification

The message passing stage is the same as for node-level tasks. The only difference: we do not use the final node embeddings directly โ€” we aggregate them.

JK-Net readout: rather than using only the last-layer embeddings, JK-Net concatenates all intermediate embeddings before pooling:

h_G = READOUT( concat(h^(1)_v, h^(2)_v, ..., h^(K)_v) for v โˆˆ V )

This is particularly useful for graph classification: different nodes may require different receptive field sizes, and combining all layers ensures no scale is lost.

The GIN Recipe for Graph Classification

GIN (Graph Isomorphism Network) achieves 1-WL expressiveness. For graph classification:

  1. K layers of GIN message passing (sum aggregation + injective MLP)
  2. Sum readout over all layer outputs:
h_G = ฮฃ_{k=0}^{K} ฮฃ_{v โˆˆ V} h^{(k)}_v

This double-sum ensures both layer-wise and node-wise information is captured.

  1. MLP classifier on h_G โ†’ ลท

The combination of sum aggregation (injective over multisets) + sum readout (preserves count information) + MLP (universal approximator) achieves the maximum expressiveness of any MPNN.

Why sum over all layers? This is the GIN paper's key insight. Summing over all K layer outputs means the graph embedding represents the collection of all K-hop neighbourhoods for all nodes. Two graphs with different structures at any scale will have different sums โ€” unlike using only the final layer, which captures only K-hop structure.

Benchmarks and Datasets

TUDatasets (standard graph classification benchmarks):

  • MUTAG (188 graphs, 2 classes): mutagenic aromatic compounds
  • PROTEINS (1113 graphs, 2 classes): enzyme vs non-enzyme proteins
  • IMDB-B (1000 graphs, 2 classes): movie collaboration graphs
  • REDDIT-B (2000 graphs, 2 classes): discussion thread graphs
  • COLLAB (5000 graphs, 3 classes): collaboration networks

Note: these benchmarks have been criticised for high variance and potential data leakage. OGB (Open Graph Benchmark) provides more rigorous benchmarks.

OGB graph classification benchmarks:

  • ogbg-molhiv: HIV activity prediction (41,127 molecules)
  • ogbg-molpcba: molecular property prediction (437,929 molecules)
  • ogbg-ppa: protein function prediction (158,100 protein interaction graphs)

Baseline vs State-of-the-Art Performance

On MUTAG and similar small datasets, the performance hierarchy is roughly:

GCN + mean pooling: ~73%
GCN + sum pooling:  ~80%
GIN + sum pooling:  ~89%
DiffPool:           ~87%
Set2Set + MPNN:     ~91%
Graph Transformers: ~92%+

(Illustrative; exact numbers vary by split and implementation.)

Common Failure Modes

Readout bottleneck: using mean pooling with a powerful GNN loses count information โ€” two graphs with different sizes but proportionally identical node distributions get the same embedding.

Depth collapse: adding too many message passing layers โ†’ oversmoothing โ†’ all node embeddings identical โ†’ graph embeddings identical regardless of structure.

Benchmark overfitting: TUDataset benchmarks are small and high-variance. Performance differences < 2% should not be interpreted as meaningful without statistical testing.

End-to-End Training

The entire pipeline (GNN + readout + MLP) is trained end-to-end with a single loss (cross-entropy for classification, MSE for regression). The readout step is differentiable for all standard choices (sum/mean/max are differentiable; attention readout is differentiable; DiffPool is differentiable via soft assignment; TopKPool is approximately differentiable via score gating).

Summary

Design choiceRecommendation
Message passingGIN (most expressive MPNN)
ReadoutSum (most expressive), or attention for task-adaptive
Hierarchical poolingDiffPool (small graphs), TopKPool/SAGPool (large graphs)
MLP depth2-3 layers with batch norm
Layer combinationJK-Net style concatenation before readout

Graph classification ties together all the concepts in the pooling section: the choice of message passing determines per-node expressiveness; the readout determines what graph-level information is preserved; the MLP maps the graph summary to the prediction. Getting all three right is what separates random-chance performance from state-of-the-art.

References