Papers
Topics
Authors
Recent
Search
2000 character limit reached

Hierarchical Framework for Retrosynthesis Prediction with Enhanced Reaction Center Localization

Published 29 Nov 2024 in physics.chem-ph | (2411.19503v1)

Abstract: Retrosynthesis is essential for designing synthetic pathways for complex molecules and can be revolutionized by AI to automate and accelerate chemical synthesis planning for drug discovery and materials science. Here, we propose a hierarchical framework for retrosynthesis prediction that systematically integrates reaction center identification, action prediction, and termination decision into a unified pipeline. Leveraging a molecular encoder pretrained with contrastive learning, the model captures both atom and bond level representations, enabling accurate identification of reaction centers and prediction of chemical actions. The framework addresses the scarcity of multiple reaction center data through augmentation strategies, enhancing the ability of the model to generalize to diverse reaction scenarios. The proposed approach achieves competitive performance across benchmark datasets, with notably high topk accuracy and exceptional reaction center identification capabilities, demonstrating its robustness in handling complex transformations. These advancements position the framework as a promising tool for future applications in material design and drug discovery.

Summary

  • The paper introduces HierRetro, a hierarchical framework that integrates reaction center identification, action prediction, and termination decisions for improved retrosynthesis prediction accuracy.
  • HierRetro leverages contrastive learning to pretrain a molecular encoder, implicitly capturing 3D structural information from 2D graphs, and achieves high accuracy on the USPTO-50k benchmark.
  • Ablation studies demonstrate that both contrastive learning pretraining and the reaction center type prediction module significantly enhance the model's performance.

The paper introduces HierRetro, a hierarchical framework designed for retrosynthesis prediction, integrating reaction center identification, action prediction, and termination decisions into a unified pipeline. It addresses limitations in existing methods by systematically classifying reaction center types (atom or bond) to enhance decision-making and leverages contrastive learning to pretrain a molecular encoder, capturing 3D structural information implicitly from 2D molecular graphs.

The paper's key contributions are:

  • Hierarchical Model: Decomposes retrosynthesis into distinct modules to enhance prediction accuracy and flexibility in addressing diverse molecular reactions.
  • Contrastive Learning: Incorporates contrastive learning to pretrain a molecular encoder, capturing 3-dimensional structural information of input molecules using solely 2-dimensional molecular graphs.

