Skip to content

Commit 652b0bc

Browse files
authored
[dataset] Fix multinode packing (#4402)
1 parent 7f86b25 commit 652b0bc

File tree

10 files changed

+63
-15
lines changed

10 files changed

+63
-15
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ Megatron训练参数继承自Megatron参数和基本参数。基本参数的内
296296

297297
- add_version: 在`save`上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True。
298298
- 🔥packing: 是否使用序列packing,默认为False。
299+
- 🔥packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache <shared_path>`参数来实现这一要求。
299300
- 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。
300301
- lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。
301302
- max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
354354
- 🔥packing: 是否使用序列packing提升计算效率,默认为False。当前支持`swift pt/sft`
355355
- 注意:使用packing请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)
356356
- 支持的多模态模型参考:https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
357+
- packing_cache: 指定 packing 缓存目录。默认值为`None`,表示缓存将存储在环境变量 `$MODELSCOPE_CACHE`所指定的路径下。在跨节点使用 packing 功能时,需确保所有节点的 packing 缓存路径共享且一致。你可以通过设置`MODELSCOPE_CACHE`环境变量,或在命令行中添加 `--packing_cache <shared_path>`参数来实现这一要求。
357358
- 🔥lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数在LLM训练中默认设置为False,而MLLM训练默认为True,节约内存。
358359
- acc_strategy: 训练和验证时计算acc的策略。可选为`seq``token`级别的acc,默认为`token`
359360
- max_new_tokens: 覆盖生成参数。predict_with_generate=True时的最大生成token数量,默认64。

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine
363363
- 🔥packing: Whether to use sequence packing to improve computational efficiency. The default value is False. Currently supports `swift pt/sft`.
364364
- Note: When using packing, please combine it with `--attn_impl flash_attn` and ensure "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629).
365365
- Supported multimodal models reference: https://github.com/modelscope/ms-swift/blob/main/examples/train/packing/qwen2_5_vl.sh
366+
- packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache <shared_path>` argument in the command line.
366367
- 🔥lazy_tokenize: Whether to use lazy tokenization. If set to False, all dataset samples are tokenized before training (for multimodal models, this includes reading images from disk). This parameter defaults to False for LLM training, and True for MLLM training, to save memory.
367368
- acc_strategy: Strategy for calculating accuracy during training and validation. Options are `seq`-level and `token`-level accuracy, with `token` as the default.
368369
- max_new_tokens: Generation parameter override. The maximum number of tokens to generate when `predict_with_generate=True`, defaulting to 64.

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ Megatron training parameters inherit from Megatron parameters and basic paramete
307307

308308
- add_version: Adds a directory `<version>-<timestamp>` to `save` to prevent overwriting weights, default is True.
309309
- 🔥packing: Whether to use sequence packing, defaults to False.
310+
- 🔥packing_cache: Specifies the directory for packing cache. The default value is `None`, which means the cache will be stored in the path defined by the environment variable `$MODELSCOPE_CACHE`. When using the packing feature across multiple nodes, ensure that all nodes share the same packing cache directory. You can achieve this by setting the `MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache <shared_path>` argument in the command line.
310311
- 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation.
311312
- lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory).
312313
- max_epochs: Forces the training to exit after reaching `max_epochs`, and performs validation and saving of the model weights. This parameter is especially useful when using a streaming dataset. Default is None.

swift/llm/argument/base_args/base_args.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from swift.llm import Processor, Template, get_model_tokenizer, get_template, load_by_unsloth, safe_snapshot_download
1010
from swift.llm.utils import get_ckpt_dir
1111
from swift.plugin import extra_tuners
12-
from swift.utils import (check_json_format, get_dist_setting, get_logger, import_external_file, is_dist, is_master,
13-
set_device, use_hf_hub)
12+
from swift.utils import (check_json_format, check_shared_disk, get_dist_setting, get_logger, import_external_file,
13+
is_dist, is_master, set_device, use_hf_hub)
1414
from .data_args import DataArguments
1515
from .generation_args import GenerationArguments
1616
from .model_args import ModelArguments
@@ -78,12 +78,16 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat
7878
model_kwargs: Optional[Union[dict, str]] = None
7979
load_args: bool = True
8080
load_data_args: bool = False
81-
81+
# dataset
82+
packing: bool = False
83+
packing_cache: Optional[str] = None
84+
custom_register_path: List[str] = field(default_factory=list) # .py
85+
# hub
8286
use_hf: bool = False
8387
# None: use env var `MODELSCOPE_API_TOKEN`
8488
hub_token: Optional[str] = field(
8589
default=None, metadata={'help': 'SDK token can be found in https://modelscope.cn/my/myaccesstoken'})
86-
custom_register_path: List[str] = field(default_factory=list) # .py
90+
# dist
8791
ddp_timeout: int = 18000000
8892
ddp_backend: Optional[str] = None
8993

@@ -128,6 +132,17 @@ def _init_adapters(self):
128132
safe_snapshot_download(adapter, use_hf=self.use_hf, hub_token=self.hub_token) for adapter in self.adapters
129133
]
130134

135+
def _check_packing(self):
136+
if not self.packing:
137+
return
138+
error = ValueError('When using the packing feature across multiple nodes, ensure that all nodes share '
139+
'the same packing cache directory. You can achieve this by setting the '
140+
'`MODELSCOPE_CACHE` environment variable or by adding the `--packing_cache '
141+
'<shared_path>` argument in the command line.')
142+
check_shared_disk(error, self.packing_cache)
143+
if self.packing_cache:
144+
os.environ['PACKING_CACHE'] = self.packing_cache
145+
131146
def __post_init__(self):
132147
if self.use_hf or use_hf_hub():
133148
self.use_hf = True

