MuSGD: Multi-Target Sampling Gradient Descent
- MuSGD is a sampling algorithm that employs composite Stein directions to simultaneously target multiple unnormalized distributions.
- It integrates RKHS-based Stein gradients with a convex QP to determine optimal weightings, ensuring descent for all KL divergences.
- Empirical results show that MuSGD outperforms traditional methods like MGDA and linear scalarization in accuracy and calibration on datasets such as CelebA and multi-MNIST.
MuSGD (Stochastic Multiple Target Sampling Gradient Descent) is a sampling algorithm designed to simultaneously handle multiple unnormalized target distributions. It extends Stein Variational Gradient Descent (SVGD) into the multi-target/multi-objective domain by iteratively updating a population of particles via composite Stein directions. MuSGD is primarily intended for probabilistic inference and multi-task learning, with theoretical guarantees and empirical benefits over classical approaches such as linear scalarization and multi-gradient descent algorithms (Phan et al., 2022).
1. Mathematical Formulation
Given unnormalized target densities , with , the goal is to construct a sequence of intermediate distributions that progressively move closer to the joint high-density region of all . Each step is achieved by a push-forward update:
where is a transport field constructed to minimize all KL divergences simultaneously, treating the task as a multi-objective problem over the space of densities. The optimization objective is:
2. Gradient Flow, Stein Directions, and Update Rule
Continuous-Time Gradient Flow
For each target , define the Stein-variational direction in the reproducing kernel Hilbert space (RKHS) 0:
1
To ensure descent for all KL objectives, MuSGD finds an optimal convex weight vector 2 by solving the quadratic program:
3
Construct the composite descent direction:
4
Within the mean-field limit, the particle SDE is governed by:
5
Discrete-Time MuSGD Update
Particles 6 represent the empirical distribution 7. The discrete update executes:
- For all 8 and 9, compute
0
- Build Gram matrix 1 using Monte Carlo inner products.
- Solve 2 (small QP).
- Compute composite direction 3 as above.
- Update each particle:
4
Expanded form:
5
3. Theoretical Properties and Connections
In the limit of infinite RBF kernel bandwidth (or a single particle), the kernel repulsive term is eliminated and 6, yielding:
7
If 8, this matches the multi-gradient descent (MGDA) direction. The paper proves that as 9, kernel bandwidth 0, and step size 1, MuSGD exactly recovers the classical MGDA update (Phan et al., 2022).
Under standard smoothness and Lipschitz continuity of 2 and the kernel, both the continuous-time flow and discrete MuSGD trajectories converge as 3. All KL divergences decrease at each iteration:
4
4. Algorithmic Workflow and Computational Complexity
The standard MuSGD pseudocode is: 5 Complexity per iteration:
- Kernel computations and Stein terms: 5
- Building 6: 7
- QP on simplex: 8
Memory footprint includes 9 particles, 0 Stein fields at 1 points, and the 2 matrix 3.
5. Empirical Performance and Applications
Sampling Accuracy
On synthetic problems (e.g., mixtures of Gaussians in 4), MuSGD particles concentrate in the true joint high-density region, outperforming methods such as MOO-SVGD, which scatter across separate modes.
Multi-Task Learning
MuSGD has been evaluated on multi-MNIST, multi-FashionMNIST, CelebA (10 attributes), SARCOS regression datasets. Using architectures such as LeNet or ResNet-18, MuSGD alternates particle-based sampling for shared parameters via MT-SGD and task-specific parameters via SVGD. Metrics considered are ensemble accuracy, Brier score, and Expected Calibration Error (ECE).
Extracted empirical results:
- On CelebA, MuSGD attains highest mean accuracy (89.0% vs. 88.2% for MOO-SVGD) and lowest ECE (2.0% vs. 2.5%).
- On SARCOS regression, MuSGD yields the lowest RMSE for all outputs (0.0428 vs. 0.0515 for MOO-SVGD).
- MuSGD consistently outperforms single-task SGD, linear scalarization, MGDA, Pareto MTL, and MOO-SVGD in both accuracy (+1–2%) and calibration (lower ECE), across varied tasks (Phan et al., 2022).
6. Interpretation and Practical Guidelines
MuSGD generalizes SVGD to the multi-objective setting by:
- Combining multiple Stein directions through simplex QP solving.
- Retaining kernel-based repulsion for diversity among particles.
- Guaranteeing global theoretical descent for all KL objectives.
- Asymptotically coinciding with classical multi-gradient descent.
- Demonstrably enhanced joint sampling and downstream multi-task generalization under standard smoothness conditions.
A plausible implication is that MuSGD represents the first kernelized variant of multi-gradient descent, yielding improved empirical sampling efficiency and predictive calibration in multi-objective scenarios. The algorithm is readily adaptable to modern deep learning architectures, provided access to the Stein gradients and kernel functions.