Papers
Topics
Authors
Recent
Search
2000 character limit reached

TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning

Published 1 Oct 2025 in cs.LG | (2510.00739v1)

Abstract: Latent prediction--where agents learn by predicting their own latents--has emerged as a powerful paradigm for training general representations in machine learning. In reinforcement learning (RL), this approach has been explored to define auxiliary losses for a variety of settings, including reward-based and unsupervised RL, behavior cloning, and world modeling. While existing methods are typically limited to single-task learning, one-step prediction, or on-policy trajectory data, we show that temporal difference (TD) learning enables learning representations predictive of long-term latent dynamics across multiple policies from offline, reward-free transitions. Building on this, we introduce TD-JEPA, which leverages TD-based latent-predictive representations into unsupervised RL. TD-JEPA trains explicit state and task encoders, a policy-conditioned multi-step predictor, and a set of parameterized policies directly in latent space. This enables zero-shot optimization of any reward function at test time. Theoretically, we show that an idealized variant of TD-JEPA avoids collapse with proper initialization, and learns encoders that capture a low-rank factorization of long-term policy dynamics, while the predictor recovers their successor features in latent space. Empirically, TD-JEPA matches or outperforms state-of-the-art baselines on locomotion, navigation, and manipulation tasks across 13 datasets in ExoRL and OGBench, especially in the challenging setting of zero-shot RL from pixels.

Summary

  • The paper introduces TD-JEPA, a zero-shot RL framework that leverages latent-predictive representations to learn multi-step, policy-conditioned predictions.
  • It employs an off-policy temporal-difference loss to pre-train state and task encoders, enhancing value estimation and policy generalization.
  • Empirical evaluations across 65 tasks demonstrate that TD-JEPA consistently matches or exceeds state-of-the-art baselines, particularly in pixel-based environments.

TD-JEPA: Latent-Predictive Representations for Zero-Shot Reinforcement Learning

Introduction

The paper "TD-JEPA: Latent-predictive Representations for Zero-Shot Reinforcement Learning" (2510.00739) introduces a novel approach for zero-shot reinforcement learning (RL) by leveraging latent-predictive representations. A key challenge in RL is learning state representations that can effectively capture the dynamics of the environment. Such representations should facilitate efficient value estimation and policy optimization across various tasks. Previous work has utilized latent-predictive representation learning, where algorithms learn a state encoder ϕ(s)\phi(s) and a predictor PP that estimates the representation of a future state s′s'. This paradigm, known as joint-embedding predictive architecture (JEPA), operates entirely in latent space without rewards or state reconstructions.

Several RL methods have incorporated latent prediction as an auxiliary loss to enhance sample efficiency and generalization. In unsupervised settings, these losses provide a framework for learning latent world models that can solve goal-reaching tasks via test-time planning. The proposed method extends this approach by introducing policy-conditioned, multi-step formulations based on a novel off-policy temporal-difference loss. This objective focuses on representations that are predictive of long-term features critical for value estimation across multiple policies.

Methodology

TD-JEPA aims to instantiate temporal difference latent-predictive representation learning into a zero-shot unsupervised RL algorithm. The approach pre-trains four components: a state encoder, a policy-conditioned multi-step predictor, a task encoder, and a set of parameterized policies, learning end-to-end from offline, reward-free transitions. Unlike previous methods, latent prediction is the core objective, allowing TD-JEPA to learn the necessary components for distilling zero-shot policies.

The predictor approximates successor features, which enables extracting policies mapping encoded observations to optimal actions for any reward functions represented within the learned features' span. Empirically, TD-JEPA's architecture comprises state and task encoders that feed information into policy and action-value function predictors. The training leverages off-policy data, using a temporal-difference loss that updates state encoders and task representations separately. Figure 1

Figure 1: TD-JEPA trains policies πz\pi_z parameterized by latents zz. The predictor, conditioned on zz, predicts the representations of future states visited by πz\pi_z.

Theoretical Analysis

The paper provides theoretical guarantees for TD-JEPA's representation learning. In an idealized linear setting, the representations avoid collapse with appropriate initialization, recover a low-rank factorization of successor measures, and approximate successor features in latent space. These properties minimize policy evaluation error for any reward, rendering zero-shot optimization feasible.

The theoretical framework extends existing analyses by integrating a novel "gradient matching" argument, linking TD-JEPA with other unsupervised RL methodologies such as forward-backward learning and intention-conditioned value functions.

Empirical Evaluation

TD-JEPA was evaluated on 65 tasks across 13 datasets from ExoRL and OGBench, encompassing domains like locomotion, navigation, and manipulation with proprioceptive and pixel-based observations. Results indicate that TD-JEPA matches or surpasses state-of-the-art zero-shot baselines, particularly in challenging settings like pixel-based learning. The algorithm's architecture allows flexibility in learning distinct state and task representations, enhancing multi-step, policy-dependent dynamics prediction. Figure 2

Figure 2: Probabilities of improvement: how likely is method X to outperform method Y on a random domain? We report symmetrized 95\% simple bootstrap confidence intervals.

Conclusion

The paper presents TD-JEPA as a robust zero-shot reinforcement learning framework that operates entirely in latent space, producing predictors that align with successor features. Empirical evidence and theoretical analysis show promising results in diverse settings, suggesting that TD-JEPA's latent-predictive methodology effectively integrates with unsupervised RL. Future work may explore asymmetric successor measures and large-scale robotic datasets to further expand the utility of latent-predictive objectives.

Paper to Video (Beta)

Whiteboard

