Skip to content

Commit e4a7b69

Browse files
Avelina9Xbaberabb
andauthored
Added softmax_dtype argument to HFLM to coerce log_softmax computations (#2921)
* Added softmax_dtype argument to coerce log_softmax computations * move softmax_dtype --------- Co-authored-by: Baber <baber@hey.com>
1 parent 930d837 commit e4a7b69

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

lm_eval/models/huggingface.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
max_length: Optional[int] = None,
7575
device: Optional[str] = "cuda",
7676
dtype: Optional[Union[str, torch.dtype]] = "auto",
77+
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
7778
batch_size: Optional[Union[int, str]] = 1,
7879
max_batch_size: Optional[int] = 64,
7980
trust_remote_code: Optional[bool] = False,
@@ -234,6 +235,9 @@ def __init__(
234235
self.batch_schedule = 1
235236
self.batch_sizes = {}
236237
self.max_batch_size = max_batch_size
238+
self.softmax_dtype = (
239+
get_dtype(softmax_dtype) if softmax_dtype is not None else None
240+
)
237241

238242
if str(batch_size).startswith("auto"):
239243
batch_size = batch_size.split(":")
@@ -768,7 +772,11 @@ def forward_batch(batch_size):
768772
(batch_size, max_length), device=self.device
769773
).long()
770774
for _ in range(5):
771-
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) # noqa: F841
775+
out = F.log_softmax( # noqa: F841
776+
self._model_call(test_batch, **call_kwargs),
777+
dim=-1,
778+
dtype=self.softmax_dtype,
779+
)
772780

773781
return batch_size
774782

@@ -1200,7 +1208,9 @@ def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
12001208
}
12011209

12021210
multi_logits = F.log_softmax(
1203-
self._model_call(batched_inps, **call_kwargs), dim=-1
1211+
self._model_call(batched_inps, **call_kwargs),
1212+
dim=-1,
1213+
dtype=self.softmax_dtype,
12041214
) # [batch, padding_length (inp or cont), vocab]
12051215

12061216
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(

0 commit comments

Comments
 (0)