Skip to content

Commit 10c7f21

Browse files
committed
add sequential sample packing
1 parent 4a73698 commit 10c7f21

File tree

7 files changed

+169
-11
lines changed

7 files changed

+169
-11
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
base_model: meta-llama/Llama-3.2-1B
2+
# optionally might have model_type or tokenizer_type
3+
model_type: LlamaForCausalLM
4+
tokenizer_type: AutoTokenizer
5+
# Automatically upload checkpoint and final model to HF
6+
# hub_model_id: username/custom_model_name
7+
8+
load_in_8bit: true
9+
load_in_4bit: false
10+
strict: false
11+
12+
datasets:
13+
- path: mhenrichsen/alpaca_2k_test
14+
type: alpaca
15+
- path: mhenrichsen/alpaca_2k_test
16+
type: alpaca
17+
dataset_prepared_path:
18+
val_set_size: 0.0
19+
output_dir: ./outputs/lora-out
20+
21+
test_value: true
22+
23+
sequence_len: 4096
24+
sample_packing: true
25+
sample_packing_sequentially: true
26+
curriculum_sampling: true
27+
eval_sample_packing: false
28+
pad_to_sequence_len: true
29+
30+
adapter: lora
31+
lora_model_dir:
32+
lora_r: 32
33+
lora_alpha: 16
34+
lora_dropout: 0.05
35+
lora_target_linear: true
36+
lora_fan_in_fan_out:
37+
lora_modules_to_save:
38+
- embed_tokens
39+
- lm_head
40+
41+
wandb_project:
42+
wandb_entity:
43+
wandb_watch:
44+
wandb_name:
45+
wandb_log_model:
46+
47+
gradient_accumulation_steps: 4
48+
micro_batch_size: 2
49+
num_epochs: 4
50+
optimizer: adamw_bnb_8bit
51+
lr_scheduler: cosine
52+
learning_rate: 0.0002
53+
54+
train_on_inputs: false
55+
group_by_length: false
56+
bf16: auto
57+
fp16:
58+
tf32: false
59+
60+
gradient_checkpointing: true
61+
early_stopping_patience:
62+
resume_from_checkpoint:
63+
local_rank:
64+
logging_steps: 1
65+
xformers_attention:
66+
flash_attention: true
67+
s2_attention:
68+
69+
warmup_steps: 10
70+
evals_per_epoch: 4
71+
eval_table_size:
72+
eval_max_new_tokens: 128
73+
saves_per_epoch: 1
74+
debug:
75+
deepspeed:
76+
weight_decay: 0.0
77+
fsdp:
78+
fsdp_config:
79+
special_tokens:
80+
pad_token: <|end_of_text|>

src/axolotl/core/trainers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
419419
batch_size=batch_size,
420420
group_size=self.args.sample_packing_group_size,
421421
bin_size=self.args.sample_packing_bin_size,
422+
sequential=self.args.sample_packing_sequentially,
422423
drop_last=True,
423424
)
424425
if self.args.curriculum_sampling:

