Skip to content

Commit 4d36ecc

Browse files
DreamGenXwinglian
andauthored
Sequential sample packing (#2404) [skip ci]
* add sequential sample packing * chore: lint --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
1 parent 7acf93b commit 4d36ecc

File tree

7 files changed

+174
-11
lines changed

7 files changed

+174
-11
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
base_model: meta-llama/Llama-3.2-1B
2+
# optionally might have model_type or tokenizer_type
3+
model_type: LlamaForCausalLM
4+
tokenizer_type: AutoTokenizer
5+
# Automatically upload checkpoint and final model to HF
6+
# hub_model_id: username/custom_model_name
7+
8+
load_in_8bit: true
9+
load_in_4bit: false
10+
strict: false
11+
12+
datasets:
13+
- path: mhenrichsen/alpaca_2k_test
14+
type: alpaca
15+
- path: mhenrichsen/alpaca_2k_test
16+
type: alpaca
17+
dataset_prepared_path:
18+
val_set_size: 0.0
19+
output_dir: ./outputs/lora-out
20+
21+
test_value: true
22+
23+
sequence_len: 4096
24+
sample_packing: true
25+
sample_packing_sequentially: true
26+
curriculum_sampling: true
27+
eval_sample_packing: false
28+
pad_to_sequence_len: true
29+
30+
adapter: lora
31+
lora_model_dir:
32+
lora_r: 32
33+
lora_alpha: 16
34+
lora_dropout: 0.05
35+
lora_target_linear: true
36+
lora_fan_in_fan_out:
37+
lora_modules_to_save:
38+
- embed_tokens
39+
- lm_head
40+
41+
wandb_project:
42+
wandb_entity:
43+
wandb_watch:
44+
wandb_name:
45+
wandb_log_model:
46+
47+
gradient_accumulation_steps: 4
48+
micro_batch_size: 2
49+
num_epochs: 4
50+
optimizer: adamw_bnb_8bit
51+
lr_scheduler: cosine
52+
learning_rate: 0.0002
53+
54+
train_on_inputs: false
55+
group_by_length: false
56+
bf16: auto
57+
fp16:
58+
tf32: false
59+
60+
gradient_checkpointing: true
61+
early_stopping_patience:
62+
resume_from_checkpoint:
63+
local_rank:
64+
logging_steps: 1
65+
xformers_attention:
66+
flash_attention: true
67+
s2_attention:
68+
69+
warmup_steps: 10
70+
evals_per_epoch: 4
71+
eval_table_size:
72+
eval_max_new_tokens: 128
73+
saves_per_epoch: 1
74+
debug:
75+
deepspeed:
76+
weight_decay: 0.0
77+
fsdp:
78+
fsdp_config:
79+
special_tokens:
80+
pad_token: <|end_of_text|>

src/axolotl/core/trainers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _create_multipack_sampler(
112112
packing_efficiency_estimate=self.args.sample_packing_efficiency,
113113
batch_max_len=batch_max_len,
114114
batch_size=batch_size,
115+
sequential=self.args.sample_packing_sequentially,
115116
drop_last=True,
116117
)
117118

src/axolotl/core/training_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class AxolotlTrainingMixins:
3434
default=False,
3535
metadata={"help": "Use sample packing for efficient training."},
3636
)
37+
sample_packing_sequentially: bool = field(
38+
default=False,
39+
metadata={
40+
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
41+
},
42+
)
3743
multipack_real_batches: bool = field(
3844
default=False,
3945
metadata={"help": "Use real batches for efficient training."},

src/axolotl/utils/samplers/multipack.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numba
1010
import numpy as np
11-
from torch.utils.data import BatchSampler, Sampler
11+
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
1212

1313
from axolotl.utils.distributed import reduce_and_broadcast
1414

@@ -103,6 +103,55 @@ def allocate(
103103
return result, s, len(result) * c * n
104104

105105

106+
@numba.njit
107+
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
108+
"""
109+
Sequential allocator that preserves example order
110+
111+
Parameters:
112+
- lengths: The lengths of all examples
113+
- rank: The current rank (for distributed training)
114+
- c: The capacity of each bin (maximum sequence length)
115+
- n: Number of ranks
116+
117+
Returns:
118+
- result: List of batches for the current rank
119+
- total_used: Number of actual example tokens
120+
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
121+
"""
122+
result = []
123+
total_used = 0
124+
125+
# First, do sequential packing into bins
126+
all_bins = []
127+
current_bin = [0 for i in range(0)] # numba hint
128+
remaining_capacity = c
129+
130+
for idx, size in enumerate(lengths):
131+
if size <= remaining_capacity:
132+
# Example fits in current bin
133+
current_bin.append(idx)
134+
remaining_capacity -= size
135+
total_used += size
136+
else:
137+
# Example doesn't fit, start a new bin
138+
if current_bin: # Add non-empty bin to all_bins
139+
all_bins.append(current_bin)
140+
current_bin = [idx]
141+
remaining_capacity = c - size
142+
total_used += size
143+
144+
# Add the last bin if not empty
145+
if current_bin:
146+
all_bins.append(current_bin)
147+
148+
# Assign bins to ranks - each rank gets every n-th bin
149+
for bin_idx in range(rank, len(all_bins), n):
150+
result.append(all_bins[bin_idx])
151+
152+
return result, total_used, len(all_bins) * c
153+
154+
106155
class MultipackBatchSampler(BatchSampler):
107156
"""Batch sampler class for multipack"""
108157

@@ -115,13 +164,15 @@ def __init__(
115164
packing_efficiency_estimate: float = 1.0,
116165
drop_last: bool = False,
117166
num_count_samples: int = 16,
167+
sequential: bool = False,
118168
**kwargs,
119169
):
120170
super().__init__(sampler, batch_size, drop_last)
121171
self.batch_size = batch_size
122172
self.batch_max_len = batch_max_len
123173
self.lengths: np.ndarray = lengths
124174
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
175+
self.sequential = sequential
125176

126177
assert isinstance(self.lengths, np.ndarray)
127178

@@ -136,6 +187,11 @@ def __init__(
136187
# the minimum packed dataset length across all ranks determined by a gather/broadcast
137188
self.len_across_ranks = None
138189

190+
if self.sequential and not isinstance(sampler, SequentialSampler):
191+
LOG.warn(
192+
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
193+
)
194+
139195
def set_epoch(self, epoch: int):
140196
self.epoch = epoch
141197

@@ -145,13 +201,23 @@ def generate_batches(self, set_stats=False):
145201
lengths = self.lengths[indices]
146202
lengths_cumsum = np.cumsum(lengths)
147203

148-
batches, total_used, total_slots = allocate(
149-
lengths=lengths,
150-
lengths_cumsum=lengths_cumsum,
151-
rank=0,
152-
c=self.batch_max_len,
153-
n=1,
154-
)
204+
if self.sequential:
205+
LOG.debug("using sequential sample packing algorithm")
206+
batches, total_used, total_slots = allocate_sequentially(
207+
lengths=lengths,
208+
rank=0,
209+
c=self.batch_max_len,
210+
n=1,
211+
)
212+
else:
213+
LOG.debug("using non-sequential sample packing algorithm")
214+
batches, total_used, total_slots = allocate(
215+
lengths=lengths,
216+
lengths_cumsum=lengths_cumsum,
217+
rank=0,
218+
c=self.batch_max_len,
219+
n=1,
220+
)
155221

156222
batches = [
157223
[

src/axolotl/utils/schemas/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ class AxolotlInputConfig(
192192
sample_packing: bool | None = None
193193
sample_packing_group_size: int | None = 100_000
194194
sample_packing_bin_size: int | None = 200
195+
sample_packing_sequentially: bool | None = None
195196
eval_sample_packing: bool | None = None
196197
pad_to_sequence_len: bool | None = None
197198
curriculum_sampling: bool | None = None

src/axolotl/utils/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.cuda
1414
from accelerate.logging import get_logger
1515
from datasets import IterableDataset, disable_caching, enable_caching
16-
from torch.utils.data import DataLoader, RandomSampler
16+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
1717
from transformers.utils import is_torch_bf16_gpu_available
1818

1919
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -456,13 +456,18 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
456456
else:
457457
sampler_batch_size = cfg.micro_batch_size
458458
batch_max_len = cfg.sequence_len
459+
if cfg.curriculum_sampling:
460+
sampler = SequentialSampler(train_dataset)
461+
else:
462+
sampler = RandomSampler(train_dataset)
459463
sampler = MultipackBatchSampler(
460-
sampler=RandomSampler(train_dataset),
464+
sampler=sampler,
461465
lengths=get_dataset_lengths(train_dataset),
462466
batch_size=sampler_batch_size,
463467
batch_max_len=batch_max_len,
464468
group_size=cfg.sample_packing_group_size,
465469
bin_size=cfg.sample_packing_bin_size,
470+
sequential=cfg.sample_packing_sequentially,
466471
drop_last=True,
467472
)
468473

tests/test_packed_batch_sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ class TestBatchedSamplerPacking:
3838
],
3939
)
4040
@pytest.mark.parametrize("max_seq_length", [4096, 512])
41+
@pytest.mark.parametrize("sequential", [True, False])
4142
@enable_hf_offline
42-
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
43+
def test_packing(
44+
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
45+
):
4346
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
4447

4548
dataset = load_dataset(
@@ -75,6 +78,7 @@ def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
7578
batch_max_len=max_seq_length,
7679
group_size=100000,
7780
bin_size=200,
81+
sequential=sequential,
7882
)
7983

8084
loader = DataLoader(

0 commit comments

Comments
 (0)