Papers
Topics
Authors
Recent
Search
2000 character limit reached

MuSGD: Multi-Target Sampling Gradient Descent

Updated 26 January 2026
  • MuSGD is a sampling algorithm that employs composite Stein directions to simultaneously target multiple unnormalized distributions.
  • It integrates RKHS-based Stein gradients with a convex QP to determine optimal weightings, ensuring descent for all KL divergences.
  • Empirical results show that MuSGD outperforms traditional methods like MGDA and linear scalarization in accuracy and calibration on datasets such as CelebA and multi-MNIST.

MuSGD (Stochastic Multiple Target Sampling Gradient Descent) is a sampling algorithm designed to simultaneously handle multiple unnormalized target distributions. It extends Stein Variational Gradient Descent (SVGD) into the multi-target/multi-objective domain by iteratively updating a population of particles via composite Stein directions. MuSGD is primarily intended for probabilistic inference and multi-task learning, with theoretical guarantees and empirical benefits over classical approaches such as linear scalarization and multi-gradient descent algorithms (Phan et al., 2022).

1. Mathematical Formulation

Given KK unnormalized target densities {πk(x)}k=1K\{\pi_k(x)\}_{k=1}^K, with xRdx\in\mathbb{R}^d, the goal is to construct a sequence of intermediate distributions (q0q1qL)(q_0 \to q_1 \to \cdots \to q_L) that progressively move closer to the joint high-density region of all {πk}\{\pi_k\}. Each step is achieved by a push-forward update:

qt+1=Tt#qt,Tt(x)=x+ϵtϕt(x),q_{t+1} = T_t\# q_t, \quad T_t(x) = x + \epsilon_t \phi_t(x),

where ϕt(x)\phi_t(x) is a transport field constructed to minimize all KL divergences [KL(qπ1),,KL(qπK)][\mathrm{KL}(q \| \pi_1), \dots, \mathrm{KL}(q \| \pi_K)] simultaneously, treating the task as a multi-objective problem over the space of densities. The optimization objective is:

minqQ[KL(qπ1),,KL(qπK)]\min_{q \in \mathcal{Q}} [\mathrm{KL}(q \| \pi_1), \dots, \mathrm{KL}(q \| \pi_K)]

2. Gradient Flow, Stein Directions, and Update Rule

Continuous-Time Gradient Flow

For each target kk, define the Stein-variational direction in the reproducing kernel Hilbert space (RKHS) Hkd\mathcal{H}_k^d:

ψk(x)=Eyq[k(y,x)ylogπk(y)+yk(y,x)]\psi_k(x) = \mathbb{E}_{y \sim q}\left[ k(y, x) \nabla_y \log\pi_k(y) + \nabla_y k(y, x) \right]

To ensure descent for all KL objectives, MuSGD finds an optimal convex weight vector w=(w1,,wK)ΔK1w^* = (w_1, \dots, w_K) \in \Delta^{K-1} by solving the quadratic program:

w=argminw0,wi=1wTUw,Uij=ψi,ψjHkdw^* = \arg\min_{w \geq 0,\, \sum w_i = 1} w^T U w, \quad U_{ij} = \langle \psi_i, \psi_j \rangle_{\mathcal{H}_k^d}

Construct the composite descent direction:

ϕ(x)=k=1Kwkψk(x)\phi^*(x) = \sum_{k=1}^K w_k^* \psi_k(x)

Within the mean-field limit, the particle SDE is governed by:

ddtXi(t)=ϕ(Xi(t))=k=1KwkEyq[k(y,Xi)logπk(y)+yk(y,Xi)]\frac{d}{dt} X_i(t) = \phi^*(X_i(t)) = \sum_{k=1}^K w_k^*\,\mathbb{E}_{y \sim q}\bigl[ k(y, X_i) \nabla\log\pi_k(y) + \nabla_y k(y, X_i) \bigr]

Discrete-Time MuSGD Update

Particles {xi}i=1M\{x_i\}_{i=1}^M represent the empirical distribution qtq_t. The discrete update executes:

  1. For all kk and ii, compute

ψk(xi)1Mj=1M[k(xj,xi)logπk(xj)+xjk(xj,xi)]\psi_k(x_i) \approx \frac{1}{M} \sum_{j=1}^M \bigl[ k(x_j, x_i) \nabla\log\pi_k(x_j) + \nabla_{x_j} k(x_j, x_i) \bigr ]

  1. Build Gram matrix UU using Monte Carlo inner products.
  2. Solve w(t)=argminwΔK1wTUww^{(t)} = \arg\min_{w \in \Delta^{K-1}} w^T U w (small QP).
  3. Compute composite direction ϕ(xi)\phi^*(x_i) as above.
  4. Update each particle:

xi(t+1)=xi(t)+ηtϕ(xi)x_i^{(t+1)} = x_i^{(t)} + \eta_t \phi^*(x_i)

Expanded form:

