Papers
Topics
Authors
Recent
Search
2000 character limit reached

Causal Sinkhorn Balancing in Transformers

Updated 7 February 2026
  • Causal Sinkhorn Balancing is a technique that adapts the Sinkhorn normalization to enforce lower-triangular, causal constraints for autoregressive models.
  • It employs masked row and column normalization to restrict computations to present and past blocks, enabling efficient quasi-global attention.
  • Its integration within Sparse Sinkhorn Attention improves memory efficiency and model performance, as demonstrated by reduced perplexity in language modeling benchmarks.

Causal Sinkhorn Balancing is a modification of the Sinkhorn–Knopp normalization procedure that enables memory-efficient and autoregressive sorting of sequence blocks within the Sparse Sinkhorn Attention framework. By enforcing causality constraints during normalization, it ensures that permutations only leverage information from present and past blocks—thus preventing any “peeking into the future”—and maintains strict lower-triangular structure in expectation over the soft permutation matrix. This causal doubly stochastic matrix is central to enabling quasi-global attention with local computational patterns in sequence models, especially within autoregressive transformer decoders (Tay et al., 2020).

1. Motivation and Context

Sparse Sinkhorn Attention uses a small, block-wise, doubly-stochastic matrix SS to permute or “softly sort” BB blocks (of size B=/B\ell_B = \ell/B) of a full sequence before the application of local attention mechanisms. This strategy improves attention memory efficiency over vanilla self-attention, allowing each token indirect access to a broader context while retaining feasible computational complexity.

In autoregressive transformer architectures, enforcing causality in the sorting stage is critical: only past and present inputs can impact predictions for the current timestep. Standard Sinkhorn normalization, which alternates row and column normalization on exp(R)\exp(R) (where RR encodes sorting scores), inherently mixes information from the entire sequence—allowing future (unavailable) blocks to contribute to the normalization constants.

Causal Sinkhorn Balancing remedies this by restricting the computation at each step of the iterative normalization process to only currently available (past and present) blocks. Concretely, this enforces that SS is lower-triangular (in expectation), guaranteeing that block ii can only be mapped to positions i\leq i. This causal masking is essential for correct autoregressive decoding in block-reordered transformers (Tay et al., 2020).

2. Mathematical Formulation and Notation

Let \ell denote the token sequence length, BB the number of blocks, and XR×dX \in \mathbb{R}^{\ell \times d} the token embeddings. The block-pooling function ψP(X)RB×d\psi_P(X) \in \mathbb{R}^{B \times d} produces a pooled representation for each block, which is then processed by a feedforward scoring network PP to form RRB×BR \in \mathbb{R}^{B \times B} of sorting scores, where Rij0R_{ij} \geq 0.

The iterative normalization proceeds as follows:

  • S(0)=exp(R/τ+G)S^{(0)} = \exp(R/\tau + G), with temperature parameter τ\tau and GG i.i.d. Gumbel noise.
  • At each iteration k=1Kk=1\dots K:
    • S(k)=Fc(Fr(S(k1)))S^{(k)} = F_c(F_r(S^{(k-1)}))
    • Standard row normalization: Fr(X)ij=Xij/jXijF_r(X)_{ij} = X_{ij} / \sum_{j'} X_{i j'}
    • Standard column normalization: Fc(X)ij=Xij/iXijF_c(X)_{ij} = X_{ij} / \sum_{i'} X_{i' j}
  • For numerical stability, normalization is often performed in the log domain.

For Causal Sinkhorn Balancing, a mask M{0,1}B×BM \in \{0,1\}^{B \times B} is used, where Mij=1M_{ij} = 1 if jij \leq i, and $0$ otherwise. The normalization updates become:

  • Masked row normalization:

Frc(X)=Xlog(1BT[Mexp(X)])TF_r^c(X) = X - \log\left(1_B^T [M \odot \exp(X)]\right)^T

  • Masked column normalization:

Fcc(X)=X(1Blog([Mexp(X)]1B))TF_c^c(X) = X - \left(1_B \log([M \odot \exp(X)] 1_B)\right)^T

where \odot denotes element-wise product and 1B1_B is the length-BB all-ones vector.

