Papers
Topics
Authors
Recent
Search
2000 character limit reached

Continuous Thought Machines

Published 8 May 2025 in cs.LG and cs.AI | (2505.05522v3)

Abstract: Biological brains demonstrate complex neural activity, where the timing and interplay between neurons is critical to how brains process information. Most deep learning architectures simplify neural activity by abstracting away temporal dynamics. In this paper we challenge that paradigm. By incorporating neuron-level processing and synchronization, we can effectively reintroduce neural timing as a foundational element. We present the Continuous Thought Machine (CTM), a model designed to leverage neural dynamics as its core representation. The CTM has two core innovations: (1) neuron-level temporal processing, where each neuron uses unique weight parameters to process a history of incoming signals; and (2) neural synchronization employed as a latent representation. The CTM aims to strike a balance between oversimplified neuron abstractions that improve computational efficiency, and biological realism. It operates at a level of abstraction that effectively captures essential temporal dynamics while remaining computationally tractable for deep learning. We demonstrate the CTM's strong performance and versatility across a range of challenging tasks, including ImageNet-1K classification, solving 2D mazes, sorting, parity computation, question-answering, and RL tasks. Beyond displaying rich internal representations and offering a natural avenue for interpretation owing to its internal process, the CTM is able to perform tasks that require complex sequential reasoning. The CTM can also leverage adaptive compute, where it can stop earlier for simpler tasks, or keep computing when faced with more challenging instances. The goal of this work is to share the CTM and its associated innovations, rather than pushing for new state-of-the-art results. To that end, we believe the CTM represents a significant step toward developing more biologically plausible and powerful artificial intelligence systems.

Summary

  • The paper introduces a CTM architecture that leverages neuron-level temporal dynamics and neural synchronization to achieve adaptive compute.
  • It employs an internal tick mechanism to iteratively refine states, balancing biological realism with computational efficiency.
  • Evaluations on tasks like ImageNet classification, maze navigation, and sorting demonstrate improved calibration, emergent strategies, and robust generalization.

The paper "Continuous Thought Machines" (2505.05522) introduces a novel neural network architecture designed to explicitly incorporate neuron-level temporal dynamics and neural synchronization, aiming for a balance between computational efficiency and biological realism. The core idea is to leverage the timing and interplay of neural activity as a fundamental computational element, contrasting with standard deep learning models that typically abstract away temporal dynamics.

