9
9
from torch_geometric .loader import DataLoader
10
10
from rdkit import Chem , Geometry
11
11
from rdkit .Chem import AllChem
12
+
13
+ from utils .utils import time_limit , TimeoutException
12
14
from utils .visualise import PDBFile
15
+ from spyrmsd import molecule , graph
13
16
14
17
device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
15
18
still_frames = 10
@@ -91,13 +94,65 @@ def perturb_seeds(data, pdb=None):
91
94
92
95
93
96
def sample (conformers , model , sigma_max = np .pi , sigma_min = 0.01 * np .pi , steps = 20 , batch_size = 32 ,
94
- ode = False , likelihood = None , pdb = None ):
97
+ ode = False , likelihood = None , pdb = None , pg_weight_log_0 = None , pg_repulsive_weight_log_0 = None ,
98
+ pg_weight_log_1 = None , pg_repulsive_weight_log_1 = None , pg_kernel_size_log_0 = None ,
99
+ pg_kernel_size_log_1 = None , pg_langevin_weight_log_0 = None , pg_langevin_weight_log_1 = None ,
100
+ pg_invariant = False , mol = None ):
101
+
95
102
conf_dataset = InferenceDataset (conformers )
96
103
loader = DataLoader (conf_dataset , batch_size = batch_size , shuffle = False )
97
104
98
105
sigma_schedule = 10 ** np .linspace (np .log10 (sigma_max ), np .log10 (sigma_min ), steps + 1 )[:- 1 ]
99
106
eps = 1 / steps
100
107
108
+ if pg_weight_log_0 is not None and pg_weight_log_1 is not None :
109
+ edge_index , edge_mask = conformers [0 ].edge_index , conformers [0 ].edge_mask
110
+ edge_list = [[] for _ in range (torch .max (edge_index ) + 1 )]
111
+
112
+ for p in edge_index .T :
113
+ edge_list [p [0 ]].append (p [1 ])
114
+
115
+ rot_bonds = [(p [0 ], p [1 ]) for i , p in enumerate (edge_index .T ) if edge_mask [i ]]
116
+
117
+ dihedral = []
118
+ for a , b in rot_bonds :
119
+ c = edge_list [a ][0 ] if edge_list [a ][0 ] != b else edge_list [a ][1 ]
120
+ d = edge_list [b ][0 ] if edge_list [b ][0 ] != a else edge_list [b ][1 ]
121
+ dihedral .append ((c .item (), a .item (), b .item (), d .item ()))
122
+ dihedral_numpy = np .asarray (dihedral )
123
+ dihedral = torch .tensor (dihedral )
124
+
125
+ if pg_invariant :
126
+ try :
127
+ with time_limit (10 ):
128
+ mol = molecule .Molecule .from_rdkit (mol )
129
+
130
+ aprops = mol .atomicnums
131
+ am = mol .adjacency_matrix
132
+
133
+ # Convert molecules to graphs
134
+ G = graph .graph_from_adjacency_matrix (am , aprops )
135
+
136
+ # Get all the possible graph isomorphisms
137
+ isomorphisms = graph .match_graphs (G , G )
138
+ isomorphisms = [iso [0 ] for iso in isomorphisms ]
139
+ isomorphisms = np .asarray (isomorphisms )
140
+
141
+ # filter out those having an effect on the dihedrals
142
+ dih_iso = isomorphisms [:, dihedral_numpy ]
143
+ dih_iso = np .unique (dih_iso , axis = 0 )
144
+
145
+ if len (dih_iso ) > 32 :
146
+ print ("reduce isomorphisms from" , len (dih_iso ), "to" , 32 )
147
+ dih_iso = dih_iso [np .random .choice (len (dih_iso ), replace = False , size = 32 )]
148
+ else :
149
+ print ("isomorphisms" , len (dih_iso ))
150
+ dih_iso = torch .from_numpy (dih_iso ).to (device )
151
+
152
+ except TimeoutException as e :
153
+ print ("Timeout generating with non invariant kernel" )
154
+ pg_invariant = False
155
+
101
156
for batch_idx , data in enumerate (loader ):
102
157
103
158
dlogp = torch .zeros (data .num_graphs )
@@ -112,6 +167,10 @@ def sample(conformers, model, sigma_max=np.pi, sigma_min=0.01 * np.pi, steps=20,
112
167
z = torch .normal (mean = 0 , std = 1 , size = data_gpu .edge_pred .shape )
113
168
score = data_gpu .edge_pred .cpu ()
114
169
170
+ t = sigma_idx / steps # t is really 1-t
171
+ pg_weight = 10 ** (pg_weight_log_0 * t + pg_weight_log_1 * (1 - t )) if pg_weight_log_0 is not None and pg_weight_log_1 is not None else 0.0
172
+ pg_repulsive_weight = 10 ** (pg_repulsive_weight_log_0 * t + pg_repulsive_weight_log_1 * (1 - t )) if pg_repulsive_weight_log_0 is not None and pg_repulsive_weight_log_1 is not None else 1.0
173
+
115
174
if ode :
116
175
perturb = 0.5 * g ** 2 * eps * score
117
176
if likelihood :
@@ -120,6 +179,34 @@ def sample(conformers, model, sigma_max=np.pi, sigma_min=0.01 * np.pi, steps=20,
120
179
else :
121
180
perturb = g ** 2 * eps * score + g * np .sqrt (eps ) * z
122
181
182
+ if pg_weight > 0 :
183
+ n = data .num_graphs
184
+ if pg_invariant :
185
+ S , D , _ = dih_iso .shape
186
+ dih_iso_cat = dih_iso .reshape (- 1 , 4 )
187
+ tau = get_torsion_angles (dih_iso_cat , data_gpu .pos , n )
188
+ tau_diff = tau .unsqueeze (1 ) - tau .unsqueeze (0 )
189
+ tau_diff = torch .fmod (tau_diff + 3 * np .pi , 2 * np .pi ) - np .pi
190
+ tau_diff = tau_diff .reshape (n , n , S , D )
191
+ tau_matrix = torch .sum (tau_diff ** 2 , dim = - 1 , keepdim = True )
192
+ tau_matrix , indices = torch .min (tau_matrix , dim = 2 )
193
+ tau_diff = torch .gather (tau_diff , 2 , indices .unsqueeze (- 1 ).repeat (1 , 1 , 1 , D )).squeeze (2 )
194
+ else :
195
+ tau = get_torsion_angles (dihedral , data_gpu .pos , n )
196
+ tau_diff = tau .unsqueeze (1 ) - tau .unsqueeze (0 )
197
+ tau_diff = torch .fmod (tau_diff + 3 * np .pi , 2 * np .pi )- np .pi
198
+ assert torch .all (tau_diff < np .pi + 0.1 ) and torch .all (tau_diff > - np .pi - 0.1 ), tau_diff
199
+ tau_matrix = torch .sum (tau_diff ** 2 , dim = - 1 , keepdim = True )
200
+
201
+ kernel_size = 10 ** (pg_kernel_size_log_0 * t + pg_kernel_size_log_1 * (1 - t )) if pg_kernel_size_log_0 is not None and pg_kernel_size_log_1 is not None else 1.0
202
+ langevin_weight = 10 ** (pg_langevin_weight_log_0 * t + pg_langevin_weight_log_1 * (1 - t )) if pg_langevin_weight_log_0 is not None and pg_langevin_weight_log_1 is not None else 1.0
203
+
204
+ k = torch .exp (- 1 / kernel_size * tau_matrix )
205
+ repulsive = torch .sum (2 / kernel_size * tau_diff * k , dim = 1 ).cpu ().reshape (- 1 ) / n
206
+
207
+ perturb = (0.5 * g ** 2 * eps * score ) + langevin_weight * (0.5 * g ** 2 * eps * score + g * np .sqrt (eps ) * z )
208
+ perturb += pg_weight * (g ** 2 * eps * (score + pg_repulsive_weight * repulsive ))
209
+
123
210
conf_dataset .apply_torsion_and_update_pos (data , perturb .numpy ())
124
211
data_gpu .pos = data .pos .to (device )
125
212
0 commit comments