PersLay: A Neural Network Layer for Persistence Diagrams

2 minute read

Published:

TL;DR: PersLay defines a general class of persistence diagram vectorisations as ρ(Σᵢ wᵢ · φ(pᵢ)) where pᵢ are diagram points, φ is a point transformation, w is a weight function (emphasising long-lived features), and ρ is a permutation-invariant aggregation (sum, max, mean). All classical vectorisations fit this template. Training jointly with the task learns the best φ, w, ρ for the data.

The Vectorisation Problem

Persistence diagrams are multisets of points in \(\mathbb{R}^2\): a diagram \(D = \{(b_i, d_i)\}\) has variable size and is not directly compatible with neural network inputs (which expect fixed-size vectors).

The key challenge: the map from data to persistence diagram is:

  1. Non-differentiable at points where topology changes.
  2. Permutation-invariant (the diagram is a set, not an ordered list).

PersLay addresses both challenges.

The PersLay Framework

A PersLay layer computes:

$$\mathrm{PersLay}(D) = \rho\!\left(\sum_{p \in D} w(p) \cdot \varphi(p)\right)$$

where:

  • \(\varphi: \mathbb{R}^2 \to \mathbb{R}^q\) — a learnable point transformation (e.g., a small MLP or RBF).
  • \(w: \mathbb{R}^2 \to \mathbb{R}_{\geq 0}\) — a learnable weight function that emphasises important points; typically $$w(b,d) =d - b^\alpha$$ for persistence, or a learned function.
  • \(\rho\) — a permutation-invariant aggregation: sum, mean, max, or a DeepSets-style operation.

Recovering Classical Vectorisations

The PersLay framework strictly generalises classical methods:

Method\(\varphi(p)\)\(w(p)\)\(\rho\)
Persistence ImageGaussian at grid point\(\vert d-b \vert\)Sum
Persistence LandscapePiecewise linear tent\(1\)Sum
Persistence SilhouetteTent function\(\vert d-b \vert\)Mean
Betti CurveIndicator function\(1\)Sum

Each classical method corresponds to a specific, fixed choice of \(\varphi, w, \rho\). PersLay learns these jointly from data.

Differentiability

The sum \(\sum_{p \in D} w(p) \cdot \varphi(p)\) is differentiable with respect to the diagram points \(p = (b_i, d_i)\) and the parameters of \(\varphi, w\). This enables:

  1. Backpropagation through the diagram to the data (if the filtration is differentiable).
  2. End-to-end learning: PersLay → dense layers → classifier, trained jointly.
Key Insight: The PersLay framework reveals a deep connection between TDA and DeepSets: both are permutation-invariant aggregators of point features. The difference is that persistence diagram points carry geometric meaning (birth-death coordinates), so the weight function w(p) can exploit the constraint d > b. In practice, learned weights concentrate on long-lived features (large |d-b|), automatically discovering the "most topologically significant" structure in a task-specific way.

Empirical Performance

On graph classification benchmarks (MUTAG, PROTEINS, REDDIT-B), PersLay with learned parameters matches or exceeds classical kernel methods and some GNN baselines, especially when topology is a meaningful inductive bias (e.g., molecular graphs where rings matter).

References

  • M. Carrière, F. Chazal, Y. Ike, T. Lacombe, M. Royer, Y. Umeda, “PersLay: A Neural Network Layer for Persistence Diagrams and New Graph Topological Signatures,” AISTATS, 2020. arXiv:1904.09378.
  • Z. Zaheer et al., “Deep Sets,” NeurIPS 2017.