Sinkhorn-Routed Encoder
- The paper shows that Sinkhorn-Routed Encoder achieves quasi-global receptive fields through data-driven block sorting, significantly reducing memory and compute requirements.
- It employs a differentiable Sinkhorn operator to approximate permutation matrices, enabling efficient local attention and effective sequence truncation via SortCut.
- Empirical results demonstrate improved performance on language modeling and image generation tasks while maintaining competitive accuracy compared to full self-attention.
A Sinkhorn-Routed Encoder, also known as Sparse Sinkhorn Attention, is a memory- and compute-efficient attention architecture for Transformer models that leverages differentiable sorting via a Sinkhorn operator to route attention computation through dynamically re-ordered blocks of input sequences. The data-driven block sorting enables quasi-global receptive fields with local attention mechanisms, substantially reducing the memory and computation requirements compared to standard full self-attention while retaining competitive accuracy on tasks such as language modeling, sequence-to-sequence sorting, image generation, and textual entailment (Tay et al., 2020).
1. Architecture Overview
The Sinkhorn-Routed Encoder operates within a modified Transformer encoder block. An input sequence is partitioned into contiguous blocks, each of size . A meta-sorting network (SortNet) summarizes each block (e.g., via sum-pooling or first-token selection) and scores block-level relationships through a small MLP, producing a score matrix . This matrix is transformed into a doubly-stochastic matrix via the differentiable Sinkhorn operator, approximating a permutation matrix.
The (soft) permutation is used to re-order ("sort") the sequence blocks: . The sorted sequence is refolded to length , and standard block-local scaled dot-product attention is independently applied within each sorted block. Optionally, the SortCut operation truncates the sequence to the top blocks by importance after sorting, and a mixture with standard full attention over may be included for enhanced expressivity.
2. Differentiable Sorting and Sinkhorn Operator
Differentiable sorting is accomplished using the Sinkhorn operator , which seeks a matrix in the Birkhoff polytope that minimizes the cost:
where is the score matrix, and is an entropic regularization parameter. In practice, an iterative Sinkhorn-Knopp row/column normalization is performed on , where is optional Gumbel noise and is a temperature parameter. Each iteration alternates between row and column normalization:
K iterations are typically sufficient for a high-fidelity approximation to a permutation. For numerical stability, the operations are often carried out in log-domain.
Causal Sinkhorn Balancing ensures autoregressive property by masking future blocks when used for decoding: during column normalization, each row only attends to columns , implemented via masking and adjusted normalization.
3. Attention Routing and SortCut Truncation
After block sorting, each query token attends only to key tokens within its block in the new sorted order, i.e., , and attention is computed locally:
with the output .
The SortCut scheme further improves efficiency by selecting, after sorting, the top- most important blocks and discarding the remainder, resulting in attention complexity of for constant , versus (block-local) or (full). The importance ranking is induced by the learned sorting, with attention only over the truncated block set.
4. Computational Complexity and Memory Usage
The Sinkhorn-Routed Encoder achieves significant reductions in time and memory complexity compared to vanilla Transformers:
| Approach | Time / Memory Complexity | Principal Parameters |
|---|---|---|
| Full Attention | Sequence length | |
| Block-local | Block size | |
| Sinkhorn Attention | Num. blocks , KIter | |
| Sinkhorn+SortCut | Truncation budget |
A practical example with tokens and blocks yields and approximately less memory usage versus full attention. The sorting network and Sinkhorn normalization scale with ; choosing balances terms for optimal complexity (Tay et al., 2020).
5. Training and Optimization
The Sinkhorn-Routed Encoder is trained end-to-end with conventional, task-dependent primary loss functions (e.g., cross-entropy for language modeling), without auxiliary objectives for sorting. The Gumbel-Sinkhorn reparameterization enables differentiable sampling of approximate permutations, maintaining gradients for backpropagation. Automatic differentiation is employed in the Sinkhorn loop, commonly in the log domain for stability. Gradient clipping (with norm-1 or norm-5) is applied at the whole-Transformer and SortNet levels. The Adam optimizer is standard. Optimal hyperparameters include a sorting network of depth one (single linear layer), temperature , and 5–10 Sinkhorn iterations. Too many iterations or a non-causal Sinkhorn on decoders degrades performance.
6. Empirical Performance
Benchmarks show that the Sinkhorn-Routed Encoder matches or surpasses both full and sparse Transformer variants on diverse tasks. On sequence-to-sequence sorting, the Sinkhorn model () attains lower edit distance (0.4054) and higher exact match (49.2%) than a Sparse Transformer. For language modeling (LM1B, base model, 50M params), perplexity improves to 40.79 (Sinkhorn) from 41.57 (Transformer), further reduced to 40.11 in the mixture model. On larger word-level LM1B (430M params), Sinkhorn achieves 28.39 (mixture: 27.34) versus 27.59 (Transformer). Results on char-level LM1B, pixel-wise CIFAR-10 image generation, and document classification with SortCut encoders (IMDb, SNLI) demonstrate competitive accuracy and efficiency. Ablation studies indicate that disabling the Sinkhorn permutation (i.e., ) severely degrades performance (LM1B PPL rises to 52.4 from 40.8) and validate the necessity of causal Sinkhorn for decoders (Tay et al., 2020).
7. Significance and Related Directions
The Sinkhorn-Routed Encoder introduces a paradigm of data-dependent, differentiable sequence reordering to unlock quasi-global attention with only local computation. By learning the permutation via end-to-end training and employing a principled Sinkhorn relaxation, the model combines the benefits of global receptive fields and scalability. This method relates to broader trends in efficient attention, sparse and block-based attention methods, and neural sorting. Innovations such as Causal Sinkhorn Balancing and SortCut truncation further extend its utility in both encoding and decoding contexts. The approach has become a reference point in subsequent work on learnable routing and efficient Transformer architectures (Tay et al., 2020).