- The paper introduces a CTM architecture that leverages neuron-level temporal dynamics and neural synchronization to achieve adaptive compute.
- It employs an internal tick mechanism to iteratively refine states, balancing biological realism with computational efficiency.
- Evaluations on tasks like ImageNet classification, maze navigation, and sorting demonstrate improved calibration, emergent strategies, and robust generalization.
The paper "Continuous Thought Machines" (2505.05522) introduces a novel neural network architecture designed to explicitly incorporate neuron-level temporal dynamics and neural synchronization, aiming for a balance between computational efficiency and biological realism. The core idea is to leverage the timing and interplay of neural activity as a fundamental computational element, contrasting with standard deep learning models that typically abstract away temporal dynamics.
The Continuous Thought Machine (CTM) operates along an internal dimension referred to as "internal ticks," decoupled from the input data's inherent sequence. This allows the model to iteratively refine its internal state and computations. The architecture consists of several key components:
- Internal Sequence Dimension: The CTM progresses through discrete internal ticks (t∈{1,…,T}), allowing computation to unfold over time, independent of the data's structure. This is analogous to a self-generated timeline for "thought."
- Synapse Model: A recurrent model, $f_{\theta_{\text{syn}$, interconnects neurons in a shared latent space. It takes the current neuron states (zt) and external input (attention output ot) to produce "pre-activations" (at) for the next tick. The paper uses a U-NET-like MLP for this, suggesting deeper synaptic connections are beneficial.
1
|
a_t = synapse_model(torch.cat([z_t, o_t], dim=-1)) |
A history of the M most recent pre-activations (At∈RD×M) is maintained.
- Neuron-Level Models (NLMs): Each neuron has its own private parameterized model (gθd​​) that processes the history of its incoming pre-activations (Adt​) to produce its next "post-activation" (zdt+1​). This allows each neuron to have unique temporal processing characteristics. The full set of post-activations (zt+1) are concatenated with the attention output for the next internal tick.
1
2
3
4
5
|
# Simplified NLM application (using einsum for efficiency as shown in paper)
# history: (batch_size, D, M), weights: (D, M, hidden_dim, 1)
# This is a simplified view; actual NLMs are MLPs
latent_h = torch.einsum('bdm,dmhw->bhw', pre_activation_history, nlm_weights_1)
post_activations = torch.einsum('bhw,dhwo->bdo', latent_h, nlm_weights_2) |
A history of post-activations (Zt∈RD×t) is maintained, growing with the number of internal ticks.
- Neural Synchronization: This is the core latent representation. It's computed from the post-activation history Zt. The full synchronization matrix is St=Zt⋅(Zt)⊺∈RD×D. This matrix captures the temporal correlation between every pair of neurons over the entire history of computation up to the current tick.
1
2
3
|
# Simplified synchronization calculation (ignoring learned decay for clarity)
# Z_t: (batch_size, D, t)
S_t = torch.einsum('bdt,bet->bde', Z_t, Z_t) # (batch_size, D, D) |
To manage complexity and focus, a subset of neuron pairs are selected from St to form action synchronization (Sactiont​) and output synchronization (Soutt​) vectors. The paper explores different sampling strategies (dense, semi-dense, random).
The computation of synchronization incorporates a learnable exponential decay (rij​) for each neuron pair, allowing the model to learn the influence of past activity at different timescales. This can be computed recursively for efficiency, avoiding recomputing the entire history's dot product at each step.
1
2
3
4
5
6
7
|
# Recursive update of synchronization (simplified for one pair ij)
# alpha_t_ij and beta_t_ij are states maintained recursively
# z_t_plus_1_i and z_t_plus_1_j are current post-activations for neurons i and j
# r_ij is the learnable decay rate for pair ij
alpha_t_plus_1_ij = torch.exp(-r_ij) * alpha_t_ij + z_t_plus_1_i * z_t_plus_1_j
beta_t_plus_1_ij = torch.exp(-r_ij) * beta_t_ij + 1.0
S_t_plus_1_ij = alpha_t_plus_1_ij / torch.sqrt(beta_t_plus_1_ij) |
- Data Interaction (Attention and Output): The CTM interacts with inputs and produces outputs using the synchronization representation.
- Attention: Action synchronization (Sactiont​) is projected to an attention query (qt). This query is used in standard cross-attention with features extracted from the input data (keys and values). The attention output (ot) is then concatenated with the post-activations for the next tick.
1
2
|
q_t = W_in(S_t_action)
o_t = Attention(Q=q_t, KV=feature_extractor(data)) |
- Output: Output synchronization (Soutt​) is projected to the final output space (e.g., logits yt).
- Loss Function: The CTM produces outputs at each internal tick. The loss is computed per tick (Lt). The final loss used for training is an aggregation of losses from two dynamic points: the internal tick with minimum loss (t1​) and the internal tick with maximum certainty (t2​, where certainty is 1 - normalized entropy). This dynamic aggregation encourages the CTM to develop meaningful representations across ticks and naturally facilitates adaptive compute.
The paper evaluates the CTM on a diverse set of tasks to showcase its capabilities:
- ImageNet-1K Classification: While not state-of-the-art in raw accuracy (72.47% top-1), the CTM demonstrates novel behaviors. It exhibits rich, diverse, and complex neural dynamics. The internal thought dimension enables adaptive compute, allowing the CTM to stop processing early for easier images based on certainty. Prediction analysis shows that certainty increases over internal ticks, leading to good calibration without specific post-training adjustments. Visualizations show attention patterns that smoothly shift and focus on salient features over time, demonstrating an emergent observational process. Emergent traveling waves in the neuron activation space are also observed.
- 2D Mazes: A challenging maze navigation task requiring outputting a sequence of moves (not just the path) and trained without positional embeddings in attention. The CTM significantly outperforms LSTM and feed-forward baselines, suggesting it is more capable of building and utilizing an internal "world model" or cognitive map. Demonstrations show the CTM's attention tracing the solution path iteratively, and the model generalizes well to longer paths and larger mazes by reapplying the trained process. This suggests the CTM learns a general problem-solving procedure.
- CIFAR-10: Comparison to human performance and baselines. The CTM achieves better test accuracy and significantly better calibration than LSTMs and a feed-forward model. Its performance correlates well with human difficulty ratings, and its uncertainty trend aligns with human reaction times. Neural activity traces show much richer dynamics for the CTM compared to the LSTM baseline on this task.
- CIFAR-100 Ablations: Experiments varying the number of neurons (width) and internal ticks. Increasing width improves accuracy up to a point and increases the diversity/dissimilarity of neuron activity patterns. Increasing internal ticks generally improves accuracy and reveals an emergent tendency for the CTM to have two phases of high certainty during its thought process, regardless of the total number of ticks available.
- Sorting: Sorting 30 real numbers, a task used to study adaptive compute. The CTM is trained to output the sorted sequence over its internal ticks using CTC loss. The analysis shows clear patterns in the "wait times" (number of ticks before outputting an element), correlating with the element's position in the sequence and the difference from the previous value, indicating the CTM uses a data-dependent internal algorithm. It also generalizes to data from different distributions.
- Parity: Computing cumulative parity of a 64-length binary sequence. The CTM learns this sequential task better than LSTMs, and performance improves with more internal ticks. Analysis of attention patterns reveals different problem-solving strategies emerge in different training runs (e.g., sequential scanning vs. reverse scanning/planning), showcasing the CTM's ability to form diverse strategies.
- Q&A MNIST: A task testing memory, retrieval, and arithmetic. The CTM observes MNIST digits and sequence of index/operator embeddings, then outputs the result of modular arithmetic operations. The CTM outperforms LSTMs, especially with more internal ticks. It demonstrates memory retrieval even when observed digits are outside the NLM's memory window, suggesting synchronization is key to this. The CTM learns to compute the operations sequentially as embeddings are observed, generalizing to more operations than seen during training.
- Reinforcement Learning: Applying the CTM to POMDPs (CartPole, Acrobot, MiniGrid Four Rooms). The CTM maintains continuous neural dynamics across environment steps (using a sliding window history). It achieves comparable performance to LSTM baselines on these sequential decision-making tasks, demonstrating its ability to function as a stateful recurrent agent and interact continuously with an environment while exhibiting richer neural dynamics than LSTMs.
Practical Implementation Details & Considerations:
- Computational Cost: The primary limitation is sequential processing across internal ticks, which cannot be parallelized like layers in a feed-forward network. This makes training longer than standard feed-forward models. The cost per internal tick is manageable (O(D2) for full synchronization, but reduced to O(Dsub​) with sampling).
- Parameter Count: NLMs add parameters scaling with D×M×dhidden​, which increases parameter count compared to models with static activation functions. However, the paper shows the CTM can achieve strong results with parameter counts comparable to or less than LSTMs on some tasks.
- Memory: Maintaining activation histories (At, Zt) adds memory overhead, linear in T and M. Recursive computation of synchronization helps mitigate this for the dot product itself. For RL, a sliding window history is used to bound memory.
- Stability: The certainty-based loss and the synchronization representation seem to contribute to training stability on tasks where standard RNNs struggled (mazes, Q&A MNIST, parity).
- Generalization: The CTM demonstrates impressive generalization on mazes (longer paths, larger maps) and sorting (different distributions), suggesting it learns underlying procedures or world models rather than just memorizing. Generalization on Q&A MNIST is also observed for increased task length.
- Modularity: The core CTM recurrent block (synapse + NLMs + synchronization) is reusable, with task-specific input feature extractors and output projections (attention, logits).
The paper concludes by highlighting the CTM's novel use of neural synchronization and temporal dynamics as a fundamental representation, leading to emergent properties like adaptive compute, improved calibration, and interpretable processing strategies. It suggests that drawing inspiration from biological principles, even without strict adherence, can lead to powerful and qualitatively different AI capabilities, paving the way for more flexible and intelligent systems.