Skip to content

Commit 0bfb022

Browse files
authored
batch loglikelihood_rolling across requests (#2559)
* batch all rolling token windows * nit * copy to vllm * fix max_length for `get_rolling_token_windows` * bugfix * bugfix * add type hints
1 parent 976d8a0 commit 0bfb022

File tree

3 files changed

+110
-50
lines changed

3 files changed

+110
-50
lines changed

lm_eval/models/huggingface.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -905,8 +905,6 @@ def _select_cont_toks(
905905
def loglikelihood_rolling(
906906
self, requests: List[Instance], disable_tqdm: bool = False
907907
) -> List[float]:
908-
loglikelihoods = []
909-
910908
adaptive_batch_size = None
911909
if self.batch_size == "auto":
912910
# using rolling window with maximum context
@@ -915,10 +913,17 @@ def loglikelihood_rolling(
915913
print(f"Determined Largest batch size: {batch_size}")
916914
adaptive_batch_size = batch_size
917915

918-
for (string,) in tqdm(
919-
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
916+
# First, collect all windows from all requests
917+
all_windows = [] # List of (request_idx, window) tuples
918+
request_window_counts = [] # Track number of windows per request
919+
920+
for req_idx, (string,) in enumerate(
921+
tqdm(
922+
[req.args for req in requests],
923+
disable=(disable_tqdm or (self.rank != 0)),
924+
)
920925
):
921-
rolling_token_windows = list(
926+
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
922927
map(
923928
utils.make_disjoint_window,
924929
utils.get_rolling_token_windows(
@@ -931,37 +936,55 @@ def loglikelihood_rolling(
931936
)
932937

933938
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
934-
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
935-
936-
pad_amnt = 0
937-
if self.world_size > 1:
938-
# We pad out the external document-level iterator so the inner iterator doesn't hang
939-
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
940-
gathered = (
941-
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
942-
)
939+
windows = [(None,) + x for x in rolling_token_windows]
943940

944-
pad_amnt = max(gathered) - gathered[self.rank]
945-
if pad_amnt > 0:
946-
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
941+
# Store windows with their request index
942+
all_windows.extend((req_idx, window) for window in windows)
943+
request_window_counts.append(len(windows))
947944

948-
string_nll = self._loglikelihood_tokens(
949-
requests=rolling_token_windows,
950-
disable_tqdm=True,
951-
override_bs=adaptive_batch_size,
945+
# Handle distributed case padding
946+
pad_amnt = 0
947+
if self.world_size > 1:
948+
mytensor = torch.tensor(len(all_windows), device=self.device)
949+
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
950+
pad_amnt = max(gathered) - gathered[self.rank]
951+
if pad_amnt > 0:
952+
all_windows += pad_amnt * [all_windows[0]]
953+
954+
all_nlls = []
955+
batch_size = adaptive_batch_size or self.batch_size
956+
for i in range(0, len(all_windows), batch_size):
957+
batch = all_windows[i : i + batch_size]
958+
# Extract just the windows for processing, keeping track of request indices
959+
batch_indices, batch_windows = zip(*batch)
960+
961+
batch_nlls = self._loglikelihood_tokens(
962+
requests=batch_windows,
963+
disable_tqdm=False,
964+
override_bs=len(batch_windows),
952965
)
966+
# Store results with their request indices
967+
all_nlls.extend(zip(batch_indices, batch_nlls))
953968

954-
if (self.world_size > 1) and (pad_amnt > 0):
955-
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
956-
else:
957-
# discard is_greedy
958-
string_nll = [x[0] for x in string_nll]
959-
960-
string_nll = sum(string_nll)
961-
loglikelihoods.append(string_nll)
969+
# Remove padding if necessary
970+
if (self.world_size > 1) and (pad_amnt > 0):
971+
all_nlls = all_nlls[:-pad_amnt]
962972

963-
# cache this loglikelihood_rolling request
964-
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
973+
# Reconstruct per-request loglikelihoods
974+
loglikelihoods = []
975+
current_idx = 0
976+
for window_count in request_window_counts:
977+
# Get all nlls for this request
978+
request_nlls = all_nlls[current_idx : current_idx + window_count]
979+
# Sum up the nlls for this request (discarding is_greedy)
980+
request_total = sum(nll[0] for _, nll in request_nlls)
981+
loglikelihoods.append(request_total)
982+
current_idx += window_count
983+
984+
string = requests[len(loglikelihoods) - 1].args[0]
985+
self.cache_hook.add_partial(
986+
"loglikelihood_rolling", (string,), request_total
987+
)
965988

966989
return loglikelihoods
967990

lm_eval/models/vllm_causallms.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102
self.batch_size = (
103103
"auto"
104104
if isinstance(batch_size, str) and "auto" in batch_size
105-
else batch_size
105+
else int(batch_size)
106106
)
107107
if self.data_parallel_size <= 1:
108108
self.model = LLM(**self.model_args)
@@ -281,10 +281,21 @@ def run_inference_one_model(
281281
def loglikelihood_rolling(
282282
self, requests: List[Instance], disable_tqdm: bool = False
283283
) -> List[float]:
284-
loglikelihoods = []
285-
286-
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
287-
rolling_token_windows = list(
284+
adaptive_batch_size = None
285+
if self.batch_size == "auto":
286+
adaptive_batch_size = len(requests)
287+
288+
# First, collect all windows from all requests
289+
all_windows = [] # List of (request_idx, window) tuples
290+
request_window_counts = [] # Track number of windows per request
291+
292+
for req_idx, (string,) in enumerate(
293+
tqdm(
294+
[req.args for req in requests],
295+
disable=(disable_tqdm or (self.rank != 0)),
296+
)
297+
):
298+
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
288299
map(
289300
make_disjoint_window,
290301
get_rolling_token_windows(
@@ -297,20 +308,42 @@ def loglikelihood_rolling(
297308
)
298309
)
299310

300-
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
311+
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
312+
windows = [(None,) + x for x in rolling_token_windows]
301313

302-
string_nll = self._loglikelihood_tokens(
303-
rolling_token_windows,
304-
)
314+
# Store windows with their request index
315+
all_windows.extend((req_idx, window) for window in windows)
316+
request_window_counts.append(len(windows))
305317

306-
# discard is_greedy
307-
string_nll = [x[0] for x in string_nll]
318+
all_nlls = []
319+
batch_size = adaptive_batch_size or int(self.batch_size)
320+
for i in range(0, len(all_windows), batch_size):
321+
batch = all_windows[i : i + batch_size]
322+
# Extract just the windows for processing, keeping track of request indices
323+
batch_indices, batch_windows = zip(*batch)
308324

309-
string_nll = sum(string_nll)
310-
loglikelihoods.append(string_nll)
325+
batch_nlls = self._loglikelihood_tokens(
326+
requests=batch_windows,
327+
disable_tqdm=False,
328+
)
329+
# Store results with their request indices
330+
all_nlls.extend(zip(batch_indices, batch_nlls))
311331

312-
# cache this loglikelihood_rolling request
313-
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
332+
# Reconstruct per-request loglikelihoods
333+
loglikelihoods = []
334+
current_idx = 0
335+
for window_count in request_window_counts:
336+
# Get all nlls for this request
337+
request_nlls = all_nlls[current_idx : current_idx + window_count]
338+
# Sum up the nlls for this request (discarding is_greedy)
339+
request_total = sum(nll[0] for _, nll in request_nlls)
340+
loglikelihoods.append(request_total)
341+
current_idx += window_count
342+
343+
string = requests[len(loglikelihoods) - 1].args[0]
344+
self.cache_hook.add_partial(
345+
"loglikelihood_rolling", (string,), request_total
346+
)
314347

315348
return loglikelihoods
316349

lm_eval/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import re
1111
from dataclasses import asdict, is_dataclass
1212
from itertools import islice
13-
from typing import Any, Callable, List
13+
from typing import Any, Callable, Generator, List, Tuple
1414

1515
import numpy as np
1616
import yaml
@@ -201,7 +201,9 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
201201
return [f for f in filenames if "/samples_" in f and ".json" in f]
202202

203203

204-
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
204+
def get_rolling_token_windows(
205+
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
206+
) -> Generator[Tuple[List[int], List[int]], None, None]:
205207
"""
206208
- context_len allows for a rolling window context, allowing each prediction window to potentially
207209
condition on some context
@@ -228,7 +230,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
228230

229231
# Special handling for first window: predict all tokens
230232
first_seq_len = min(max_seq_len, len(token_list))
231-
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
233+
yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
232234
predicted += first_seq_len
233235

234236
while predicted < len(token_list):
@@ -242,7 +244,9 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
242244
predicted += window_pred_len
243245

244246

245-
def make_disjoint_window(pair):
247+
def make_disjoint_window(
248+
pair: Tuple[List[int], List[int]],
249+
) -> Tuple[List[int], List[int]]:
246250
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
247251
a, b = pair
248252
return a[: len(a) - (len(b) - 1)], b

0 commit comments

Comments
 (0)