Skip to content

Commit

Permalink
Rerun chai-1 in single-seq mode if MSA inputs cause error
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Jan 28, 2025
1 parent d64b4f7 commit 98d9614
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions forks/chai-lab/chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def run_inference(
num_diffn_timesteps: int = 200,
seed: int | None = None,
device: torch.device | None = None,
rerun_in_single_seq_mode: bool = True,
) -> StructureCandidates:
# Prepare inputs
assert fasta_file.exists(), fasta_file
Expand Down Expand Up @@ -346,14 +347,46 @@ def run_inference(
restraint_context=restraint_context,
)

return run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
seed=seed,
device=device,
)
try:
folding_outputs = run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
seed=seed,
device=device,
)
except Exception as e:
print(f"Error during folding: {e}")
if rerun_in_single_seq_mode:
print("Rerunning in single-sequence mode")
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
msa_profile_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
feature_context = AllAtomFeatureContext(
chains=chains,
structure_context=merged_context,
msa_context=msa_context,
profile_msa_context=msa_profile_context,
template_context=template_context,
embedding_context=embedding_context,
restraint_context=restraint_context,
)
folding_outputs = run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
seed=seed,
device=device,
)
else:
raise e

return folding_outputs


def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor:
Expand Down

0 comments on commit 98d9614

Please sign in to comment.