Papers
Topics
Authors
Recent
Search
2000 character limit reached

Auto-Regressive Masked Diffusion Models (ARMD)

Updated 28 January 2026
  • ARMD are a class of generative models that combine the flexibility of masked diffusion with the training efficiency of autoregressive methods.
  • They implement blockwise-parallel decoding using permutation-equivariant Transformers, achieving competitive performance across benchmarks.
  • Their design supports efficient inference and extends naturally to modalities like text, images, video, and 3D structures.

Auto-Regressive Masked Diffusion Models (ARMD) are a contemporary class of generative models that integrate the flexible, partially-observed context and permutation equivariance of masked diffusion models (MDMs) with the explicit conditional structure and training efficiency of autoregressive models (ARMs). ARMD enables efficient, blockwise-parallel text, image, or structured sequence generation, supporting both sequential and parallel decoding, and achieving empirical performance approaching or matching the best autoregressive baselines while offering new algorithmic capabilities for flexible inference, robust training, and fine-grained architectural control (Karami et al., 23 Jan 2026, Garg et al., 24 Nov 2025, Huang et al., 30 Apr 2025, Hoogeboom et al., 2021).

1. Mathematical Foundations and Generative Principle

ARMD models are built on a discrete-time Markov masking process on vector- or token-sequence data. Let x0VLx^0 \in \mathcal{V}^L be the initial clean data vector. The forward 'noising' process q(x1:Tx0)=t=1Tq(xtxt1)q(x_{1:T} | x_0) = \prod_{t=1}^T q(x_t | x_{t-1}) corrupts tokens by randomly replacing coordinates with a designated MASK\text{MASK} symbol, typically according to independent Bernoulli steps per coordinate. The learned reverse (denoising) process pθ(xt1xt)p_\theta(x_{t-1}|x_t) attempts to reconstruct masked tokens, parameterized by a deep model (e.g., Transformer) with bidirectional attention.

A critical insight of ARMD is that conditioning on a realized corruption schedule amounts to conditioning on a random permutation π\pi capturing the order in which tokens are masked (and thus will be unmasked during generation). The training objective, in its continuous-time or block-wise factorization, admits an explicit decomposition:

LMDM=Eπ[LAR(π)]=πw(π)LAR(π)\mathcal{L}_\mathrm{MDM} = \mathbb{E}_{\pi}\left[\mathcal{L}_\mathrm{AR}(\pi)\right] = \sum_\pi w(\pi) \cdot \mathcal{L}_\mathrm{AR}(\pi)

where each AR loss LAR(π)\mathcal{L}_\mathrm{AR}(\pi) corresponds to decoding in order π\pi, and w(π)w(\pi) is the probability of π\pi under the masking schedule (Garg et al., 24 Nov 2025). This bridges BERT-style masked modeling, sequence-permutation inference, and traditional autoregressive decoding.

2. Blockwise Causal Reframing and Model Architecture

ARMD reveals that, given a masking trajectory, tokens unmasked at the same step form blocks, yielding a 'blockwise-causal' structure: to generate a token in block tt, it suffices to condition on all previous blocks, which themselves may be generated in parallel. Formally:

  • For a permutation π\pi, partition x1:Lπx_{1:L}^{\pi} into blocks [X(1),...,X(T)][\mathcal X^{(1)}, ..., \mathcal X^{(T)}].
  • Forward process: mask blocks sequentially.
  • Reverse process: for each block tt, predict xB(t)(0)x^{(0)}_{\mathcal B(t)} conditioned on {xi(t):B(i)Tt}\{x^{(t)}_i : \mathcal B(i) \leq T-t\}.

Architecturally, this is implemented in a permutation-equivariant Transformer with strictly blockwise-causal attention masks: a two-stream stack (causal, strictly causal) ensures all conditional likelihoods required for a given block are computed in a single forward pass (Karami et al., 23 Jan 2026). Layers are designed so that tokens only attend to earlier blocks, with permutation equivariance preserved within blocks.

3. Training Procedures and Decoding Strategies

Training involves sampling random permutations (masking orders), optionally guided by a progressive permutation curriculum (beginning with left-to-right, then increasing order randomness). Objective:

Ldiff=Eq[t=1Tγ(t)nB(t)logpθ(xn0past blocks)]\mathcal{L}_{\mathrm{diff}} = \mathbb{E}_q \left[ \sum_{t=1}^T \gamma(t) \sum_{n \in \mathcal{B}(t)} -\log p_\theta(x_n^0 | \text{past blocks}) \right]

with schedule γ(t)\gamma(t) reflecting the transition weights induced by the noise schedule (Karami et al., 23 Jan 2026).

