@@ -905,8 +905,6 @@ def _select_cont_toks(
905
905
def loglikelihood_rolling (
906
906
self , requests : List [Instance ], disable_tqdm : bool = False
907
907
) -> List [float ]:
908
- loglikelihoods = []
909
-
910
908
adaptive_batch_size = None
911
909
if self .batch_size == "auto" :
912
910
# using rolling window with maximum context
@@ -915,10 +913,17 @@ def loglikelihood_rolling(
915
913
print (f"Determined Largest batch size: { batch_size } " )
916
914
adaptive_batch_size = batch_size
917
915
918
- for (string ,) in tqdm (
919
- [req .args for req in requests ], disable = (disable_tqdm or (self .rank != 0 ))
916
+ # First, collect all windows from all requests
917
+ all_windows = [] # List of (request_idx, window) tuples
918
+ request_window_counts = [] # Track number of windows per request
919
+
920
+ for req_idx , (string ,) in enumerate (
921
+ tqdm (
922
+ [req .args for req in requests ],
923
+ disable = (disable_tqdm or (self .rank != 0 )),
924
+ )
920
925
):
921
- rolling_token_windows = list (
926
+ rolling_token_windows : List [ Tuple [ List [ int ], List [ int ]]] = list (
922
927
map (
923
928
utils .make_disjoint_window ,
924
929
utils .get_rolling_token_windows (
@@ -931,37 +936,55 @@ def loglikelihood_rolling(
931
936
)
932
937
933
938
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
934
- rolling_token_windows = [(None ,) + x for x in rolling_token_windows ]
935
-
936
- pad_amnt = 0
937
- if self .world_size > 1 :
938
- # We pad out the external document-level iterator so the inner iterator doesn't hang
939
- mytensor = torch .tensor (len (rolling_token_windows ), device = self .device )
940
- gathered = (
941
- self .accelerator .gather (mytensor ).cpu ().detach ().numpy ().tolist ()
942
- )
939
+ windows = [(None ,) + x for x in rolling_token_windows ]
943
940
944
- pad_amnt = max ( gathered ) - gathered [ self . rank ]
945
- if pad_amnt > 0 :
946
- rolling_token_windows += pad_amnt * [ rolling_token_windows [ 0 ]]
941
+ # Store windows with their request index
942
+ all_windows . extend (( req_idx , window ) for window in windows )
943
+ request_window_counts . append ( len ( windows ))
947
944
948
- string_nll = self ._loglikelihood_tokens (
949
- requests = rolling_token_windows ,
950
- disable_tqdm = True ,
951
- override_bs = adaptive_batch_size ,
945
+ # Handle distributed case padding
946
+ pad_amnt = 0
947
+ if self .world_size > 1 :
948
+ mytensor = torch .tensor (len (all_windows ), device = self .device )
949
+ gathered = self .accelerator .gather (mytensor ).cpu ().detach ().numpy ().tolist ()
950
+ pad_amnt = max (gathered ) - gathered [self .rank ]
951
+ if pad_amnt > 0 :
952
+ all_windows += pad_amnt * [all_windows [0 ]]
953
+
954
+ all_nlls = []
955
+ batch_size = adaptive_batch_size or self .batch_size
956
+ for i in range (0 , len (all_windows ), batch_size ):
957
+ batch = all_windows [i : i + batch_size ]
958
+ # Extract just the windows for processing, keeping track of request indices
959
+ batch_indices , batch_windows = zip (* batch )
960
+
961
+ batch_nlls = self ._loglikelihood_tokens (
962
+ requests = batch_windows ,
963
+ disable_tqdm = False ,
964
+ override_bs = len (batch_windows ),
952
965
)
966
+ # Store results with their request indices
967
+ all_nlls .extend (zip (batch_indices , batch_nlls ))
953
968
954
- if (self .world_size > 1 ) and (pad_amnt > 0 ):
955
- string_nll = [x [0 ] for x in string_nll [:- pad_amnt ]]
956
- else :
957
- # discard is_greedy
958
- string_nll = [x [0 ] for x in string_nll ]
959
-
960
- string_nll = sum (string_nll )
961
- loglikelihoods .append (string_nll )
969
+ # Remove padding if necessary
970
+ if (self .world_size > 1 ) and (pad_amnt > 0 ):
971
+ all_nlls = all_nlls [:- pad_amnt ]
962
972
963
- # cache this loglikelihood_rolling request
964
- self .cache_hook .add_partial ("loglikelihood_rolling" , (string ,), string_nll )
973
+ # Reconstruct per-request loglikelihoods
974
+ loglikelihoods = []
975
+ current_idx = 0
976
+ for window_count in request_window_counts :
977
+ # Get all nlls for this request
978
+ request_nlls = all_nlls [current_idx : current_idx + window_count ]
979
+ # Sum up the nlls for this request (discarding is_greedy)
980
+ request_total = sum (nll [0 ] for _ , nll in request_nlls )
981
+ loglikelihoods .append (request_total )
982
+ current_idx += window_count
983
+
984
+ string = requests [len (loglikelihoods ) - 1 ].args [0 ]
985
+ self .cache_hook .add_partial (
986
+ "loglikelihood_rolling" , (string ,), request_total
987
+ )
965
988
966
989
return loglikelihoods
967
990
0 commit comments