Message Passing: The Universal GNN Framework

3 minute read

Published:

TL;DR: Message Passing Neural Networks (Gilmer et al., 2017) provide a unified framework for all GNNs. Each layer runs three steps: MESSAGE (what each neighbour sends), AGGREGATE (collect all messages), UPDATE (compute new node representation). Choosing different functions for each step gives you different GNN architectures.

The Framework

The MPNN framework (Gilmer et al., 2017, NeurIPS) defines GNN computation through a series of message passing steps. At each step t:

mt+1v = AGGREGATE({ MSG(htv, htu, euv) : u ∈ N(v) })

ht+1v = UPDATE(htv, mt+1v)

Where:

  • h^t_v — representation of node v at step t.
  • N(v) — neighbours of v.
  • e_{uv} — optional edge feature between u and v.
  • MSG — the message function.
  • AGGREGATE — combines all messages (must be permutation-invariant).
  • UPDATE — computes new representation from old + aggregated message.
B h_B A h_A MSG(h_A, h_B) C h_C MSG(h_C, h_B) D h_D MSG(h_D, h_B) ② AGGREGATE ③ UPDATE new h_B ✓ ① Compute messages
Figure 1: Node B receives messages from its three neighbours A, C, D. The messages are aggregated (e.g., summed or averaged), then combined with B's own representation in an UPDATE function to produce a new h_B.

Step 1: Message Function

The message function computes what each neighbour sends. The simplest choice: just send the neighbour’s features.

MSG(h_v, h_u, e_uv) = h_u         # GCN: just pass neighbour features
MSG(h_v, h_u, e_uv) = W · h_u     # Linear transform first
MSG(h_v, h_u, e_uv) = α · W · h_u # GAT: scale by attention weight

Including edge features allows the model to distinguish bond types in a molecule or relationship types in a knowledge graph.

Step 2: Aggregate Function

The aggregation combines all messages. It must be permutation-invariant (the order of neighbours shouldn’t matter):

AggregatorFormulaProperties  
SumΣ m_uCaptures size of neighbourhood  
Mean(1/N) Σ m_uNormalised, size-invariant
Maxmax_u m_uCaptures the most extreme feature  
Attention-weightedΣ α_u m_uAdaptive, like GAT  

GIN (see separate post) proves that sum is the most powerful aggregator for distinguishing graph structures. Mean and max lose information.

Step 3: Update Function

Given the aggregated message and the old representation, compute the new one:

h_v^new = σ(W · concat(h_v, agg_message))  # GCN-style
h_v^new = GRU(h_v, agg_message)            # Recurrent update
h_v^new = MLP(concat(h_v, agg_message))    # GraphSAGE-style

A Running Example: Molecule Property Prediction

Consider predicting if a molecule is toxic:

  • Nodes = atoms (features: atom type, charge, is_aromatic)
  • Edges = bonds (features: bond type: single/double/triple)
  • After k MPNN layers, each atom knows about its k-hop neighbourhood.
  • A readout aggregates all atom embeddings into a graph embedding.
  • An MLP predicts toxicity from the graph embedding.

After 3 layers, an atom “knows” about the atoms 3 bonds away — capturing local chemical environments like functional groups.

✅ Key Takeaways

  • All GNNs are instances of MPNN: choose MSG, AGGREGATE, and UPDATE functions.
  • AGGREGATE must be permutation-invariant. Sum is the most expressive choice (GIN).
  • After k layers, each node's embedding captures its k-hop neighbourhood.
  • Graph-level predictions require a readout function that pools node embeddings into a single vector.