Papers
Topics
Authors
Recent
Search
2000 character limit reached

VariBAD: Bayesian Meta-Learning in Deep RL

Updated 17 January 2026
  • VariBAD is a meta-learning framework that applies variational inference and Bayes-adaptive principles to model latent task dynamics for efficient exploration.
  • It jointly optimizes an encoder, decoder, and policy network to update beliefs over unknown environment parameters, enhancing adaptation.
  • Empirical evaluations in gridworld and MuJoCo tasks show that variBAD approximates Bayes-optimal performance and outperforms several leading meta-RL methods.

Variational Bayes-Adaptive Deep Reinforcement Learning (variBAD) is a meta-learning framework for performing approximate Bayes-adaptive reinforcement learning in environments with unknown dynamics and rewards. It enables an agent to maintain a belief distribution over latent task parameters, adaptively trading off exploration and exploitation via a structured uncertainty-driven policy. The method achieves this by incorporating variational inference principles into the RL loop, learning both a posterior over task variables and a policy conditioned on the inferred latent state. Empirical results demonstrate that variBAD outperforms previous meta-RL algorithms on both discrete gridworld and continuous control tasks, closely approximating Bayes-optimal performance in key domains (Zintgraf et al., 2019).

1. Bayes-Adaptive MDP Formulation

The Bayes-Adaptive Markov Decision Process (BAMDP) framework addresses optimal exploration-exploitation tradeoff by augmenting the state space with a posterior belief over hidden task parameters. Let SS denote the state space and AA the action space. Each environment is parameterized by a latent variable θ∈Θ\theta\in\Theta, influencing both transition dynamics TĪø(sā€²āˆ£s,a)T_\theta(s'|s,a) and reward functions RĪø(r∣s,a,s′)R_\theta(r|s,a,s'). The agent maintains a prior belief b0(Īø)=p(Īø)b_0(\theta)=p(\theta) and, using its trajectory Ļ„:t=(s0,a0,r1,s1,…,st)\tau_{:t}=(s_0,a_0,r_1,s_1,\ldots,s_t), updates its posterior bt(Īø)=p(Īøāˆ£Ļ„:t)b_t(\theta)=p(\theta|\tau_{:t}).

In this framework, the "hyper-state" st+=(st,bt)s_t^+ = (s_t, b_t) produces an augmented BAMDP M+M^+ whose transition and reward kernels are given by: T+(st+1,bt+1∣st,bt,at)=Eθ∼bt[TĪø(st+1∣st,at)]ā‹…Ī“[bt+1=p(Īøāˆ£Ļ„:t+1)]T^+(s_{t+1},b_{t+1}|s_t,b_t,a_t) = \mathbb{E}_{\theta\sim b_t}\left[ T_\theta(s_{t+1}|s_t,a_t) \right]\cdot \delta[b_{t+1}=p(\theta|\tau_{:t+1})]

R+(st,bt,at,st+1,bt+1)=Eθ∼bt+1[Rθ(r∣st,at,st+1)]R^+(s_t,b_t,a_t,s_{t+1},b_{t+1}) = \mathbb{E}_{\theta\sim b_{t+1}}\left[ R_\theta(r|s_t,a_t,s_{t+1}) \right]

The Bayes-optimal policy Ļ€+āˆ—\pi^{+*} maximizes the expected discounted return in M+M^+ over horizon H+H^+: J+(Ļ€)=Eb0,T+,Ļ€[āˆ‘t=0H+āˆ’1γtR+(ā‹…)]J^+(\pi) = \mathbb{E}_{b_0,T^+,\pi}\left[ \sum_{t=0}^{H^+-1} \gamma^t R^+(\cdot) \right] Exact inference and planning are intractable for high-dimensional Īø\theta, motivating the variational-approximate approach of variBAD.

2. Generative Model and Variational Approximation

variBAD leverages a joint generative model pĪø(Īø,Ļ„:H+)=p(Īø) pĪø(Ļ„:H+∣θ)p_\theta(\theta,\tau_{:H^+}) = p(\theta)\,p_\theta(\tau_{:H^+}|\theta), where: pĪø(Ļ„:H+∣θ)=p(s0)āˆt=0H+āˆ’1pĪø(st+1∣st,at,Īø) pĪø(rt+1∣st,at,st+1,Īø)p_\theta(\tau_{:H^+}|\theta) = p(s_0)\prod_{t=0}^{H^+-1} p_\theta(s_{t+1}|s_t,a_t,\theta)\,p_\theta(r_{t+1}|s_t,a_t,s_{t+1},\theta) Actions ata_t are sampled from a policy conditioned on current belief.

