Papers
Topics
Authors
Recent
Search
2000 character limit reached

Tensor Product Attention Is All You Need

Published 11 Jan 2025 in cs.CL, cs.AI, and cs.LG | (2501.06425v4)

Abstract: Scaling LLMs to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor Product Attention Transformer,(T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on language modeling tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines, including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at the decoding stage enable processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern LLMs. The code is available at https://github.com/tensorgi/T6.

Summary

  • The paper introduces Tensor Product Attention (TPA), a novel mechanism using tensor decomposition to factorize Q, K, and V, achieving over 10x reduction in LLM KV cache size during inference.
  • TPA enables the T6 sequence modeling architecture and integrates seamlessly with RoPE, allowing direct replacement of MHA layers in existing LLMs like LLaMA and Gemma.
  • Experiments show TPA converges faster and achieves lower loss than MHA/GQA baselines, demonstrating improved performance and scalability for LLMs.

The paper introduces Tensor Product Attention (TPA), a novel attention mechanism designed to mitigate the memory overhead associated with key-value (KV) caches in LLMs during inference. The core idea involves factorizing queries ($\Qb$), keys ($\Kb$), and values ($\Vb$) using tensor decompositions, thereby enabling a compact representation of these entities and a significant reduction in KV cache size. The authors introduce the T\,(T6) model architecture based on TPA for sequence modeling.

TPA employs contextual low-rank factorization, where queries, keys, and values are decomposed into contextual low-rank components. This dynamic factorization of activations, as opposed to static weights, constructs low-rank representations that substantially reduce KV cache memory usage. TPA is natively compatible with rotary positional embeddings (RoPE), allowing for a direct replacement of multi-head attention (MHA) layers in existing LLM architectures like LLaMA and Gemma.

The authors summarize their primary contributions as follows:

  • Proposing TPA, a mechanism that factorizes $\Qb$, $\Kb$, and $\Vb$ activations using contextual tensor-decompositions to achieve a 10×10\times or more reduction in inference-time KV cache size relative to standard attention mechanism with improved performance compared to previous methods such as MHA, MQA, GQA, and MLA.
  • Proposing T\,(T6), a new TPA-based model architecture for sequence modeling. On language modeling experiments, T6\ consistently improves validation perplexity and downstream evaluation performance with reduced KV cache size.
  • Showing that TPA integrates seamlessly with RoPE, facilitating easy adoption in popular foundation model architectures such as LLaMA and Gemma.

The paper also provides background on scaled dot-product attention, MHA, multi-query attention (MQA), grouped-query attention (GQA), RoPE, and multi-head latent attention (MLA). Notations used include bold uppercase letters for matrices, bold lowercase for vectors, and italic uppercase for learnable parameter matrices. The tensor product of two vectors $\ab\in\RR^m, \bbb\in \RR^n$ is defined as $\ab\otimes\bbb=\Cb\in \RR^{m\times n}$, with Cij=aibjC_{ij}=a_ib_j, and the vectorization of a matrix $\Cb\in \RR^{m\times n}$ is defined as $\text{vec}(\Cb)=\db\in\RR^{m n}$, with diâ‹…n+j=Cijd_{i\cdot n+j}=C_{ij}. Scaled dot-product attention is given by:

$\operatorname{Attention}(\Qb, \Kb, \Vb) = \operatorname{Softmax}\Bigl(\tfrac{\Qb \Kb^{\top}}{\sqrt{d_k}\Bigr)\,\Vb,$

where $\Qb, \Kb, \Vb \in \RR^{n \times d_k}$.

The MHA computes each head ii for token embedding $\xb_t \in \mathbb{R}^{d_{\text{model}}$ as:

$\Qb_{t,i} = (\bW_i^Q)^{\top} \,\xb_t \in\mathbb{R}^{d_h}, \quad \Kb_{t,i} = (\bW_i^K)^{\top} \,\xb_t \in\mathbb{R}^{d_h}, \quad \Vb_{t,i} = (\bW_i^V)^{\top} \,\xb_t \in\mathbb{R}^{d_h},$

where $\bW_i^Q, \bW_i^K, \bW_i^V \in \mathbb{R}^{d_{\text{model} \times d_h}$ are learnable projection matrices.

MQA shares keys and values across heads, expressed as:

$\Qb_{i} = \Xb\bW^Q_{i}, \quad \Kb_{\text{shared} = \Xb\bW^K_{\text{shared}, \quad \Vb_{\text{shared} = \Xb\bW^V_{\text{shared},$

with $\bW^Q_{i} \in \mathbb{R}^{d_{\text{model} \times d_k}, \quad \bW^K_{\text{shared}, \bW^V_{\text{shared} \;\in\; \mathbb{R}^{\,d_{\text{model} \times d_k}$.

GQA partitions the hh total heads into GG groups, each with a single set of keys and values:

$\Kb_{g(i)} = \Xb\,\bW^K_{g(i)}, \quad \Vb_{g(i)} = \Xb\,\bW^V_{g(i)}, \quad \Qb_{i} = \Xb\,\bW^Q_{i},$

where $\bW^K_{g}, \bW^V_{g} \in \mathbb{R}^{d_{\text{model} \times d_k}$ and $\bW^Q_{i} \in \mathbb{R}^{\,d_{\text{model} \times d_k}$.

RoPE uses a rotation operator $\Tb_t \in \RR^{d_h \times d_h}$ corresponding to the tt-th position, and $\operatorname{RoPE}\left(\Qb_t\right) \triangleq \Qb_t\Tb_t$, where $\Qb_t \in \RR^{h \times d_h}$.

MLA introduces a low-rank compression of the keys and values to reduce the Key-Value (KV) caching cost at inference.

$\mathbf{C}<sup>{KV}</sup> =\mathbf{X}\bW<sup>{DKV},</sup> \quad (\bW<sup>{DKV}</sup> \in \mathbb{R}<sup>{\,</sup> d_{\text{model}\times d_c}),\ \operatorname{Concat}\bigl(\mathbf{K}<em>{1}<sup>{C},\mathbf{K}</sup></em>{2}<sup>{C},\ldots,\mathbf{K}_{h}<sup>{C}\bigr)</sup></sup> =\mathbf{K}<sup>{C}</sup> =\mathbf{C}<sup>{KV}\bW<sup>{UK},</sup></sup> \quad (\bW<sup>{UK}</sup> \in \mathbb{R}<sup>{d_c\times</sup> d_h h}).$In **TPA**, the hidden-state vector$\xb_t \in \mathbb{R}<sup>{d_{\text{model}$for thett-th token in a sequence of lengthTT. **TPA** factorizes each$\Qb_{t},</sup> \Kb_{t}, \Vb_{t}$into a sum of tensor products:$\Qb_{t} = \frac{1}{R_Q} \sum_{r=1}<sup>{R_Q}</sup> \ab<sup>{Q}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{Q}_{r}(\xb_t),</sup> \quad \Kb_{t} = \frac{1}{R_K} \sum_{r=1}<sup>{R_K}</sup> \ab<sup>{K}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{K}_{r}(\xb_t),</sup> \quad \Vb_{t} = \frac{1}{R_V} \sum_{r=1}<sup>{R_V}</sup> \ab<sup>{V}_{r}(\xb_t)</sup> \;\otimes\; \bbb<sup>{V}<em>{r}(\xb_t),$where$\ab<sup>{Q}</sup></em>{r}(\xb_t)</sup> \in \mathbb{R}<sup>h,</sup> \bbb<sup>{Q}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>{d_h},</sup> \ab<sup>{K}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>h,</sup> \bbb<sup>{K}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>{d_h},</sup> \ab<sup>{V}_{r}(\xb_t)</sup> \in \mathbb{R}<sup>h$, and$\bbb<sup>{V}_{r}(\xb_t)</sup></sup> \in \mathbb{R}<sup>{d_h}$.

The latent factor maps are given by:

$\ab<sup>{Q}_{r}(\xb_t)</sup> = \bW<sup>{a<sup>Q}_{r}\,\xb_t</sup></sup> \in \mathbb{R}<sup>h,</sup> \quad \bbb<sup>Q_{r}(\xb_t)</sup> = \bW<sup>{b<sup>Q}_{r}\,\xb_t</sup></sup> \in \mathbb{R}<sup>{d_h}.$After$\Qb,\Kb,\Vb$are factorized, multi-head attention proceeds as in standard Transformers, with:$head_i</sup> = \operatorname{Softmax}\Bigl( \tfrac{1}{\sqrt{d_h} \,\Qb_{i} \, (\Kb_{i})<sup>\top</sup> \Bigr) \;\Vb_{i},$where$\Qb_{i}, \Kb_{i}, \Vb_{i} \in \mathbb{R}<sup>{T</sup> \times d_h}$ are the slices along the head dimension.

For RoPE integration, the paper suggests pre-rotating the token-dimension factors:

$\tilde\Bb_K(\xb_t) \;\longleftarrow\;</p> <h1 class='paper-heading' id='operatorname-rope-_t-bigl-bb_k-xb_t-bigr-40-qb_t-41-operatorname-rope-qb-_t'>\operatorname{RoPE}_t\bigl(\Bb_K(\xb_t)\bigr).$A key theorem states that RoPE's relative translational property is preserved in **TPA**. If$\Qb_t$is factorized by **TPA**, then$\operatorname{RoPE}({\Qb}_t)</h1> <p>\frac{1}{R_Q} \Ab_{Q}(\xb_t)<sup>\top</sup> \,\widetilde{\Bb}<em>{Q}(\xb_t),$where$\widetilde{\Bb}</em>{Q}(\xb_t) = \operatorname{RoPE}<em>t\bigl(\Bb</em>{Q}(\xb_t)\bigr)$.

The memory cost per token in TPA is ( RK+RV ) (h+dh)(\,R_K + R_V\,)\,\bigl(h + d_h\bigr), which can be significantly lower than the standard caching cost of 2 h dh2\,h\,d_h.

The paper demonstrates how MHA, MQA, and GQA can be unified as non-contextual variants of TPA. Specifically, standard MHA can be viewed as a specific instance of TPA in which: 1) the rank is set equal to the number of heads; 2) the head dimension factor is non-contextual; 3) the token dimension factor is a linear function of $\xb_t$.

In MQA, all heads share a single set of keys/values, corresponding to RK=RV=1R_K = R_V = 1 along the head dimension, while GQA partitions hh heads into GG groups, each sharing keys/values within that group.

The T\,(T6) architecture, which utilizes TPA in place of standard MHA or GQA, is also detailed. The feed-forward network (FFN) adopts a SwiGLU layer, and RoPE is applied to the $\Qb$ and $\Kb$.

Experiments were conducted on the FineWeb-Edu 100B dataset, comparing T6 against the baseline Llama architecture with SwiGLU activation and RoPE embeddings, as well as Llama variants that replace MHA with MQA, GQA, or MLA. Models were trained at small (124M parameters), medium (353M), and large (773M) scales using the AdamW optimizer.

Results indicate that TPA and its variant TPA-KVonly converge as fast as or faster than the baselines while achieving visibly lower final losses. In downstream evaluations on standard benchmarks, TPA generally ties or outperforms all competing methods.

The paper concludes that TPA offers a flexible, memory-efficient alternative to standard multi-head attention, advancing the scalability of modern LLMs.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 18 tweets with 806 likes about this paper.

HackerNews

  1. Tensor Product Attention Is All You Need (160 points, 103 comments) 

Reddit