Papers
Topics
Authors
Recent
Search
2000 character limit reached

Differential Transformer

Published 7 Oct 2024 in cs.CL and cs.LG | (2410.05258v2)

Abstract: Transformer tends to overallocate attention to irrelevant context. In this work, we introduce Diff Transformer, which amplifies attention to the relevant context while canceling noise. Specifically, the differential attention mechanism calculates attention scores as the difference between two separate softmax attention maps. The subtraction cancels noise, promoting the emergence of sparse attention patterns. Experimental results on language modeling show that Diff Transformer outperforms Transformer in various settings of scaling up model size and training tokens. More intriguingly, it offers notable advantages in practical applications, such as long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. By being less distracted by irrelevant context, Diff Transformer can mitigate hallucination in question answering and text summarization. For in-context learning, Diff Transformer not only enhances accuracy but is also more robust to order permutation, which was considered as a chronic robustness issue. The results position Diff Transformer as a highly effective and promising architecture to advance LLMs.

Summary

  • The paper introduces a novel differential attention mechanism that subtracts dual softmax maps to cancel noise and emphasize relevant context.
  • The approach achieves similar performance to standard Transformers with roughly 65% of the model size or training tokens, offering efficiency gains.
  • Experimental results demonstrate improved language modeling, more accurate key information retrieval, and reduced activation outliers through GroupNorm in multi-head settings.

The paper introduces the Differential Transformer (Diff Transformer), an architecture designed to mitigate the issue of over-allocation of attention to irrelevant context in standard Transformers. The core innovation lies in the differential attention mechanism, which computes attention scores as the difference between two separate softmaxsoftmax attention maps.

The differential attention mechanism involves partitioning the query and key vectors into two groups and computing two separate softmaxsoftmax attention maps. The attention scores are then calculated as the difference between these two maps. This subtraction aims to cancel out noise and promote sparse attention patterns, allowing the model to focus on relevant context.

The mathematical formulation of the differential attention operator DiffAttn()\operatorname{DiffAttn}(\cdot) is given by: [Q1;Q2]=XWQ,[K1;K2]=XWK,V=XWV[Q_1 ; Q_2] = X W^Q ,\quad [K_1 ; K_2] = X W^K ,\quad V = X W^V, DiffAttn(X)=(softmax(Q1K1Td)λ softmax(Q2K2Td))V\operatorname{DiffAttn}(X) = (softmax(\frac{Q_1 K^{T}_1}{\sqrt{d}}) - \lambda~softmax(\frac{Q_2 K^{T}_2}{\sqrt{d}}))V,

where:

  • $X \in \mathbb{R}^{N \times d_{\text{model}$ is the input
  • Q1,Q2,K1,K2RN×dQ_1, Q_2, K_1, K_2 \in \mathbb{R}^{N \times d} are query and key projections
  • VRN×2dV \in \mathbb{R}^{N \times 2d} is the value projection
  • $W^Q, W^K, W^V \in \mathbb{R}^{d_{\text{model} \times 2d}$ are parameter matrices
  • λ\lambda is a learnable scalar

To stabilize learning dynamics, λ\lambda is re-parameterized as: $\lambda = \exp( \mathbf{\lambda_{q_1} \cdot \mathbf{\lambda_{k_1} ) - \exp( \mathbf{\lambda_{q_2} \cdot \mathbf{\lambda_{k_2} ) + \lambda_{\text{init}$,

where:

  • $\mathbf{\lambda_{q_1}, \mathbf{\lambda_{k_1}, \mathbf{\lambda_{q_2}, \mathbf{\lambda_{k_2} \in \mathbb{R}^{d}$ are learnable vectors
  • $\lambda_{\text{init} \in (0,1)$ is a constant for initialization.

In multi-head differential attention, the outputs of individual heads are normalized using GroupNorm()\operatorname{GroupNorm}(\cdot) and scaled by (1λinit)(1 - \lambda_{\text{init}}) to align gradients with the standard Transformer architecture.

Experimental results demonstrate that Diff Transformer outperforms Transformer in various language modeling tasks. Scaling experiments indicate that Diff Transformer requires approximately 65% of the model size or training tokens compared to Transformer to achieve comparable performance.

The paper presents results on downstream tasks, including long-context modeling, key information retrieval, hallucination mitigation, in-context learning, and reduction of activation outliers. Diff Transformer exhibits notable advantages in these practical applications. For instance, in key information retrieval, Diff Transformer shows superior accuracy in retrieving information from long contexts, particularly when the relevant information is located in the first half of the context. The paper also evaluates contextual hallucination in text summarization and question answering, finding that Diff Transformer mitigates hallucination compared to Transformer. For in-context learning, Diff Transformer enhances accuracy and demonstrates greater robustness to order permutations in demonstration examples.

Furthermore, Diff Transformer reduces outliers in model activations, offering potential benefits for quantization. Attention logits and hidden states exhibit lower top activation values compared to Transformer, indicating fewer activation outliers.

Ablation studies validate the design choices of Diff Transformer. Removing GroupNorm degrades performance, highlighting its importance in normalizing diverse statistics between heads. The performance is robust to different initialization strategies for λ\lambda.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 97 tweets with 3646 likes about this paper.

HackerNews

  1. Differential Transformer (558 points, 177 comments)