TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning
Abstract: Latent prediction--where agents learn by predicting their own latents--has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Glossary
- Action-value function: The expected cumulative discounted reward of taking an action in a state under a policy. "the action-value function measures the cumulative discounted reward obtained by the policy over an infinite horizon"
- Bellman equation: A recursive relation expressing a quantity (e.g., value or successor features) in terms of immediate outcomes and its future values. "successor features admit a Bellman equation $F_\phi^{\pi_z}(s,a) = \bE_{s'\sim P(\cdot | s,a), a'\sim\pi_z(s')}[\phi(s') + \gamma F_\phi^{\pi_z}(s',a')]$"
- Bootstrapped: Using the model’s own prediction as part of the target during training. "plus a bootstrapped version of itself"
- Covariance regularization: A stabilization technique that penalizes feature covariance to encourage well-conditioned representations. "stabilization strategies, e.g. target networks and covariance regularization."
- Exponential Moving Average (EMA): A smoothing update that maintains target networks as exponentially weighted averages of online networks. "Update target networks , , {, } via EMA of , , {, }"
- Expectile regression: A regression method that fits expectiles, used here to decompose successor measures. "learns a multilinear decomposition of the successor measure via expectile regression"
- Forward and backward TD losses: Temporal-difference losses that approximate successor measures using forward and backward dynamics. "related to forward and backward TD losses for approximating the successor measure"
- Joint-embedding predictive architecture (JEPA): A learning paradigm where encoders are trained to predict each other’s embeddings without reconstruction. "an instance of the joint-embedding predictive architecture~\citep[] [JEPA]{lecun2022path} paradigm."
- Latent dynamics model: A predictor that estimates the latent representation of future states. "a latent dynamics model estimating the representation of a future state "
- Latent space: The feature space produced by an encoder where learning and prediction are performed. "perform self-supervised learning entirely in latent space"
- Latent-predictive representation learning: Learning representations by predicting future latent embeddings rather than raw observations. "latent-predictive (a.k.a. self-predictive) representation learning"
- Least-squares TD: A TD method that finds the fixed point of the projected Bellman equation via least-squares. "solve a least-squares TD problem"
- MC-JEPA (Monte-Carlo JEPA loss): A loss that trains predictors to match future task embeddings using samples from successor measures. "MC-JEPA stands for Monte-Carlo (MC) JEPA loss"
- Oblique projection: A projection onto a subspace along non-orthogonal directions, arising from projected TD fixed points. "yielding the fixed point of a projected Bellman operator whose closed-form expression is an oblique projection"
- Off-policy: Learning from data generated by policies different from the one being evaluated or optimized. "it can thus be estimated from off-policy, offline datasets."
- On-policy: Sampling data under the same policy that is being evaluated or trained. "as on-policy samples are needed"
- Orthogonal projection: A projection onto a subspace along perpendicular directions. "where (resp. ) is an orthogonal projection on the span of (resp. )"
- Orthonormality regularization: A loss encouraging learned features to be orthonormal across a batch. "Compute orthonormality regularization losses"
- Policy-conditioned: Conditioned on the identity or parameters of a policy, influencing prediction or representation. "we introduce a policy-conditioned, multi-step formulation"
- Policy-conditional successor measures: Successor measures that explicitly depend on the policy under consideration. "directly modeling policy-conditional successor measures is on average beneficial"
- Projected Bellman operator: The Bellman operator projected onto an approximation subspace; its fixed point defines TD solutions. "the fixed point of a projected Bellman operator"
- Q-values: Action-values; expected discounted returns for actions in states under a policy. "Q-values for any reward function can be written as"
- Self-supervised learning: Learning signals derived from the data itself without external labels or rewards. "perform self-supervised learning entirely in latent space without any reward"
- State encoder: A function that maps raw states to latent features used by downstream models. "These algorithms jointly learn a state encoder and a predictor "
- Stop-gradient: An operation that prevents gradients from flowing through a tensor during backpropagation. "and $\sg{\phi}$ denotes stop-gradient."
- Successor features: Expected discounted sums of feature vectors over future states under a policy. "approximates successor features for each policy"
- Successor measure: The unnormalized discounted distribution over states visited under a policy. "is referred to as the successor measure"
- Successor measure approximation loss: An objective that fits a low-rank decomposition to successor measures. "We define a (non-latent-predictive) successor measure approximation loss"
- Target networks: Delayed network copies used to stabilize training by providing slowly moving targets. "stabilization strategies, e.g. target networks"
- Task encoder: An encoder defining the space of linear rewards or tasks to be solved. "we will refer to as a task encoder."
- Temporal-difference (TD) loss: A bootstrapped training objective that predicts targets using one-step transitions. "a novel off-policy temporal-difference loss"
- Zero-shot policy optimization: Deriving policies for new rewards without additional training using pre-learned structures. "perform zero-shot policy optimization for any downstream reward"
- Zero-shot unsupervised RL: Learning from reward-free data to enable immediate policy deployment on downstream tasks. "a zero-shot unsupervised RL algorithm"
Collections
Sign up for free to add this paper to one or more collections.