- The paper introduces HierRetro, a hierarchical framework that integrates reaction center identification, action prediction, and termination decisions for improved retrosynthesis prediction accuracy.
- HierRetro leverages contrastive learning to pretrain a molecular encoder, implicitly capturing 3D structural information from 2D graphs, and achieves high accuracy on the USPTO-50k benchmark.
- Ablation studies demonstrate that both contrastive learning pretraining and the reaction center type prediction module significantly enhance the model's performance.
The paper introduces HierRetro, a hierarchical framework designed for retrosynthesis prediction, integrating reaction center identification, action prediction, and termination decisions into a unified pipeline. It addresses limitations in existing methods by systematically classifying reaction center types (atom or bond) to enhance decision-making and leverages contrastive learning to pretrain a molecular encoder, capturing 3D structural information implicitly from 2D molecular graphs.
The paper's key contributions are:
- Hierarchical Model: Decomposes retrosynthesis into distinct modules to enhance prediction accuracy and flexibility in addressing diverse molecular reactions.
- Contrastive Learning: Incorporates contrastive learning to pretrain a molecular encoder, capturing 3-dimensional structural information of input molecules using solely 2-dimensional molecular graphs.
The HierRetro framework consists of several key components:
- Molecular Encoder: Employs the Uni-Mol+ architecture, utilizing Transformer blocks to process atom-level and pair-level representations, capturing both local and global molecular features.
- Atom representations encode individual atomic properties.
- Pair representations combine bond features with 2D graph topology.
- Contrastive Learning Pretraining: Enhances the encoder’s ability to capture 3D structural insights without requiring explicit 3D data, inspired by the 3D-Infomax approach.
- Employs a dual-network setup: a 2D network utilizing molecular connectivity and a 3D network incorporating conformer information.
- Pretraining is performed using the GEOM-Drug dataset, which contains 304,466 molecules annotated with diverse quantum mechanical (QM) properties and 3D conformers.
- Multitasking Agent: Addresses multiple interconnected tasks critical for retrosynthesis prediction, including reaction center identification, action prediction, and termination decision.
- Updates atom and pair representations at each step t using information from the current and previous states:
- hat=Wachat+Waphat−1
- hat is the updated atom representation at step t
- Wac is the weight matrix applied to the current atom representations
- hat−1 is the previous atom representation
- Wap is the weight matrix applied to the previous atom representations
- hbt=Wbchbt+Wbphbt−1
- hbt is the updated pair representation at step t
- Wbc is the weight matrix applied to the current pair representations
- hbt−1 is the previous pair representation
- Wbp is the weight matrix applied to the previous pair representations
- Performs three key tasks: reaction center identification, action prediction, and termination decision.
- Reaction Center Identification: Consists of three modules: Reaction Center Type Prediction (RCP), Atom Center Prediction (AC), and Bond Center Prediction (BC).
- RCP Module: Determines whether the reaction center is an atom or a bond using graph-level super node features hgt:
- pRCPt=σ(WRCPhgt)
- pRCPt is the predicted probability for the reaction center being a bond
- σ is the sigmoid function
- WRCP is the weight matrix for the RCP module
- hgt is the graph-level super node features
- AC Module: Predicts the atom reaction center among all atoms in the molecule using the atom representation hat:
- pACt=softmax(WAChat)
- pACt is the predicted probability distribution over all atoms
- softmax is the softmax function
- WAC is the weight matrix for the AC module
- hat is the atom representation
- BC Module: Predicts the bond reaction center among all atom pairs in the molecule using the flattened bond pair representations hbt:
- pBCt=softmax(WBChbt)
- pBCt is the predicted probability distribution over all atom pairs
- WBC is the weight matrix for the BC module
- hbt is the bond representation
- Overall reaction center probability pRCt:
- pRCt=concat[(1−pRCPt)⋅pACt,pRCPt⋅pBCt]
- Loss function for reaction center identification:
- $\mathcal{L_{RC}=-\sum_t \sum_{k \in RC^t} y^t_{RC,k} log p^t_{RC,k}$
- $\mathcal{L_{RC}$ is the loss for reaction center identification
- yRC,kt is the ground truth probability for reaction center k
- pRC,kt is the predicted probability for reaction center k
- Action Prediction: Determines the specific chemical modifications required at the identified reaction center and consists of Atom Action Prediction (AA) and Bond Action Prediction (BA) modules.
- AA Module: Predicts the probability distribution of atom actions occurring at a given atom index i using the atom feature ha,it:
- pAA,it=softmax(WAAha,it)
- pAA,it is the predicted probability distribution over possible atom actions
- WAA is the weight matrix for the AA module
- ha,it is the atom feature of the identified atom
- BA Module: Predicts the probability distribution of atom actions occurring at a given bond index ij using the pair feature hb,ijt:
- pBA,ijt=softmax(WBAhb,ijt)
- pBA,ijt is the predicted probability distribution over possible bond actions
- WBA is the weight matrix for the BA module
- hb,ijt is the pair feature of the identified bond
- Losses for action prediction modules:
- Atom action loss: LAA=−∑tlog(pAA,itt)
- LAA is the atom action loss
- it is atom index identified at step t
- pAA,itt is the atom action probability at atom it
- Bond action loss: LBA=−t∑log(pBA,ijtt)+log(pAA,itt)+log(pAA,jtt)
- LBA is the bond action loss
- ijt is the bond index identified at step t
- pBA,ijtt is the bond action probability at bond ijt
- pAA,itt and pAA,jtt are the atom action probabilities for the two atoms involved in the bond
- Termination Determination: Predicts whether the retrosynthesis process should conclude or proceed to the next reaction step using graph-level super node features hgt:
- Termination probability pTt:
- pTt=σ(WThgt)
- pTt is the termination probability
- WT is the weight matrix for the termination module
- hgt is the graph-level super node features
- Loss function for the termination decision task:
- LT=−t∑(yTtlog(pTt)+(1−yTt)log(1−pTt))
- LT is the loss for termination
- yTt is the ground truth termination probability
- pTt is the predicted termination probability
- Dynamic Adaptive Multi-Task Learning (DAMT): Integrates the four losses ($\mathcal{L_{RC}$, $\mathcal{L_{AA}$, $\mathcal{L_{BA}$, and $\mathcal{L_{T}$) by dynamically adjusting loss weights based on task complexity and normalizing magnitudes to ensure balanced learning.
The model was trained using the AdamW optimizer with a WarmUpWrapper and ReduceLROnPlateau learning rate scheduler, consisting of 6 encoder blocks with a atom hidden dimension of 256 and pair hidden dimension of 128.
The paper presents a performance evaluation against template-based, template-free, and semi-template methods using top-k exact match accuracy on the USPTO-50k benchmark dataset. In the reaction type unknown case, HierRetro achieved a top-3 accuracy of 78.3\%, outperforming all other models. In the reaction type known case, HierRetro achieved a top-3 accuracy of 89.4\%, surpassing Graph2Edits by 1.9\%, and a top-5 accuracy of 93.3%, outperforming Graph2Edits and LocalRetro by 1.8\% and 0.9%, respectively. Round-trip accuracy evaluation showed the model achieved 94.7\% accuracy for k=3 and 96.2\% for k=5, outperforming all models except LocalRetro, and at k=10, HierRetro achieved the highest accuracy of 97.9\%, surpassing Graph2Edits.
Ablation studies demonstrated that incorporating a contrastive learning (CL)-based pre-trained encoder improved top-1 accuracy by 2.5\%, and adding the reaction center type prediction module increased reaction center identification accuracy by approximately 5%.
The paper includes an analysis of multiple reaction centers, addressing the challenges of limited generalization caused by the USPTO dataset’s constrained size through a permutation-based augmentation strategy. Analysis of drug molecules Fruquintinib and Nirogacestat, not included in the training dataset, demonstrates the model’s practical applications in multistep retrosynthesis predictions.