-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathtest_mapper.py
123 lines (98 loc) · 4.87 KB
/
test_mapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pytest
from rxnmapper import RXNMapper
from .utils import assert_correct_maps
@pytest.fixture(scope="module")
def rxn_mapper() -> RXNMapper:
"""
Fixture to get the RXNMapper, cached with module scope so that the weights
do not need to be loaded multiple times.
"""
return RXNMapper()
def test_example_maps(rxn_mapper: RXNMapper):
rxns = [
"CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F",
"C1COCCO1.CC(C)(C)OC(=O)CONC(=O)NCc1cccc2ccccc12.Cl>>O=C(O)CONC(=O)NCc1cccc2ccccc12",
"C=CCN=C=S.CNCc1ccc(C#N)cc1.NNC(=O)c1cn2c(n1)CCCC2>>C=CCN1C(C2=CN3CCCCC3=N2)=NN=C1N(C)CC1=CC=C(C#N)C=C1",
]
expected = [
{
"mapped_rxn": "[CH3:1][CH:2]([CH3:3])[SH:4].CN(C)C=O.F[c:5]1[n:6][cH:7][cH:8][cH:9][c:10]1[F:11].O=C([O-])[O-].[K+].[K+]>>[CH3:1][CH:2]([CH3:3])[S:4][c:5]1[n:6][cH:7][cH:8][cH:9][c:10]1[F:11]",
"confidence": 0.9565619900376546,
},
{
"mapped_rxn": "C1COCCO1.CC(C)(C)[O:3][C:2](=[O:1])[CH2:4][O:5][NH:6][C:7](=[O:8])[NH:9][CH2:10][c:11]1[cH:12][cH:13][cH:14][c:15]2[cH:16][cH:17][cH:18][cH:19][c:20]12.Cl>>[O:1]=[C:2]([OH:3])[CH2:4][O:5][NH:6][C:7](=[O:8])[NH:9][CH2:10][c:11]1[cH:12][cH:13][cH:14][c:15]2[cH:16][cH:17][cH:18][cH:19][c:20]12",
"confidence": 0.9704424331552834,
},
{
"mapped_rxn": "S=[C:17]=[N:4][CH2:3][CH:2]=[CH2:1].[NH:18]([CH3:19])[CH2:20][c:21]1[cH:22][cH:23][c:24]([C:25]#[N:26])[cH:27][cH:28]1.O=[C:5]([c:6]1[cH:7][n:8]2[c:9]([n:10]1)[CH2:11][CH2:12][CH2:13][CH2:14]2)[NH:15][NH2:16]>>[CH2:1]=[CH:2][CH2:3][n:4]1[c:5](-[c:6]2[cH:7][n:8]3[c:9]([n:10]2)[CH2:11][CH2:12][CH2:13][CH2:14]3)[n:15][n:16][c:17]1[N:18]([CH3:19])[CH2:20][c:21]1[cH:22][cH:23][c:24]([C:25]#[N:26])[cH:27][cH:28]1",
"confidence": 0.919023506871605,
},
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns)
assert_correct_maps(results, expected)
def test_fragment_bond(rxn_mapper: RXNMapper):
rxns = ["CC[O-]~[Na+].BrCC.[Na+]~[H-]>>CCOCC"]
expected = [
{
"mapped_rxn": "Br[CH2:2][CH3:1].[Na+]~[O-:3][CH2:4][CH3:5].[H-]~[Na+]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5]",
"confidence": 0.9606074439250337,
}
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns)
assert_correct_maps(results, expected)
def test_extended_smiles_format(rxn_mapper: RXNMapper):
rxns = ["CC[O-].[Na+].BrCC.[Na+].[H-]>>CCOCC |f:0.1,3.4|"]
expected = [
{
"mapped_rxn": "Br[CH2:2][CH3:1].[Na+].[O-:3][CH2:4][CH3:5].[H-].[Na+]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5] |f:1.2,3.4|",
"confidence": 0.9606074439250337,
}
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns)
assert_correct_maps(results, expected)
def test_no_canonicalization(rxn_mapper: RXNMapper):
rxns = ["C(C)O.BrC(C)>>CCOCC"]
# Note that in the mapped RXN, the first reactant still has a parenthesis (which is
# desired). The other parenthesis disappears, but there is probably not much we can
# do here to keep it.
expected = [
{
"mapped_rxn": "[CH2:2]([CH3:1])[OH:3].Br[CH2:4][CH3:5]>>[CH3:1][CH2:2][O:3][CH2:4][CH3:5]",
"confidence": 0.9754605679009868,
}
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
assert_correct_maps(results, expected)
def test_reaction_with_invalid_valence(rxn_mapper: RXNMapper):
# Here, "BrCFC" and "CCOCFC" are valid SMILES with invalid valence. Still,
# the model is able to do a prediction for them.
rxns = ["CCO.BrCFC>>CCOCFC"]
expected = [
{
"mapped_rxn": "[CH3:1][CH2:2][OH:3].Br[CH2:4][F:5][CH3:6]>>[CH3:1][CH2:2][O:3][CH2:4][F:5][CH3:6]",
"confidence": 0.9730222157275086,
}
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
assert_correct_maps(results, expected)
def test_multiple_products(rxn_mapper: RXNMapper):
# Reverse the reaction from the previous example
rxns = ["CCOCC.[Na+]~[Br-]>>CC[O-]~[Na+].BrCC"]
expected = [
{
"mapped_rxn": "[CH3:1][CH2:2][O:6][CH2:5][CH3:4].[Br-:3]~[Na+:7]>>[CH3:1][CH2:2][Br:3].[CH3:4][CH2:5][O-:6]~[Na+:7]",
"confidence": 0.7312865053856896,
}
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns)
assert_correct_maps(results, expected)
def test_reaction_with_dative_bond(rxn_mapper: RXNMapper):
rxns = ["COC(=O)CCBr.O=C[O-]->[K+]>>COC(=O)CCOC(=O)C"]
expected = [
{
"mapped_rxn": "[CH3:1][O:2][C:3](=[O:4])[CH2:5][CH2:6]Br.[O:9]=[CH:8][O-:7]->[K+]>>[CH3:1][O:2][C:3](=[O:4])[CH2:5][CH2:6][O:7][C:8](=[O:9])[CH3:10]",
"confidence": 0.9322116305783666,
}
]
results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
assert_correct_maps(results, expected)