- The paper introduces minLSTM and minGRU, simplified RNNs that remove state dependencies in gates to allow full parallel training.
- The methodology eliminates non-linear range restrictions and leverages parallel prefix scan, achieving up to 1361× faster training on long sequences.
- Empirical results demonstrate competitive performance in language modeling, selective copying, and reinforcement learning despite an 88% increase in memory usage.
The paper simplifies traditional Recurrent Neural Networks (RNNs), specifically Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs), to create minimal versions named minLSTMs and minGRUs that can be parallelized during training. These simplified models use fewer parameters than their traditional counterparts while achieving competitive performance on various sequence modeling tasks.
The paper begins by revisiting the landscape of sequence modeling, dominated by RNNs for two decades before the advent of Transformers. The inherent sequential nature of RNNs limited parallelization, making them computationally inefficient for long sequences. Transformers, introduced in 2017, enabled parallel training through self-attention but suffer from quadratic computational complexity with respect to sequence length. This limitation has sparked renewed interest in parallelizable recurrent models that scale more efficiently. Recent methods such as state-space models, linearized attention, and linear recurrent neural networks have shown promise in addressing these scalability issues.
The authors focus on LSTMs and GRUs as early examples of input-dependent recurrent models. By removing dependencies on previous states from the gates of these models, they enable parallel training. Further simplification leads to minLSTMs and minGRUs, which use fewer parameters, are fully parallelizable, and achieve competitive performance. Implementations of minGRU and minLSTM are provided in plain PyTorch.
The paper reviews traditional RNNs, highlighting their suitability for tasks involving sequential data but also noting the challenges related to vanishing and exploding gradients. The LSTM is defined by the following equations:
- Γo=σ(Wioxt+Uhoht−1+bo)
- Γo is the output gate.
- σ is the sigmoid function.
- Wio is the weight matrix for the input xt.
- Uho is the weight matrix for the previous hidden state ht−1.
- bo is the bias vector for the output gate.
- Γf=σ(Wifxt+Uhfht−1+bf)
- Γf is the forget gate.
- Wif is the weight matrix for the input xt.
- Uhf is the weight matrix for the previous hidden state ht−1.
- bf is the bias vector for the forget gate.
- Γi=σ(Wiixt+Uhiht−1+bi)
- Γi is the input gate.
- Wii is the weight matrix for the input xt.
- Uhi is the weight matrix for the previous hidden state ht−1.
- bi is the bias vector for the input gate.
- C~t=tanh(Wicxt+Uhcht−1+bc)
- C~t is the candidate cell state.
- tanh is the hyperbolic tangent function.
- Wic is the weight matrix for the input xt.
- Uhc is the weight matrix for the previous hidden state ht−1.
- bc is the bias vector for the candidate cell state.
- Ct=Γf⊙Ct−1+Γi⊙C~t
- Ct is the cell state.
- ⊙ denotes element-wise multiplication.
- ht=Γo⊙tanh(Ct)
- ht is the hidden state.
The GRU, a simplification of the LSTM, uses only two gates and a single state. The GRU equations are:
- Γu=σ(Wiuxt+Uhuht−1+bu)
- Γu is the update gate.
- Wiu is the weight matrix for the input xt.
- Uhu is the weight matrix for the previous hidden state ht−1.
- bu is the bias vector for the update gate.
- Γr=σ(Wirxt+Uhrht−1+br)
- Γr is the reset gate.
- Wir is the weight matrix for the input xt.
- Uhr is the weight matrix for the previous hidden state ht−1.
- br is the bias vector for the reset gate.
- h~t=tanh(Wihxt+Uhh(Γr⊙ht−1)+bh)
- h~t is the candidate hidden state.
- Wih is the weight matrix for the input xt.
- Uhh is the weight matrix for the previous hidden state ht−1.
- bh is the bias vector for the candidate hidden state.
- ht=(1−Γu)⊙ht−1+Γu⊙h~t
The paper reviews the parallel prefix scan algorithm and its application in computing recurrence relations of the form vt=atvt−1+bt.
The methodology section details the simplification of GRUs and LSTMs to enable training via parallel scan. For the minGRU, the first step involves dropping previous state dependencies from the gates. The GRU's hidden state recurrence is ht=(1−zt)⊙ht−1+zt⊙h~t, where zt=σ(Wizxt+Uhzht−1+bz) and h~t=tanh(Wihxt+Uhh(rt⊙ht−1)+bh). To enable parallel scan, the dependencies on ht−1 are removed, simplifying the equations to zt=σ(Wizxt+bz) and h~t=tanh(Wihxt+bh).
The second step involves dropping the range restriction of candidate states. The hyperbolic tangent function, which restricts the range of hidden states, is removed, further simplifying the model. The resulting minGRU equations are ht=(1−zt)⊙ht−1+zt⊙h~t, zt=σ(Wizxt+bz), and h~t=Wihxt+bh. The minGRU requires O(2dhdx) parameters, compared to GRU's O(3dh(dx+dh)).
For the minLSTM, the first step involves dropping previous state dependencies from the gates. Similar to the minGRU, the hidden state dependencies are removed from the input, forget, and candidate cell state equations. The second step involves dropping the range restriction of candidate states, removing the hyperbolic tangent function. The third step involves simplifying the scaling of the output by dropping the output gate. The resulting minLSTM equations are Ct=ft⊙Ct−1+it⊙C~t, ft=σ(Wifxt+bf), it=σ(Wiixt+bi), and C~t=Wicxt+bc. The minLSTM requires O(3dhdx) parameters compared to LSTM's O(4dh(dx+dh)).
The paper then presents empirical results, comparing the minimal versions with their traditional counterparts and modern sequence models. The runtime for sequence lengths of $512$ for minLSTM, minGRU, and Mamba were $2.97$, $2.72$, and $2.71$ milliseconds respectively. For a sequence with length $4096$, the runtime were $3.41$, $3.25$, and $3.15$ respectively. For a sequence length of $512$, minGRUs and minLSTMs were 175× and 235× faster per training step than GRUs and LSTMs on a T4 GPU. The improvement is even more significant as sequences grow in length with minGRUs and minLSTMs being 1324× and 1361× faster for a sequence length of $4096$. The minimal variants use ∼88% more memory compared to their traditional counterparts. minLSTM and minGRU are able to solve the Selective Copying task, achieving performance comparable to S6 and surpassing other modern baselines. In reinforcement learning tasks, minLSTM and minGRU outperform Decision S4 and achieve performance competitive with Decision Transformer, Aaren, and Mamba. In language modeling tasks, minGRU, minLSTM, Mamba, and Transformers achieved comparable test losses of $1.548$, $1.555$, $1.575$, and $1.547$ respectively.
The related work section provides an overview of recent efficient recurrent sequence models, categorized into deep state-space models, recurrent versions of attention, and parallelizable RNNs.
The paper concludes by highlighting the parallel training enabled by removing gate dependencies on previous states. The minimal versions offer fewer parameters, full parallelizability, and competitive performance. The authors suggest a reevaluation of simpler foundational models like LSTM and GRU.
The limitations section acknowledges the hardware constraints that impacted the scale of the experiments, including the use of older GPUs with limited memory. Gradient accumulation was used to accommodate memory limitations, reducing the effective batch size and slowing down training.