Skip to content

Commit 5f713b4

Browse files
committed
addition of particle guidance
1 parent fcad6fb commit 5f713b4

File tree

5 files changed

+177
-3
lines changed

5 files changed

+177
-3
lines changed

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,44 @@ Then to test it:
6767

6868
python test_boltzmann.py --model_dir workdir/boltz_T500 --temp 500 --model_steps 20 --original_model_dir /workdir/drugs_seed_boltz/ --out boltzmann.out
6969

70+
71+
## Particle Guidance sampling
72+
73+
In [this manuscript]() we propose a new sampling method for jointly sampling a set of particles using diffusion models that we call particle guidance. We demonstrate that for the task of molecular conformer generation this provides significant improvements in precision and recall compared to standard I.I.D. diffusion sampling. To run the particle guidance sampling with torsional diffusion to replicate the results of the paper (similarly you can run on your own molecules)
74+
75+
For the permutation invariant kernel guidance (higher quality, slower):
76+
77+
# minimizing recall error
78+
python generate_confs.py --tqdm --batch_size 128 --no_energy --inference_steps=20 --model_dir=workdir/drugs_default --test_csv=data/DRUGS/test_smiles.csv --pg_invariant=True --pg_kernel_size_log_0=1.7565691770646286 --pg_kernel_size_log_1=1.1960868735428605 --pg_langevin_weight_log_0=-2.2245183818892103 --pg_langevin_weight_log_1=-2.403905082248579 --pg_repulsive_weight_log_0=-2.158537381110402 --pg_repulsive_weight_log_1=-2.717482077162461 --pg_weight_log_0=0.8004013644746992 --pg_weight_log_1=-0.9255658381081596
79+
# minimizing precision error
80+
python generate_confs.py --tqdm --batch_size 128 --no_energy --inference_steps=20 --model_dir=workdir/drugs_default --test_csv=data/DRUGS/test_smiles.csv --pg_invariant=True --pg_kernel_size_log_0=-0.9686202580381296 --pg_kernel_size_log_1=-0.7808409291022302 --pg_langevin_weight_log_0=-2.434216242826782 --pg_langevin_weight_log_1=-0.2602238633333869 --pg_repulsive_weight_log_0=-2.0439285313973237 --pg_repulsive_weight_log_1=-1.468234554877924 --pg_weight_log_0=0.3495680598729498 --pg_weight_log_1=-0.22001939454654185
81+
82+
83+
For the non-permutation invariant kernel guidance (faster, slightly lower quality, but still better than I.I.D.):
84+
85+
# minimizing recall error
86+
python generate_confs.py --tqdm --batch_size 128 --no_energy --inference_steps=20 --model_dir=workdir/drugs_default --test_csv=data/DRUGS/test_smiles.csv --pg_kernel_size_log_0=2.35958 --pg_kernel_size_log_1=-0.78826 --pg_langevin_weight_log_0=-1.55054 --pg_langevin_weight_log_1=-2.70316 --pg_repulsive_weight_log_0=1.01317 --pg_repulsive_weight_log_1=-2.68407 --pg_weight_log_0=0.60504 --pg_weight_log_1=-1.15020
87+
# minimizing precision error
88+
python generate_confs.py --tqdm --batch_size 128 --no_energy --inference_steps=20 --model_dir=workdir/drugs_default --test_csv=data/DRUGS/test_smiles.csv --pg_kernel_size_log_0=1.29503 --pg_kernel_size_log_1=1.45944 --pg_langevin_weight_log_0=-2.88867 --pg_langevin_weight_log_1=-2.47591 --pg_repulsive_weight_log_0=-1.01222 --pg_repulsive_weight_log_1=-1.91253 --pg_weight_log_0=-0.16253 --pg_weight_log_1=0.79355
89+
7090
## Citation
91+
92+
If you use this code, please cite:
93+
7194
@article{jing2022torsional,
7295
title={Torsional Diffusion for Molecular Conformer Generation},
7396
author={Bowen Jing and Gabriele Corso and Jeffrey Chang and Regina Barzilay and Tommi Jaakkola},
7497
journal={arXiv preprint arXiv:2206.01729},
7598
year={2022}
7699
}
77100