The Continuous Thought Machine (CTM) operates along an internal dimension referred to as "internal ticks," decoupled from the input data's inherent sequence. This allows the model to iteratively refine its internal state and computations. The architecture consists of several key components:

  1. Internal Sequence Dimension: The CTM progresses through discrete internal ticks (t∈{1,…,T}t \in \{1, \dots, T\}), allowing computation to unfold over time, independent of the data's structure. This is analogous to a self-generated timeline for "thought."
  2. Synapse Model: A recurrent model, $f_{\theta_{\text{syn}$, interconnects neurons in a shared latent space. It takes the current neuron states (ztz^t) and external input (attention output oto^t) to produce "pre-activations" (ata^t) for the next tick. The paper uses a U-NET-like MLP for this, suggesting deeper synaptic connections are beneficial.
    1
    
    a_t = synapse_model(torch.cat([z_t, o_t], dim=-1))
    A history of the MM most recent pre-activations (At∈RD×M\mathbf{A}^t \in \mathbb{R}^{D \times M}) is maintained.
  3. Neuron-Level Models (NLMs): Each neuron has its own private parameterized model (gθdg_{\theta_d}) that processes the history of its incoming pre-activations (Adt\mathbf{A}_d^t) to produce its next "post-activation" (zdt+1z_d^{t+1}). This allows each neuron to have unique temporal processing characteristics. The full set of post-activations (zt+1z^{t+1}) are concatenated with the attention output for the next internal tick.
    1
    2
    3
    4
    5
    
    # Simplified NLM application (using einsum for efficiency as shown in paper)
    # history: (batch_size, D, M), weights: (D, M, hidden_dim, 1)
    # This is a simplified view; actual NLMs are MLPs
    latent_h = torch.einsum('bdm,dmhw->bhw', pre_activation_history, nlm_weights_1)
    post_activations = torch.einsum('bhw,dhwo->bdo', latent_h, nlm_weights_2)
    A history of post-activations (Zt∈RD×t\mathbf{Z}^t \in \mathbb{R}^{D \times t}) is maintained, growing with the number of internal ticks.
  4. Neural Synchronization: This is the core latent representation. It's computed from the post-activation history Zt\mathbf{Z}^t. The full synchronization matrix is St=Zt⋅(Zt)⊺∈RD×DS^t = \mathbf{Z}^t \cdot (\mathbf{Z}^t)^\intercal \in \mathbb{R}^{D \times D}. This matrix captures the temporal correlation between every pair of neurons over the entire history of computation up to the current tick.
    1
    2
    3
    
    # Simplified synchronization calculation (ignoring learned decay for clarity)
    # Z_t: (batch_size, D, t)
    S_t = torch.einsum('bdt,bet->bde', Z_t, Z_t) # (batch_size, D, D)
    To manage complexity and focus, a subset of neuron pairs are selected from StS^t to form action synchronization (SactiontS^t_\text{action}) and output synchronization (SouttS^t_\text{out}) vectors. The paper explores different sampling strategies (dense, semi-dense, random). The computation of synchronization incorporates a learnable exponential decay (rijr_{ij}) for each neuron pair, allowing the model to learn the influence of past activity at different timescales. This can be computed recursively for efficiency, avoiding recomputing the entire history's dot product at each step.
    1
    2
    3
    4
    5
    6
    7
    
    # Recursive update of synchronization (simplified for one pair ij)
    # alpha_t_ij and beta_t_ij are states maintained recursively
    # z_t_plus_1_i and z_t_plus_1_j are current post-activations for neurons i and j
    # r_ij is the learnable decay rate for pair ij
    alpha_t_plus_1_ij = torch.exp(-r_ij) * alpha_t_ij + z_t_plus_1_i * z_t_plus_1_j
    beta_t_plus_1_ij = torch.exp(-r_ij) * beta_t_ij + 1.0
    S_t_plus_1_ij = alpha_t_plus_1_ij / torch.sqrt(beta_t_plus_1_ij)
  5. Data Interaction (Attention and Output): The CTM interacts with inputs and produces outputs using the synchronization representation.
    • Attention: Action synchronization (SactiontS^t_\text{action}) is projected to an attention query (qtq^t). This query is used in standard cross-attention with features extracted from the input data (keys and values). The attention output (oto^t) is then concatenated with the post-activations for the next tick.
      1
      2
      
      q_t = W_in(S_t_action)
      o_t = Attention(Q=q_t, KV=feature_extractor(data))
    • Output: Output synchronization (SouttS^t_\text{out}) is projected to the final output space (e.g., logits yty^t).
      1
      
      y_t = W_out(S_t_out)
  6. Loss Function: The CTM produces outputs at each internal tick. The loss is computed per tick (Lt\mathcal{L}^t). The final loss used for training is an aggregation of losses from two dynamic points: the internal tick with minimum loss (t1t_1) and the internal tick with maximum certainty (t2t_2, where certainty is 1 - normalized entropy). This dynamic aggregation encourages the CTM to develop meaningful representations across ticks and naturally facilitates adaptive compute.

The paper evaluates the CTM on a diverse set of tasks to showcase its capabilities:

  • ImageNet-1K Classification: While not state-of-the-art in raw accuracy (72.47% top-1), the CTM demonstrates novel behaviors. It exhibits rich, diverse, and complex neural dynamics. The internal thought dimension enables adaptive compute, allowing the CTM to stop processing early for easier images based on certainty. Prediction analysis shows that certainty increases over internal ticks, leading to good calibration without specific post-training adjustments. Visualizations show attention patterns that smoothly shift and focus on salient features over time, demonstrating an emergent observational process. Emergent traveling waves in the neuron activation space are also observed.
  • 2D Mazes: A challenging maze navigation task requiring outputting a sequence of moves (not just the path) and trained without positional embeddings in attention. The CTM significantly outperforms LSTM and feed-forward baselines, suggesting it is more capable of building and utilizing an internal "world model" or cognitive map. Demonstrations show the CTM's attention tracing the solution path iteratively, and the model generalizes well to longer paths and larger mazes by reapplying the trained process. This suggests the CTM learns a general problem-solving procedure.
  • CIFAR-10: Comparison to human performance and baselines. The CTM achieves better test accuracy and significantly better calibration than LSTMs and a feed-forward model. Its performance correlates well with human difficulty ratings, and its uncertainty trend aligns with human reaction times. Neural activity traces show much richer dynamics for the CTM compared to the LSTM baseline on this task.
  • CIFAR-100 Ablations: Experiments varying the number of neurons (width) and internal ticks. Increasing width improves accuracy up to a point and increases the diversity/dissimilarity of neuron activity patterns. Increasing internal ticks generally improves accuracy and reveals an emergent tendency for the CTM to have two phases of high certainty during its thought process, regardless of the total number of ticks available.
  • Sorting: Sorting 30 real numbers, a task used to study adaptive compute. The CTM is trained to output the sorted sequence over its internal ticks using CTC loss. The analysis shows clear patterns in the "wait times" (number of ticks before outputting an element), correlating with the element's position in the sequence and the difference from the previous value, indicating the CTM uses a data-dependent internal algorithm. It also generalizes to data from different distributions.
  • Parity: Computing cumulative parity of a 64-length binary sequence. The CTM learns this sequential task better than LSTMs, and performance improves with more internal ticks. Analysis of attention patterns reveals different problem-solving strategies emerge in different training runs (e.g., sequential scanning vs. reverse scanning/planning), showcasing the CTM's ability to form diverse strategies.
  • Q&A MNIST: A task testing memory, retrieval, and arithmetic. The CTM observes MNIST digits and sequence of index/operator embeddings, then outputs the result of modular arithmetic operations. The CTM outperforms LSTMs, especially with more internal ticks. It demonstrates memory retrieval even when observed digits are outside the NLM's memory window, suggesting synchronization is key to this. The CTM learns to compute the operations sequentially as embeddings are observed, generalizing to more operations than seen during training.
  • Reinforcement Learning: Applying the CTM to POMDPs (CartPole, Acrobot, MiniGrid Four Rooms). The CTM maintains continuous neural dynamics across environment steps (using a sliding window history). It achieves comparable performance to LSTM baselines on these sequential decision-making tasks, demonstrating its ability to function as a stateful recurrent agent and interact continuously with an environment while exhibiting richer neural dynamics than LSTMs.

Practical Implementation Details & Considerations:

  • Computational Cost: The primary limitation is sequential processing across internal ticks, which cannot be parallelized like layers in a feed-forward network. This makes training longer than standard feed-forward models. The cost per internal tick is manageable (O(D2)\mathcal{O}(D^2) for full synchronization, but reduced to O(Dsub)\mathcal{O}(D_{\text{sub}}) with sampling).
  • Parameter Count: NLMs add parameters scaling with D×M×dhiddenD \times M \times d_\text{hidden}, which increases parameter count compared to models with static activation functions. However, the paper shows the CTM can achieve strong results with parameter counts comparable to or less than LSTMs on some tasks.
  • Memory: Maintaining activation histories (At\mathbf{A}^t, Zt\mathbf{Z}^t) adds memory overhead, linear in TT and MM. Recursive computation of synchronization helps mitigate this for the dot product itself. For RL, a sliding window history is used to bound memory.
  • Stability: The certainty-based loss and the synchronization representation seem to contribute to training stability on tasks where standard RNNs struggled (mazes, Q&A MNIST, parity).
  • Generalization: The CTM demonstrates impressive generalization on mazes (longer paths, larger maps) and sorting (different distributions), suggesting it learns underlying procedures or world models rather than just memorizing. Generalization on Q&A MNIST is also observed for increased task length.
  • Modularity: The core CTM recurrent block (synapse + NLMs + synchronization) is reusable, with task-specific input feature extractors and output projections (attention, logits).

The paper concludes by highlighting the CTM's novel use of neural synchronization and temporal dynamics as a fundamental representation, leading to emergent properties like adaptive compute, improved calibration, and interpretable processing strategies. It suggests that drawing inspiration from biological principles, even without strict adherence, can lead to powerful and qualitatively different AI capabilities, paving the way for more flexible and intelligent systems.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 48 tweets with 2968 likes about this paper.

HackerNews

  1. Continuous Thought Machines (3 points, 0 comments) 

Reddit

  1. Continuous Thought Machines (4 points, 1 comment)