Gradient-Based Meta-Learning via Hierarchical Bayes
- The paper introduces a hierarchical Bayesian formulation that recasts gradient-based meta-learning as empirical Bayes inference using MAP approximations.
- It leverages techniques such as Laplace approximations, K-FAC, and Gradient-EM to efficiently estimate curvature and improve scalability.
- The framework motivates extensions like Dirichlet process mixtures to handle non-stationary task distributions and enhance model flexibility.
Gradient-based meta-learning as hierarchical Bayes formalizes fast adaptation to new tasks in terms of empirical Bayes inference within multi-level probabilistic models. This probabilistic lens illuminates both the theoretical underpinnings and practical extensions of algorithms such as Model-Agnostic Meta-Learning (MAML), and motivates extensions based on approximate inference, mixture models, and curvature-aware optimization (Grant et al., 2018, Jerfel et al., 2018, Zou et al., 2020).
1. Hierarchical Bayesian Formulation of Meta-Learning
Meta-learning considers data collected from a distribution over tasks. The hierarchical Bayesian model posits (i) a global "meta"-parameter, and (ii) for each task , task-specific parameters, generating data via a likelihood model. Denote:
- : global meta-parameter,
- : task-specific parameter for task ,
- : per-task train and validation datasets.
The generative process is:
- : prior over ,
- : prior over induced by ,
- : likelihood of task data given .
The joint distribution is
Empirical Bayes seeks maximizing the marginal likelihood of observed data:
which is generally intractable for high-dimensional (e.g., neural networks) (Grant et al., 2018).
2. Gradient-Based Meta-Learning as Marginal Likelihood Approximation
MAML approximates the intractable integrals in the hierarchical model by using maximum a posteriori (MAP) or Laplace approximations. Concretely:
- For each task, the integral is replaced by evaluating at the MAP estimate :
- With a Gaussian prior , this becomes ridge-regularized minimization:
where .
Initializing at and taking gradient steps yields approximating . The meta-objective is then
which matches the negative log marginal likelihood under the point estimate (Grant et al., 2018, Zou et al., 2020).
3. MAP and Laplace Approximations in the Inner Loop
The inner gradient-based update acts as a MAP or truncated Laplace approximation to the task posterior
A full Laplace approximation expands around the mode , with local curvature (the Hessian at ), yielding a Gaussian posterior
This leads to the curvature corrected meta-objective:
where the penalizes overly flat posteriors (Grant et al., 2018).
4. Scalable Inference: K-FAC and Gradient-EM
Computing the full Hessian is computationally prohibitive; it is replaced by a block-diagonal K-FAC Fisher approximation , where each block is a Kronecker-factorization of the layerwise Fisher matrix:
The log-determinant can therefore be computed efficiently, and the overall meta-algorithm optimizes the curvature-corrected loss via autodiff (Grant et al., 2018).
Gradient-EM (GEM) further improves computational and memory efficiency by decoupling inner updates (E-step) and meta-update (M-step). Instead of unrolling gradients through the inner update trajectory, the meta-gradient is computed with the gradient-EM identity:
and posterior expectations are approximated via a small number of inner gradient steps in variational families (Zou et al., 2020). This sidesteps expensive backpropagation through inner updates required in standard MAML.
5. Mixture Extensions: Dirichlet Process Priors for Latent Task Clusters
A limitation of a single global prior is that it can be suboptimal in the presence of heterogeneous or non-stationary task distributions. Using a Dirichlet process (DP) mixture within the hierarchical Bayesian framework allows flexible, nonparametric modeling of task structure (Jerfel et al., 2018). Each cluster parameter corresponds to a "good" initialization; tasks select a cluster using a CRP prior, and adapt within that cluster:
- For each task : , ,
Stochastic MAP-EM approximations are used, recapitulating the inner–outer loop pattern. When the DP reduces to a single component (), this setup coincides exactly with standard MAML. Otherwise, the meta-learner can track and instantiate new clusters, thereby adapting to task distribution shifts and mitigating negative transfer.
6. Algorithmic Summaries
An overview of the implementation, highlighting differences between key algorithms:
| Approach | Inner Update | Meta-Update | Posterior Approximation |
|---|---|---|---|
| MAML | GD steps | Unrolled gradient through inner | Point (mode/MAP) |
| HMAML | GD steps + K-FAC | Unrolled curvature penalty | Laplace around MAP |
| GEM-Bayes | S variational steps | Gradient-EM identity | Variational (Gaussian) |
| DP-MAML | K GD/EM per cluster | EM over clusters | DP mixture MAP |
- In HMAML, K-FAC is used for scalable curvature estimation (Grant et al., 2018).
- GEM-Bayes applies the EM identity for meta-gradients, decoupling inner optimization and improving memory/computational efficiency (Zou et al., 2020).
- DP-MAML enables online mixture adaptation for non-stationary tasks, robust to latent distribution shifts (Jerfel et al., 2018).
7. Implications and Theoretical Significance
Viewing gradient-based meta-learning as hierarchical Bayes unifies disparate algorithms through a probabilistic perspective. This connection justifies inner gradient steps as MAP or approximate Bayesian inference, grounds meta-objectives as marginal likelihood maximization, and directly motivates extensions:
- Curvature corrections via Laplace/K-FAC penalize flat posteriors and sharpen uncertainty quantification (Grant et al., 2018).
- Mixture/DPM models allow for flexible, task-adaptive priors and effective continual learning (Jerfel et al., 2018).
- Gradient-EM techniques improve scalability for deep neural applications and enable richer variational inference with controlled overfitting (Zou et al., 2020).
This framework suggests further directions including mixture-of-modes, improved curvature approximations, and variational extensions, all operating within a coherent probabilistic meta-learning paradigm.