Decoding supports both sequential autoregressive (one token at a time), native blockwise ARMD sampling (block at a time), and strided parallel generation. For the latter, indices are partitioned into SS parallel streams, yielding significant inference acceleration at minimal perplexity cost, as shown empirically on language modeling benchmarks.

4. Theoretical Equivalence to Weighted AR Models and Learnable Decoding Orders

A central theoretical advance is the proof that MDM objectives with multivariate masking schedules are equivalent to expectations over weighted AR objectives for all possible decoding orders (Garg et al., 24 Nov 2025). By introducing distinct, learnable masking schedules α(t)\alpha_\ell(t) per token, one induces a non-uniform distribution w(π)w(\pi) over orderings. Gradient-based training can then adapt both model and masking parameters to discover and exploit favorable orders, improving negative log-likelihood while maintaining competitive data fidelity metrics.

This establishes ARMD as an explicit generalization of conventional (fixed-order) ARMs: by selecting schedule parameters, one traverses the spectrum from fixed left-to-right to fully randomized or even state-dependent orders, encompassing conventional diffusion and autoregressive training as special cases (Hoogeboom et al., 2021).

5. Empirical Benchmarks and Mechanism Analysis

Empirical studies confirm that ARMD models close the traditional performance gap between masked diffusion and autoregressive models on standard language modeling and discrete generation tasks. For example, ARMD-small (\sim125M params) at 180K training steps outperforms both D3PM categorical diffusion and GPT-2 small on LAMBADA, WikiText2, and 1BW; ARMD-medium (\sim345M) at 300K matches or exceeds strong AR baselines (Karami et al., 23 Jan 2026). Blockwise-parallel decoding with stride S=4S=4 enables halving inference latency at a \sim1 ppl penalty.

Mechanistic analyses following ARM\toMDM post-training demonstrate that, for tasks with only local sequential dependencies, the circuit structure and head/layer specialization of the ARM is largely preserved; for global-planning tasks, such as Countdown, diffusion post-training induces a systematic mechanism shift: circuitry is reorganized toward early-layer processing and more distributed, less sharply specialized features, supporting globally bidirectional context integration (Kong et al., 21 Jan 2026). Jaccard overlaps of edge attributions between ARM and MDM are \sim0.1–0.2 on local tasks, dropping to <0.03<0.03 on global tasks.

6. Extensions to Other Modalities and Domains

ARMD frameworks have been successfully instantiated in image, video, 3D shape, and structured data domains:

  • In vision, ARMD comprises an outer AR loop partitioning latent/image tokens and inner blockwise diffusion (with masking, distillation, and parallel inference), achieving SOTA FID and IS on ImageNet with over 30×30\times inference speedup (Gu et al., 19 Nov 2025).
  • In video, ARMD fuses framewise AR diffusion with MDM-inspired masking: static regions are copied from previous frames, while dynamic motion regions are predicted, reducing temporal drift and improving FVD and CLIP similarity on UCF-101 and MSR-VTT (Weng et al., 2023).
  • In 3D generation, LTM3D leverages masked autoencoder backbones, prepending condition tokens via prefix learning, and performing per-token diffusion in AR order. Guided sampling with latent reconstructions improves sample quality across signed distance fields, point clouds, and meshes (Kang et al., 30 May 2025).

These demonstrate the portability of ARMD to latent-tokenized domains, with conditional prefix learning and masked sampling schemes generalized across modalities.

7. High-Level Implications, Limitations, and Future Directions

ARMD unifies and extends the masked modeling, autoregressive, and diffusion modeling paradigms, yielding efficient, permutation-flexible blockwise architectures that rival or surpass strong ARMs in accuracy and sample coherence, particularly when enabled with learned masking schedules (Karami et al., 23 Jan 2026, Garg et al., 24 Nov 2025).

Identified limitations include increased FLOPs per-layer owing to the two-stream architecture, and potential sample quality degradation under aggressive parallel sampling for data with strong, dense local dependencies. The progressive permutation curriculum and schedule design partially mitigate these, and empirical results indicate substantial win in training efficiency (3–8×\times reduction in total steps over prior MDMs).

Anticipated advances include adaptive or data-dependent block partitioning, hybrid circuit designs combining induction head and global planner modules, extension of ARMD to billion-parameter scale language and multimodal models, and integration of reward-guided RL, distillation, and prefix tuning for downstream preference alignment and controlled generation (Gu et al., 19 Nov 2025, Kong et al., 21 Jan 2026). These directions promise robust, interpretable generative models with finely balanced trade-offs between global coherence, training speed, and parallelizable inference.

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Auto-Regressive Masked Diffusion Models (ARMD).