MPNN: The General Message Passing Neural Network Framework

4 minute read

Published:

TL;DR: MPNN (Gilmer et al., 2017) defines a general GNN layer as three functions: (1) a message function M that computes messages between node pairs; (2) an aggregation function □ that combines all messages; (3) an update function U that produces the new node state. GCN, GAT, GIN, GraphSAGE are all special cases.

Why a Unified Framework?

By 2017, several successful GNN architectures existed — GCN, GGNN (Gated Graph Neural Network), Interaction Networks, etc. Each was described differently, making comparison and design difficult.

Gilmer et al. (2017) introduced MPNN to unify all spatial GNNs under one abstraction. This framework:

  • Makes design choices explicit and comparable
  • Enables systematic ablation and design
  • Identifies what is shared and what differs between architectures
  • Reveals the fundamental limits of the class (later formalised by the WL test)

The MPNN Framework

A single MPNN layer for node v computes:

Phase 1 — Message Computation:

m^{(k)}_{u→v} = M^{(k)}( h^{(k-1)}_v, h^{(k-1)}_u, e_{uv} )

For each neighbour u of v, compute a message. The message can depend on the sender’s state h_u, the receiver’s state h_v, and the edge feature e_{uv}.

Phase 2 — Aggregation:

m^{(k)}_v = □_{u ∈ N(v)} m^{(k)}_{u→v}

Aggregate all messages from neighbours. □ must be permutation-invariant (the order of neighbours does not matter): sum, mean, max, LSTM (with a random permutation of neighbours).

Phase 3 — Update:

h^{(k)}_v = U^{(k)}( h^{(k-1)}_v, m^{(k)}_v )

Combine the aggregated message with the node’s previous state to produce the new state.

After K rounds, each node’s representation h^{(K)}_v encodes information from its K-hop neighbourhood.

ModelMessage MAggregation □Update U
GCN(Ã h_u) / √(d_u d_v)SumLinear + ReLU
GraphSAGEW h_uMean (or max)W [h_v ‖ agg] + ReLU
GATα_{uv} W h_uSumReLU
GINh_uSumMLP( (1+ε)h_v + Σ h_u )
MPNN (original)r(h_u, e_{uv})SumGRU(h_v, m_v)

Where r is a learned edge network and GRU is a gated recurrent unit.

Edge Features in MPNN

A key advantage of the MPNN formulation: edge features are first-class citizens. The message function M can freely use e_{uv}:

m_{u→v} = A(e_{uv}) h_u

Where A is a learned edge-specific linear map (the “edge network” idea from Gilmer et al.). This is essential for molecules where bond types (single, double, aromatic) are critical features.

Readout for Graph-Level Prediction

After K message-passing rounds, MPNN computes a graph-level representation via a readout function:

h_G = R( { h^{(K)}_v : v ∈ V } )

The readout R must be permutation-invariant. Common choices:

  • Sum: h_G = Σ h_v (sensitive to graph size)
  • Mean: h_G = (1/N) Σ h_v (normalised)
  • Set2Set: an attention-based readout with memory

The Limits of MPNN

All MPNN models share the same fundamental limitation: their expressive power is bounded by the 1-Weisfeiler-Lehman graph isomorphism test (1-WL test, see the Expressivity section).

Any two graphs that are indistinguishable by 1-WL are also indistinguishable by any MPNN. Graphs with identical multisets of k-hop neighbourhood features cannot be told apart, regardless of the specific M, □, and U functions.

This limitation motivates:

  • Higher-order GNNs (k-WL for k > 1)
  • Structural encodings (adding positional/structural features that break symmetry)
  • Graph Transformers (attend globally, not just to neighbours)
Why is this limit important? It means that no matter how expressive your message function or update function is, if two nodes have identical multisets of neighbour features (at every hop), the MPNN cannot distinguish them. Knowing this limit tells you exactly when to add structural encodings or upgrade to a more expressive architecture.

Practical Implementation

In PyTorch Geometric, MPNN-style models are implemented by inheriting MessagePassing:

class MyGNNLayer(MessagePassing):
    def __init__(self):
        super().__init__(aggr='sum')  # □ = sum

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # M: message from j to i
        return self.edge_net(edge_attr) @ x_j.unsqueeze(-1)

    def update(self, aggr_out, x):
        # U: update node state
        return self.gru(aggr_out, x)

Summary

The MPNN framework captures every spatial GNN in three functions:

FunctionRoleMust satisfy
M (message)Compute edge messagesCan use sender, receiver, edge features
□ (aggregation)Combine neighbour messagesPermutation invariant
U (update)Update node stateCombines old state + aggregate

The space of MPNNs is defined by choices for M, □, and U. Understanding this space — and its limits — is the foundation for understanding all GNN research from 2016 to the present.

References