Papers
Topics
Authors
Recent
Search
2000 character limit reached

BertsWin: Efficient 3D Pre-training Architecture

Updated 1 January 2026
  • BertsWin is a hybrid self-supervised pre-training architecture that mitigates topological sparsity by preserving a complete 3D token grid using BERT-style masking.
  • It employs Swin Transformer windowed attention to limit computational complexity from O(N²) to O(N) while maintaining local spatial context.
  • The design integrates a 3D CNN stem, a multi-component structural loss, and a GradientConductor optimizer to achieve rapid semantic convergence and resource efficiency.

BertsWin is a hybrid self-supervised pre-training architecture for three-dimensional volumetric data that resolves the topological sparsity encountered by conventional Masked Autoencoders (MAEs) when applied to 3D structures. It integrates a full BERT-style token masking strategy with Swin Transformer windowed attention to simultaneously maintain spatial topology and computational efficiency. BertsWin was introduced to address the structural discontinuities and slow convergence inherent to 3D MAE pre-training, and achieves substantial improvements in both semantic convergence speed and resource utilization, as demonstrated on cone beam computed tomography (CBCT) of the temporomandibular joint (TMJ) (Limarenko et al., 25 Dec 2025).

1. Motivation for BertsWin Pre-training

While state-of-the-art 2D masked autoencoders (MAEs) exploit the redundancy of planar images, masking 75% of 3D patches in volumetric inputs produces severe topological sparsity. This disrupts anatomical continuity, fragmenting spatially connected regions and destroying geometric priors. The resulting fragmentation leads to blocking artifacts and loss of context, with networks forced to reconstruct entire structures from sparse, disconnected tokens, which significantly slows convergence rates. Computationally, expanding MAE to full 3D attention incurs a prohibitive O(N2)O(N^2) complexity, where NN is the number of volumetric patches.

BertsWin provides two key innovations: (a) it maintains a complete 3D token grid (mask tokens plus visible embeddings), preserving spatial topology throughout the encoder; (b) it replaces global attention with Swin Transformer windows, limiting computation to O(N)O(N) complexity while retaining effective local spatial context. This approach targets rapid, resource-efficient learning of structural priors in 3D self-supervised pre-training (Limarenko et al., 25 Dec 2025).

2. High-Level Architecture

BertsWin comprises four distinct modules in a sequential arrangement:

A. 3D CNN Stem (Hybrid Patch Embedding):

  • Input volume V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W} is split into non-overlapping P3P^3-sized patches, yielding N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P) total tokens.
  • 25% of patches are selected at random and passed through four 3D convolutional blocks to produce visible token embeddings evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}.

B. Full 3D Token Grid with Positional Embedding:

  • A binary mask M∈{0,1}B×NM\in\{0,1\}^{B\times N} records the visibility of each patch.
  • Visible embeddings are scattered to their locations in the grid; masked locations are filled with a learnable token emask∈RC\mathbf{e}_\text{mask}\in\mathbb{R}^C.
  • Fixed or learnable positional embeddings of shape (N,C)(N, C) are added for spatial context.

C. Single-Scale Swin Transformer Encoder:

  • Consists of 12 Swin blocks, each performing window-based local attention with window size NN0 (typically NN1), shift patterns for inter-window information transfer, 12 attention heads, and no downsampling.
  • The output is a feature grid NN2.

D. 3D CNN Decoder:

  • Three transposed convolution layers (strides 4, 2, 2) upsample the features to reconstruct the full volumetric input NN3.
Module Input Shape Key Parameters
CNN Stem NN4 4 blocks, patch size NN5, stride 2
Token Grid + Positional Embed. NN6 Binary mask NN7, NN8
Swin Transformer Encoder NN9 12 blocks, window size O(N)O(N)0, 12 heads
CNN Decoder O(N)O(N)1 3 transposed-conv layers (strides 4,2,2)

3. 3D BERT-style Masking Mechanism

BertsWin employs 3D patch-level masking, generalizing BERT-style masking to volumetric data. For O(N)O(N)2 patches and masking ratio O(N)O(N)3, the binary mask O(N)O(N)4 for sample O(N)O(N)5, patch O(N)O(N)6 is defined as:

O(N)O(N)7

Token construction proceeds via:

O(N)O(N)8

where O(N)O(N)9 denotes voxel values in patch V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}0, V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}1 denotes the CNN stem embedding, and V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}2 is the learnable embedding for masked positions. Positional embeddings V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}3 are added:

V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}4

This complete grid, containing all visible and masked patches, is propagated through the encoder. By preserving the 3D topology during masking and processing, the architecture maintains anatomical coherence and improves convergence dynamics (Limarenko et al., 25 Dec 2025).

4. Structural Priority Loss: Multi-Component Variance and PhysLoss