xi(t+1)=xi(t)+ηtk=1Kwk(t)1Mj=1M[k(xj,xi)logπk(xj)+xjk(xj,xi)]x_i^{(t+1)} = x_i^{(t)} + \eta_t \sum_{k=1}^K w_k^{(t)} \frac{1}{M} \sum_{j=1}^M \left[ k(x_j, x_i) \nabla\log\pi_k(x_j) + \nabla_{x_j} k(x_j, x_i) \right]

3. Theoretical Properties and Connections

In the limit of infinite RBF kernel bandwidth (or a single particle), the kernel repulsive term is eliminated and k(x,y)1k(x,y) \rightarrow 1, yielding:

ψk(x)logπk(x),ϕ(x)k=1Kwklogπk(x)\psi_k(x) \to \nabla\log\pi_k(x), \quad \phi^*(x) \to \sum_{k=1}^K w_k^* \nabla\log\pi_k(x)

If πkek\pi_k \propto e^{-\ell_k}, this matches the multi-gradient descent (MGDA) direction. The paper proves that as MM \to \infty, kernel bandwidth σ\sigma \to \infty, and step size η0\eta \to 0, MuSGD exactly recovers the classical MGDA update (Phan et al., 2022).

Under standard smoothness and Lipschitz continuity of logπk\nabla\log\pi_k and the kernel, both the continuous-time flow and discrete MuSGD trajectories converge as η0\eta \to 0. All KL divergences decrease at each iteration:

DKL(q[T]πk)=DKL(qπk)ϵϕHkd2+O(ϵ2),kD_{KL}(q^{[T]} \| \pi_k) = D_{KL}(q \| \pi_k) - \epsilon \|\phi^*\|_{\mathcal{H}_k^d}^2 + O(\epsilon^2), \quad \forall\, k

4. Algorithmic Workflow and Computational Complexity

The standard MuSGD pseudocode is:

1
2
3
4
5
6
7
8
9
10
11
12
13
Input:   Unnormalized densities {π_k}_{k=1}^K, kernel k,
         M particles {x_i}₁ᴹ, stepsizes {η_t}, #iterations T
Output:  Particles approximating the joint high–density region

for t = 0 ... T-1 do
    For each k = 1...K, compute ψ_k at particles:
        ∀i, ψ_k(x_i) ← (1/M) ∑_{j=1}^M [k(x_j,x_i)∇log π_k(x_j) + ∇_{x_j} k(x_j,x_i)]
    Build U∈ℝ^{K×K}, U_{ij}← ⟨ψ_i,ψ_j⟩ via the MC formula
    Solve w^{(t)}=argmin_{w∈Δ^{K-1}} wᵀ U w (QP on simplex)
    φ^*(x_i) ← ∑_{k=1}^K w_k^{(t)} ψ_k(x_i)
    x_i ← x_i + η_t φ^*(x_i), for i=1…M
end for
return {x_i}
Complexity per iteration:

  • Kernel computations and Stein terms: O(KM2d)O(K M^2 d)
  • Building UU: O(K2M2d)O(K^2 M^2 d)
  • QP on simplex: O(K3)O(K^3)

Memory footprint includes M×dM \times d particles, KK Stein fields at MM points, and the K×KK \times K matrix UU.

5. Empirical Performance and Applications

Sampling Accuracy

On synthetic problems (e.g., mixtures of Gaussians in R2\mathbb{R}^2), MuSGD particles concentrate in the true joint high-density region, outperforming methods such as MOO-SVGD, which scatter across separate modes.

Multi-Task Learning

MuSGD has been evaluated on multi-MNIST, multi-FashionMNIST, CelebA (10 attributes), SARCOS regression datasets. Using architectures such as LeNet or ResNet-18, MuSGD alternates particle-based sampling for shared parameters via MT-SGD and task-specific parameters via SVGD. Metrics considered are ensemble accuracy, Brier score, and Expected Calibration Error (ECE).

Extracted empirical results:

  • On CelebA, MuSGD attains highest mean accuracy (89.0% vs. 88.2% for MOO-SVGD) and lowest ECE (2.0% vs. 2.5%).
  • On SARCOS regression, MuSGD yields the lowest RMSE for all outputs (0.0428 vs. 0.0515 for MOO-SVGD).
  • MuSGD consistently outperforms single-task SGD, linear scalarization, MGDA, Pareto MTL, and MOO-SVGD in both accuracy (+1–2%) and calibration (lower ECE), across varied tasks (Phan et al., 2022).

6. Interpretation and Practical Guidelines

MuSGD generalizes SVGD to the multi-objective setting by:

  • Combining multiple Stein directions through simplex QP solving.
  • Retaining kernel-based repulsion for diversity among particles.
  • Guaranteeing global theoretical descent for all KL objectives.
  • Asymptotically coinciding with classical multi-gradient descent.
  • Demonstrably enhanced joint sampling and downstream multi-task generalization under standard smoothness conditions.

A plausible implication is that MuSGD represents the first kernelized variant of multi-gradient descent, yielding improved empirical sampling efficiency and predictive calibration in multi-objective scenarios. The algorithm is readily adaptable to modern deep learning architectures, provided access to the Stein gradients and kernel functions.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to MuSGD Optimizer.