TREAD: Token Routing for Efficient Architecture-agnostic Diffusion Training
Published 8 Jan 2025 in cs.CV and cs.AI | (2501.04765v2)
Abstract: Diffusion models have emerged as the mainstream approach for visual generation. However, these models typically suffer from sample inefficiency and high training costs. Consequently, methods for efficient finetuning, inference and personalization were quickly adopted by the community. However, training these models in the first place remains very costly. While several recent approaches - including masking, distillation, and architectural modifications - have been proposed to improve training efficiency, each of these methods comes with a tradeoff: they achieve enhanced performance at the expense of increased computational cost or vice versa. In contrast, this work aims to improve training efficiency as well as generative performance at the same time through routes that act as a transport mechanism for randomly selected tokens from early layers to deeper layers of the model. Our method is not limited to the common transformer-based model - it can also be applied to state-space models and achieves this without architectural modifications or additional parameters. Finally, we show that TREAD reduces computational cost and simultaneously boosts model performance on the standard ImageNet-256 benchmark in class-conditional synthesis. Both of these benefits multiply to a convergence speedup of 14x at 400K training iterations compared to DiT and 37x compared to the best benchmark performance of DiT at 7M training iterations. Furthermore, we achieve a competitive FID of 2.09 in a guided and 3.93 in an unguided setting, which improves upon the DiT, without architectural changes.
The paper introduces TREAD, a token routing method that enhances diffusion model training without requiring architectural changes.
It achieves significant convergence speedup—up to 25.39× faster—and improves FID scores on benchmarks like ImageNet-1K.
The approach is validated on both transformer-based and state-space models, offering practical efficiency gains in class-conditional synthesis.
This paper introduces TREAD (Token Routing for Efficient Architecture-agnostic Diffusion Training), a method to improve the training efficiency of diffusion models by using predefined routes that store token information until it is reintroduced to deeper layers of the model. TREAD is applicable to both transformer-based and state-space models without architectural modifications. The authors show that TREAD reduces the computational cost and boosts model performance on ImageNet-1K 256×256 in class-conditional synthesis. The method achieves a 9.55× convergence speedup at 400K training iterations compared to DiT and a 25.39× speedup compared to the best benchmark performance of DiT at 7M training iterations.
The main contributions of this paper are:
The paper investigates token-specific computational bypass techniques for diffusion models and introduces a training strategy that requires no architectural modifications while simultaneously improving both qualitative performance and training speed.
The approach is extended from a single-route to a multi-route framework, enhancing diffusion models across various architectures and improving convergence speed for both DiT and state space models.
The method matches the performance of DiT with an FID of 19.47 and improves upon it, achieving an FID of 10.63 under the same number of iterations and identical settings in the standard benchmark of class-conditional synthesis on ImageNet-1K 256×256, resulting in a 9.55× convergence speedup. Using the best FID of 9.62 reported by DiT-XL/2 as a baseline, the method achieves a speedup of 25.39× and reaches a better FID of 9.32 within 41 hours.
The authors adopt the framework established by \citet{song2021scorebased_sde} and use Stochastic Differential Equations (SDEs) to define the forward diffusion process, which gradually transforms real data x0∼pdata(x0) into a noise distribution xT∼N(0,σmax2I) with the following SDE:
dx=f(x,t)dt+g(t)dW,
where:
x is the data
t is the time variable, 0≤t≤T
f is the drift coefficient
g is the diffusion coefficient
W is a standard Wiener process
The reverse process generates x0 samples through another SDE:
dx=[f(x,t)−g(t)2∇xlogpt(x)]dt+g(t)dWˉ,
where Wˉ signifies a reverse-time Wiener process, and dt represents an infinitesimal negative timestep. This reverse SDE can be reformulated into a probability flow ordinary differential equation (ODE), which retains the same marginal distributions pt(x) as the forward SDE at each timestep t:
dx=[f(x,t)−21g(t)2∇xlogpt(x)]dt.
Utilizing the formulation of Karras et al. \cite{karras2022elucidating} (EDM), the authors simplify the drift term by setting f(x,t):=0 and defining the diffusion coefficient as g(t):=2t. Consequently, the forward SDE simplifies to:
xt=x0+n,
where:
xt is the noisy data at time t
x0 is the original data
n∼N(0,t2I) is Gaussian noise
The corresponding probability flow ODE can be formulated with the score function s(x,t):=∇xlogpt(x):
dx=−ts(x,t)dt.
To estimate s(xt,t), the EDM approach parameterizes a denoising function Dθ(xt,t) which minimizes the denoising score matching loss:
To address the challenge of training a diffusion model on only a subset of tokens, the authors propose to decompose the loss into two parts: 1) the denoising score matching loss and 2) the Masked AutoEncoder (MAE) loss. The former is applied only to the visible tokens, while the MAE loss acts as an auxiliary task for reconstructing masked tokens from visible ones.
The authors define a route as:
$r = \{\, (D_^{l_i}, D_^{l_j}) \mid 0 \leq i < j \leq B \,\}$,
where:
B+1 is the total number of layers in the network $D_$,
$L = \{ D_^{l_1}, D_^{l_2}, \ldots, D_^{l_B} \}$ is the set of layers in $D_$ .
Each pair $(D_^{l_i}, D_^{l_j}) \in r$ represents a connection from the start layer $D_^{l_i}$ to the end layer $D_^{l_j}$.
For multiple routes, the authors use a linear combination of xt and the representation produced by the direct predecessor of rj,k, which is layer $D_^{l_{j-1}}$. Specifically, this combination is defined as:
The authors extend the formulation of the loss functions Ldsm and Lmae by introducing a set of routes R. Here, Rk denotes the k-th route in the set, where each route rk∈RN includes its own binary mask.
The final multi-route ensemble loss is formulated as L:
L=N1k=1∑N(Ldsmk+λLmaek),
where N denotes the total number of sequential routes employed.
The experiments were conducted on ImageNet-1K 256×256 with a batch size of $256$. The AdamW optimizer was used with a learning rate of 1e-4, no weight decay, and β parameters of $0.9$ and $0.999$. A selection rate of 50% was applied for each route, and an EMA model with an update rate of $0.9999$ was used. The authors utilized the Fréchet Inception Distance (FID) as the main metric to evaluate the quality of the models and report the values on 50,000 samples.
The authors compared their DiT S, B, and XL models to their respective counterparts and presented further comparisons to other efficiency-oriented methods using XL-sized models. The improvements extend positively to larger models, leading to an FID of $12.47$ with DiT-XL/2-1. Using DiT-XL/2-2, the authors almost cut the FID in half (FID $19.47$ → $10.63$) at 400K steps. They also show DiT-B/2-1 (random) as a baseline for randomly selecting the ending location of the route r with a measured FID of $36.80$. Lastly, the authors show the extension of their method to SSMs with RWKV as a representative model, which shows performance improvements (FID=53.79) compared to their own Diffusion-RKWV baseline (FID=59.84).
The authors examined the speed measured in iterations per second using a single A100 with a batch size of $64$ to indicate the cost per iteration. TREAD performs favorably against all baselines, with a higher speed and lower costs compared to MaskDiT, the baseline DiT, and MDT under identical settings. TREAD exhibits higher speed (it/s=3.03) compared to MaskDiT (it/s=2.98), the baseline DiT (it/s=1.86), and MDT (it/s=1.02). The authors calculated an overall convergence speedup of 9.55× compared to a vanilla DiT.
To showcase the scaling to larger token numbers, the authors chose ImageNet512 in addition to their main results on ImageNet256. Further, they transitioned to a text-to-image setting using the MS-COCO dataset. For both tasks, the authors used the B/2 model size and evaluated the FID on 10,000 samples, without classifier-free guidance after 400K training iterations. For both settings, the authors achieve significantly better performance while offering a more efficient training setup.
Multiple ablation studies were conducted to examine the effect that each component has on the performance of TREAD. All experiments were conducted using a DiT-B/2 model and evaluated using 50,000 samples.
The authors ablated the impact of the selection ratio on performance and noticed a significant boost in performance when employing a random selection rate of $0.25$ (FID=39.30) and $0.5$ (FID=35.78) outperforming the vanilla DiT-B/2 (FID=43.47) on top of the significant reduction of cost per iteration. The medium selection rate of $0.5$ resulted in a better FID than $0.25$. The authors also experimented with the locations inside of the model that were determined as the start and end of the token route r. The authors found that the choice of i and j seemed to be fairly arbitrary with the exception of r8,10, which resulted in the worst performance (FID=232.53). The near-equal performance of DiT-B/2-1 (FID=35.78) and DiT-B/2-1 (random) (FID=36.80) underlines this. The authors also ablated the effect of leaving out the Multi-Route Loss and demonstrated that the combination of all routes r and their corresponding binary masks into the loss is vital to the performance of the model. Finally, the authors show that the use of the linear combination t,j−1 results in superior performance compared to only using xt or $D_^{l_{j-1}(\cdot, \cdot)}$.