Row normalization only sums over columns jij \leq i, and column normalization only sums over rows iji \geq j, ensuring that the resulting SS has lower-triangular support. The Sinkhorn iterations project exp(R/τ)\exp(R/\tau) into the intersection of the Birkhoff polytope and the lower-triangular cone defined by MM.

3. Algorithmic Details

The causal Sinkhorn algorithm proceeds as follows:

  1. Initialization: In the log domain, set LR/τ+L \leftarrow R/\tau + GumbelNoise().
  2. Iterative Normalization (for kk in 1K1 \dots K):
    • Row normalization, masked:
      • For each ii:
      • aij=1iexp(Lij)a_i \leftarrow \sum_{j=1}^i \exp(L_{ij})
      • For jj in 1B1 \dots B:
        • If jij \leq i, set LijLijlogaiL_{ij} \leftarrow L_{ij} - \log a_i
        • Else, set LijL_{ij} \leftarrow -\infty (zero out future mass)
    • Column normalization, masked:
      • For each jj:
      • bji=jBexp(Lij)b_j \leftarrow \sum_{i=j}^B \exp(L_{ij})
      • For ii in 1B1 \dots B:
        • If iji \geq j, set LijLijlogbjL_{ij} \leftarrow L_{ij} - \log b_j
        • Else, set LijL_{ij} \leftarrow -\infty
  3. Finalization:
    • Sexp(L)S \leftarrow \exp(L)

Convergence is achieved when all row and column sums—computed under MM—are within a small tolerance ϵ\epsilon of $1$.

4. Complexity and Theoretical Properties

The computation cost per iteration is O(B2)O(B^2), arising from evaluation of masked row and column sums and updates. Therefore, total runtime is O(KB2)O(K \cdot B^2) over KK iterations. Memory demand is O(B2)O(B^2) for SS, and O(d)O(\ell \cdot d) for the sequence itself.

Convergence follows from the Sinkhorn–Knopp theorem, provided the initial matrix exp(R/τ)\exp(R/\tau) has support on the masked diagonal (Mii=1M_{ii}=1). Causal masking restricts support to the lower-triangular region, but convergence to the unique matrix in the causally-masked Birkhoff polytope still holds under mild positivity assumptions.

5. Integration within Sparse Sinkhorn Attention

After KK causal Sinkhorn iterations, the resulting SRB×BS \in \mathbb{R}^{B \times B} softly permutes the BB blocks of XX. The sorted representation is X=SB(X)X' = S B(X), where B(X)RB×(Bd)B(X) \in \mathbb{R}^{B \times (\ell_B \cdot d)} reassembles tokens into blockwise sequences. Local attention is then applied within each block, operating on tokens that are now quasi-globally reordered.

During decoding, SS is recomputed at each timestep using cumulative-sum pooling ψP\psi_P to ensure R(i)R(i) depends only on blocks i\leq i, preserving causality. For encoding, blockwise sum-pooling is used. In all configurations, QQ, KK, and VV are block-permuted by SS prior to attention.

6. Empirical Performance and Observed Effects

Empirical ablation shows that omitting Sinkhorn normalization entirely (K=0K=0) reduces performance significantly, raising perplexity by $10$–$11$ on LM1B (see Table 9 in (Tay et al., 2020)). The computational and memory overhead of causal Sinkhorn is low: each layer incurs only O(KB2)O(K \cdot B^2) additional cost on top of standard local attention (which has O(B)O(\ell \cdot \ell_B) complexity). With BB \ll \ell and K5K \approx 5–$10$, this overhead is minor.

Compared to vanilla Transformer and Sparse Transformer baselines, applying causal Sinkhorn in the decoder enables matching or superior accuracy on language modeling, sorting, and generation benchmarks, while operating with a much reduced attention-memory footprint (total O(B2+(/B)2)O(B^2 + (\ell/B)^2) vs. O(2)O(\ell^2)). Guidance for optimal operation includes setting τ0.5\tau \approx 0.5–$0.75$ and K5K \approx 5–$10$; excessively low temperature (τ0\tau \to 0) or high iteration counts (K>20K > 20) can moderately degrade perplexity (cf. Figures 4–5 in (Tay et al., 2020)).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Causal Sinkhorn Balancing.