Skip to content

Commit a500911

Browse files
committed
fix(gpt): drop deprecation usage of get_max_length()
1 parent bf0ec25 commit a500911

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

ChatTTS/model/gpt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ def _prepare_generation_inputs(
187187
if cache_position is not None
188188
else past_key_values.get_seq_length()
189189
)
190-
max_cache_length = past_key_values.get_max_length()
190+
try:
191+
max_cache_length = past_key_values.get_max_cache_shape()
192+
except:
193+
max_cache_length = past_key_values.get_max_length() # deprecated in transformers 4.48
191194
cache_length = (
192195
past_length
193196
if max_cache_length is None

0 commit comments

Comments
 (0)