Papers
Topics
Authors
Recent
Search
2000 character limit reached

Were RNNs All We Needed?

Published 2 Oct 2024 in cs.LG and cs.AI | (2410.01201v3)

Abstract: The introduction of Transformers in 2017 reshaped the landscape of deep learning. Originally proposed for sequence modelling, Transformers have since achieved widespread success across various domains. However, the scalability limitations of Transformers - particularly with respect to sequence length - have sparked renewed interest in novel recurrent models that are parallelizable during training, offer comparable performance, and scale more effectively. In this work, we revisit sequence modelling from a historical perspective, focusing on Recurrent Neural Networks (RNNs), which dominated the field for two decades before the rise of Transformers. Specifically, we examine LSTMs (1997) and GRUs (2014). We demonstrate that by simplifying these models, we can derive minimal versions (minLSTMs and minGRUs) that (1) use fewer parameters than their traditional counterparts, (2) are fully parallelizable during training, and (3) achieve surprisingly competitive performance on a range of tasks, rivalling recent models including Transformers.

Citations (2)

Summary

  • The paper introduces minLSTM and minGRU, simplified RNNs that remove state dependencies in gates to allow full parallel training.
  • The methodology eliminates non-linear range restrictions and leverages parallel prefix scan, achieving up to 1361× faster training on long sequences.
  • Empirical results demonstrate competitive performance in language modeling, selective copying, and reinforcement learning despite an 88% increase in memory usage.

The paper simplifies traditional Recurrent Neural Networks (RNNs), specifically Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs), to create minimal versions named minLSTMs and minGRUs that can be parallelized during training. These simplified models use fewer parameters than their traditional counterparts while achieving competitive performance on various sequence modeling tasks.

The paper begins by revisiting the landscape of sequence modeling, dominated by RNNs for two decades before the advent of Transformers. The inherent sequential nature of RNNs limited parallelization, making them computationally inefficient for long sequences. Transformers, introduced in 2017, enabled parallel training through self-attention but suffer from quadratic computational complexity with respect to sequence length. This limitation has sparked renewed interest in parallelizable recurrent models that scale more efficiently. Recent methods such as state-space models, linearized attention, and linear recurrent neural networks have shown promise in addressing these scalability issues.

The authors focus on LSTMs and GRUs as early examples of input-dependent recurrent models. By removing dependencies on previous states from the gates of these models, they enable parallel training. Further simplification leads to minLSTMs and minGRUs, which use fewer parameters, are fully parallelizable, and achieve competitive performance. Implementations of minGRU and minLSTM are provided in plain PyTorch.

The paper reviews traditional RNNs, highlighting their suitability for tasks involving sequential data but also noting the challenges related to vanishing and exploding gradients. The LSTM is defined by the following equations:

  • Γo=σ(Wioxt+Uhoht1+bo)\Gamma_o = \sigma(W_{io}x_t + U_{ho}h_{t-1} + b_o)
    • Γo\Gamma_o is the output gate.
    • σ\sigma is the sigmoid function.
    • WioW_{io} is the weight matrix for the input xtx_t.
    • UhoU_{ho} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • bob_o is the bias vector for the output gate.
  • Γf=σ(Wifxt+Uhfht1+bf)\Gamma_f = \sigma(W_{if}x_t + U_{hf}h_{t-1} + b_f)
    • Γf\Gamma_f is the forget gate.
    • WifW_{if} is the weight matrix for the input xtx_t.
    • UhfU_{hf} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • bfb_f is the bias vector for the forget gate.
  • Γi=σ(Wiixt+Uhiht1+bi)\Gamma_i = \sigma(W_{ii}x_t + U_{hi}h_{t-1} + b_i)
    • Γi\Gamma_i is the input gate.
    • WiiW_{ii} is the weight matrix for the input xtx_t.
    • UhiU_{hi} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • bib_i is the bias vector for the input gate.
  • C~t=tanh(Wicxt+Uhcht1+bc)\tilde{C}_t = \tanh(W_{ic}x_t + U_{hc}h_{t-1} + b_c)
    • C~t\tilde{C}_t is the candidate cell state.
    • tanh\tanh is the hyperbolic tangent function.
    • WicW_{ic} is the weight matrix for the input xtx_t.
    • UhcU_{hc} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • bcb_c is the bias vector for the candidate cell state.
  • Ct=ΓfCt1+ΓiC~tC_t = \Gamma_f \odot C_{t-1} + \Gamma_i \odot \tilde{C}_t
    • CtC_t is the cell state.
    • \odot denotes element-wise multiplication.
  • ht=Γotanh(Ct)h_t = \Gamma_o \odot \tanh(C_t)
    • hth_t is the hidden state.

