Commit f2d7aa8 1 parent 89df26c commit f2d7aa8 Copy full SHA for f2d7aa8
File tree 2 files changed +20
-0
lines changed
2 files changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -128,6 +128,14 @@ def convert_batch_to_attns(
128
128
return_tensors = "pt" ,
129
129
)
130
130
parsed_input = {k : v .to (self .device ) for k , v in encoded_ids .items ()}
131
+
132
+ max_input_length = parsed_input ["input_ids" ].shape [1 ]
133
+ max_supported_by_model = self .model .config .max_position_embeddings
134
+ if max_input_length > max_supported_by_model :
135
+ raise ValueError (
136
+ f"Reaction SMILES has { max_input_length } tokens, should be at most { max_supported_by_model } ."
137
+ )
138
+
131
139
with torch .no_grad ():
132
140
output = self .model (** parsed_input )
133
141
attentions = output [2 ]
Original file line number Diff line number Diff line change @@ -137,3 +137,15 @@ def test_reaction_with_asterisks(rxn_mapper: RXNMapper):
137
137
138
138
results = rxn_mapper .get_attention_guided_atom_maps (rxns , canonicalize_rxns = False )
139
139
assert_correct_maps (results , expected )
140
+
141
+
142
+ def test_too_long_reaction_smiles_produce_exception_with_understandable_error_message (
143
+ rxn_mapper : RXNMapper ,
144
+ ):
145
+ # dummy reaction with 1 + 3 + 500 * 2 + 3 + 1 = 1008 tokens
146
+ rxn = "C=C" + "[C+][C-]" * 500 + ">>CC"
147
+
148
+ with pytest .raises (ValueError ) as excinfo :
149
+ _ = rxn_mapper .get_attention_guided_atom_maps ([rxn ], canonicalize_rxns = False )
150
+
151
+ assert "Reaction SMILES has 1008 tokens, should be" in str (excinfo .value )
You can’t perform that action at this time.
0 commit comments