Papers
Topics
Authors
Recent
Search
2000 character limit reached

Reducing Transformer Key-Value Cache Size with Cross-Layer Attention

Published 21 May 2024 in cs.LG and cs.CL | (2405.12981v1)

Abstract: Key-value (KV) caching plays an essential role in accelerating decoding for transformer-based autoregressive LLMs. However, the amount of memory required to store the KV cache can become prohibitive at long sequence lengths and large batch sizes. Since the invention of the transformer, two of the most effective interventions discovered for reducing the size of the KV cache have been Multi-Query Attention (MQA) and its generalization, Grouped-Query Attention (GQA). MQA and GQA both modify the design of the attention block so that multiple query heads can share a single key/value head, reducing the number of distinct key/value heads by a large factor while only minimally degrading accuracy. In this paper, we show that it is possible to take Multi-Query Attention a step further by also sharing key and value heads between adjacent layers, yielding a new attention design we call Cross-Layer Attention (CLA). With CLA, we find that it is possible to reduce the size of the KV cache by another 2x while maintaining nearly the same accuracy as unmodified MQA. In experiments training 1B- and 3B-parameter models from scratch, we demonstrate that CLA provides a Pareto improvement over the memory/accuracy tradeoffs which are possible with traditional MQA, enabling inference with longer sequence lengths and larger batch sizes than would otherwise be possible

Citations (32)

Summary

  • The paper proposes Cross-Layer Attention (CLA), which shares key-value activations across layers to achieve a 2× reduction in memory usage.
  • It demonstrates through experiments on 1B and 3B scale models that CLA integrated with MQA maintains accuracy while considerably reducing the KV cache size.
  • The findings offer practical implications for deploying memory-efficient transformers in resource-constrained environments.

Reducing Transformer Key-Value Cache Size with Cross-Layer Attention

Introduction

The study presented in the paper "Reducing Transformer Key-Value Cache Size with Cross-Layer Attention" addresses a significant challenge in the application of autoregressive LLMs: managing the memory footprint of the key-value (KV) cache during decoding. This component is crucial for efficient model performance, especially at long sequence lengths and large batch sizes, where memory demands can become prohibitive. Traditional approaches like Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) have laid the groundwork for reducing memory use by sharing key/value heads across query heads within an attention layer. This paper proposes an advancement by introducing Cross-Layer Attention (CLA), which extends this concept of sharing across adjacent layers, thus aiming to achieve further reductions in memory usage while maintaining accuracy.

Cross-Layer Attention Overview

Cross-Layer Attention (CLA) innovatively modifies the transformer architecture by allowing KV activations to be shared among layers rather than just within layers, as was the case with MQA. This sharing manifests in reduced unique KV heads per token, providing significant memory savings. The CLA mechanism is particularly compatible with traditional transformers and can be integrated with existing MQA or GQA architectures. Figure 1

Figure 1: Schematic of two consecutive layers in a transformer using a traditional attention design and Cross-Layer Attention.

Figure 2

Figure 2: Schematic of KV cache structures under different attention configurations in a 10-layer transformer.

Experimental Results

Extensive experimentation at 1B and 3B parameter scales demonstrates that CLA achieves notable improvements in the memory/accuracy Pareto frontier, reducing the KV cache size considerably. Specifically, the integration of CLA with MQA facilitates a 2×2\times reduction in storage compared to MQA alone, without significant accuracy loss. This advantage becomes particularly concrete when applied at varying scales and learning rates. Figure 3

Figure 3: The accuracy/memory Pareto frontier for models with CLA (red) and without CLA (blue). Lower is better on both axes.

Theoretical and Practical Implications

The introduction of CLA holds practical implications for deploying memory-efficient LLMs capable of handling larger batch sizes and longer sequence lengths. It suggests a viable path for optimizing transformer models in resource-constrained environments, especially where memory bandwidth and latency are critical concerns. Theoretically, CLA broadens the scope for future research into attention mechanisms that balance computational efficiency with model performance.

Conclusion

Cross-Layer Attention represents a crucial step forward in the evolution of efficient transformer architectures, providing significant memory savings with minimal impact on model accuracy. By advancing the memory/accuracy tradeoffs of transformer models, CLA offers a promising direction for optimizing large-scale LLMs, potentially enabling more capable and efficient AI systems in the future.

This study presents a detailed analysis and experimentation demonstrating CLA's potential for enhancing transformer models' efficiency, marking its relevance for both theoretical advancements and practical applications in AI deployments.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 17 tweets with 1089 likes about this paper.