Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add group_by_length optional feature #398

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
apply to other finetuning scripts
  • Loading branch information
nuance1979 committed Jun 22, 2023
commit 5fbe0ced8e213909da556caa96057edaea1670c8
49 changes: 47 additions & 2 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
from lightning.fabric.strategies import DeepSpeedStrategy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


instruction_tuning = True
@@ -113,15 +115,18 @@ def train(
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
group_by_length: bool = False,
) -> None:
"""The training loop.

Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0

for iter_num in range(max_iters):

loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length)
for iter_num, (input_ids, targets) in enumerate(loader):
if iter_num >= max_iters:
break
if step_count <= warmup_iters:
# linear warmup
lr = learning_rate * step_count / warmup_iters
@@ -223,6 +228,46 @@ def pad_right(x, pad_id):
return x, y


class InstructionDataset(Dataset):
def __init__(self, data: list):
self._data = data

def __len__(self):
return len(self._data)

def __getitem__(self, i: int):
input_ids = self._data[i]["input_ids"].type(torch.int64)
labels = self._data[i]["labels"].type(torch.int64)
return input_ids, labels


def get_dataloader(
fabric: L.Fabric,
data: torch.Tensor,
micro_batch_size: int,
group_by_length: bool,
):
from length_grouped_sampler import LengthGroupedSampler

def collate_fn(batch):
x, y = zip(*batch)
batch_x = pad_sequence(x, batch_first=True)
batch_y = pad_sequence(y, batch_first=True, padding_value=-1)
return batch_x, batch_y

dataset = InstructionDataset(data)
sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None
loader = DataLoader(
dataset,
batch_size=micro_batch_size,
shuffle=(sampler is None),
sampler=sampler,
collate_fn=collate_fn,
pin_memory=True,
)
return fabric.setup_dataloaders(loader)


def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
49 changes: 47 additions & 2 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,8 @@
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
from lightning.fabric.strategies import DeepSpeedStrategy
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


eval_interval = 600
@@ -119,15 +121,18 @@ def train(
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
group_by_length: bool = False,
) -> None:
"""The training loop.

Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
step_count = 0

for iter_num in range(max_iters):

loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length)
for iter_num, (input_ids, targets) in enumerate(loader):
if iter_num >= max_iters:
break
if step_count <= warmup_iters:
# linear warmup
lr = learning_rate * step_count / warmup_iters
@@ -227,6 +232,46 @@ def pad_right(x, pad_id):
return x, y


class InstructionDataset(Dataset):
def __init__(self, data: list):
self._data = data

def __len__(self):
return len(self._data)

def __getitem__(self, i: int):
input_ids = self._data[i]["input_ids"].type(torch.int64)
labels = self._data[i]["labels"].type(torch.int64)
return input_ids, labels


def get_dataloader(
fabric: L.Fabric,
data: torch.Tensor,
micro_batch_size: int,
group_by_length: bool,
):
from length_grouped_sampler import LengthGroupedSampler

def collate_fn(batch):
x, y = zip(*batch)
batch_x = pad_sequence(x, batch_first=True)
batch_y = pad_sequence(y, batch_first=True, padding_value=-1)
return batch_x, batch_y

dataset = InstructionDataset(data)
sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None
loader = DataLoader(
dataset,
batch_size=micro_batch_size,
shuffle=(sampler is None),
sampler=sampler,
collate_fn=collate_fn,
pin_memory=True,
)
return fabric.setup_dataloaders(loader)


def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
49 changes: 47 additions & 2 deletions finetune/full.py
Original file line number Diff line number Diff line change
@@ -25,6 +25,8 @@
from lit_llama.tokenizer import Tokenizer
from lit_llama.utils import save_model_checkpoint
from scripts.prepare_alpaca import generate_prompt
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


instruction_tuning = True
@@ -95,6 +97,7 @@ def train(
train_data: np.ndarray,
val_data: np.ndarray,
out_dir: str,
group_by_length: bool = False,
) -> None:
"""The training loop.

@@ -103,8 +106,10 @@ def train(
step_count = 0
model.train()

for iter_num in range(max_iters):

loader = get_dataloader(fabric, train_data, micro_batch_size, group_by_length)
for iter_num, (input_ids, targets) in enumerate(loader):
if iter_num >= max_iters:
break
is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0

if step_count <= warmup_iters:
@@ -208,6 +213,46 @@ def pad_right(x, pad_id):
return x, y


class InstructionDataset(Dataset):
def __init__(self, data: list):
self._data = data

def __len__(self):
return len(self._data)

def __getitem__(self, i: int):
input_ids = self._data[i]["input_ids"].type(torch.int64)
labels = self._data[i]["labels"].type(torch.int64)
return input_ids, labels


def get_dataloader(
fabric: L.Fabric,
data: torch.Tensor,
micro_batch_size: int,
group_by_length: bool,
):
from length_grouped_sampler import LengthGroupedSampler

def collate_fn(batch):
x, y = zip(*batch)
batch_x = pad_sequence(x, batch_first=True)
batch_y = pad_sequence(y, batch_first=True, padding_value=-1)
return batch_x, batch_y

dataset = InstructionDataset(data)
sampler = LengthGroupedSampler(micro_batch_size, lengths=[len(x) for x, _ in dataset]) if group_by_length else None
loader = DataLoader(
dataset,
batch_size=micro_batch_size,
shuffle=(sampler is None),
sampler=sampler,
collate_fn=collate_fn,
pin_memory=True,
)
return fabric.setup_dataloaders(loader)


def load_datasets(data_dir):
train_data = torch.load(os.path.join(data_dir, "train.pt"))
val_data = torch.load(os.path.join(data_dir, "test.pt"))
2 changes: 1 addition & 1 deletion finetune/length_grouped_sampler.py
Original file line number Diff line number Diff line change
@@ -98,4 +98,4 @@ def __len__(self):

def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
return iter(indices)
return iter(indices)