Learning Filtrations: Task-Optimised Topology

2 minute read

Published:

TL;DR: A filtration f: K โ†’ โ„ is just a function assigning "importance" to simplices. If we parameterise f by a neural network (e.g., a GNN on graph nodes/edges), we can learn f end-to-end by backpropagating through the persistence diagram. Graph Filtration Learning (Hofer et al. 2020) does exactly this: a 1-layer GNN outputs node values, inducing a filtration, whose persistence diagram is vectorised and classified.

The Fixed Filtration Problem

In classical TDA, the filtration is fixed by the data geometry:

  • Rips: \(f(\sigma) = \max_{u,v \in \sigma} d(u,v)\).
  • Sublevel set: \(f(\sigma) = \max_{v \in \sigma} h(v)\) for some fixed height function \(h\).

But for graph classification, the โ€œrightโ€ filtration depends on the task:

  • For classifying molecules by toxicity, bond lengths matter.
  • For social networks, community structure matters.
  • For protein folding graphs, secondary structure matters.

A fixed filtration cannot be simultaneously optimal for all tasks.

Graph Filtration Learning

Setup: Given a graph \(G = (V, E)\) with node features $$X \in \mathbb{R}^{V\times d}$$, define:
  1. A parameterised filtration \(f_\theta: V \cup E \to \mathbb{R}\) using a GNN:
    • Node values: \(f_\theta(v) = \mathrm{GNN}_\theta(v, X)\)
    • Edge values: \(f_\theta(\{u,v\}) = \max(f_\theta(u), f_\theta(v))\) (flag complex convention)
  2. Compute persistence diagram \(\mathrm{dgm}(f_\theta(G))\) using the flag filtration.

  3. Vectorise via PersLay or persistence images โ†’ dense layers โ†’ classification.

  4. Train end-to-end: gradients flow back through PersLay โ†’ through \(\mathrm{dgm}\) โ†’ to \(\theta\) in the GNN.
$$\theta^* = \arg\min_\theta \mathcal{L}(\mathrm{PersLay}(\mathrm{dgm}(f_\theta(G))), y)$$

Expressive Power

Theorem (Hofer et al. 2020): Graph Filtration Learning is strictly more powerful than 1-WL GNNs on certain graph families. Two graphs that are indistinguishable by the Weisfeiler-Leman test can be distinguished by their persistent \(H_0\) under a learned filtration.

The intuition: topology sees global structure (connectivity, cycles) that local message-passing misses.

Extended to Higher Dimensions

For \(H_1\), \(H_2\) persistence, one builds the Rips filtration on the learned node values: \(K_t = \{S \subseteq V : f_\theta(v) \leq t \ \forall v \in S, \mathrm{diam}(S) \leq t\}\)

This captures loops and voids in the graph, weighted by learned node importance.

Comparison: Fixed vs. Learned Filtrations

PropertyFixed (Rips/Sublevel)Learned
ComputationFastRequires GNN forward pass
Task optimalityNoYes
InterpretabilityHigh (geometric meaning)Task-dependent
Requires labelsNoYes (supervised)
Key Insight: Learning filtrations resolves a long-standing tension in TDA-ML: classical TDA computes topological features that are provably stable and interpretable, but may not be discriminative for a given task. Learned filtrations sacrifice some interpretability to gain task-relevance. The result in practice is that learned filtrations outperform fixed filtrations on benchmark graph classification tasks by 5โ€“15%, particularly on datasets with rich node features (e.g., biological networks, molecular graphs).

References

  • C. Hofer, F. Graf, B. Rieck, M. Niethammer, R. Kwitt, โ€œGraph Filtration Learning,โ€ ICML 2020. arXiv:1905.10996.
  • B. Rieck, C. Bock, K. Borgwardt, โ€œA Persistent Weisfeiler-Lehman Procedure for Graph Classification,โ€ ICML 2019.