DiffPool: Learning Hierarchical Graph Pooling

4 minute read

Published:

TL;DR: DiffPool (Ying et al., 2018) uses a second GNN to learn soft cluster assignments โ€” each node is assigned fractionally to each cluster. The coarsened graph is then passed to another GNN, building a hierarchy from fine-grained nodes to coarse super-nodes. It's end-to-end differentiable and captures multi-scale structure that flat pooling misses.

The Limitation of Flat Global Pooling

Global mean/sum/max pooling jumps directly from N node embeddings to a single graph embedding. For graphs with hierarchical structure โ€” like molecules (atoms โ†’ functional groups โ†’ whole molecule) or social networks (people โ†’ communities โ†’ factions) โ€” this skips all intermediate scales.

CNNs solve this with hierarchical pooling (max-pool after each conv layer). DiffPool brings the same idea to graphs, with the key challenge: graph pooling must be permutation-invariant and must handle variable numbers of nodes.

The DiffPool Architecture

DiffPool processes each pooling level l with two GNNs:

1. Embedding GNN โ€” computes node embeddings:

Z^{(l)} = GNN_{embed}^{(l)}( A^{(l)}, H^{(l)} )

2. Pooling GNN โ€” computes cluster assignments:

S^{(l)} = softmax( GNN_{pool}^{(l)}( A^{(l)}, H^{(l)} ) )

S^{(l)} โˆˆ โ„^{N_l ร— k_l} is the soft assignment matrix: S[i,j] is the probability that node i belongs to cluster j. softmax is applied row-wise.

Coarsening: compute the next levelโ€™s adjacency and embeddings:

H^{(l+1)} = S^{(l)T} Z^{(l)} โˆˆ โ„^{k_l ร— d}
A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)} โˆˆ โ„^{k_l ร— k_l}

The new adjacency A^{(l+1)} is the cluster-to-cluster connectivity โ€” how much two clusters share edges through their constituent nodes.

Why Soft Assignment?

Hard assignment (each node assigned to exactly one cluster) would require argmax โ€” not differentiable. Soft assignment (each node fractionally assigned to all clusters) allows end-to-end gradient flow.

The assignment S^{(l)} is learned jointly with the rest of the network. The model discovers which nodes should be clustered together โ€” without any external supervision on the clustering.

The analogy to attention: DiffPool's soft assignment is conceptually similar to attention in Transformers โ€” it learns a soft selection over items. Here, each cluster "attends" to nodes, and the resulting cluster embedding is a weighted sum of node embeddings (S^T Z). The difference: DiffPool reduces the number of tokens (nodes coarsened to clusters), while attention preserves the token count.

Auxiliary Loss Terms

DiffPool adds two regularisation losses to guide clustering quality:

Link prediction loss: encourages nodes connected by an edge to be assigned to the same cluster:

L_{LP} = ||A^{(l)} - S^{(l)} S^{(l)T}||_F

If S^{(l)} S^{(l)T} approximates A^{(l)}, clusters are graph-connected.

Entropy loss: encourages each nodeโ€™s assignment to be concentrated (not uniformly spread):

L_E = (1/N) ฮฃ_i H( S_i )

Where H is entropy. Low entropy โ†’ sharp cluster assignment โ†’ more interpretable clusters.

Computational Cost

DiffPool runs two GNNs at each level. For a graph with N nodes coarsened to k clusters:

  • Forward pass: O(Nยฒ d) for the dense assignment matrix
  • Memory: O(Nยฒ) โ€” DiffPool operates on dense adjacency matrices

This makes DiffPool quadratic in N โ€” practical for graphs with hundreds of nodes (molecules), but not for social networks or knowledge graphs with millions of nodes.

When DiffPool Helps

DiffPool outperforms flat pooling when:

  1. The task requires multi-scale understanding: molecular property prediction benefits from atom-level and functional-group-level representations simultaneously
  2. Graphs have natural hierarchical structure: trees, clustered communities, hierarchical molecules
  3. The graph is small enough: typically โ‰ค 1000 nodes

It is a standard strong baseline on the TUDataset graph classification benchmarks (PROTEINS, MUTAG, IMDB, RDT).

Limitations

  1. Quadratic memory: no scaling to large graphs
  2. Fixed number of clusters: must specify k^{(l)} per level before training
  3. No guarantee of meaningful clusters: the auxiliary losses help but do not force semantically meaningful groupings
  4. Sensitive to the number of pooling levels: too many levels โ†’ over-compression; too few โ†’ flat pooling

Summary

ComponentRole
GNN_{embed}Compute node representations at this scale
GNN_{pool}Learn soft cluster assignments
S^T ZAggregate node embeddings into cluster embeddings
S^T A SCoarsen adjacency to cluster graph
L_{LP} + L_EEncourage connectivity-aligned, sharp clusters

DiffPool introduced the idea of learned hierarchical pooling for graphs โ€” differentiable, end-to-end, and structure-aware. Its quadratic complexity limits scale, but for small-graph tasks (molecules, proteins), it remains a reference architecture.

References