Alessandra
Toniato
*ab,
Alain C.
Vaucher
ab,
Philippe
Schwaller
ab and
Teodoro
Laino
ab
aIBM Research Europe, Saümerstrasse 4, 8803 Rüschlikon, Switzerland. E-mail: ato@zurich.ibm.com
bNational Center for Competence in Research-Catalysis (NCCR-Catalysis), Zurich, Switzerland
First published on 16th February 2023
Over the past four years, several research groups demonstrated the combination of domain-specific language representation with recent NLP architectures to accelerate innovation in a wide range of scientific fields. Chemistry is a great example. Among the various chemical challenges addressed with language models, retrosynthesis demonstrates some of the most distinctive successes and limitations. Single-step retrosynthesis, the task of identifying reactions able to decompose a complex molecule into simpler structures, can be cast as a translation problem, in which a text-based representation of the target molecule is converted into a sequence of possible precursors. A common issue is a lack of diversity in the proposed disconnection strategies. The suggested precursors typically fall in the same reaction family, which limits the exploration of the chemical space. We present a retrosynthesis Transformer model that increases the diversity of the predictions by prepending a classification token to the language representation of the target molecule. At inference, the use of these prompt tokens allows us to steer the model towards different kinds of disconnection strategies. We show that the diversity of the predictions improves consistently, which enables recursive synthesis tools to circumvent dead ends and consequently, suggests synthesis pathways for more complex molecules.
Irrespective of the approach, either template-free or template-based, the principle that underlies all these methods is the same: a model is trained on some data (often the compound to synthesize, given as a text string, an embedding, or a graph) and then evaluated by comparing its output to a target (the set of “optimal” precursors). However, this perspective is sometimes at odds with the chemistry at hand. In fact, for each target molecule, there is generally a wide variety of valid disconnections that connect the target molecule to different sets of precursors. If the dataset were hypothetically perfectly balanced, all conceivable reactions leading to a target molecule would be evenly represented, but in practice this is far from being the case. Existing reaction datasets, and consequently models, give more weight to well-represented reaction classes, thus penalising more interesting but less frequent disconnections. For example, Fig. 1 shows an insufficient diversity for the proposed list of disconnections. Here, we interpret diversity as “chemical class diversity”, considering a model more diverse in its predictions if these belong to different reaction classes as defined by NameRXN.14
Fig. 1 Classes of the single-step baseline predictions for the 2,3-diamino-6-nitrotoluene molecule, as produced by Schwaller et al.6 As can be noted, all but one are different forms of deprotection ordered by model confidence. |
To increase the diversity of the predictions in single-step text-based retrosynthesis models and counteract the effect of imbalanced datasets, we propose a prompt-based scheme to enhance and guide more diversity in the language model predictions. We introduce a modified transformer-based model.6,15 Inspired by works in natural language processing for prompt-based learning,16–19 we show that concatenating a class information during training (as an additional token), leads to more diverse predictions at inference. We experiment with different classification strategies, including clustering reaction fingerprints20 to evaluate the adequate number of tokens. We compare the cluster token prompt model to a baseline translation model in terms of topn accuracy, round-trip accuracy, class diversity and coverage. After training our model on the proprietary Pistachio21 data, we increased the class diversity of the predictions to an average of 5.3 for each reaction target compared to 1.9 of the pristine model, while retaining a high value of 62% for the round-trip accuracy of the disconnections.
At test time, the input product molecule can be concatenated to all the available cluster tokens (see Fig. 2, bottom), generating X equivalent inputs, where X is the number of cluster tokens used. The first token seen by the transformer is the cluster token. This will steer the predictions towards typical disconnections for that class. Collecting all the predictions that were ranked the highest by the model, commonly called top one (top1) predictions, for the X class-tokens (and possibly additional predictions for each of the X class-tokens), leads to a set of disconnections more diverse than the topN outputs of a regular Transformer model, which we use as a baseline. The advantage of this strategy is that the steering acts as a weak influencer of the predictions, rather than a forcing term, such as using a certain template, which can either lead to failure or success. In comparison to the baseline model, the cluster token prompt approach allows the model to “select” from a limited pool of options while yet leaving it with much flexibility. In the following section, we present our models and the results in more details.
The data were first suitably pre-processed (see Section 3.1). We used two ways to produce the cluster tokens to prepend in front of each reaction: the first one relies on the NameRXN classification and the second one on a K-means clustering algorithm. For the K-means clustering, we identified the clusters with the reaction fingerprints20 (see Section 3.3 for details). The models tested are described below:
• baseline: a Transformer model6,15 with no cluster-token information.
• 12clusters: a model that utilizes as tokens all the first level classification available from NameRXN (i.e. classes from ‘0’ to ‘11’).
• 3clustersRandom: a model built on top of the 12 classes from NameRXN which we grouped randomly in 3 clusters.
• 4clustersRandom: same as the model above, but with 4 clusters.
• 3clustersKmeans: this model results from the application of the K-means clustering algorithm with 3 clusters on the 3 dimensions obtained from a PCA analysis of the reaction fingerprints.
• 4clustersKmeans: same as the model above, but with 4 clusters.
• optimalKmeans: in this model, we estimated the optimal PCA dimension for the fingerprints (14) and the optimal number of clusters (10). The procedure is described in Section 3.3.
Once the token was identified for each reaction, it was prepended to the SMILES string with the following format: [i] for i = 0…X (see Fig. 2), with X being the number of tokens available in each of the models.
For the models evaluation, we split randomly the data into a training/validation/test set with a proportion of 80/10/10 for five different random seeds, and we proceeded as follows:
(1) We chose one of the splits randomly and we trained all the cluster token prompt models. We tested them against the validation set and chose the best performing model.
(2) Then, we merged the train and validation set for the five different seeds and trained the best prompt-based model plus the baseline model.
(3) We compared the so-trained baseline and best models against the test sets.
Each of the trained models, including the baseline, was trained for 260000 steps with 1 GPU (approximately 48 hours of training). Indeed, at later checkpoints no improvement over the loss function was observed.
In Fig. 3, we report the results for the prompt-based models evaluated on the validation set. For each model, we retained the top24 predictions as X*topk = 24 = topN where X is the number of class tokens for each model and topk is the number of predictions retained for each token-concatenated sample (e.g. for the 12clusters model, X = 12 and topk = 2). The plots report 4 metrics of interest as a function of the number of topN predictions analyzed (see Section 3.4 for the metrics definition). To properly compare models, we looked only at top20 predictions (and not top24), as for the optimalKmeans model only 20 predictions per sample were produced (2 for each token-conditioned input).
Fig. 3 Model metrics. Top left: coverage. Top right: topn accuracy. Bottom left: class diversity. Bottom right: round-trip accuracy. For the definition of each metric please refer to Section 3.4. |
All cluster token models show a good coverage (above 95%) after top3 predictions. The 12clustersKmeans model is the only one performing poorly from this point of view. Looking at the accuracy, we see that it increases slowly and reaches a top20 value between 18% and 25% for all models. In addition to reactants, our retrosynthesis models predicts a wide range of precursors, and is not limited to the disconnected fragments only. Therefore, many times the ground truth appears with a slightly different set of reagents, justifying the low accuracy values. Accordingly, when a model can produce multiple correct answers, accuracy is not the most crucial metric to consider. Different publications have been questioning the suitability of the top1/topN accuracy for single-step retrosynthesis models.6,28 We consider the value of the round-trip accuracy to be more interesting (see Section 3.4). This value measures the ability to recover the input molecule by running a forward reaction model on top of the predicted precursors (details on the forward model are in Section 3.2). This metric decreases with the number of topN predictions considered. The decay is more consistent for models utilizing a greater number of tokens (12clusters, 12clustersKmeans). Note that this is to be expected, since we are asking for disconnection conditions that may be impossible to satisfy for some input molecules. However, a high value of coverage guarantees at least one proposed valid disconnection per input molecule. It is important to note that round-trip accuracy does not take into consideration that the top20 predictions for a sample, even if correct, can all collapse into one. This happens for example if the model predicts an identical set of reactants multiple times (or for the case with reagents, multiple times the same reactants and a different solvent). For this reason, the final metric that we report, the class diversity, is the most interesting one as it takes into account all these challenges. It measures the average of the different (NameRXN) classes predicted for a given input, considering only the valid predictions (see Section 3.4). The value highly depends on the number of cluster tokens used and differs from one strategy (NameRXN) to the other (K-means clustering). As a clarifying example, a class diversity of 5 means that there are at least 5 valid predictions that are fairly different. A baseline with an average class diversity of 1.9 for 20 predictions means that even if all predictions are valid, on average only 1.9 are interesting because of being distinctly different from one another.
Using more tokens results in more diversity in the predictions (5.2 for the 12clusters model at top20 predictions), but also a higher number of incorrect predictions. The 12clustersKmeans model instead loses in round-trip accuracy without a relevant compensation on the class diversity side. The most interesting models are the 12clusters, from the point of view of the increased class diversity, and the optimalKmeans, which reaches decent values of class diversity and could be used also in a setting where the reaction classification labels are not available.
In a second step, we chose the best models (12clusters and optimalKmeans), and compared their performance against the baseline. We evaluated our models on five randomly chosen test splits, where, this time, the validation set was included in the training. The results on the top20 predictions are reported in Fig. 4. As can be seen, the prompt-based model does indeed boost the diversity of the predictions. On the test set, we achieve an average boost of class diversity of about 3.4 points for the 12clusters model. For completeness, we report in Appendix B the behaviour of the baseline model and the best models as a function of the topn predictions, with standard errors.
Fig. 4 Final comparison of the best prompt-based models and the baseline against the test set. The values of the metrics reported are averaged across five random seeds. For convenience, standard error values are reported in Table 1. |
Table 1 shows the (top20) metrics with standard error bounds for the three models under consideration, generated from the five different random seed experiments.
Model | Coverage | Accuracy | Round-trip accuracy | Class diversity |
---|---|---|---|---|
Baseline | 96.58 ± 0.06% | 28.28 ± 0.05% | 79.50 ± 0.68% | 1.90 ± 0.01 |
optimalKmeans | 97.69 ± 0.04% | 19.02 ± 0.47% | 66.27 ± 0.95% | 3.67 ± 0.02 |
12clustersKmeans | 97.94 ± 0.06% | 18.42 ± 0.31% | 62.03 ± 0.53% | 5.27 ± 0.05 |
For comparison, we report in Fig. 5 an example of prediction with the baseline model and the 12clusters model. While for the baseline the proposed disconnections all belong to the class of Saponification reactions (6), for the 12clusters model we observe much more diversity in terms of reaction classes. Also, looking at the main reactants generated, the prompt-based model proposes different alternatives (e.g. Acylation reaction versus Saponification).
Fig. 5 A chemical example predictions with the baseline retrosynthesis model and the prompt-based model. |
• removal of duplicates and invalid reactions
• merge reactants and reagents: in chemistry reactants are the main actors in the reaction, but they are helped by other molecules that allow the reaction to take place (e.g. solvents) even if not contributing atoms to the final product. In our work we merged reactants and reagents (also known as ‘precursors’) on the left hand side of the reaction (e.g. A > B > C → A B ≫ C).
• set on the precursors: given no real relationship between the number of times a molecule appears in the patent reaction and the stoichiometry, we made molecules unique.
• removal of multi-products reactions: this operation was performed after removing residual precursors molecules from the product side.
• removal of reactions where the product contains atom types not present in the precursors side.
• removal of single-atom products.
• removal of reactions where the absolute formal charge exceeded the value of 2.
• removal of reactions where the maximum number of tokens was above 500.
• removal of reactions with the same set of precursors, but different products.
We provide the already cleaned public dataset USPTO 50k26 together with the code.
The cleaned dataset was randomly split into training, test and validation sets (80%/10%/10%) for five different random seeds. One of these splits was used to choose the best cluster token model, while the comparison to the baseline was performed against all five random seeds, merging validation and train set.
Fig. 6 Relevant PCA components analysis. Two drops in the variance are observed around the 2nd/3rd component and a smaller one around the 14th component. |
For the 12clustersKmeans, 3clustersKmeans and 4clustersKmeans models, we kept only the first three components. For the optimalKmeans model we shot further and included all the first 14 components. Subsequently, for the K-means clustering, we relied on a fixed number of clusters for the first models (12clustersKmeans, 3clustersKmeans and 4clustersKmeans). On the other hand, for the optimalKmeans model, we first performed an analysis to determine the optimal grouping.33 This can be done by measuring the sum of the squared distances to the nearest cluster center (inertia). This allows computing a plot of the inertias against the number of clusters used. The optimal k is said to coincide with the elbow of the plot, where the inertia value change starts to be less significant. The inertia plots can be found in Appendix C.
Fig. 7 shows the clusters generated for the training set of the optimalKmeans model. The plots for the other K-means-models can be found in Appendix C.
Fig. 7 t-SNE projection for 50000 training samples of the optimalKmeans model. The different colours represent the different K-means clusters. |
(1) |
(2) |
(3) |
The whole data processing procedure, the dataset, the scripts and the models are available with the code. For this smaller dataset, we built three random models and three models based on clustering of reaction fingerprints. We used 2, 5 and 10 tokens for the clustering. As for Pistachio, we chose the best cluster token prompt-based models by comparing them against the validation set. We concluded the analysis with the confrontation against the baseline on five random seeds on the test set.
In Fig. 8, we compare the cluster token prompt-based models trained on USPTO 50k, while in Fig. 9, we compare the final best models.
Fig. 8 Metrics for the models trained on USPTO. Top left: coverage. Top right: topn accuracy. Bottom left: class diversity. Bottom right: round-trip accuracy. |
Fig. 9 Final comparison of the best cluster token prompt-based models and the baseline against the test set for the open source dataset. The values of the metrics reported are averaged across 5 random seeds. For convenience, standard error values are reported in Table 2. |
Differently from the results with Pistachio we notice that the models can better predict the ground truth precursors. It is to be noted that USPTO 50k is a smaller dataset where only reactants and not reagents are reported (differently from Pistachio), so the training task is much easier than with Pistachio. At the same time, though, the round-trip accuracy has a quite low value, even if the forward model for the evaluations was trained with the same USPTO 50k dataset and reached an Accuray of 77.46% (and 95.29% accuracy on the classification model). This behaviour can be ascribed to the fact that the dataset is too small and it is not able to generalize sufficiently well. On top of this, it is to be noted that for USPTO-50k the accuracy and round-trip accuracy of the prompt models increase a lot compared to the baseline model. This trend is inverted for the case of the Pistachio dataset (see Fig. 4). We believe that the increased accuracy with respect to the baseline can be ascribed to the size and easiness of the open-source dataset. Indeed, a model trained on USPTO 50k sees less examples from each of the classes and this gives more specificity to the conditioning token, which gives an additional hint to the model for the prediction with respect to the baseline (higher topN accuracy). Then, being the task easier (only reactants), the round-trip accuracy shows also an increase. For Pistachio this does not happen because the reaction space is much larger and diverse and includes reagents. The conditioning in this case has access to more reactions and therefore many predictions can include the original disconnection with different reagents (lower topN accuracy), which the proxy model might not be confident enough to validate (lower round-trip accuracy).
Looking at Fig. 9, we see that for the 10clusters model, corresponding to using all the reaction classes ids as single tokens, the class diversity increases to 3.1. The best top20 accuracy as well as the round-trip accuracy is reached by the 10clustersKmeans model.
We also report the standard error values at top20 predictions for the best models, computed with the same random seeds. The values can be found in Table 2. We observe that the error bar is more significant for the open-source models. This can be ascribed to the smaller dataset. Indeed, for only 50k data points we cannot create sufficiently general splits as for the 2 million data samples from Pistachio. The 10clustersKmeans model is the best compromise through all metrics.
Model | Coverage | Accuracy | Round-trip accuracy | Class diversity |
---|---|---|---|---|
Baseline | 94.64 ± 0.97% | 70.94 ± 0.31% | 23.13 ± 0.46% | 1.54 ± 0.29 |
10clusters | 97.28 ± 0.10% | 67.84 ± 0.47% | 30.84 ± 0.65% | 3.06 ± 0.07 |
10clustersKmeans | 97.49 ± 0.15% | 74.09 ± 0.17% | 41.21 ± 0.69% | 2.60 ± 0.04 |
Fig. 10 Metrics for the baseline model trained on Pistachio. Top left: coverage. Top right: topn accuracy. Bottom left: class diversity. Bottom right: round-trip accuracy. |
The same plots are reported in Fig. 11 and 12 for the 12clusters and the optimalKmeans models. For all models, it can be observed that the standard error on the class diversity is quite high, changing a lot across compounds, but it is the same for the cluster token prompt-based models and the baseline.
Fig. 11 Metrics for the 12clusters model trained on Pistachio. Top left: coverage. Top right: topn accuracy. Bottom left: class diversity. Bottom right: round-trip accuracy. |
Fig. 12 Metrics for the optimalKmeans model trained on Pistachio. Top left: coverage. Top right: topn accuracy. Bottom left: class diversity. Bottom right: round-trip accuracy. |
The cluster token prompt models trained with Pistachio are also accessible through the IBM RXN for Chemistry website.34
This journal is © The Royal Society of Chemistry 2023 |