Graph Neural Networks: Learning on Graphs

3 minute read

Published:

TL;DR: GNNs learn vector representations for nodes (and graphs) by iteratively aggregating information from neighbourhoods. They outperform flat neural networks on any data that is naturally relational — molecules, social graphs, knowledge graphs, road networks, and more.

Graphs Are Everywhere

A graph G = (V, E) consists of:

  • Nodes (V): entities — atoms, people, papers, intersections.
  • Edges (E): relationships — bonds, friendships, citations, roads.
  • Features on nodes and/or edges: atom type, age, year, speed limit.

Real-world data that’s naturally a graph:

  • Molecules: atoms = nodes, bonds = edges. Predicting drug toxicity or binding affinity.
  • Social networks: users = nodes, follows/friends = edges. Recommendation, fraud detection.
  • Knowledge graphs: entities = nodes, relations = edges. Question answering, link prediction.
  • Citation networks: papers = nodes, citations = edges. Classifying papers by topic.
  • Road networks: intersections = nodes, roads = edges. Route planning, traffic prediction.

Why Not Just Use Standard Neural Networks?

A standard MLP takes a fixed-size vector as input. Graphs have:

  • Variable size — different graphs have different numbers of nodes and edges.
  • No canonical ordering — there’s no “first” node; permuting nodes shouldn’t change predictions.
  • Relational structure — the patterns live in the connections, not just the individual features.

GNNs are designed to respect all three of these properties.

A Simple Graph A C: atom B N: atom C O: atom D H: atom single double GNN Node Embeddings A [0.2, 0.8, ...] B [0.5, 0.3, ...] C [0.1, 0.9, ...] Downstream tasks: Node classification · Link prediction · Graph classification
Figure 1: A GNN takes a graph with node features (atom types) and produces rich node embeddings that capture local and global structure. These embeddings support downstream tasks.

The Core Idea: Aggregate from Neighbours

Every GNN follows the same fundamental principle, called message passing:

Each node’s new representation = function(its current representation, representations of its neighbours)

After k iterations, node v’s embedding captures information from all nodes up to k hops away (its k-hop neighbourhood).

This is beautiful because:

  • Nearby nodes influence each other (just like in the real world).
  • The same aggregation function works on graphs of any size.
  • The function is learned from data, so it adapts to the task.

Three Task Levels

GNNs can produce predictions at three granularities:

LevelWhat you predictExample
NodeLabel for each nodeIs this user a bot?
EdgeLabel or score for each edgeWill A befriend B?
GraphLabel for the whole graphIs this molecule toxic?

For node tasks: use the node embeddings directly. For graph tasks: readout (pooling) the node embeddings into a single graph vector.

The Landscape of GNN Architectures

ModelYearKey idea
GCN2016Spectral convolution → normalised averaging
GAT2018Attention weights on edges
GraphSAGE2017Inductive learning via neighbourhood sampling
GIN2019Most expressive aggregator (sum + MLP)
Sheaf NN2022+Section-space diffusion, generalises GCN

✅ Key Takeaways

  • Graphs model relational data: atoms, users, papers, intersections — any entities with relationships.
  • GNNs learn by iterative neighbourhood aggregation: after k layers, each node knows about its k-hop neighbourhood.
  • The same model works on graphs of any size and any node ordering — it's permutation invariant/equivariant.
  • Supports node-, edge-, and graph-level predictions via readout pooling.