The GRU, a simplification of the LSTM, uses only two gates and a single state. The GRU equations are:

  • Γu=σ(Wiuxt+Uhuht1+bu)\Gamma_u = \sigma(W_{iu}x_t + U_{hu}h_{t-1} + b_u)
    • Γu\Gamma_u is the update gate.
    • WiuW_{iu} is the weight matrix for the input xtx_t.
    • UhuU_{hu} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • bub_u is the bias vector for the update gate.
  • Γr=σ(Wirxt+Uhrht1+br)\Gamma_r = \sigma(W_{ir}x_t + U_{hr}h_{t-1} + b_r)
    • Γr\Gamma_r is the reset gate.
    • WirW_{ir} is the weight matrix for the input xtx_t.
    • UhrU_{hr} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • brb_r is the bias vector for the reset gate.
  • h~t=tanh(Wihxt+Uhh(Γrht1)+bh)\tilde{h}_t = \tanh(W_{ih}x_t + U_{hh}(\Gamma_r \odot h_{t-1}) + b_h)
    • h~t\tilde{h}_t is the candidate hidden state.
    • WihW_{ih} is the weight matrix for the input xtx_t.
    • UhhU_{hh} is the weight matrix for the previous hidden state ht1h_{t-1}.
    • bhb_h is the bias vector for the candidate hidden state.
  • ht=(1Γu)ht1+Γuh~th_t = (1 - \Gamma_u) \odot h_{t-1} + \Gamma_u \odot \tilde{h}_t

The paper reviews the parallel prefix scan algorithm and its application in computing recurrence relations of the form vt=atvt1+btv_t = a_t v_{t-1} + b_t.

The methodology section details the simplification of GRUs and LSTMs to enable training via parallel scan. For the minGRU, the first step involves dropping previous state dependencies from the gates. The GRU's hidden state recurrence is ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t, where zt=σ(Wizxt+Uhzht1+bz)z_t = \sigma(W_{iz}x_t + U_{hz}h_{t-1} + b_z) and h~t=tanh(Wihxt+Uhh(rtht1)+bh)\tilde{h}_t = \tanh(W_{ih}x_t + U_{hh}(r_t \odot h_{t-1}) + b_h). To enable parallel scan, the dependencies on ht1h_{t-1} are removed, simplifying the equations to zt=σ(Wizxt+bz)z_t = \sigma(W_{iz}x_t + b_z) and h~t=tanh(Wihxt+bh)\tilde{h}_t = \tanh(W_{ih}x_t + b_h).

The second step involves dropping the range restriction of candidate states. The hyperbolic tangent function, which restricts the range of hidden states, is removed, further simplifying the model. The resulting minGRU equations are ht=(1zt)ht1+zth~th_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t, zt=σ(Wizxt+bz)z_t = \sigma(W_{iz}x_t + b_z), and h~t=Wihxt+bh\tilde{h}_t = W_{ih}x_t + b_h. The minGRU requires O(2dhdx)O(2d_h d_x) parameters, compared to GRU's O(3dh(dx+dh))O(3d_h (d_x + d_h)).

For the minLSTM, the first step involves dropping previous state dependencies from the gates. Similar to the minGRU, the hidden state dependencies are removed from the input, forget, and candidate cell state equations. The second step involves dropping the range restriction of candidate states, removing the hyperbolic tangent function. The third step involves simplifying the scaling of the output by dropping the output gate. The resulting minLSTM equations are Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t, ft=σ(Wifxt+bf)f_t = \sigma(W_{if}x_t + b_f), it=σ(Wiixt+bi)i_t = \sigma(W_{ii}x_t + b_i), and C~t=Wicxt+bc\tilde{C}_t = W_{ic}x_t + b_c. The minLSTM requires O(3dhdx)O(3d_h d_x) parameters compared to LSTM's O(4dh(dx+dh))O(4d_h (d_x + d_h)).

The paper then presents empirical results, comparing the minimal versions with their traditional counterparts and modern sequence models. The runtime for sequence lengths of $512$ for minLSTM, minGRU, and Mamba were $2.97$, $2.72$, and $2.71$ milliseconds respectively. For a sequence with length $4096$, the runtime were $3.41$, $3.25$, and $3.15$ respectively. For a sequence length of $512$, minGRUs and minLSTMs were 175×175 \times and 235×235 \times faster per training step than GRUs and LSTMs on a T4 GPU. The improvement is even more significant as sequences grow in length with minGRUs and minLSTMs being 1324×1324 \times and 1361×1361 \times faster for a sequence length of $4096$. The minimal variants use 88%\sim 88\% more memory compared to their traditional counterparts. minLSTM and minGRU are able to solve the Selective Copying task, achieving performance comparable to S6 and surpassing other modern baselines. In reinforcement learning tasks, minLSTM and minGRU outperform Decision S4 and achieve performance competitive with Decision Transformer, Aaren, and Mamba. In language modeling tasks, minGRU, minLSTM, Mamba, and Transformers achieved comparable test losses of $1.548$, $1.555$, $1.575$, and $1.547$ respectively.

The related work section provides an overview of recent efficient recurrent sequence models, categorized into deep state-space models, recurrent versions of attention, and parallelizable RNNs.

The paper concludes by highlighting the parallel training enabled by removing gate dependencies on previous states. The minimal versions offer fewer parameters, full parallelizability, and competitive performance. The authors suggest a reevaluation of simpler foundational models like LSTM and GRU.

The limitations section acknowledges the hardware constraints that impacted the scale of the experiments, including the use of older GPUs with limited memory. Gradient accumulation was used to accommodate memory limitations, reducing the effective batch size and slowing down training.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 54 tweets with 3372 likes about this paper.

HackerNews

  1. Were RNNs all we needed? (516 points, 260 comments) 

Reddit

  1. Were RNNs All We Needed? (4 points, 1 comment) 
  2. Were RNNs All We Needed? (3 points, 2 comments) 
  3. Were RNNs All We Needed? (2 points, 0 comments) 
  4. Were RNNs All We Needed? (1 point, 1 comment) 
  5. Were RNNs all we needed? (1 point, 0 comments)