The method introduces an amortized variational posterior qĻ•(Īøāˆ£Ļ„:t)q_\phi(\theta|\tau_{:t}), parameterized as a diagonal Gaussian (μt,σt)(\mu_t,\sigma_t) output by a recurrent inference network. The training objective employs the evidence lower bound (ELBO) at each time step tt: log⁔pĪø(Ļ„:H+)≄Eθ∼qĻ•(ā‹…āˆ£Ļ„:t)[log⁔pĪø(Ļ„:H+∣θ)]āˆ’KL[qĻ•(Īøāˆ£Ļ„:t)ā€‰āˆ£āˆ£ā€‰p(Īø)]\log p_\theta(\tau_{:H^+}) \ge \mathbb{E}_{\theta\sim q_\phi(\cdot|\tau_{:t})}[\log p_\theta(\tau_{:H^+}|\theta)] - \text{KL}[q_\phi(\theta|\tau_{:t})\,||\,p(\theta)] The ELBO enables tractable meta-learning of the underlying task posterior and dynamics decoder.

3. Meta-Training Algorithm

variBAD’s meta-training jointly optimizes three parameter sets: the encoder Ļ•\phi for qĻ•q_\phi, the decoder Īø\theta for pĪøp_\theta, and the policy ψ\psi for Ļ€Ļˆ(a∣s,z)\pi_\psi(a|s,z), where zz denotes the sampled latent. The total loss at each meta-training iteration is: L(Ļ•,Īø,ψ)=āˆ’EM∼p(M)[JRL(ψ,Ļ•)]+λ EM,Ļ„[āˆ‘t=0H+KL[qĻ•(Īøāˆ£Ļ„:t)∣∣qĻ•(Īøāˆ£Ļ„:tāˆ’1)]āˆ’Eθ∼qĻ•[log⁔pĪø(Ļ„:H+∣θ)]]L(\phi,\theta,\psi) = -\mathbb{E}_{M\sim p(M)}[J_{RL}(\psi,\phi)] + \lambda\,\mathbb{E}_{M,\tau}\Bigg[ \sum_{t=0}^{H^+} \text{KL}[q_\phi(\theta|\tau_{:t})||q_\phi(\theta|\tau_{:t-1})] - \mathbb{E}_{\theta\sim q_\phi}[\log p_\theta(\tau_{:H^+}|\theta)] \Bigg] JRLJ_{RL} is the standard expected RL return; λ∈[0,1]\lambda \in [0,1] controls the balance between RL and ELBO terms. Training proceeds with policy-gradient updates (PPO/A2C) for ψ\psi and Adam updates on (Ļ•,Īø)(\phi,\theta) using ELBO gradients.

The meta-training workflow is summarized as follows:

Step Operation Update Type
Sample task MiM_i Reset environment, encoder hidden state Initialization
Collect trajectory Encode (st,at,rt+1)(s_t,a_t,r_{t+1}) via GRU to (μt,σt)(\mu_t,\sigma_t) Forward Pass
Sample latent Īøt=μt+σtāŠ™Ļµ\theta_t = \mu_t + \sigma_t \odot \epsilon Forward Pass
Condition policy stz=[s_t^z = [s_t;Īøt],atāˆ¼Ļ€Ļˆ(a∣stz); \theta_t],\quad a_t\sim\pi_\psi(a|s_t^z) Action
Compute ELBO ELBOtELBO_t over batch Loss Evalu.
Optimizer step Adam/PPO update on (Ļ•,Īø,ψ)(\phi, \theta, \psi) Learning

4. Online Adaptation and Uncertainty-Driven Action Selection