The HierRetro framework consists of several key components:

  1. Molecular Encoder: Employs the Uni-Mol+ architecture, utilizing Transformer blocks to process atom-level and pair-level representations, capturing both local and global molecular features.
    • Atom representations encode individual atomic properties.
    • Pair representations combine bond features with 2D graph topology.
  2. Contrastive Learning Pretraining: Enhances the encoder’s ability to capture 3D structural insights without requiring explicit 3D data, inspired by the 3D-Infomax approach.
    • Employs a dual-network setup: a 2D network utilizing molecular connectivity and a 3D network incorporating conformer information.
    • Pretraining is performed using the GEOM-Drug dataset, which contains 304,466 molecules annotated with diverse quantum mechanical (QM) properties and 3D conformers.
  3. Multitasking Agent: Addresses multiple interconnected tasks critical for retrosynthesis prediction, including reaction center identification, action prediction, and termination decision.
    • Updates atom and pair representations at each step tt using information from the current and previous states:
      • hat=Wachat+Waphat1h_a^t=W_a^c h_a^t+W_a^p h_a^{t-1}
        • hath_a^t is the updated atom representation at step tt
        • WacW_a^c is the weight matrix applied to the current atom representations
        • hat1h_a^{t-1} is the previous atom representation
        • WapW_a^p is the weight matrix applied to the previous atom representations
      • hbt=Wbchbt+Wbphbt1h_b^t=W_b^c h_b^t+W_b^p h_b^{t-1}
        • hbth_b^t is the updated pair representation at step tt
        • WbcW_b^c is the weight matrix applied to the current pair representations
        • hbt1h_b^{t-1} is the previous pair representation
        • WbpW_b^p is the weight matrix applied to the previous pair representations
    • Performs three key tasks: reaction center identification, action prediction, and termination decision.
  4. Reaction Center Identification: Consists of three modules: Reaction Center Type Prediction (RCP), Atom Center Prediction (AC), and Bond Center Prediction (BC).
    • RCP Module: Determines whether the reaction center is an atom or a bond using graph-level super node features hgth_g^t:
      • pRCPt=σ(WRCPhgt)p_{RCP}^t=\sigma (W_{RCP} h_g^t )
        • pRCPtp_{RCP}^t is the predicted probability for the reaction center being a bond
        • σ\sigma is the sigmoid function
        • WRCPW_{RCP} is the weight matrix for the RCP module
        • hgth_g^t is the graph-level super node features
    • AC Module: Predicts the atom reaction center among all atoms in the molecule using the atom representation hath_a^t:
      • pACt=softmax(WAChat)p_{AC}^t=softmax(W_{AC} h_a^t )
        • pACtp_{AC}^t is the predicted probability distribution over all atoms
        • softmaxsoftmax is the softmax function
        • WACW_{AC} is the weight matrix for the AC module
        • hath_a^t is the atom representation
    • BC Module: Predicts the bond reaction center among all atom pairs in the molecule using the flattened bond pair representations hbth_b^t:
      • pBCt=softmax(WBChbt)p_{BC}^t=softmax(W_{BC} h_b^t )
        • pBCtp_{BC}^t is the predicted probability distribution over all atom pairs
        • WBCW_{BC} is the weight matrix for the BC module
        • hbth_b^t is the bond representation
    • Overall reaction center probability pRCtp_{RC}^t:
      • pRCt=concat[(1pRCPt)pACt,pRCPtpBCt]p_{RC}^t=concat[(1-p_{RCP}^t )\centerdot p_{AC}^t, p_{RCP}^t \centerdot p_{BC}^t ]
    • Loss function for reaction center identification:
      • $\mathcal{L_{RC}=-\sum_t \sum_{k \in RC^t} y^t_{RC,k} log p^t_{RC,k}$
        • $\mathcal{L_{RC}$ is the loss for reaction center identification
        • yRC,kty^t_{RC,k} is the ground truth probability for reaction center kk
        • pRC,ktp^t_{RC,k} is the predicted probability for reaction center kk
  5. Action Prediction: Determines the specific chemical modifications required at the identified reaction center and consists of Atom Action Prediction (AA) and Bond Action Prediction (BA) modules.
    • AA Module: Predicts the probability distribution of atom actions occurring at a given atom index ii using the atom feature ha,ith^t_{a,i}:
      • pAA,it=softmax(WAAha,it)p^t_{AA,i}=softmax(W_{AA} h^t_{a,i} )
        • pAA,itp^t_{AA,i} is the predicted probability distribution over possible atom actions
        • WAAW_{AA} is the weight matrix for the AA module
        • ha,ith^t_{a,i} is the atom feature of the identified atom
    • BA Module: Predicts the probability distribution of atom actions occurring at a given bond index ijij using the pair feature hb,ijth_{b,ij}^t:
      • pBA,ijt=softmax(WBAhb,ijt)p_{BA,ij}^t=softmax(W_{BA} h_{b,ij}^t )
        • pBA,ijtp_{BA,ij}^t is the predicted probability distribution over possible bond actions
        • WBAW_{BA} is the weight matrix for the BA module
        • hb,ijth_{b,ij}^t is the pair feature of the identified bond
    • Losses for action prediction modules:
      • Atom action loss: LAA=tlog(pAA,itt)\mathcal{L}_{AA} = -\sum_t \log(p_{AA,i_t}^t)
        • LAA\mathcal{L}_{AA} is the atom action loss
        • iti_t is atom index identified at step tt
        • pAA,ittp_{AA,i_t}^t is the atom action probability at atom iti_t
      • Bond action loss: LBA=tlog(pBA,ijtt)+log(pAA,itt)+log(pAA,jtt)\mathcal{L}_{BA} = -\sum_t \log(p_{BA,ij_t}^t) + \log(p_{AA,i_t}^t) + \log(p_{AA,j_t}^t)
        • LBA\mathcal{L}_{BA} is the bond action loss
        • ijtij_t is the bond index identified at step tt
        • pBA,ijttp_{BA,ij_t}^t is the bond action probability at bond ijtij_t
        • pAA,ittp_{AA,i_t}^t and pAA,jttp_{AA,j_t}^t are the atom action probabilities for the two atoms involved in the bond
  6. Termination Determination: Predicts whether the retrosynthesis process should conclude or proceed to the next reaction step using graph-level super node features hgth_g^t:
    • Termination probability pTtp_T^t:
      • pTt=σ(WThgt)p_T^t=\sigma (W_T h_g^t)
        • pTtp_T^t is the termination probability
        • WTW_T is the weight matrix for the termination module
        • hgth_g^t is the graph-level super node features
    • Loss function for the termination decision task:
      • LT=t(yTtlog(pTt)+(1yTt)log(1pTt))\mathcal{L}_T = -\sum_t \bigg( y_T^t \log(p_T^t) + (1-y_T^t) \log(1-p_T^t) \bigg)
        • LT\mathcal{L}_T is the loss for termination
        • yTty_T^t is the ground truth termination probability
        • pTtp_T^t is the predicted termination probability
  7. Dynamic Adaptive Multi-Task Learning (DAMT): Integrates the four losses ($\mathcal{L_{RC}$, $\mathcal{L_{AA}$, $\mathcal{L_{BA}$, and $\mathcal{L_{T}$) by dynamically adjusting loss weights based on task complexity and normalizing magnitudes to ensure balanced learning.

The model was trained using the AdamW optimizer with a WarmUpWrapper and ReduceLROnPlateau learning rate scheduler, consisting of 6 encoder blocks with a atom hidden dimension of 256 and pair hidden dimension of 128.

The paper presents a performance evaluation against template-based, template-free, and semi-template methods using top-k exact match accuracy on the USPTO-50k benchmark dataset. In the reaction type unknown case, HierRetro achieved a top-3 accuracy of 78.3\%, outperforming all other models. In the reaction type known case, HierRetro achieved a top-3 accuracy of 89.4\%, surpassing Graph2Edits by 1.9\%, and a top-5 accuracy of 93.3%, outperforming Graph2Edits and LocalRetro by 1.8\% and 0.9%, respectively. Round-trip accuracy evaluation showed the model achieved 94.7\% accuracy for k=3 and 96.2\% for k=5, outperforming all models except LocalRetro, and at k=10, HierRetro achieved the highest accuracy of 97.9\%, surpassing Graph2Edits.

Ablation studies demonstrated that incorporating a contrastive learning (CL)-based pre-trained encoder improved top-1 accuracy by 2.5\%, and adding the reaction center type prediction module increased reaction center identification accuracy by approximately 5%.

The paper includes an analysis of multiple reaction centers, addressing the challenges of limited generalization caused by the USPTO dataset’s constrained size through a permutation-based augmentation strategy. Analysis of drug molecules Fruquintinib and Nirogacestat, not included in the training dataset, demonstrates the model’s practical applications in multistep retrosynthesis predictions.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Authors (2)

Collections

Sign up for free to add this paper to one or more collections.