Causal-JEPA: Object-Level Causal Reasoning
- Causal-JEPA is an object-centric world model that extends JEPA by applying object-level masking, compelling the system to perform counterfactual predictions.
- It employs a masked-transformer predictor to reconstruct hidden object trajectories and isolate minimal influence neighborhoods for robust causal inference.
- Empirical results demonstrate significant gains in visual question answering and model predictive control, highlighting improved interaction-dependent reasoning and planning efficiency.
Causal-JEPA (C-JEPA) is an object-centric world model that imposes a causal inductive bias through object-level masking of latent representations, extending the Joint Embedding Predictive Architecture (JEPA) paradigm from patch-based to object-based domains. By intervening at the level of entire object trajectories within latent space, C-JEPA forces the model to solve counterfactual-style prediction tasks, yielding improved robustness in interaction-dependent reasoning, visual question answering, and data-efficient planning. This framework demonstrates empirically and provides formal guarantees that object-level masking induces stable discovery of minimal “influence neighborhoods,” i.e., the set of state variables that are essential for predicting an object’s future, thus enforcing true relational reasoning between objects (Nam et al., 11 Feb 2026).
1. Motivation and Conceptual Foundations
C-JEPA addresses key shortcomings in conventional object-centric world models, which typically utilize frozen object encoders such as Slot Attention or SAVi to factor video frames into a set of entity-specific latents (“slots”). When trained solely on pixel reconstruction or autoregressive prediction, such models tend to exploit trivial dynamics—memorizing self-trajectories or spurious correlations—without attending to genuine object–object interactions. Consequently, these models often succeed on short-term rollouts but underperform on tasks requiring causal dependence between distinct entities.
C-JEPA resolves this limitation by generalizing JEPA from masked patch prediction to the object slot level. By masking entire object trajectories throughout a history window (while preserving an identity “anchor” at the earliest step), the model is compelled to estimate a masked object’s latent state exclusively from the temporal evolution of other objects and any provided exogenous signals. This induces a counterfactual-style prediction problem: “What would this object’s embedding have been had it been unobserved, given the rest of the scene?” The training objective is thus explicitly constructed to make relational reasoning between entities necessary, eliminating shortcut solutions and trivial local dynamics (Nam et al., 11 Feb 2026).
2. Object-Level Masking and Latent Intervention Mechanism
The observation at each time step is processed by an object encoder , generating where . Auxiliary variables (e.g., action vectors or proprioception ) are concatenated, yielding entity tokens .
Mask selection operates by sampling a subset of slots to be masked for each timestep within the history window . For masked slots , the token is replaced with: where is a learned linear projection and is a learned temporal embedding. The resulting masked inputs for time are: Stacking masked tokens over all in and the planning horizon creates the full masked trajectory .
Crucially, this object-level masking alters only the observability for the predictor, not the underlying transition mechanism . The predictor is thus tasked with reconstructing the masked slots’ states as counterfactuals, given only partially observed historical contexts, enforcing interaction-dependent and relational prediction (Nam et al., 11 Feb 2026).
3. Training Objective and Loss Formulation
The C-JEPA framework utilizes a masked-transformer predictor —a Vision Transformer (ViT) with bidirectional attention—that consumes the masked sequence and outputs predicted latents . The training loss is given by: This decomposes into a history-reconstruction term (forcing correct completion of masked slots in the past) and a future-prediction term (enforcing accurate rollout over the horizon): Notably, C-JEPA employs no pixel-level reconstruction loss; history masking explicitly suppresses trivial self-dynamics, requiring the model to utilize interaction pathways for imputation, while future prediction constrains the system to forecast correct long-term trajectories.
Auxiliary variables (action or proprioception features) are included in but are never masked, ensuring that control-relevant state information is always available to the planner (Nam et al., 11 Feb 2026).
4. Formal Causal Inductive Bias and Theoretical Guarantees
Under standard causality assumptions—temporally directed dependencies, a shared transition mechanism, coherence of slot representations, and finite history sufficiency—object-level masking induces a specific causal inductive bias. For masked completion, define the “influence neighborhood” for slot as the minimal subset of history tokens (excluding ) such that: The following holds for the mean-squared error (MSE) loss:
- Interaction Necessity Theorem: The Bayes-optimal predictor for masked latent completion is
Any predictor ignoring relevant variables in incurs higher expected error.
- Stable Influence Neighborhood Corollary: Repeated object-level masking concentrates predictor attention on the intervention-stable set . This mechanism parallels Invariant Risk Minimization and Invariant Causal Prediction, identifying stable predictive dependencies in the dataset without requiring explicit graph discovery (Nam et al., 11 Feb 2026).
5. Model Architecture and Model Predictive Control
The object encoder in C-JEPA utilizes VideoSAUR or SAVi to aggregate frozen DINOv2 patch features into –7 slots using Slot Attention. The masked-transformer predictor comprises six transformer layers, each with 16 attention heads, head dimension , and MLP dimension 2048, processing a sequence of object and auxiliary tokens.
For planning, C-JEPA applies Model Predictive Control (MPC). At each time step , the system observes , encodes it to , masks only the future slots, and predicts . Optimal action sequences are found via the Cross-Entropy Method (CEM) objective: CEM is configured with 300 samples, 30 elite candidates, and 30 iterations. Only 4 slots 128-dimensional features are utilized (≈1.02% of the latent tokens required by a DINO-WM patch-based world model), leading to significantly increased efficiency—over 8x faster planning on a single GPU (Nam et al., 11 Feb 2026).
6. Empirical Evaluation and Comparative Studies
Visual Question Answering (CLEVRER)
C-JEPA is evaluated on CLEVRER by rolling out from frame 128 to 160 and feeding slot trajectories into ALOE for various question types. Compared to an identical architecture with no history masking (OC-JEPA), C-JEPA achieves an absolute ∼20% gain in counterfactual question accuracy and 1–6% gains on descriptive, predictive, and explanatory categories. Peak performance occurs with a masking budget of 3–4 objects per frame.
Model Predictive Control (Push-T)
With four object slots and auxiliary action/proprioception embeddings, C-JEPA achieves 88.7% planning success versus 60.7% for OC-DINO-WM (object encoder + patch predictor) and 91.3% for DINO-WM (patch-based), using only 1.02% of the input tokens. Planning is faster (673s vs 5763s for 50 trajectories on an L40s GPU).
Comparative and Ablation Analysis
Several baselines are compared:
| Model | VQA (QA accuracy) | Push-T Planning Success | Latent Token Usage |
|---|---|---|---|
| C-JEPA | High (with masking) | 88.7% | 1.02% |
| OC-JEPA (no masking) | Lower (∼20% below) | — | — |
| SlotFormer | Degrades without recon | — | — |
| OCVP-Seq | Moderate loss if no recon | — | — |
| DINO-WM | — | 91.3% | 100% |
| OC-DINO-WM | — | 60.7% | — |
Removing pixel-reconstruction losses has a substantial negative impact on SlotFormer and OCVP-Seq but not on C-JEPA. Performance in C-JEPA increases with the fraction of masked objects to an optimal threshold (≈3/7 on CLEVRER, 1/4 on Push-T), then degrades. Object-level masking is more stable than token- or tube-level masking, which can destabilize planning at higher budgets (Nam et al., 11 Feb 2026).
7. Limitations and Future Directions
C-JEPA’s effectiveness is contingent on the quality of the frozen object encoder: collapse or misalignment in Slot Attention restricts the attainable performance ceiling. Direct validation of the model’s inferred influence neighborhoods against ground-truth causal graphs remains necessary for deeper understanding of its discovery capabilities.
Possible future developments include joint training of the encoder and predictor to enhance slot representations, extension to multi-modal or physics-parameter auxiliaries, and application to environments with greater complexity and numbers of interacting entities (Nam et al., 11 Feb 2026).