BertsWin introduces a structural loss decomposing per-patch mean squared error (MSE) into three components—brightness (V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}5), contrast (V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}6), and structure (V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}7)—defined per patch pair V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}8:

  • Brightness: V∈RB×D×H×WV\in\mathbb{R}^{B\times D\times H\times W}9
  • Contrast: P3P^30
  • Structure: P3P^31\ where P3P^32 is the patch mean, P3P^33 is the patch standard deviation, and P3P^34 is patchwise correlation.

The Multi-Component Variance loss is:

P3P^35

with weights P3P^36, P3P^37, P3P^38.

PhysLoss further prioritizes structurally critical regions by computing P3P^39 over three domains:

  • Global patch domain N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)0,
  • Soft-tissue mask N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)1,
  • Bone-surface shell mask N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)2.

The complete loss function is:

N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)3

with coefficients N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)4, N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)5, N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)6. This multi-component loss supports anatomically aware learning and accelerates semantic convergence (Limarenko et al., 25 Dec 2025).

5. GradientConductor (GCond) Optimizer

GradientConductor (GCond) is a custom optimizer combining features of LION (sign updates), LARS (trust-ratio scaling), and Adam (bias correction). Parameter N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)7 is updated as follows:

  1. First-moment estimate (momentum) with bias correction:

N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)8

  1. Trust ratio scaling (LARS):

N=(D/P)â‹…(H/P)â‹…(W/P)N=(D/P)\cdot(H/P)\cdot(W/P)9

  1. Parameter update using sign of first moment (LION-style):

evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}0

Where evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}1, evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}2, evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}3, and effective learning rate evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}4. Only the first moment evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}5 is stored, reducing optimizer memory by ∼50% relative to AdamW. The optimizer confers stable warm-up (bias correction), memory efficiency, and cross-layer gradient scaling (Limarenko et al., 25 Dec 2025).

6. Computational Complexity and Empirical Convergence

FLOP Analysis

For input resolution evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}6 and patch size evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}7 (thus evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}8):

  • BertsWin: encoder evis∈RB×nvis×C\mathbf{e}_\text{vis}\in\mathbb{R}^{B\times n_\text{vis}\times C}9 GFLOPs, stem M∈{0,1}B×NM\in\{0,1\}^{B\times N}0 GFLOPs, decoder M∈{0,1}B×NM\in\{0,1\}^{B\times N}1 GFLOPs, total M∈{0,1}B×NM\in\{0,1\}^{B\times N}2 GFLOPs.
  • MONAI ViT-MAE: encoder M∈{0,1}B×NM\in\{0,1\}^{B\times N}3, stem M∈{0,1}B×NM\in\{0,1\}^{B\times N}4, decoder M∈{0,1}B×NM\in\{0,1\}^{B\times N}5, total M∈{0,1}B×NM\in\{0,1\}^{B\times N}6 GFLOPs.

At M∈{0,1}B×NM\in\{0,1\}^{B\times N}7 resolution:

  • BertsWin: M∈{0,1}B×NM\in\{0,1\}^{B\times N}8 GFLOPs (linear in M∈{0,1}B×NM\in\{0,1\}^{B\times N}9).
  • ViT baseline: emask∈RC\mathbf{e}_\text{mask}\in\mathbb{R}^C0 GFLOPs (emask∈RC\mathbf{e}_\text{mask}\in\mathbb{R}^C1).

BertsWin therefore achieves theoretical FLOP parity at standard resolution and a ~4.1× reduction at high resolution, attributed to its window-based attention.

Convergence Speed

  • Monai MAE ViT (L2+AdamW): 660 epochs to best validation MSE.
  • BertsWin (L2+AdamW): 114 epochs, yielding a semantic speedup of emask∈RC\mathbf{e}_\text{mask}\in\mathbb{R}^C2.
  • BertsWin + PhysLoss + GCond: 44 epochs, providing a total speedup of emask∈RC\mathbf{e}_\text{mask}\in\mathbb{R}^C3.
  • Due to GFLOPs parity per epoch, GPU-hour requirements drop by the same factors, offering substantial practical acceleration (Limarenko et al., 25 Dec 2025).

7. Summary and Context

BertsWin resolves topological sparsity in masked-pretraining for 3D volumes by restoring a complete token grid, leveraging Swin-style local attention for scalable computation, and implementing a structural loss with anatomical prioritization. Combined with a memory-efficient custom optimizer, the architecture provides efficient training and rapid convergence, achieving up to emask∈RC\mathbf{e}_\text{mask}\in\mathbb{R}^C4 epoch reduction with maintained or reduced per-iteration FLOPs. These innovations are empirically validated on 3D TMJ CT segmentation, addressing both computational and topological bottlenecks in 3D self-supervised learning (Limarenko et al., 25 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to BertsWin Architecture.