- The paper presents Comba as a novel bilinear RNN that integrates closed-loop control mechanisms to enhance memory management and inference efficiency.
- It employs a Scalar-Plus-Low-Rank update with state and output feedback corrections to address Transformer limitations like quadratic complexity.
- Experimental results demonstrate up to 40% faster forward passes and superior performance on language modeling and vision benchmarks compared to baselines.
The paper "Comba: Improving Bilinear RNNs with Closed-loop Control" (2506.02475) introduces a novel recurrent neural network architecture called Comba, designed to address the limitations of Transformers, particularly the quadratic complexity and unbounded memory growth for long sequences. Comba builds upon recent advancements in efficient sequence modeling, specifically Nonlinear RNNs, by incorporating principles from closed-loop control theory.
Understanding the Context: Linear vs. Nonlinear RNNs
The paper categorizes efficient sequence models into two main groups:
- Linear RNNs (e.g., Mamba, GLA): These models function as linear key-value associative memory registers, characterized by state updates like St​=(αt​,βt​)⊗(St−1​,kt​vtT​) and reads ot​=St​qt​. They offer constant memory and O(L) inference time but rely on heuristic, data-dependent gating for memory management. Examples include Mamba [36], GLA [101], and RetNet [88]. Their state transition is often approximated by diagonal or scalar matrices for efficiency.
- Nonlinear RNNs (e.g., DeltaNet, RWKV-7, TTT): These models introduce richer interactions between the state S and the input k, moving beyond simple linear memory registers. They often employ the Delta learning rule for supervised memory control, formulating state updates that resemble affine bilinear systems. Examples include Gated-DeltaNet [100] and RWKV-7 [71]. While more expressive, efficient parallelization during training can be challenging.
Comba is proposed as a novel variant of Nonlinear RNNs, explicitly designed with closed-loop control mechanisms.
Comba Architecture and Closed-loop Control
Comba's core innovation lies in adopting a closed-loop control perspective for memory management. Unlike previous models that might use state feedback (like the Delta rule correcting the input v), Comba introduces feedback at two stages:
- State-based feedback on the input (v): The input vt​ is corrected based on the previous state St−1​ and key kt​. This is conceptually similar to the Delta rule vnew​=vt​−St−1​kt​. In Comba's formulation, this is implicitly handled within the state update based on the scalar αt​ and feedback factor b.
- Output feedback correction: The query qt​ is corrected based on the key kt​ using a scalar factor d. The output is computed as ot​=St​(qt​−dkt​). From an optimization perspective, this term (qt​,dkt​) encourages similarity optimization between q and k, which the authors find crucial for improving memory retrieval and recall.
The state transition in Comba adopts a Scalar-Plus-Low-Rank (SPLR) form:
St​=St−1​(αt​−βt​kt​ktT​)+βt​vt​ktT​
ot​=St​(qt​−dkt​)
Here:
- St​∈RD×D is the state (memory).
- qt​,kt​,vt​∈RD are the query, key, and value vectors (for a single head).
- αt​ is a data-dependent scalar forget gate, typically initialized close to 1 to encourage learning to forget.
- βt​ is a data-dependent scalar input gate, typically initialized in (0,1).
- kt​ktT​ is the low-rank term, enabling interaction between the state and the input k.
- d is a trainable scalar for output feedback, which can be initialized to a small value (0.02) for smaller models or 1 for larger models (1.3B+).
The learnable parameters {α,b,c,d} are implemented as simple trainable scalars, not requiring data dependency, contributing to efficiency. The feedback strength for the state update is controlled by βt​=b∘βt​, where b is constrained in (0,1) via a Sigmoid function to ensure feedback is weaker than incremental information. Short convolutions are applied to q,k,v projections to improve local feature interactions.
Comba can be implemented recursively for inference with O(1) time and memory per step (detailed in Appendix B.1 pseudocode).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
|
def Recurrent_comba(q, k, v, alpha, beta, b, d):
B, T, H, D = q.shape # Batch, Time, Heads, Dimension
# Output correction: q_new = q - d * k
q_new = q - d * k.unsqueeze(-2) # Assuming d is shape (B, H, 1, 1) or similar
# State and output tensors
o = torch.zeros_like(v)
S = torch.zeros(B, H, D, D) # State S per batch and head, DxD matrix
for i in range(T):
_q, _k, _v = q_new[:, i], k[:, i], v[:, i]
_alpha, _beta = alpha[:, i], beta[:, i] # Scalars per batch and head
# State feedback term: S_{t-1}k_t
S_k_prod = torch.einsum('bhd m, bhd -> bhm', S, _k) # (B, H, D) dot (B, H, D, D) -> (B, H, D)
# Corrected input v_new
_v_new = _beta.unsqueeze(-1) * (_v - b.unsqueeze(-1) * S_k_prod) # b is a scalar/vector
# State update: S_t = S_{t-1}(alpha_t - beta_t k_t k_t^T) + beta_t v_new k_t^T
# S_t = alpha_t S_{t-1} - beta_t S_{t-1} k_t k_t^T + beta_t v_new k_t^T
# S_t = alpha_t S_{t-1} + beta_t (v_new k_t^T - S_{t-1} k_t k_t^T)
# S_t = alpha_t S_{t-1} + beta_t (v_new - S_{t-1}k_t) k_t^T -- this matches the delta rule perspective
# State update using the SPLR form:
# S_{t-1} * alpha_t - S_{t-1} * beta_t k_t k_t^T + beta_t v_t k_t^T
# State decay term: S_{t-1} * alpha_t
S_decay = _alpha.unsqueeze(-1).unsqueeze(-1) * S # Element-wise scalar multiply
# Low-rank update term: beta_t v_new k_t^T
low_rank_update = _beta.unsqueeze(-1).unsqueeze(-1) * _v_new.unsqueeze(-1) * _k.unsqueeze(-2) # (B,H,D,1) * (B,H,1,D) -> (B,H,D,D)
# State transition matrix: M_t = alpha_t * I - beta_t * k_t k_t^T
# S_t = S_{t-1} @ M_t + beta_t v_t k_t^T --> The paper's formula has alpha_t outside the matrix product
# The paper's formula: St = St-1(at-Btk+k!) + BEvtk]
# This implies the state transition matrix is (alpha_t I - beta_t k_t k_t^T) applied from the right?
# Let's re-read Eq 5: St = St-1(at-Btk+kÄ®) + B+UtKI
# This means S_{t-1} is multiplied by a matrix (alpha_t I - beta_t k_t k_t^T) - this seems unusual as S is DxD and the term is a DxD matrix
# Looking at Figure 1 and Eq 3/4/5, the update rule St = atSt-1 - Bt(St-1kt - vt)kÄ® seems more standard (S_t = alpha_t S_{t-1} + beta_t (v_t - S_{t-1}k_t)k_t^T)
# And Eq 5 St = St-1(at-Btk+kÄ®) + BEvtk] is derived from this perspective.
# Let's assume the state update is: S_t = alpha_t S_{t-1} + beta_t (v_t - b * S_{t-1}k_t) k_t^T based on Table 3 and discussion.
# Corrected state update based on typical Nonlinear RNN form and Table 3/Eq 5 derivation
_v_corrected = v[:, i] - b.unsqueeze(-1) * torch.einsum('bhd m, bhd -> bhm', S, k[:, i])
S = _alpha.unsqueeze(-1).unsqueeze(-1) * S + _beta.unsqueeze(-1).unsqueeze(-1) * _v_corrected.unsqueeze(-1) * k[:, i].unsqueeze(-2)
# Output: o_t = S_t q_new_t
o[:, i] = torch.einsum('bhd m, bhd -> bhm', S, _q) # (B, H, D, D) dot (B, H, D) -> (B, H, D)
return o
|
Note: The pseudocode derivation from the paper's equations is slightly ambiguous regarding matrix multiplication order, but the common form for these models is St​=αt​St−1​+βt​(…)ktT​. The pseudocode provided in App. B.1 suggests S is a DxD matrix and performs operations like S * _k[..., None] and _k.unsqueeze(-1) * _v_new.unsqueeze(-2), implying element-wise or outer products depending on shapes. The re-interpreted update above attempts to align with the Delta rule perspective.
Hardware-Efficient Training (Chunk-wise Parallelism)
While recursive inference is efficient, training requires parallelization. Comba utilizes a chunk-wise parallel kernel implemented in Triton [90]. This approach breaks the sequence into chunks and combines intra-chunk parallel computation with inter-chunk recurrence. The paper employs WY representation [12] and UT transform [49] techniques, similar to DeltaNet [100], to formulate the chunk-wise updates efficiently. A key optimization mentioned is computing the inverse matrix in the UT transform step only once, compared to twice in Gated-DeltaNet, contributing to speed improvements. The parallel training formulation involves matrix operations over chunks (Eq. 11 and 12 in the paper).
Experimental Results and Practical Implications
The authors evaluate Comba on language modeling (SlimPajama, LongBench) and vision tasks (ImageNet classification, object tracking) using models with 340M and 1.3B parameters.
- Operator Efficiency: The Triton-based chunk-wise Comba kernel shows significant speed improvements (nearly 40% faster forward pass) over Gated-DeltaNet, demonstrating its hardware efficiency for training.
- Language Modeling: Comba-SPLR outperforms other variants (IPLR, DPLR) and several baselines (RWKV7, Gated-DeltaNet) on commonsense reasoning and recall-intensive tasks. The ablation studies highlight the critical role of the output correction (q,dk) term in improving perplexity and recall by optimizing memory retrieval. Initializing the output feedback factor d correctly (0.02 for 340M, 1 for 1.3B) is important for performance. The standard MLP-based architecture (Figure 2a) performs better than a Mamba-like architecture (Figure 2b) for language modeling.
- Long-context Modeling: Comba generally performs well on LongBench tasks, showing strong capability in QA and Few-shot settings with 10K context length.
- Vision Modeling: Comba-T and Comba-S variants achieve state-of-the-art accuracy-efficiency trade-offs on ImageNet classification compared to Transformer (DeiT), sparse attention (Agent Attention), and Linear RNN (Vision Mamba) baselines, with lower FLOPS and parameter counts at similar accuracy levels. On object tracking, Comba variants consistently outperform baselines, demonstrating strong temporal modeling and long-range dependency capture.
- Hybrid Architectures: Replacing some Comba layers with softmax attention layers (Figure 2c) improves recall, suggesting that hybrid approaches leveraging the strengths of different mechanisms can be beneficial.
Implementation Considerations
- Computational Resources: Training the 340M model required 8x A800 GPUs for 10 hours, and the 1.3B model required 32x A800 GPUs for 48 hours, indicating substantial but feasible resource requirements for models up to 1.3B parameters. Scaling to larger models (e.g., 2.7B) is computationally intensive (estimated 32x A800 for 120+ hours).
- Kernel Implementation: The use of custom Triton kernels is crucial for achieving the reported training efficiency. Implementing these kernels requires expertise in low-level GPU programming. The paper provides pseudocode for the recurrent inference pass, which is simpler to implement for deployment once the model is trained.
- Hyperparameters: Learning rate (3e-4), optimizer (AdamW), weight decay (0.01), and gradient clipping (1.0) are standard but model scale and initialization strategies for gates and the output feedback factor d require tuning.
- Limitations: Direct comparison to some recent models (Titans, Lattice, MIRAS) was not possible due to lack of open-source implementations. Further exploration is needed to understand Comba's performance characteristics on summarization and code tasks where it lagged slightly behind Gated-DeltaNet.
In summary, Comba provides a practical and performant alternative to Transformers for sequence modeling, particularly in resource-constrained or long-sequence scenarios. Its novel integration of closed-loop control principles, the efficient SPLR state transition, and a highly optimized Triton kernel make it competitive across language and vision domains. Implementing Comba requires leveraging the provided Triton kernel code or developing similar optimized kernels for efficient training.