diff --git a/finetune/adapter.py b/finetune/adapter.py index fc815830..cdc40be8 100644 --- a/finetune/adapter.py +++ b/finetune/adapter.py @@ -30,6 +30,8 @@ from lit_llama.tokenizer import Tokenizer from scripts.prepare_alpaca import generate_prompt from lightning.fabric.strategies import DeepSpeedStrategy +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader instruction_tuning = True @@ -113,6 +115,7 @@ def train( train_data: np.ndarray, val_data: np.ndarray, out_dir: str, + group_by_length: bool = False, ) -> None: """The training loop. @@ -120,8 +123,10 @@ def train( """ step_count = 0 - for iter_num in range(max_iters): - + loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length) + for iter_num, (input_ids, targets) in enumerate(loader): + if iter_num >= max_iters: + break if step_count <= warmup_iters: # linear warmup lr = learning_rate * step_count / warmup_iters @@ -223,6 +228,46 @@ def pad_right(x, pad_id): return x, y +class InstructionDataset(Dataset): + def __init__(self, data: list): + self._data = data + + def __len__(self): + return len(self._data) + + def __getitem__(self, i: int): + input_ids = self._data[i]["input_ids"].type(torch.int64) + labels = self._data[i]["labels"].type(torch.int64) + return input_ids, labels + + +def get_dataloader( + fabric: L.Fabric, + data: torch.Tensor, + micro_batch_size: int, + group_by_length: bool, +): + from length_grouped_sampler import LengthGroupedSampler + + def collate_fn(batch): + x, y = zip(*batch) + batch_x = pad_sequence(x, batch_first=True) + batch_y = pad_sequence(y, batch_first=True, padding_value=-1) + return batch_x, batch_y + + dataset = InstructionDataset(data) + sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None + loader = DataLoader( + dataset, + batch_size=micro_batch_size, + shuffle=(sampler is None), + sampler=sampler, + collate_fn=collate_fn, + pin_memory=True, + ) + return fabric.setup_dataloaders(loader) + + def load_datasets(data_dir): train_data = torch.load(os.path.join(data_dir, "train.pt")) val_data = torch.load(os.path.join(data_dir, "test.pt")) diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index c686cd15..741c4a1c 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -36,6 +36,8 @@ from lit_llama.tokenizer import Tokenizer from scripts.prepare_alpaca import generate_prompt from lightning.fabric.strategies import DeepSpeedStrategy +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader eval_interval = 600 @@ -119,6 +121,7 @@ def train( train_data: np.ndarray, val_data: np.ndarray, out_dir: str, + group_by_length: bool = False, ) -> None: """The training loop. @@ -126,8 +129,10 @@ def train( """ step_count = 0 - for iter_num in range(max_iters): - + loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length) + for iter_num, (input_ids, targets) in enumerate(loader): + if iter_num >= max_iters: + break if step_count <= warmup_iters: # linear warmup lr = learning_rate * step_count / warmup_iters @@ -227,6 +232,46 @@ def pad_right(x, pad_id): return x, y +class InstructionDataset(Dataset): + def __init__(self, data: list): + self._data = data + + def __len__(self): + return len(self._data) + + def __getitem__(self, i: int): + input_ids = self._data[i]["input_ids"].type(torch.int64) + labels = self._data[i]["labels"].type(torch.int64) + return input_ids, labels + + +def get_dataloader( + fabric: L.Fabric, + data: torch.Tensor, + micro_batch_size: int, + group_by_length: bool, +): + from length_grouped_sampler import LengthGroupedSampler + + def collate_fn(batch): + x, y = zip(*batch) + batch_x = pad_sequence(x, batch_first=True) + batch_y = pad_sequence(y, batch_first=True, padding_value=-1) + return batch_x, batch_y + + dataset = InstructionDataset(data) + sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None + loader = DataLoader( + dataset, + batch_size=micro_batch_size, + shuffle=(sampler is None), + sampler=sampler, + collate_fn=collate_fn, + pin_memory=True, + ) + return fabric.setup_dataloaders(loader) + + def load_datasets(data_dir): train_data = torch.load(os.path.join(data_dir, "train.pt")) val_data = torch.load(os.path.join(data_dir, "test.pt")) diff --git a/finetune/full.py b/finetune/full.py index bf94da49..58246861 100644 --- a/finetune/full.py +++ b/finetune/full.py @@ -25,6 +25,8 @@ from lit_llama.tokenizer import Tokenizer from lit_llama.utils import save_model_checkpoint from scripts.prepare_alpaca import generate_prompt +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader instruction_tuning = True @@ -95,6 +97,7 @@ def train( train_data: np.ndarray, val_data: np.ndarray, out_dir: str, + group_by_length: bool = False, ) -> None: """The training loop. @@ -103,8 +106,10 @@ def train( step_count = 0 model.train() - for iter_num in range(max_iters): - + loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length) + for iter_num, (input_ids, targets) in enumerate(loader): + if iter_num >= max_iters: + break is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0 if step_count <= warmup_iters: @@ -208,6 +213,46 @@ def pad_right(x, pad_id): return x, y +class InstructionDataset(Dataset): + def __init__(self, data: list): + self._data = data + + def __len__(self): + return len(self._data) + + def __getitem__(self, i: int): + input_ids = self._data[i]["input_ids"].type(torch.int64) + labels = self._data[i]["labels"].type(torch.int64) + return input_ids, labels + + +def get_dataloader( + fabric: L.Fabric, + data: torch.Tensor, + micro_batch_size: int, + group_by_length: bool, +): + from length_grouped_sampler import LengthGroupedSampler + + def collate_fn(batch): + x, y = zip(*batch) + batch_x = pad_sequence(x, batch_first=True) + batch_y = pad_sequence(y, batch_first=True, padding_value=-1) + return batch_x, batch_y + + dataset = InstructionDataset(data) + sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None + loader = DataLoader( + dataset, + batch_size=micro_batch_size, + shuffle=(sampler is None), + sampler=sampler, + collate_fn=collate_fn, + pin_memory=True, + ) + return fabric.setup_dataloaders(loader) + + def load_datasets(data_dir): train_data = torch.load(os.path.join(data_dir, "train.pt")) val_data = torch.load(os.path.join(data_dir, "test.pt")) diff --git a/finetune/length_grouped_sampler.py b/finetune/length_grouped_sampler.py new file mode 100644 index 00000000..412e1304 --- /dev/null +++ b/finetune/length_grouped_sampler.py @@ -0,0 +1,101 @@ +# Derived from https://github.com/huggingface/transformers +# ------------------------------------------------------------------------------------------ +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------ + +from typing import Optional, List +import logging +import torch +from torch.utils.data import Dataset, Sampler + +logger = logging.get_logger(__name__) + + +def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): + """ + Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar + lengths. To do this, the indices are: + + - randomly permuted + - grouped in mega-batches of size `mega_batch_mult * batch_size` + - sorted by length in each mega-batch + + The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of + maximum length placed first, so that an OOM happens sooner rather than later. + """ + # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. + if mega_batch_mult is None: + mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) + # Just in case, for tiny datasets + if mega_batch_mult == 0: + mega_batch_mult = 1 + + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = mega_batch_mult * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + + # The rest is to get the biggest batch first. + # Since each megabatch is sorted by descending length, the longest element is the first + megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] + max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() + # Switch to put the longest element in first position + megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0] + + return [i for megabatch in megabatches for i in megabatch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + dataset: Optional[Dataset] = None, + lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, + generator=None, + ): + if dataset is None and lengths is None: + raise ValueError("One of dataset and lengths must be provided.") + + self.batch_size = batch_size + if lengths is None: + model_input_name = model_input_name if model_input_name is not None else "input_ids" + if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]: + raise ValueError( + "Can only automatically infer lengths for datasets whose items are dictionaries with an " + f"'{model_input_name}' key." + ) + lengths = [len(feature[model_input_name]) for feature in dataset] + elif isinstance(lengths, torch.Tensor): + logger.info( + "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..." + ) + lengths = lengths.tolist() + + self.lengths = lengths + self.generator = generator + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator) + return iter(indices) diff --git a/finetune/lora.py b/finetune/lora.py index 18737015..e9298d34 100644 --- a/finetune/lora.py +++ b/finetune/lora.py @@ -22,6 +22,8 @@ from lit_llama.model import LLaMA, LLaMAConfig from lit_llama.tokenizer import Tokenizer from scripts.prepare_alpaca import generate_prompt +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset, DataLoader instruction_tuning = True @@ -90,6 +92,7 @@ def train( val_data: np.ndarray, tokenizer_path: str, out_dir: str, + group_by_length: bool = False, ) -> None: """The training loop. @@ -97,8 +100,10 @@ def train( """ step_count = 0 - for iter_num in range(max_iters): - + loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length) + for iter_num, (input_ids, targets) in enumerate(loader): + if iter_num >= max_iters: + break if step_count <= warmup_iters: # linear warmup lr = learning_rate * step_count / warmup_iters @@ -202,6 +207,46 @@ def pad_right(x, pad_id): return x, y +class InstructionDataset(Dataset): + def __init__(self, data: list): + self._data = data + + def __len__(self): + return len(self._data) + + def __getitem__(self, i: int): + input_ids = self._data[i]["input_ids"].type(torch.int64) + labels = self._data[i]["labels"].type(torch.int64) + return input_ids, labels + + +def get_dataloader( + fabric: L.Fabric, + data: torch.Tensor, + micro_batch_size: int, + group_by_length: bool, +): + from length_grouped_sampler import LengthGroupedSampler + + def collate_fn(batch): + x, y = zip(*batch) + batch_x = pad_sequence(x, batch_first=True) + batch_y = pad_sequence(y, batch_first=True, padding_value=-1) + return batch_x, batch_y + + dataset = InstructionDataset(data) + sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None + loader = DataLoader( + dataset, + batch_size=micro_batch_size, + shuffle=(sampler is None), + sampler=sampler, + collate_fn=collate_fn, + pin_memory=True, + ) + return fabric.setup_dataloaders(loader) + + def load_datasets(data_dir): train_data = torch.load(os.path.join(data_dir, "train.pt")) val_data = torch.load(os.path.join(data_dir, "test.pt"))