TP-UNet: Temporal & Topology-Aware Segmentation
- TP-UNet is an innovative framework that embeds temporal prompts and topological priors into UNet architectures to improve segmentation accuracy.
- It employs semantic alignment and contrastive learning to fuse text and image features, yielding notable improvements in Dice scores compared to standard methods.
- The topology-preserving variant uses diffeomorphic deformation fields with Jacobian regularization to ensure anatomically correct and connected segmentations.
TP-UNet refers to two distinct but influential frameworks in medical image segmentation: Temporal Prompt Guided UNet for integrating temporal priors via prompts (Wang et al., 2024), and the Topology-Preserving Segmentation Network, leveraging diffeomorphic deformation fields for topological correctness (Zhang et al., 2022). Each addresses limitations of standard UNet-based architectures by incorporating domain-specific priors—temporal or topological—into the deep learning pipeline, thereby achieving state-of-the-art results in their respective tasks.
1. Temporal Prompt Guided UNet
The Temporal Prompt Guided UNet (TP-UNet) (Wang et al., 2024) extends conventional UNet models for 2D medical image segmentation by integrating temporal information using textual prompts. Traditional UNet variants process each axial slice of a 3D scan independently, omitting organ-order and slice-position cues that experienced radiologists leverage intuitively. This often results in confusion among adjacent organs with similar appearances but distinct anatomical progression along the scan axis.
Motivation and Temporal Priors
The appearance probability of abdominal organs as a function of axial slice position can be modeled by organ-specific normal distributions over the normalized slice index (): for example, the stomach commonly appears in lower-index slices, the small intestine mid-range, and the large intestine in later slices. Standard UNet architectures ignore this temporal prior; TP-UNet injects it directly.
Each slice is paired with an auto-generated prompt:
"This is {an MRI / a CT} of the {organ} with a segmentation period of {i/N}."
The prompt encodes both the imaging modality, the target organ, and the slice's relative position, directly guiding the model’s expectations for that slice’s content.
2. Architecture and Multimodal Integration
TP-UNet is implemented as a multimodal extension of the classic 2D UNet:
- Backbone: Four-block encoder with max-pooling, and a decoder with upsampling and skip connections.
- Text/Image Embedding: Deepest encoder output is paired with a text prompt embedding , where textual features are extracted with CLIP+LoRA or Electra+SFT encoders.
Two core mechanisms are used between encoder and decoder:
Semantic Alignment via Contrastive Learning
A bidirectional InfoNCE loss aligns text and image embedding spaces, drawing matched pairs of image and prompt representations together, and separating unmatched pairs. Given image feature and its corresponding prompt , the losses are
with total contrastive loss:
Cross-Attention Fusion
Aligned image and prompt embeddings are concatenated and fused via pixel-wise cross-attention, computed as:
The resulting feature map is merged by a convolution and is injected into the decoder via the skip-connection for downstream mask prediction.
3. Optimization, Experimental Datasets, and Results
Loss Functions
Segmentation supervision is achieved by combining binary cross-entropy and Tversky loss: with optional Dice loss reporting.
Datasets and Setup
- UW-Madison GI Tract (MRI):
- 26,746 training, 3,820 validation, 7,642 test 2D slices, multi-organ (stomach, small, large intestine)
- 7:1:2 data split
- LITS 2017 (CT):
- 13,000 3D scans, 10,967 2D slices used for liver segmentation
- 7:1:2 split
Preprocessing involves CoarseDropout, HorizontalFlip, and ShiftScaleRotate. The optimizer is AdamW with weight decay and initial learning rate , cosine-annealing scheduling, and temperature for contrastive alignment.
Quantitative Outcomes
| Dataset | Backbone | Dice Score | Jaccard | Improvement over UNet |
|---|---|---|---|---|
| UW-Madison (MRI) | Electra | 0.9266 | 0.8943 | +4.44% mDice |
| LITS (CT) | Not specified | 0.9125 | 0.8780 | +6.08% mDice |
TP-UNet sets state-of-the-art performance, outperforming Swin-UNet by +1.3% Dice on UW-Madison and previous LITS SOTA by +9.21% (Wang et al., 2024). Qualitative analysis demonstrates sharper boundaries and fewer false positives, especially for slices where organ appearance is ambiguous.
Ablation Studies
Removal of temporal prompts reduces mDice by 2.1% (UW-Madison), and eliminating prompts entirely causes a 5.36% drop (LITS). Semantic alignment and cross-attention fusion contribute additional performance. The dominant contributor is the temporal prompt.
4. Topology-Preserving Segmentation Network (TPSN)
The Topology-Preserving Segmentation Network (TPSN, also referred to as TP-UNet in some sources) (Zhang et al., 2022) addresses the problem of topology errors in medical segmentation. TPSN outputs a diffeomorphic deformation field that warps a simple binary template mask (with the desired connectivity) onto the anatomy in the image, guaranteeing that resulting segmentations possess the correct topological structure.
Mathematical Formulation
Let be the image (-dimensional), and a template mask with correct topology. A UNet (2D or 3D) is adapted to take both and as input and outputs a deformation field describing a spatial map .
- Deformation: or directly as
- Prediction:
A spatial transformer layer executes mask warping using bilinear or trilinear sampling.
Loss and Topology Regularization
To guarantee topology preservation, the Jacobian determinant of is constrained to remain positive via an -ReLU penalty:
The total loss combines Dice loss, Jacobian penalty, and optional Laplacian smoothness:
with reported weights , , (Zhang et al., 2022).
Multi-Scale Cascade
A three-level coarse-to-fine cascade (ml-TPSN) improves boundary detail. The machine learning pipeline trains and infers at low, medium, and full resolution, with each TPSN module starting from the previous upsampled mask.
5. Experimental Evaluation and Comparative Analysis
Datasets and Protocol
- Ham10000 (2D dermoscopy): 10,015 images, split 9,000/1,015, resized to
- KiTS21 (3D CT kidneys): 300 volumes, 210/70 split, resampled to
Results
| Dataset | UNet | UNet+CCA | TEDSNet | TPSN | ml-TPSN |
|---|---|---|---|---|---|
| Ham10000 (2D) | 93.50±0.29% | 93.63±0.29% | 89.91±0.27% | 93.77±0.37% | 94.42±0.39% |
| KiTS21 (3D) | 92.76±0.38% | 92.81±0.38% | 87.67±0.29% | 92.73±0.48% | 93.17±0.41% |
TPSN consistently outperforms baseline UNet and UNet+CCA (connected-component analysis), especially in cases with partial data loss or low contrast, yielding a segmentation that is, by construction, topologically correct and free of spurious disconnected components.
Inference and Post-processing
Given a test image and template mask, TPSN produces the final mask in a single forward pass, and no further post-processing is required to ensure topology correctness.
6. Distinctions and Nomenclature
The term "TP-UNet" is ambiguous, referring to both Temporal Prompt Guided UNet (Wang et al., 2024) and Topology-Preserving Segmentation Network (Zhang et al., 2022) in contemporary literature. Each framework leverages different forms of prior:
- Temporal Prompt Guidance: Uses slice-wise textual prompts encoding anatomical order.
- Topology Preservation: Incorporates topological priors via template-based diffeomorphic deformation and Jacobian regularization.
A plausible implication is that future works may combine both temporal and topological priors within unified segmentation models if anatomical constraints and scan protocols permit.
7. Prospective Extensions
For Temporal Prompt Guided UNet, proposed future directions include employing 3D prompt sequences, enabling uncertainty-aware prompt construction, and integrating clinical report information to improve volumetric spatial–temporal consistency (Wang et al., 2024).
For TPSN, plausible extensions involve templates with more complex topology, data-driven template adaptation, or fusion with semantic priors beyond topology, aiming to generalize topology preservation across multi-object and variable-connectivity segmentation tasks (Zhang et al., 2022).