Graph Multiset Transformer (GMT)
- The paper introduces GMT, a novel graph pooling approach that reformulates node aggregation as a multiset encoding task augmented with explicit structural modeling.
- GMT employs a graph multi-head attention mechanism and Transformer-style self-attention to transform variable-sized node sets into fixed-dimensional embeddings with the power of the 1-WL test.
- Experimental results show GMT's superior performance and efficiency in graph tasks, delivering competitive accuracy on classification, reconstruction, and generation while scaling to large graphs.
The Graph Multiset Transformer (GMT) is a global graph pooling mechanism that formulates the graph pooling problem as a multiset encoding challenge augmented with explicit modeling of structural dependencies. GMT offers a permutation-invariant, injective, and computationally efficient approach for condensing arbitrary node sets into fixed-dimensional graph representations, using a graph-structured multi-head attention mechanism. The design directly addresses the limitations of both simple sum/average pooling and established hierarchical pooling protocols by integrating auxiliary graph structure within the pooling operation and achieving expressive power up to the 1-dimensional Weisfeiler–Lehman graph isomorphism test (Baek et al., 2021).
1. Motivation and Background
In graph neural networks (GNNs), effective node and edge representations are achieved via iterative neighborhood aggregation. However, numerous downstream tasks—such as graph classification, reconstruction, and generation—require a global, fixed-dimensional graph embedding. Common strategies for this readout step involve functionals across the set of node embeddings, including sum, mean, or max pooling. These functionals uniformly weigh nodes and cannot adapt to task relevance or exploit higher-order structural dependencies. Hierarchical pooling via node dropping (e.g., TopKPool, SAGPool, SortPool) risks critical information loss, while clustering-based schemes (e.g., DiffPool, MinCutPool) require dense assignments and adjacency coarsening, incurring computational demands and potentially sacrificing injectiveness—so different, non-isomorphic graphs can collapse to indistinguishable representations.
Reconceptualizing the problem, GMT considers the node embeddings as a multiset and incorporates the adjacency matrix as auxiliary structure. The graph pooling objective is then to design a function:
that is permutation-invariant, injective over pairs , and computationally efficient.
2. Model Architecture and Pooling Mechanism
GMT proceeds by compressing node features into “seed” embeddings using a multi-head attention architecture tailored for graphs, followed (optionally) by Transformer-style self-attention among the condensed vectors.
Key Components:
- Seed Matrix : A fixed, learnable set of query vectors.
- Graph Multi-Head Attention (GMH): For attention heads, each computes projections:
- , ,
- Attention weights:
- Output:
- Aggregate:
- Graph Multiset Pooling (GMPool): Applies a residual connection, row-wise feed-forward layer, and layer normalization:
where .
- Self-Attention Over Condensed Vectors: A Transformer block allowing the pooled vectors to interact.
- Readout Pipeline: Typically,
- Final embedding (with denoting identity)
Notably, graph structure is injected into attention keys and values through the local GNN blocks, crucially distinguishing GMT from structure-agnostic attention pooling.
3. Theoretical Properties
GMT guarantees permutation invariance: Any permutation of node order induces a corresponding permutation of key and value rows, but the set of outputs is invariant as a function of the input multiset. The construction yields injectiveness, contingent on the injectiveness of the underlying GNN projections in each attention head. Specifically, for each fixed seed, the output vector is an injective function over the input node multiset. Consequently, the architecture replicates the expressive capacity of the 1-dimensional Weisfeiler–Lehman (1-WL) graph isomorphism test: two non-isomorphic graphs distinguishable by 1-WL will map to different outputs under GMT.
The following table summarizes these properties:
| Property | Guarantee | Condition |
|---|---|---|
| Permutation invariance | Always | Applies to any input ordering |
| Injectiveness | Yes, up to 1-WL power | Each GNN must be injective (e.g., GIN or stronger multilayer) |
| Computational efficiency | time and memory |
4. Extension to Hierarchical Pooling
GMT's soft attention coefficients can be recast as a soft assignment matrix for cluster-based pooling, where each seed acts as a cluster centroid. This enables hierarchical coarsening analogous to DiffPool, but with complexity rather than . Coarsened adjacency and features are efficiently obtained. By stacking multiple levels with fresh seeds, a sparse, hierarchical pooling framework is achieved, inheriting the injectiveness and permutation invariance of the underlying GMT mechanisms.
5. Experimental Results
GMT demonstrates superior or competitive performance across a variety of graph learning tasks:
Graph Classification: Across TU-datasets (DD, PROTEINS, MUTAG, etc.) and OGB datasets (HIV, Tox21, ToxCast, BBBP), GMT achieves or matches state-of-the-art accuracy, outperforming both node-drop and clustering-based pooling methods. For instance, on DD, PROTEINS, and MUTAG, GMT reports mean accuracies/ROC-AUCs of , , respectively, and on HIV (AUC), exceeding or matching all baselines (Baek et al., 2021).
- Ablation Studies: Removal of graph-structured attention, self-attention, or GNN-based key/value projections degrades performance, establishing all components as contributing significant gains.
- Efficiency: GMT scales to graphs with up to 10,000 nodes with GB GPU RAM and s per forward pass; dense cluster methods require GB and more time.
- Graph Reconstruction and Generation: GMT yields near-perfect reconstruction on synthetic graphs and higher atom-type exact-match ratios ( at compression vs. $0.55$ for MinCutPool) on ZINC molecules. In generative settings (MolGAN, retrosynthetic prediction with Graph-Logic Network), replacing traditional pooling with GMT improves validity and top- accuracies, respectively.
6. Practical Usage and Extensions
Practical recommendations include:
- For graph classification, replacing standard pooling with GMPool is effective in any GNN pipeline.
- A single GMPool improves graph auto-encoder reconstruction when used as pooling and unpooling.
- Stacking GMPool layers facilitates hierarchical abstraction, with graph coarsening via when memory allows.
- Standard hyperparameters: seeds, –$8$ heads, , and one or two self-attention layers typically suffice.
Future directions proposed include designing specialized GNN for different trade-offs between receptive field size and locality, experimenting with seed initialization, and extending the framework to edge- or subgraph-level pooling using multiple query seeds. GMT's injectiveness, efficiency, and expressive power make it particularly suitable as a default choice for global graph representation learning across diverse domains (Baek et al., 2021).