VariBAD: Bayesian Meta-Learning in Deep RL
- VariBAD is a meta-learning framework that applies variational inference and Bayes-adaptive principles to model latent task dynamics for efficient exploration.
- It jointly optimizes an encoder, decoder, and policy network to update beliefs over unknown environment parameters, enhancing adaptation.
- Empirical evaluations in gridworld and MuJoCo tasks show that variBAD approximates Bayes-optimal performance and outperforms several leading meta-RL methods.
Variational Bayes-Adaptive Deep Reinforcement Learning (variBAD) is a meta-learning framework for performing approximate Bayes-adaptive reinforcement learning in environments with unknown dynamics and rewards. It enables an agent to maintain a belief distribution over latent task parameters, adaptively trading off exploration and exploitation via a structured uncertainty-driven policy. The method achieves this by incorporating variational inference principles into the RL loop, learning both a posterior over task variables and a policy conditioned on the inferred latent state. Empirical results demonstrate that variBAD outperforms previous meta-RL algorithms on both discrete gridworld and continuous control tasks, closely approximating Bayes-optimal performance in key domains (Zintgraf et al., 2019).
1. Bayes-Adaptive MDP Formulation
The Bayes-Adaptive Markov Decision Process (BAMDP) framework addresses optimal exploration-exploitation tradeoff by augmenting the state space with a posterior belief over hidden task parameters. Let denote the state space and the action space. Each environment is parameterized by a latent variable , influencing both transition dynamics and reward functions . The agent maintains a prior belief and, using its trajectory , updates its posterior .
In this framework, the "hyper-state" produces an augmented BAMDP whose transition and reward kernels are given by:
The Bayes-optimal policy maximizes the expected discounted return in over horizon : Exact inference and planning are intractable for high-dimensional , motivating the variational-approximate approach of variBAD.
2. Generative Model and Variational Approximation
variBAD leverages a joint generative model , where: Actions are sampled from a policy conditioned on current belief.
The method introduces an amortized variational posterior , parameterized as a diagonal Gaussian output by a recurrent inference network. The training objective employs the evidence lower bound (ELBO) at each time step : The ELBO enables tractable meta-learning of the underlying task posterior and dynamics decoder.
3. Meta-Training Algorithm
variBADās meta-training jointly optimizes three parameter sets: the encoder for , the decoder for , and the policy for , where denotes the sampled latent. The total loss at each meta-training iteration is: is the standard expected RL return; controls the balance between RL and ELBO terms. Training proceeds with policy-gradient updates (PPO/A2C) for and Adam updates on using ELBO gradients.
The meta-training workflow is summarized as follows:
| Step | Operation | Update Type |
|---|---|---|
| Sample task | Reset environment, encoder hidden state | Initialization |
| Collect trajectory | Encode via GRU to | Forward Pass |
| Sample latent | Forward Pass | |
| Condition policy | s_t | Action |
| Compute ELBO | over batch | Loss Evalu. |
| Optimizer step | Adam/PPO update on | Learning |
4. Online Adaptation and Uncertainty-Driven Action Selection
During evaluation, only the encoder and policy are retained. With each new trajectory , the encoder maintains and updates the latent posterior , yielding: As data accumulates, posterior variance collapses (), smoothly annealing the policy from exploration to exploitation. This enables dynamically structured āuncertainty-drivenā exploration, matching Bayes-optimal online adaptation.
5. Architecture and Implementation Specifications
- Encoder : MLP embedding, one layer of size 32 (ReLU); GRU (hidden size 64ā128); final linear mapping to for a -dimensional Gaussian ( typical).
- Decoder : Transition model ā MLP (64,32), ReLU; output Gaussian/categorical for . Reward model ā similar MLP, scalar output.
- Policy network : MLP (32 for grid, 128 for MuJoCo), TanH activation, input ; critic head of similar dimensions.
- Optimization: PPO/A2C, Adam ( grid, MuJoCo), clipping $0.1$, value coefficient $0.5$, entropy $0.01$, ā$0.99$, GAE . VAE: Adam ; ELBO (grid), $0.1$ (MuJoCo). Max grad norm $0.5$. No extra dropout used; KL regularizes latent .
6. Empirical Evaluation
- Gridworld: 55 grid, unknown goal in 24 cells. Actions: {up, right, down, left, stay}, horizon , BAMDP horizon . Sparse reward. variBAD achieves returns matching Bayes-optimal by episode 3, outperforming posterior sampling. Decoderās rewardcell belief concentrates on true goal, with rapid collapse of latent confirming uncertainty-driven exploration.
- MuJoCo Meta-RL: Tasks include AntDir (forward/back, 2), HalfCheetahDir (left/right, 2), HalfCheetahVel (varied speeds, 10), Walker (randomized body, 20). Evaluation metric is first-episode (online) return in new tasks. In all domains, variBADās first-rollout returns exceed those of RL², PEARL (off-policy posterior sampling), E-MAML, and ProMP. For example:
| Task | variBAD | RL² | PEARL | E-MAML | ProMP |
|---|---|---|---|---|---|
| AntDir | ~2150 | 2000 | 1500 | 400 | 600 |
| HalfCheetahDir | ~4000 | 3500 | 1000 | 500 | 800 |
| HalfCheetahVel | ~3200 | 2800 | 1100 | 300 | 500 |
| Walker | ~5000 | 4500 | 2000 | 600 | 700 |
PEARL converges in frames (off-policy); variBAD/RL² require frames (on-policy). At convergence, variBAD matches or exceeds oracle PPO returns (which know the true task). Posterior mean flips sign (e.g., direction tasks) within 20 steps, and variance declines rapidly, enabling early exploitation.
7. Limitations and Future Directions
variBAD is the first scalable deep-RL algorithm to leverage an explicit approximate Bayesian belief over latent task variables for structured exploration. Its variational inference framework delivers a low-dimensional state to condition policies on, along with a quantifiable uncertainty estimate. However, the approach requires meta-training on a distribution representative of test tasks and does not guarantee formal Bayes-optimality due to neural network approximation. Training complexity is substantial due to recurrent inference and on-policy learning, with off-policy methods left for future work. Further research directions include exploiting the decoder for model-based planning at test time, learning a faster-adapting prior , and handling out-of-distribution (OOD) tasks via continual encoder fine-tuning.
Relevant experimental data, exact hyperparameters, and architecture specifications are available in the original codebase (Zintgraf et al., 2019).