Papers
Topics
Authors
Recent
Search
2000 character limit reached

STRAFE: Transformer-Based Survival Analysis

Updated 9 February 2026
  • STRAFE is a deep-learning architecture for discrete-time survival analysis that models longitudinal EHR visits using transformer-based sequence modeling and convolution to predict time-to-event outcomes.
  • It leverages pre-trained OMOP embeddings and sinusoidal timestamps to handle irregular event intervals and accommodate right-censored data effectively.
  • Evaluations on CKD patient cohorts demonstrate significant performance gains, with reduced MAE and improved AUC compared to classical and deep-learning baselines.

STRAFE is a deep-learning architecture for time-to-event (survival) prediction from longitudinal electronic health records (EHR), leveraging transformer-based sequence modeling and designed to handle irregular event intervals and censored outcomes. STRAFE was introduced for predicting chronic kidney disease (CKD) progression using large-scale claims data and demonstrated improved performance in both event-time and fixed-time risk prediction tasks compared to classical and deep-learning baselines (Zisser et al., 2023).

1. Architecture Overview

STRAFE models a patient's visit history as an ordered set of timestamped EHR visits, each consisting of a collection of OMOP concepts. Each visit VjiV_j^i for patient ii at time tjit_j^i includes codes CjiC_j^i encoded through a pre-trained skip-gram embedding ϕ:CRde\phi: C \rightarrow \mathbb{R}^{d_e} with de=128d_e=128. The content embedding per visit is ψ(Vji)=cCjiϕ(c)\psi(V_j^i) = \sum_{c \in C_j^i} \phi(c). To capture irregular visit intervals, each timestamp is mapped via a sinusoidal embedding τ(Vji)=[sin(t~jiω)cos(t~jiω)]\tau(V_j^i) = [\sin(\tilde{t}_j^i \cdot \omega) \,\|\, \cos(\tilde{t}_j^i \cdot \omega)], with ω\omega a geometric progression of frequencies.

The per-visit encoding xj=ψ(Vji)+τ(Vji)x_j = \psi(V_j^i) + \tau(V_j^i) yields the visit sequence ii0, where ii1 is the max sequence length (truncated/padded). This sequence is processed by ii2 layer of multi-head self-attention (ii3 heads, dropout ii4), producing ii5.

Subsequently, ii6 is projected onto a fixed monthly time grid of ii7 months via a 1D convolution, yielding ii8. Temporal embeddings are added for each time point, and a second self-attention block (ii9, tjit_j^i0) models dependencies across months. A two-layer MLP, applied to each time point, outputs tjit_j^i1, the complement of the discrete monthly hazard.

The discrete-time survival function is given by:

tjit_j^i2

and the mean predicted time-to-event is tjit_j^i3.

2. Survival Modeling and Loss Function

STRAFE is formulated for discrete-time survival analysis and accommodates right-censored data through a custom likelihood. For each patient tjit_j^i4, outcomes are encoded as tjit_j^i5, where tjit_j^i6 is the event or censoring time and tjit_j^i7 is the event indicator. The joint loss over a batch is

tjit_j^i8

with observed-case loss:

tjit_j^i9

and censored-case loss:

CjiC_j^i0

This framework does not assume proportional hazards and optimizes the discrete event likelihood across both observed and censored outcomes.

3. Training Protocol and Implementation

STRAFE employs OMOP concept embeddings pre-trained on 35 million claims using skip-gram word2vec (window=90 days, CjiC_j^i1, vocabulary size ≈36,000). Patient sequences are standardized to CjiC_j^i2 visits. Model selection was guided via grid search. Key hyperparameters include:

  • First self-attention: CjiC_j^i3, CjiC_j^i4, dropout=0.3
  • 1D convolution projects CjiC_j^i5 visits to CjiC_j^i6 monthly bins
  • Second self-attention: CjiC_j^i7, CjiC_j^i8
  • Optimization: Adam, batch size 256, learning rate CjiC_j^i9, no explicit weight decay or momentum
  • Implementation framework: PyTorch (Tesla K80 GPU); comparison baselines in scikit-survival and Pycox

These choices were intended to maximize out-of-sample predictive performance without adding explicit architectural complexity or heavy regularization.

4. Performance Evaluation

STRAFE was evaluated on a real-world dataset of over 130,000 CKD stage 3 patients, with the following results on a held-out cohort (ϕ:CRde\phi: C \rightarrow \mathbb{R}^{d_e}0):

Time-to-event prediction (48-month horizon):

Model C-index MAE (mo)
RSF (BOW) 0.609 32.33
RSF (emb) 0.719 31.85
DeepHit (BOW) 0.580 28.39
DeepHit (emb) 0.714 28.59
STRAFE 0.710 22.16
STRAFE-LSTM 0.710 21.59
Uncont. STRAFE 0.690 22.14
Uncont. LSTM 0.711 23.04

Fixed-time risk (AUC-ROC at 6, 12, 24 months):

Model 6 mo 12 mo 24 mo
LR (BOW) 0.622 0.598 0.603
LR (emb) 0.711 0.710 0.720
SARD 0.725 0.731 0.748
RSF (emb) 0.719 0.723 0.683
DeepHit (emb) 0.729 0.728 0.725
STRAFE 0.751 0.754 0.764

STRAFE reduced mean absolute error (MAE) by approximately 25% compared to DeepHit, and improved AUC by 2–3 points over SARD (p ≈ ϕ:CRde\phi: C \rightarrow \mathbb{R}^{d_e}1 at 24 months). Embedding usage was the principal driver of performance gains in C-index. Subgroup analysis indicated AUCs rising to ~0.80 in patients <60 years, with male AUC 0.761 vs female 0.748.

In top-decile risk stratification, STRAFE achieved a PPV of 20.9% at 12 months (vs base rate 6.67%, a 3x lift) and 28.4% at 24 months (vs 14.98%).

5. Model Interpretability and Visualization

STRAFE's first self-attention matrix ϕ:CRde\phi: C \rightarrow \mathbb{R}^{d_e}2, defined as ϕ:CRde\phi: C \rightarrow \mathbb{R}^{d_e}3, quantifies visit-to-visit relatedness. High-attention visit pairs can be visualized as graph nodes (visits colored by dominant ICD chapter) with edge weights reflecting attention scores. In documented ablation studies, removing highly weighted visits caused substantial shifts in predicted survival, confirming their importance for outcome risk.

This mechanism enables per-patient explanation by highlighting the visits most influential for time-to-deterioration predictions. This suggests utility for targeted clinical interventions, e.g. attribution of risk to hypertension versus respiratory events in CKD management.

6. Clinical Impact and Use Cases

STRAFE's ability to model right-censored survival, to exploit sequence structure in visit-level health data, and to provide highly granular per-patient predictions supports its application in intervention targeting for chronic disease management. The threefold lift in PPV among top-decile risk patients, and improved MAE in time-to-event prediction, demonstrate advantages over both classical survival forests and prior deep-learning methods. Its explainability facilitates clinical deployment by providing actionable patient-specific risk drivers and supporting transparent decision support in high-stakes medical settings (Zisser et al., 2023).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to STRAFE.