src/axolotl/core/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class AxolotlTrainingMixins:
3232
default=False,
3333
metadata={"help": "Use sample packing for efficient training."},
3434
)
35+
sample_packing_sequentially: bool = field(
36+
default=False,
37+
metadata={"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."},
38+
)
3539
multipack_real_batches: bool = field(
3640
default=False,
3741
metadata={"help": "Use real batches for efficient training."},

src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ class AxolotlInputConfig(
795795
sample_packing: Optional[bool] = None
796796
sample_packing_group_size: Optional[int] = 100_000
797797
sample_packing_bin_size: Optional[int] = 200
798+
sample_packing_sequentially: Optional[bool] = None
798799
eval_sample_packing: Optional[bool] = None
799800
pad_to_sequence_len: Optional[bool] = None
800801
curriculum_sampling: Optional[bool] = None

src/axolotl/utils/samplers/multipack.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numba
1010
import numpy as np
11-
from torch.utils.data import BatchSampler, Sampler
11+
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
1212

1313
from axolotl.utils.distributed import reduce_and_broadcast
1414

@@ -103,6 +103,55 @@ def allocate(
103103
return result, s, len(result) * c * n
104104

105105

106+
@numba.njit
107+
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
108+
"""
109+
Sequential allocator that preserves example order
110+
111+
Parameters:
112+
- lengths: The lengths of all examples
113+
- rank: The current rank (for distributed training)
114+
- c: The capacity of each bin (maximum sequence length)
115+
- n: Number of ranks
116+
117+
Returns:
118+
- result: List of batches for the current rank
119+
- total_used: Number of actual example tokens
120+
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
121+
"""
122+
result = []
123+
total_used = 0
124+
125+
# First, do sequential packing into bins
126+
all_bins = []
127+
current_bin = [0 for i in range(0)] # numba hint
128+
remaining_capacity = c
129+
130+
for idx, size in enumerate(lengths):
131+
if size <= remaining_capacity:
132+
# Example fits in current bin
133+
current_bin.append(idx)
134+
remaining_capacity -= size
135+
total_used += size
136+
else:
137+
# Example doesn't fit, start a new bin
138+
if current_bin: # Add non-empty bin to all_bins
139+
all_bins.append(current_bin)
140+
current_bin = [idx]
141+
remaining_capacity = c - size
142+
total_used += size
143+
144+
# Add the last bin if not empty
145+
if current_bin:
146+
all_bins.append(current_bin)
147+
148+
# Assign bins to ranks - each rank gets every n-th bin
149+
for bin_idx in range(rank, len(all_bins), n):
150+
result.append(all_bins[bin_idx])
151+
152+
return result, total_used, len(all_bins) * c
153+
154+
106155
class MultipackBatchSampler(BatchSampler):
107156
"""
108157
Batch Sampler class for multipack
@@ -117,13 +166,15 @@ def __init__(
117166
packing_efficiency_estimate: float = 1.0,
118167
drop_last: bool = False,
119168
num_count_samples: int = 16,
169+
sequential: bool = False,
120170
**kwargs,
121171
):
122172
super().__init__(sampler, batch_size, drop_last)
123173
self.batch_size = batch_size
124174
self.batch_max_len = batch_max_len
125175
self.lengths: np.ndarray = lengths
126176
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
177+
self.sequential = sequential
127178

128179
assert isinstance(self.lengths, np.ndarray)
129180

@@ -138,6 +189,10 @@ def __init__(
138189
# the minimum packed dataset length across all ranks determined by a gather/broadcast
139190
self.len_across_ranks = None
140191

192+
if self.sequential and not isinstance(sampler, SequentialSampler):
193+
LOG.warn("using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?")
194+
195+
141196
def set_epoch(self, epoch: int):
142197
self.epoch = epoch
143198

@@ -147,13 +202,23 @@ def generate_batches(self, set_stats=False):
147202
lengths = self.lengths[indices]
148203
lengths_cumsum = np.cumsum(lengths)
149204

150-
batches, total_used, total_slots = allocate(
151-
lengths=lengths,
152-
lengths_cumsum=lengths_cumsum,
153-
rank=0,
154-
c=self.batch_max_len,
155-
n=1,
156-
)
205+
if self.sequential:
206+
LOG.debug("using sequential sample packing algorithm")
207+
batches, total_used, total_slots = allocate_sequentially(
208+
lengths=lengths,
209+
rank=0,
210+
c=self.batch_max_len,
211+
n=1,
212+
)
213+
else:
214+
LOG.debug("using non-sequential sample packing algorithm")
215+
batches, total_used, total_slots = allocate(
216+
lengths=lengths,
217+
lengths_cumsum=lengths_cumsum,
218+
rank=0,
219+
c=self.batch_max_len,
220+
n=1,
221+
)
157222

158223
batches = [
159224
[

src/axolotl/utils/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.cuda
1414
from accelerate.logging import get_logger
1515
from datasets import IterableDataset, disable_caching, enable_caching
16-
from torch.utils.data import DataLoader, RandomSampler
16+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
1717
from transformers.utils import is_torch_bf16_gpu_available
1818

1919
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -455,13 +455,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
455455
else:
456456
sampler_batch_size = cfg.micro_batch_size
457457
batch_max_len = cfg.sequence_len
458+
if cfg.curriculum_sampling:
459+
sampler = SequentialSampler(train_dataset)
460+
else:
461+
sampler = RandomSampler(train_dataset)
458462
sampler = MultipackBatchSampler(
459-
sampler=RandomSampler(train_dataset),
463+
sampler=sampler,
460464
lengths=get_dataset_lengths(train_dataset),
461465
batch_size=sampler_batch_size,
462466
batch_max_len=batch_max_len,
463467
group_size=cfg.sample_packing_group_size,
464468
bin_size=cfg.sample_packing_bin_size,
469+
sequential=cfg.sample_packing_sequentially,
465470
drop_last=True,
466471
)
467472

tests/test_packed_batch_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ class TestBatchedSamplerPacking:
3434
],
3535
)
3636
@pytest.mark.parametrize("max_seq_length", [4096, 512])
37-
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
37+
@pytest.mark.parametrize("sequential", [True, False])
38+
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length, sequential):
3839
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
3940

4041
dataset = load_dataset(
@@ -70,6 +71,7 @@ def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
7071
batch_max_len=max_seq_length,
7172
group_size=100000,
7273
bin_size=200,
74+
sequential=sequential
7375
)
7476

7577
loader = DataLoader(

0 commit comments

Comments
 (0)