101+
If you also employ the particle guidance sampling technique, please also cite:
102+
103+
@article{corso2023particle,
104+
title={Particle Guidance: non-I.I.D. Diverse Sampling with Diffusion Models},
105+
author={Gabriele Corso and Yilun Xu and Valentin de Bortoli and Regina Barzilay and Tommi Jaakkola},
106+
year={2023}
107+
}
108+
78109
## License
79110
MIT

diffusion/sampling.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from torch_geometric.loader import DataLoader
1010
from rdkit import Chem, Geometry
1111
from rdkit.Chem import AllChem
12+
13+
from utils.utils import time_limit, TimeoutException
1214
from utils.visualise import PDBFile
15+
from spyrmsd import molecule, graph
1316

1417
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1518
still_frames = 10
@@ -91,13 +94,65 @@ def perturb_seeds(data, pdb=None):
9194

9295

9396
def sample(conformers, model, sigma_max=np.pi, sigma_min=0.01 * np.pi, steps=20, batch_size=32,
94-
ode=False, likelihood=None, pdb=None):
97+
ode=False, likelihood=None, pdb=None, pg_weight_log_0=None, pg_repulsive_weight_log_0=None,
98+
pg_weight_log_1=None, pg_repulsive_weight_log_1=None, pg_kernel_size_log_0=None,
99+
pg_kernel_size_log_1=None, pg_langevin_weight_log_0=None, pg_langevin_weight_log_1=None,
100+
pg_invariant=False, mol=None):
101+
95102
conf_dataset = InferenceDataset(conformers)
96103
loader = DataLoader(conf_dataset, batch_size=batch_size, shuffle=False)
97104

98105
sigma_schedule = 10 ** np.linspace(np.log10(sigma_max), np.log10(sigma_min), steps + 1)[:-1]
99106
eps = 1 / steps
100107

108+
if pg_weight_log_0 is not None and pg_weight_log_1 is not None:
109+
edge_index, edge_mask = conformers[0].edge_index, conformers[0].edge_mask
110+
edge_list = [[] for _ in range(torch.max(edge_index) + 1)]
111+
112+
for p in edge_index.T:
113+
edge_list[p[0]].append(p[1])
114+
115+
rot_bonds = [(p[0], p[1]) for i, p in enumerate(edge_index.T) if edge_mask[i]]
116+
117+
dihedral = []
118+
for a, b in rot_bonds:
119+
c = edge_list[a][0] if edge_list[a][0] != b else edge_list[a][1]
120+
d = edge_list[b][0] if edge_list[b][0] != a else edge_list[b][1]
121+
dihedral.append((c.item(), a.item(), b.item(), d.item()))
122+
dihedral_numpy = np.asarray(dihedral)
123+
dihedral = torch.tensor(dihedral)
124+
125+
if pg_invariant:
126+
try:
127+
with time_limit(10):
128+
mol = molecule.Molecule.from_rdkit(mol)
129+
130+
aprops = mol.atomicnums
131+
am = mol.adjacency_matrix
132+
133+
# Convert molecules to graphs
134+
G = graph.graph_from_adjacency_matrix(am, aprops)
135+
136+
# Get all the possible graph isomorphisms
137+
isomorphisms = graph.match_graphs(G, G)
138+
isomorphisms = [iso[0] for iso in isomorphisms]
139+
isomorphisms = np.asarray(isomorphisms)
140+
141+
# filter out those having an effect on the dihedrals
142+
dih_iso = isomorphisms[:, dihedral_numpy]
143+
dih_iso = np.unique(dih_iso, axis=0)
144+
145+
if len(dih_iso) > 32:
146+
print("reduce isomorphisms from", len(dih_iso), "to", 32)
147+
dih_iso = dih_iso[np.random.choice(len(dih_iso), replace=False, size=32)]
148+
else:
149+
print("isomorphisms", len(dih_iso))
150+
dih_iso = torch.from_numpy(dih_iso).to(device)
151+
152+
except TimeoutException as e:
153+
print("Timeout generating with non invariant kernel")
154+
pg_invariant = False
155+
101156
for batch_idx, data in enumerate(loader):
102157