During evaluation, only the encoder qĻ•q_\phi and policy Ļ€Ļˆ\pi_\psi are retained. With each new trajectory Ļ„:t\tau_{:t}, the encoder maintains and updates the latent posterior (μt,σt)(\mu_t,\sigma_t), yielding: Īøt∼N(μt,σt),stz=[st;Īøt],atāˆ¼Ļ€Ļˆ(a∣stz)\theta_t \sim \mathcal{N}(\mu_t, \sigma_t),\quad s_t^z = [s_t;\theta_t],\quad a_t \sim \pi_\psi(a|s_t^z) As data accumulates, posterior variance σt\sigma_t collapses (→0\to 0), smoothly annealing the policy from exploration to exploitation. This enables dynamically structured ā€œuncertainty-drivenā€ exploration, matching Bayes-optimal online adaptation.

5. Architecture and Implementation Specifications

  • Encoder qĻ•q_\phi: MLP embedding, one layer of size 32 (ReLU); GRU (hidden size 64–128); final linear mapping to (μ,log⁔σ)(\mu,\log\sigma) for a dd-dimensional Gaussian (d=5d=5 typical).
  • Decoder pĪøp_\theta: Transition model TĪøT_\theta – MLP (64,32), ReLU; output Gaussian/categorical for s′s'. Reward model RĪøR_\theta – similar MLP, scalar output.
  • Policy network Ļ€Ļˆ\pi_\psi: MLP (32 for grid, 128 for MuJoCo), TanH activation, input [s,Īø][s,\theta]; critic head of similar dimensions.
  • Optimization: PPO/A2C, Adam (lr=1eāˆ’3lr=1e^{-3} grid, 7eāˆ’47e^{-4} MuJoCo), clipping $0.1$, value coefficient $0.5$, entropy $0.01$, γ=0.95\gamma=0.95–$0.99$, GAE Ļ„=0.95\tau=0.95. VAE: Adam 1eāˆ’31e^{-3}; ELBO Ī»=1.0\lambda=1.0 (grid), $0.1$ (MuJoCo). Max grad norm $0.5$. No extra dropout used; KL regularizes latent Īø\theta.

6. Empirical Evaluation

  • Gridworld: 5Ɨ\times5 grid, unknown goal in 24 cells. Actions: {up, right, down, left, stay}, horizon H=15H=15, BAMDP horizon H+=60H^+=60. Sparse reward. variBAD achieves returns matching Bayes-optimal by episode 3, outperforming posterior sampling. Decoder’s P(P(reward∣|cell)) belief concentrates on true goal, with rapid collapse of latent σ\sigma confirming uncertainty-driven exploration.
  • MuJoCo Meta-RL: Tasks include AntDir (forward/back, 2), HalfCheetahDir (left/right, 2), HalfCheetahVel (varied speeds, ∼\sim10), Walker (randomized body, ∼\sim20). Evaluation metric is first-episode (online) return in new tasks. In all domains, variBAD’s first-rollout returns exceed those of RL², PEARL (off-policy posterior sampling), E-MAML, and ProMP. For example:
Task variBAD RL² PEARL E-MAML ProMP
AntDir ~2150 2000 1500 400 600
HalfCheetahDir ~4000 3500 1000 500 800
HalfCheetahVel ~3200 2800 1100 300 500
Walker ~5000 4500 2000 600 700

PEARL converges in <2Ɨ106<2\times 10^6 frames (off-policy); variBAD/RL² require ∼5Ɨ107\sim5\times 10^7 frames (on-policy). At convergence, variBAD matches or exceeds oracle PPO returns (which know the true task). Posterior mean flips sign (e.g., direction tasks) within ∼\sim20 steps, and variance σ\sigma declines rapidly, enabling early exploitation.

7. Limitations and Future Directions

variBAD is the first scalable deep-RL algorithm to leverage an explicit approximate Bayesian belief over latent task variables for structured exploration. Its variational inference framework delivers a low-dimensional state to condition policies on, along with a quantifiable uncertainty estimate. However, the approach requires meta-training on a distribution p(M)p(M) representative of test tasks and does not guarantee formal Bayes-optimality due to neural network approximation. Training complexity is substantial due to recurrent inference and on-policy learning, with off-policy methods left for future work. Further research directions include exploiting the decoder pĪøp_\theta for model-based planning at test time, learning a faster-adapting prior p(Īø)p(\theta), and handling out-of-distribution (OOD) tasks via continual encoder fine-tuning.

Relevant experimental data, exact hyperparameters, and architecture specifications are available in the original codebase (Zintgraf et al., 2019).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Variational Bayes-Adaptive Deep Reinforcement Learning (variBAD).