Published 9 Feb 2025 in cs.CL and cs.AI | (2502.06049v1)
Abstract: This paper introduces the Large Memory Model (LM2), a decoder-only Transformer architecture enhanced with an auxiliary memory module that aims to address the limitations of standard Transformers in multi-step reasoning, relational argumentation, and synthesizing information distributed over long contexts. The proposed LM2 incorporates a memory module that acts as a contextual representation repository, interacting with input tokens via cross attention and updating through gating mechanisms. To preserve the Transformers general-purpose capabilities, LM2 maintains the original information flow while integrating a complementary memory pathway. Experimental results on the BABILong benchmark demonstrate that the LM2model outperforms both the memory-augmented RMT model by 37.1% and the baseline Llama-3.2 model by 86.3% on average across tasks. LM2 exhibits exceptional capabilities in multi-hop inference, numerical reasoning, and large-context question-answering. On the MMLU dataset, it achieves a 5.0% improvement over a pre-trained vanilla model, demonstrating that its memory module does not degrade performance on general tasks. Further, in our analysis, we explore the memory interpretability, effectiveness of memory modules, and test-time behavior. Our findings emphasize the importance of explicit memory in enhancing Transformer architectures.
The paper introduces LM2, a Large Memory Model that enhances Transformer architectures with a dynamic memory module to better handle long-term dependencies in extended contexts.
LM2 demonstrates superior performance on long-context reasoning tasks, including multi-hop inference and numerical reasoning on BABILong, outperforming baselines like RMT and Llama-3.2.
Integrating the memory module not only improves performance on specific long-context tasks but also enhances or maintains performance on general tasks, as shown by experiments on the MMLU dataset.
This paper introduces a new way to improve how LLMs deal with long pieces of text by giving them a better memory.
Background and Relevance
Transformer-based models, like GPT-3 (Generative Pre-trained Transformer 3) and BERT (Bidirectional Encoder Representations from Transformers), have been very successful in many language tasks. However, they struggle when they need to understand and reason about very long contexts. This is because they sometimes have trouble finding the important information in a sea of irrelevant data.
To solve this problem, the researchers created the Large Memory Model (LM2), which adds a special memory module to the standard Transformer architecture. This memory module acts like a separate storage space where the model can keep track of important information it has seen.
Here's how the LM2 works:
Memory Initialization: The memory module starts with a "memory bank," which is like a collection of slots where information can be stored. Each slot is initially set to a neutral state.
The memory bank is represented by M∈RN×d×d, where:
N is the number of memory slots.
d is the hidden dimension of each slot.
R denotes real numbers.
Each memory slot is initialized as an identity matrix: Mr=Id×d, where r∈{1,…,N} and Id×d is the identity matrix.
Memory Information Flow: When the model processes new input, it uses a technique called "cross attention" to compare the input to the information stored in the memory bank. This helps the model find the memory slots that contain the most relevant information.
Input embeddings E act as the query, while the memory bank M serves as both the key and the value store.
The input embeddings E∈RT×d (where T is the sequence length) and memory bank M∈RN×d are projected into query (Q), key (K), and value (V) spaces:
Q=EtWQ,K=MtWK,V=MtWV,
where WQ,WK,WV∈Rd×d are learnable projection matrices, and t stands for decoder block t.
Et is the input embedding at decoder block t
Mt is the memory bank at decoder block t
WQ is the query projection matrix
WK is the key projection matrix
WV is the value projection matrix
* The attention scores are computed as the scaled dot product of the query and key matrices:
A=softmax(dQK⊤),
where A∈RT×N represents the alignment between the input sequence and memory slots.
* Q represents the query matrix.
* K represents the key matrix.
* d represents the hidden dimension of each slot.
* R denotes real numbers.
* T is the sequence length.
* N is the number of memory slots.
* The resultant attention output is
Emem=AV,
where Emem∈RT×d integrates information from the input and memory.
* A represents the alignment between the input sequence and memory slots.
* V represents the value matrix.
* T is the sequence length.
* d is the hidden dimension of each slot.
* R denotes real numbers.
* To control how much the memory influences the model's output, an "output gate" is used. This gate decides how much of the information retrieved from memory should be passed on to the next layer.
gout=σ(EmemWout),
where Wout∈Rd×d is a learnable parameter matrix, and σ is the sigmoid activation function.
* gout is the output gate.
* Emem integrates information from the input and memory.
* Wout is a learnable parameter matrix.
* σ is the sigmoid activation function.
* d is the hidden dimension of each slot.
* R denotes real numbers.
* The gated memory output is then computed as:
Egated=gout⋅Mt.
* Egated is the gated memory output.
* gout is the output gate.
* Mt is the memory bank at decoder block t.
* The gated memory output is integrated into the standard attention flow of the Transformer decoder through a skip connection. Specifically, the output of the self-attention mechanism, Eattn, is combined with the gated memory output as
Enext=Eattn+Egated.
* Enext represents the combined output that is passed to the next decoder layer.
* Eattn is the output of the self-attention mechanism.
* Egated is the gated memory output.
Memory Updates: The memory module also needs to update its contents to store new information and remove irrelevant information. This is done using three "gates":
Input Gate: This gate decides how much of the new input should be written into the memory.
gin=σ(EtWin),
where Win∈Rd×d is a learnable parameter matrix, Et is the current input representation, and σ is the sigmoid activation function.
gin is the input gate
Win is a learnable parameter matrix
Et is the current input representation
σ is the sigmoid activation function.
R denotes real numbers.
* Forget Gate: This gate decides which parts of the existing memory should be erased or forgotten.
gforget=σ(EmemWforget),
where Wforget∈Rd×d.
* gforget is the forget gate
* Wforget is a learnable parameter matrix
* Emem integrates information from the input and memory.
* σ is the sigmoid activation function.
* R denotes real numbers.
* Output Gate: This gate, described earlier, controls how much of the memory content is used to generate the final output.
* The updated memory state is:
Mt+1=gin⋅tanh(Emem)+gforget⋅Mt,
where a tanh function is applied to keep the new memory content bounded.
* Mt+1 is the updated memory state
* gin is the input gate
* Emem integrates information from the input and memory.
* gforget is the forget gate
* tanh is the hyperbolic tangent function
* Mt is the current memory state
Experiments and Results
The researchers tested the LM2 on a dataset called BABILong, which is designed to test how well models can reason about long contexts. The LM2 outperformed other models, including a memory-augmented model called RMT (Recurrent Memory Transformer) and a baseline model called Llama-3.2. The LM2 was better at multi-hop inference (answering questions that require multiple steps of reasoning), numerical reasoning, and question-answering in long contexts.
To make sure that the memory module didn't hurt the model's ability to perform general tasks, the researchers also tested it on the MMLU dataset, which covers a wide range of academic subjects. The LM2 performed better than a standard Llama model, showing that the memory module can actually improve performance on general tasks as well.
Key Findings
The LM2's memory module helps it to better understand and reason about long contexts.
The memory module does not degrade performance on general tasks and can even improve it.
The way the memory module is integrated into the Transformer architecture is important for achieving the best performance.
The memory module stores and retrieves information in a way that is relevant to the task at hand.
The memory module adapts its contents during testing to focus on the most important information.
Contributions
The paper introduces a new memory-augmented Transformer architecture that can capture and use long-term dependencies in data.
The paper proposes a new way to integrate memory into the Transformer architecture, which allows the model to maintain its original capabilities while also benefiting from the memory module.
The paper shows that the LM2 outperforms existing models on long context reasoning tasks.
In summary, this paper presents a promising new approach to improving the ability of LLMs to handle long contexts by giving them a better memory.