Glossary

  • Action-value function: The expected cumulative discounted reward of taking an action in a state under a policy. "the action-value function QrÏ€(s,a)Q_r^\pi(s,a) measures the cumulative discounted reward obtained by the policy over an infinite horizon"
  • Bellman equation: A recursive relation expressing a quantity (e.g., value or successor features) in terms of immediate outcomes and its future values. "successor features admit a Bellman equation $F_\phi^{\pi_z}(s,a) = \bE_{s'\sim P(\cdot | s,a), a'\sim\pi_z(s')}[\phi(s') + \gamma F_\phi^{\pi_z}(s',a')]$"
  • Bootstrapped: Using the model’s own prediction as part of the target during training. "plus a bootstrapped version of itself"
  • Covariance regularization: A stabilization technique that penalizes feature covariance to encourage well-conditioned representations. "stabilization strategies, e.g. target networks and covariance regularization."
  • Exponential Moving Average (EMA): A smoothing update that maintains target networks as exponentially weighted averages of online networks. "Update target networks ϕ−\phi^-, Tϕ−T_\phi^{-}, {ψ−\psi^-, Tψ−T_\psi^{-}} via EMA of Ï•\phi, TÏ•T_\phi, {ψ\psi, TψT_\psi}"
  • Expectile regression: A regression method that fits expectiles, used here to decompose successor measures. "learns a multilinear decomposition of the successor measure via expectile regression"
  • Forward and backward TD losses: Temporal-difference losses that approximate successor measures using forward and backward dynamics. "related to forward and backward TD losses for approximating the successor measure"
  • Joint-embedding predictive architecture (JEPA): A learning paradigm where encoders are trained to predict each other’s embeddings without reconstruction. "an instance of the joint-embedding predictive architecture~\citep[] [JEPA]{lecun2022path} paradigm."
  • Latent dynamics model: A predictor that estimates the latent representation of future states. "a latent dynamics model estimating the representation of a future state s′s'"
  • Latent space: The feature space produced by an encoder where learning and prediction are performed. "perform self-supervised learning entirely in latent space"
  • Latent-predictive representation learning: Learning representations by predicting future latent embeddings rather than raw observations. "latent-predictive (a.k.a. self-predictive) representation learning"
  • Least-squares TD: A TD method that finds the fixed point of the projected Bellman equation via least-squares. "solve a least-squares TD problem"
  • MC-JEPA (Monte-Carlo JEPA loss): A loss that trains predictors to match future task embeddings using samples from successor measures. "MC-JEPA stands for Monte-Carlo (MC) JEPA loss"
  • Oblique projection: A projection onto a subspace along non-orthogonal directions, arising from projected TD fixed points. "yielding the fixed point of a projected Bellman operator whose closed-form expression is an oblique projection"
  • Off-policy: Learning from data generated by policies different from the one being evaluated or optimized. "it can thus be estimated from off-policy, offline datasets."
  • On-policy: Sampling data under the same policy that is being evaluated or trained. "as on-policy samples s+∼MÏ€z(⋅∣s,a)s^+\sim M^{\pi_z}(\cdot|s,a) are needed"
  • Orthogonal projection: A projection onto a subspace along perpendicular directions. "where Πϕ\Pi_\phi (resp. Πψ\Pi_\psi) is an orthogonal projection on the span of Ï•\phi (resp. ψ\psi)"
  • Orthonormality regularization: A loss encouraging learned features to be orthonormal across a batch. "Compute orthonormality regularization losses"
  • Policy-conditioned: Conditioned on the identity or parameters of a policy, influencing prediction or representation. "we introduce a policy-conditioned, multi-step formulation"
  • Policy-conditional successor measures: Successor measures that explicitly depend on the policy under consideration. "directly modeling policy-conditional successor measures is on average beneficial"
  • Projected Bellman operator: The Bellman operator projected onto an approximation subspace; its fixed point defines TD solutions. "the fixed point of a projected Bellman operator"
  • Q-values: Action-values; expected discounted returns for actions in states under a policy. "Q-values for any reward function r(s)=ψ(s)⊤zrr(s) = \psi(s)^\top z_r can be written as"
  • Self-supervised learning: Learning signals derived from the data itself without external labels or rewards. "perform self-supervised learning entirely in latent space without any reward"
  • State encoder: A function that maps raw states to latent features used by downstream models. "These algorithms jointly learn a state encoder Ï•(s)\phi(s) and a predictor PP"
  • Stop-gradient: An operation that prevents gradients from flowing through a tensor during backpropagation. "and $\sg{\phi}$ denotes stop-gradient."
  • Successor features: Expected discounted sums of feature vectors over future states under a policy. "approximates successor features for each policy"
  • Successor measure: The unnormalized discounted distribution over states visited under a policy. "is referred to as the successor measure"
  • Successor measure approximation loss: An objective that fits a low-rank decomposition to successor measures. "We define a (non-latent-predictive) successor measure approximation loss"
  • Target networks: Delayed network copies used to stabilize training by providing slowly moving targets. "stabilization strategies, e.g. target networks"
  • Task encoder: An encoder defining the space of linear rewards or tasks to be solved. "we will refer to ψ\psi as a task encoder."
  • Temporal-difference (TD) loss: A bootstrapped training objective that predicts targets using one-step transitions. "a novel off-policy temporal-difference loss"
  • Zero-shot policy optimization: Deriving policies for new rewards without additional training using pre-learned structures. "perform zero-shot policy optimization for any downstream reward"
  • Zero-shot unsupervised RL: Learning from reward-free data to enable immediate policy deployment on downstream tasks. "a zero-shot unsupervised RL algorithm"

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 2 tweets with 3 likes about this paper.