@@ -74,6 +74,7 @@ def __init__(
74
74
max_length : Optional [int ] = None ,
75
75
device : Optional [str ] = "cuda" ,
76
76
dtype : Optional [Union [str , torch .dtype ]] = "auto" ,
77
+ softmax_dtype : Optional [Union [str , torch .dtype ]] = None ,
77
78
batch_size : Optional [Union [int , str ]] = 1 ,
78
79
max_batch_size : Optional [int ] = 64 ,
79
80
trust_remote_code : Optional [bool ] = False ,
@@ -234,6 +235,9 @@ def __init__(
234
235
self .batch_schedule = 1
235
236
self .batch_sizes = {}
236
237
self .max_batch_size = max_batch_size
238
+ self .softmax_dtype = (
239
+ get_dtype (softmax_dtype ) if softmax_dtype is not None else None
240
+ )
237
241
238
242
if str (batch_size ).startswith ("auto" ):
239
243
batch_size = batch_size .split (":" )
@@ -768,7 +772,11 @@ def forward_batch(batch_size):
768
772
(batch_size , max_length ), device = self .device
769
773
).long ()
770
774
for _ in range (5 ):
771
- out = F .log_softmax (self ._model_call (test_batch , ** call_kwargs ), dim = - 1 ) # noqa: F841
775
+ out = F .log_softmax ( # noqa: F841
776
+ self ._model_call (test_batch , ** call_kwargs ),
777
+ dim = - 1 ,
778
+ dtype = self .softmax_dtype ,
779
+ )
772
780
773
781
return batch_size
774
782
@@ -1200,7 +1208,9 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
1200
1208
}
1201
1209
1202
1210
multi_logits = F .log_softmax (
1203
- self ._model_call (batched_inps , ** call_kwargs ), dim = - 1
1211
+ self ._model_call (batched_inps , ** call_kwargs ),
1212
+ dim = - 1 ,
1213
+ dtype = self .softmax_dtype ,
1204
1214
) # [batch, padding_length (inp or cont), vocab]
1205
1215
1206
1216
for (request_str , ctx_tokens , _ ), logits , inplen , cont_toks in zip (
0 commit comments