swift/llm/argument/train_args.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
108108
add_version: bool = True
109109
resume_only_model: bool = False
110110
create_checkpoint_symlink: bool = False
111-
112-
# dataset
113-
packing: bool = False
114111
lazy_tokenize: Optional[bool] = None
115112

116113
# plugin
@@ -174,8 +171,8 @@ def __post_init__(self) -> None:
174171
self.accelerator_config = {'dispatch_batches': False}
175172
self.training_args = TrainerFactory.get_training_args(self)
176173
self.training_args.remove_unused_columns = False
177-
178174
self._add_version()
175+
self._check_packing()
179176

180177
if 'swanlab' in self.report_to:
181178
self._init_swanlab()

swift/llm/dataset/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ class IndexedDataset(Dataset):
240240

241241
@staticmethod
242242
def get_cache_dir(dataset_name: str):
243-
cache_dir = os.path.join(get_cache_dir(), 'tmp', dataset_name)
243+
cache_dir = os.getenv('PACKING_CACHE') or os.path.join(get_cache_dir(), 'tmp')
244+
cache_dir = os.path.join(cache_dir, dataset_name)
244245
os.makedirs(cache_dir, exist_ok=True)
245246
assert dataset_name is not None, f'dataset_name: {dataset_name}'
246247
return cache_dir

swift/megatron/argument/train_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class MegatronTrainArguments(MegatronArguments, BaseArguments):
1616
add_version: bool = True
1717
# dataset
1818
lazy_tokenize: bool = False
19-
packing: bool = False
2019

2120
def init_model_args(self, config):
2221
self.megatron_model_meta = get_megatron_model_meta(self.model_type)
@@ -43,6 +42,7 @@ def __post_init__(self):
4342
self.load = to_abspath(self.load, check_path_exist=True)
4443
BaseArguments.__post_init__(self)
4544
self._init_save()
45+
self._check_packing()
4646
self.seq_length = self.seq_length or self.max_length
4747
if self.streaming:
4848
self.dataloader_type = 'external'

swift/utils/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from .logger import get_logger
1111
from .np_utils import get_seed, stat_array, transform_jsonl_to_df
1212
from .tb_utils import TB_COLOR, TB_COLOR_SMOOTH, plot_images, read_tensorboard_file, tensorboard_smoothing
13-
from .torch_utils import (Serializer, activate_parameters, find_all_linears, find_embedding, find_layers, find_norm,
14-
freeze_parameters, gc_collect, get_current_device, get_device, get_device_count,
15-
get_model_parameter_info, get_n_params_grads, init_process_group, safe_ddp_context,
16-
seed_worker, set_default_ddp_config, set_device, show_layers, time_synchronize)
13+
from .torch_utils import (Serializer, activate_parameters, check_shared_disk, find_all_linears, find_embedding,
14+
find_layers, find_norm, freeze_parameters, gc_collect, get_current_device, get_device,
15+
get_device_count, get_model_parameter_info, get_n_params_grads, init_process_group,
16+
safe_ddp_context, seed_worker, set_default_ddp_config, set_device, show_layers,
17+
time_synchronize)
1718
from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port,
1819
get_env_args, import_external_file, lower_bound, parse_args, patch_getattr, read_multi_line,
1920
seed_everything, split_list, subprocess_run, test_time, upper_bound)

swift/utils/torch_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from transformers.trainer_utils import set_seed
2222
from transformers.utils import is_torch_cuda_available, is_torch_mps_available, is_torch_npu_available
2323

24-
from .env import get_dist_setting, is_dist, is_dist_ta, is_local_master, is_master
24+
from .env import get_dist_setting, get_node_setting, is_dist, is_dist_ta, is_local_master, is_master
2525
from .logger import get_logger
2626
from .utils import deep_getattr
2727

@@ -401,3 +401,33 @@ def seed_worker(worker_id: int, num_workers: int, rank: int):
401401
init_seed = torch.initial_seed() % 2**32
402402
worker_seed = num_workers * rank + init_seed
403403
set_seed(worker_seed)
404+
405+
406+
def check_shared_disk(error, cache_dir: Optional[str] = None):
407+
nnodes = get_node_setting()[1]
408+
if nnodes <= 1:
409+
return True
410+
assert dist.is_initialized()
411+
if cache_dir is None:
412+
cache_dir = os.path.join(get_cache_dir(), 'tmp')
413+
os.makedirs(cache_dir, exist_ok=True)
414+
tmp_path = os.path.join(cache_dir, 'check_shared_disk.tmp')
415+
is_shared_disk = True
416+
with safe_ddp_context(None, True):
417+
if os.path.exists(tmp_path):
418+
os.remove(tmp_path)
419+
try:
420+
with safe_ddp_context(None, True):
421+
if is_master():
422+
with open(tmp_path, 'w'):
423+
pass
424+
else:
425+
if not os.path.exists(tmp_path):
426+
is_shared_disk = False
427+
finally:
428+
if is_master() and os.path.exists(tmp_path):
429+
os.remove(tmp_path)
430+
shared_state = [None] * dist.get_world_size()
431+
dist.all_gather_object(shared_state, is_shared_disk)
432+
if not all(shared_state):
433+
raise error

0 commit comments

Comments
 (0)