MuSGD: Multi-Target Sampling Gradient Descent
- MuSGD is a sampling algorithm that employs composite Stein directions to simultaneously target multiple unnormalized distributions.
- It integrates RKHS-based Stein gradients with a convex QP to determine optimal weightings, ensuring descent for all KL divergences.
- Empirical results show that MuSGD outperforms traditional methods like MGDA and linear scalarization in accuracy and calibration on datasets such as CelebA and multi-MNIST.
MuSGD (Stochastic Multiple Target Sampling Gradient Descent) is a sampling algorithm designed to simultaneously handle multiple unnormalized target distributions. It extends Stein Variational Gradient Descent (SVGD) into the multi-target/multi-objective domain by iteratively updating a population of particles via composite Stein directions. MuSGD is primarily intended for probabilistic inference and multi-task learning, with theoretical guarantees and empirical benefits over classical approaches such as linear scalarization and multi-gradient descent algorithms (Phan et al., 2022).
1. Mathematical Formulation
Given unnormalized target densities , with , the goal is to construct a sequence of intermediate distributions that progressively move closer to the joint high-density region of all . Each step is achieved by a push-forward update:
where is a transport field constructed to minimize all KL divergences simultaneously, treating the task as a multi-objective problem over the space of densities. The optimization objective is:
2. Gradient Flow, Stein Directions, and Update Rule
Continuous-Time Gradient Flow
For each target , define the Stein-variational direction in the reproducing kernel Hilbert space (RKHS) :
To ensure descent for all KL objectives, MuSGD finds an optimal convex weight vector by solving the quadratic program:
Construct the composite descent direction:
Within the mean-field limit, the particle SDE is governed by:
Discrete-Time MuSGD Update
Particles represent the empirical distribution . The discrete update executes:
- For all and , compute
- Build Gram matrix using Monte Carlo inner products.
- Solve (small QP).
- Compute composite direction as above.
- Update each particle:
Expanded form:
3. Theoretical Properties and Connections
In the limit of infinite RBF kernel bandwidth (or a single particle), the kernel repulsive term is eliminated and , yielding:
If , this matches the multi-gradient descent (MGDA) direction. The paper proves that as , kernel bandwidth , and step size , MuSGD exactly recovers the classical MGDA update (Phan et al., 2022).
Under standard smoothness and Lipschitz continuity of and the kernel, both the continuous-time flow and discrete MuSGD trajectories converge as . All KL divergences decrease at each iteration:
4. Algorithmic Workflow and Computational Complexity
The standard MuSGD pseudocode is:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
Input: Unnormalized densities {π_k}_{k=1}^K, kernel k,
M particles {x_i}₁ᴹ, stepsizes {η_t}, #iterations T
Output: Particles approximating the joint high–density region
for t = 0 ... T-1 do
For each k = 1...K, compute ψ_k at particles:
∀i, ψ_k(x_i) ← (1/M) ∑_{j=1}^M [k(x_j,x_i)∇log π_k(x_j) + ∇_{x_j} k(x_j,x_i)]
Build U∈ℝ^{K×K}, U_{ij}← ⟨ψ_i,ψ_j⟩ via the MC formula
Solve w^{(t)}=argmin_{w∈Δ^{K-1}} wᵀ U w (QP on simplex)
φ^*(x_i) ← ∑_{k=1}^K w_k^{(t)} ψ_k(x_i)
x_i ← x_i + η_t φ^*(x_i), for i=1…M
end for
return {x_i} |
- Kernel computations and Stein terms:
- Building :
- QP on simplex:
Memory footprint includes particles, Stein fields at points, and the matrix .
5. Empirical Performance and Applications
Sampling Accuracy
On synthetic problems (e.g., mixtures of Gaussians in ), MuSGD particles concentrate in the true joint high-density region, outperforming methods such as MOO-SVGD, which scatter across separate modes.
Multi-Task Learning
MuSGD has been evaluated on multi-MNIST, multi-FashionMNIST, CelebA (10 attributes), SARCOS regression datasets. Using architectures such as LeNet or ResNet-18, MuSGD alternates particle-based sampling for shared parameters via MT-SGD and task-specific parameters via SVGD. Metrics considered are ensemble accuracy, Brier score, and Expected Calibration Error (ECE).
Extracted empirical results:
- On CelebA, MuSGD attains highest mean accuracy (89.0% vs. 88.2% for MOO-SVGD) and lowest ECE (2.0% vs. 2.5%).
- On SARCOS regression, MuSGD yields the lowest RMSE for all outputs (0.0428 vs. 0.0515 for MOO-SVGD).
- MuSGD consistently outperforms single-task SGD, linear scalarization, MGDA, Pareto MTL, and MOO-SVGD in both accuracy (+1–2%) and calibration (lower ECE), across varied tasks (Phan et al., 2022).
6. Interpretation and Practical Guidelines
MuSGD generalizes SVGD to the multi-objective setting by:
- Combining multiple Stein directions through simplex QP solving.
- Retaining kernel-based repulsion for diversity among particles.
- Guaranteeing global theoretical descent for all KL objectives.
- Asymptotically coinciding with classical multi-gradient descent.
- Demonstrably enhanced joint sampling and downstream multi-task generalization under standard smoothness conditions.
A plausible implication is that MuSGD represents the first kernelized variant of multi-gradient descent, yielding improved empirical sampling efficiency and predictive calibration in multi-objective scenarios. The algorithm is readily adaptable to modern deep learning architectures, provided access to the Stein gradients and kernel functions.