Skip to content

Commit f2d7aa8

Browse files
authored
Improve error message for too long reactions (#58)
1 parent 89df26c commit f2d7aa8

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

rxnmapper/core.py

+8
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ def convert_batch_to_attns(
128128
return_tensors="pt",
129129
)
130130
parsed_input = {k: v.to(self.device) for k, v in encoded_ids.items()}
131+
132+
max_input_length = parsed_input["input_ids"].shape[1]
133+
max_supported_by_model = self.model.config.max_position_embeddings
134+
if max_input_length > max_supported_by_model:
135+
raise ValueError(
136+
f"Reaction SMILES has {max_input_length} tokens, should be at most {max_supported_by_model}."
137+
)
138+
131139
with torch.no_grad():
132140
output = self.model(**parsed_input)
133141
attentions = output[2]

tests/test_mapper.py

+12
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,15 @@ def test_reaction_with_asterisks(rxn_mapper: RXNMapper):
137137

138138
results = rxn_mapper.get_attention_guided_atom_maps(rxns, canonicalize_rxns=False)
139139
assert_correct_maps(results, expected)
140+
141+
142+
def test_too_long_reaction_smiles_produce_exception_with_understandable_error_message(
143+
rxn_mapper: RXNMapper,
144+
):
145+
# dummy reaction with 1 + 3 + 500 * 2 + 3 + 1 = 1008 tokens
146+
rxn = "C=C" + "[C+][C-]" * 500 + ">>CC"
147+
148+
with pytest.raises(ValueError) as excinfo:
149+
_ = rxn_mapper.get_attention_guided_atom_maps([rxn], canonicalize_rxns=False)
150+
151+
assert "Reaction SMILES has 1008 tokens, should be" in str(excinfo.value)

0 commit comments

Comments
 (0)