103158
dlogp = torch.zeros(data.num_graphs)
@@ -112,6 +167,10 @@ def sample(conformers, model, sigma_max=np.pi, sigma_min=0.01 * np.pi, steps=20,
112167
z = torch.normal(mean=0, std=1, size=data_gpu.edge_pred.shape)
113168
score = data_gpu.edge_pred.cpu()
114169

170+
t = sigma_idx / steps # t is really 1-t
171+
pg_weight = 10**(pg_weight_log_0 * t + pg_weight_log_1 * (1 - t)) if pg_weight_log_0 is not None and pg_weight_log_1 is not None else 0.0
172+
pg_repulsive_weight = 10**(pg_repulsive_weight_log_0 * t + pg_repulsive_weight_log_1 * (1 - t)) if pg_repulsive_weight_log_0 is not None and pg_repulsive_weight_log_1 is not None else 1.0
173+
115174
if ode:
116175
perturb = 0.5 * g ** 2 * eps * score
117176
if likelihood:
@@ -120,6 +179,34 @@ def sample(conformers, model, sigma_max=np.pi, sigma_min=0.01 * np.pi, steps=20,
120179
else:
121180
perturb = g ** 2 * eps * score + g * np.sqrt(eps) * z
122181

182+
if pg_weight > 0:
183+
n = data.num_graphs
184+
if pg_invariant:
185+
S, D, _ = dih_iso.shape
186+
dih_iso_cat = dih_iso.reshape(-1, 4)
187+
tau = get_torsion_angles(dih_iso_cat, data_gpu.pos, n)
188+
tau_diff = tau.unsqueeze(1) - tau.unsqueeze(0)
189+
tau_diff = torch.fmod(tau_diff + 3 * np.pi, 2 * np.pi) - np.pi
190+
tau_diff = tau_diff.reshape(n, n, S, D)
191+
tau_matrix = torch.sum(tau_diff ** 2, dim=-1, keepdim=True)
192+
tau_matrix, indices = torch.min(tau_matrix, dim=2)
193+
tau_diff = torch.gather(tau_diff, 2, indices.unsqueeze(-1).repeat(1, 1, 1, D)).squeeze(2)
194+
else:
195+
tau = get_torsion_angles(dihedral, data_gpu.pos, n)
196+
tau_diff = tau.unsqueeze(1) - tau.unsqueeze(0)
197+
tau_diff = torch.fmod(tau_diff+3*np.pi, 2*np.pi)-np.pi
198+
assert torch.all(tau_diff < np.pi + 0.1) and torch.all(tau_diff > -np.pi - 0.1), tau_diff
199+
tau_matrix = torch.sum(tau_diff**2, dim=-1, keepdim=True)
200+
201+
kernel_size = 10 ** (pg_kernel_size_log_0 * t + pg_kernel_size_log_1 * (1 - t)) if pg_kernel_size_log_0 is not None and pg_kernel_size_log_1 is not None else 1.0
202+
langevin_weight = 10 ** (pg_langevin_weight_log_0 * t + pg_langevin_weight_log_1 * (1 - t)) if pg_langevin_weight_log_0 is not None and pg_langevin_weight_log_1 is not None else 1.0
203+
204+
k = torch.exp(-1 / kernel_size * tau_matrix)
205+
repulsive = torch.sum(2/kernel_size*tau_diff*k, dim=1).cpu().reshape(-1) / n
206+
207+
perturb = (0.5 * g ** 2 * eps * score) + langevin_weight * (0.5 * g ** 2 * eps * score + g * np.sqrt(eps) * z)
208+
perturb += pg_weight * (g ** 2 * eps * (score + pg_repulsive_weight * repulsive))
209+
123210
conf_dataset.apply_torsion_and_update_pos(data, perturb.numpy())
124211
data_gpu.pos = data.pos.to(device)
125212

generate_confs.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@
3434
parser.add_argument('--batch_size', type=int, default=32, help='Number of conformers generated in parallel')
3535
parser.add_argument('--xtb', type=str, default=None, help='If set, it indicates path to local xtb main directory')
3636
parser.add_argument('--no_energy', action='store_true', default=False, help='If set skips computation of likelihood, energy etc')
37+
38+
parser.add_argument('--pg_weight_log_0', type=float, default=None)
39+
parser.add_argument('--pg_weight_log_1', type=float, default=None)
40+
parser.add_argument('--pg_repulsive_weight_log_0', type=float, default=None)
41+
parser.add_argument('--pg_repulsive_weight_log_1', type=float, default=None)
42+
parser.add_argument('--pg_langevin_weight_log_0', type=float, default=None)
43+
parser.add_argument('--pg_langevin_weight_log_1', type=float, default=None)
44+
parser.add_argument('--pg_kernel_size_log_0', type=float, default=None)
45+
parser.add_argument('--pg_kernel_size_log_1', type=float, default=None)
46+
parser.add_argument('--pg_invariant', type=bool, default=False)
3747
args = parser.parse_args()
3848

3949
"""
@@ -113,7 +123,15 @@ def sample_confs(raw_smi, n_confs, smi):
113123

114124
if not args.no_model and n_rotable_bonds > 0.5:
115125
conformers = sample(conformers, model, args.sigma_max, args.sigma_min, args.inference_steps,
116-
args.batch_size, args.ode, args.likelihood, pdb)
126+
args.batch_size, args.ode, args.likelihood, pdb,
127+
pg_weight_log_0=args.pg_weight_log_0, pg_weight_log_1=args.pg_weight_log_1,
128+
pg_repulsive_weight_log_0=args.pg_repulsive_weight_log_0,
129+
pg_repulsive_weight_log_1=args.pg_repulsive_weight_log_1,
130+
pg_kernel_size_log_0=args.pg_kernel_size_log_0,
131+
pg_kernel_size_log_1=args.pg_kernel_size_log_1,
132+
pg_langevin_weight_log_0=args.pg_langevin_weight_log_0,
133+
pg_langevin_weight_log_1=args.pg_langevin_weight_log_1,
134+
pg_invariant=args.pg_invariant, mol=mol)
117135

118136
if args.dump_pymol:
119137
if not osp.isdir(args.dump_pymol):

utils/torsion.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,25 @@ def perturb_batch(data, torsion_updates, split=False, return_updates=False):
100100
idx_edges += mask_rotate.shape[0]
101101
if return_updates:
102102
return pos_new, torsion_update_list
103-
return pos_new
103+
return pos_new
104+
105+
106+
def bdot(a, b):
107+
return torch.sum(a*b, dim=-1, keepdim=True)
108+
109+
110+
def get_torsion_angles(dihedral, batch_pos, batch_size):
111+
batch_pos = batch_pos.reshape(batch_size, -1, 3)
112+
113+
c, a, b, d = dihedral[:, 0], dihedral[:, 1], dihedral[:, 2], dihedral[:, 3]
114+
c_project_ab = batch_pos[:,a] + bdot(batch_pos[:,c] - batch_pos[:,a], batch_pos[:,b] - batch_pos[:,a]) / bdot(batch_pos[:,b] - batch_pos[:,a], batch_pos[:,b] - batch_pos[:,a]) * (batch_pos[:,b] - batch_pos[:,a])
115+
d_project_ab = batch_pos[:,a] + bdot(batch_pos[:,d] - batch_pos[:,a], batch_pos[:,b] - batch_pos[:,a]) / bdot(batch_pos[:,b] - batch_pos[:,a], batch_pos[:,b] - batch_pos[:,a]) * (batch_pos[:,b] - batch_pos[:,a])
116+
dshifted = batch_pos[:,d] - d_project_ab + c_project_ab
117+
cos = bdot(dshifted - c_project_ab, batch_pos[:,c] - c_project_ab) / (
118+
torch.norm(dshifted - c_project_ab, dim=-1, keepdim=True) * torch.norm(batch_pos[:,c] - c_project_ab, dim=-1,
119+
keepdim=True))
120+
cos = torch.clamp(cos, -1 + 1e-5, 1 - 1e-5)
121+
angle = torch.acos(cos)
122+
sign = torch.sign(bdot(torch.cross(dshifted - c_project_ab, batch_pos[:,c] - c_project_ab), batch_pos[:,b] - batch_pos[:,a]))
123+
torsion_angles = (angle * sign).squeeze(-1)
124+
return torsion_angles

utils/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,20 @@ def start(self, tag):
5454
def end(self, tag):
5555
self.times[tag] += time.time() - self.starts[tag]
5656
del self.starts[tag]
57+
58+
59+
import signal
60+
from contextlib import contextmanager
61+
class TimeoutException(Exception): pass
62+
63+
@contextmanager
64+
def time_limit(seconds):
65+
def signal_handler(signum, frame):
66+
raise TimeoutException("Timed out!")
67+
68+
signal.signal(signal.SIGALRM, signal_handler)
69+
signal.alarm(seconds)
70+
try:
71+
yield
72+
finally:
73+
signal.alarm(0)

0 commit comments

Comments
 (0)