Set2Set and Attention Readout: Order-Invariant Graph Summaries

4 minute read

Published:

TL;DR: Attention readout weights node embeddings by learned importance scores before summing — nodes that matter more for the task contribute more to the graph embedding. Set2Set extends this with an LSTM that makes T passes over the node set, each time computing a different attention query. This yields a richer, order-invariant graph summary.

Beyond Uniform Pooling

Mean and sum pooling treat all nodes identically. But for most tasks, nodes differ greatly in importance:

  • In a molecule, the reactive functional group matters more than inert backbone atoms
  • In a social network, hubs matter more than peripheral nodes
  • In a citation graph, landmark papers matter more than derivative works

Attention readout learns these importance differences during training.

Attention Readout (Global Attention Pooling)

For each node v, compute a scalar importance score:

a_v = MLP_gate( h_v ) ∈ ℝ

Normalise scores with softmax:

α_v = exp(a_v) / Σ_{u ∈ V} exp(a_u)

Compute the graph embedding as a weighted sum:

h_G = Σ_{v ∈ V} α_v · MLP_out(h_v)

This is a single-pass soft attention over all nodes. The model learns which nodes to weight highly for the specific prediction task.

Properties:

  • Permutation-invariant (softmax and weighted sum are unordered)
  • Differentiable: all operations are smooth
  • Task-conditioned: α_v depends on h_v which encodes local neighbourhood

Limitation: each node is scored independently. The attention over node v does not account for what other nodes contribute — the weights are computed in isolation.

Set2Set (Vinyals et al., 2015)

Set2Set produces a graph embedding using T steps of LSTM-driven attention. At each step t, the LSTM maintains a query vector q_t, which is used to compute attention over all nodes:

Step t:

e^t_v = q_t · h_v (attention score for node v) α^t_v = softmax(e^t_v) m_t = Σ_v α^t_v h_v (weighted sum at step t)

Update the LSTM:

(q_{t+1}, c_{t+1}) = LSTM( [q_t ; m_t], c_t )

After T steps, the final graph embedding is:

h_G = [q_T ; m_T]

(concatenation of LSTM hidden state and final attended message)

Why multiple passes? At step t=1, the query q_1 is random — the model attends roughly uniformly. At step t=2, q_2 has seen what step 1 attended to, and can direct attention elsewhere. By step T, the LSTM has built up a rich query sequence — each step "reads" a different aspect of the node set. This is analogous to multi-head attention reading different subspaces.

Set2Set vs Attention Readout vs Sum

PropertySumAttention ReadoutSet2Set
Weights nodes uniformlyYesNoNo
Learns importanceNoYes (independently)Yes (iteratively)
Multiple passes over nodesNoNoYes (T passes)
Output dimensiondd2d
ComplexityO(N d)O(N d)O(T N d)
Permutation-invariantYesYesYes

When Set2Set Helps

Set2Set is particularly effective when:

  • Graph-level prediction requires integrating information from multiple disjoint node subsets
  • Different “aspects” of the graph matter for the prediction (Set2Set reads each in turn)
  • The graph size varies widely across the dataset (attention readout adapts better than fixed pooling)

On molecular benchmarks (QM9 for molecular property prediction), Set2Set significantly outperforms mean/sum pooling and slightly outperforms single-pass attention.

Multi-head Attention Readout

A simpler extension of attention readout: compute K independent attention heads, each with its own gate MLP:

h_G = concat( Σ_v α^1_v h_v, ..., Σ_v α^K_v h_v )

Each head learns to attend to a different subset of important nodes. This gives multi-aspect graph summarisation without the LSTM overhead of Set2Set.

Summary

MethodCore ideaStrength
Sum/MeanUniform aggregationSimple, fast
Attention readoutLearned per-node weightsTask-adaptive
Set2SetLSTM queries node set T timesRich multi-pass summary
Multi-head attentionMultiple independent attention poolsBalanced expressiveness/cost

For small graphs (molecules, proteins), Set2Set and multi-head attention provide meaningful improvements over flat pooling. For large graphs, the O(T N d) cost of Set2Set may be prohibitive, making single-pass attention readout the preferred choice.

References