Skip to content

Commit 3166525

Browse files
authored
Merge branch 'master' into tohtana/deepcompile
2 parents 30b89eb + a21e5b9 commit 3166525

File tree

4 files changed

+34
-10
lines changed

4 files changed

+34
-10
lines changed

accelerator/hpu_accelerator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def __init__(self):
2323
import habana_frameworks.torch.hpu as hpu
2424
self.hpu = hpu
2525
torch.use_deterministic_algorithms(True)
26+
# TODO: remove this WA when memory mapping break is resolved.
27+
torch.utils.deterministic.fill_uninitialized_memory = False
2628
except ImportError as e:
2729
raise ValueError(
2830
f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")

blogs/huggingface-tp/README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,15 @@ Figure 2 illustrates the basic flowchart, The division of TP and ZeRO is impleme
4848

4949
# Usage
5050

51-
Although we evaluated AutoTP training with Llama2 & Llama3 models in this blog, we expect compatibility with other Hugging Face models, especially [those](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) previously validated with AutoTP inference. Please upgrade accelerate and transformers to the master branch. We will add their minimum version once they have release tag.
5251

5352

53+
Although we evaluated AutoTP training with Llama2 & Llama3 models in this blog, we expect compatibility with other Hugging Face models, especially [those](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) previously validated with AutoTP inference.
54+
55+
**Requirements**
56+
- `deepspeed >= 0.16.4`
57+
- `transformers >= 4.50.1`
58+
- `accelerate >= 1.6.0`
59+
5460
**Enable TP training**
5561

5662
Similar to ZeRO, AutoTP training is enabled using the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/) by specifying ```[tensor_parallel][autotp_size]```.
@@ -113,12 +119,10 @@ Models saved this way can be directly used for HF format inference without inter
113119
Saving Checkpoints remains compatible with HF transformers. Use [trainer.save_state()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_state) or set the save interval for automatic saving, which can be used to resume training.
114120
```
115121
trainer.train(resume_from_checkpoint="your_saved_path/checkpoint-1200")
116-
)
117122
```
118123

119124
# Example
120-
We validated AutoTP training using supervised finetune training (SFT) task: [stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca). The original benchmark model used in this project is Llama2-7B.
121-
125+
We validated AutoTP training using supervised finetune training (SFT) task: [stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca). The original benchmark model used in this project is Llama2-7B. The example code is also available [here](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/tensor_parallel)
122126

123127

124128
**Training Loss curve**
@@ -216,7 +220,7 @@ The following loss curves depict SFT training, where gbs is uniformly set to 32,
216220

217221
# Miscellaneous
218222

219-
If users define their own dataloader, please ensure data consistency within ```deepspeed.utils.get_tensor_model_parallel_group()```. DeepSpeed provides basic validation functions to assist with this.
223+
If users define their own dataloader, please ensure data consistency within ```deepspeed.utils.groups.get_tensor_model_parallel_group()```. DeepSpeed provides basic validation functions to assist with this.
220224

221225
Furthermore, if users are not using transformers library, you can replace the ```TensorParallel_Layer``` layer and its subclasses as needed. See ```prepare_tp_model``` function in ```unit/model_parallelism/test_autotp_training.py```. Users can also define different shard and gather for subclasses of ```TensorParallel_Layer.```
222226

deepspeed/runtime/engine.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3958,7 +3958,9 @@ def offload_states(self,
39583958
param_offload_config = self.zero_offload_param()
39593959
assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters."
39603960

3961-
assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters."
3961+
assert not isinstance(
3962+
self.optimizer,
3963+
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer."
39623964

39633965
if device == OffloadDeviceEnum.none:
39643966
logger.warning("No device specified for offloading states.")
@@ -3977,4 +3979,9 @@ def reload_states(self, non_blocking: bool = False) -> None:
39773979
"""
39783980
assert self.zero_optimization_stage(
39793981
) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3."
3982+
3983+
assert not isinstance(
3984+
self.optimizer,
3985+
DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer."
3986+
39803987
self.optimizer.reload_states(non_blocking=non_blocking)

deepspeed/runtime/pipe/module.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,26 +443,34 @@ def _partition_layers(self, method='uniform'):
443443

444444
self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])
445445

446+
@staticmethod
447+
def _recursive_getattr(module: torch.nn.Module, attr_name: str) -> torch.Tensor:
448+
'''Allow getting an attribute like "linear.weight"'''
449+
weight = module
450+
for item in attr_name.split("."):
451+
weight = getattr(weight, item)
452+
return weight
453+
446454
def allreduce_tied_weight_gradients(self):
447455
'''All reduce the gradients of the tied weights between tied stages'''
448456
for key, comm in self.tied_comms.items():
449457
for attr_name in comm['weight_attr']:
450-
weight = getattr(self.tied_modules[key], attr_name)
458+
weight = self._recursive_getattr(self.tied_modules[key], attr_name)
451459
dist.all_reduce(weight.grad, group=comm['group'])
452460

453461
def get_tied_weights_and_groups(self):
454462
weight_group_list = []
455463
for key, comm in self.tied_comms.items():
456464
for attr_name in comm['weight_attr']:
457-
weight = getattr(self.tied_modules[key], attr_name)
465+
weight = self._recursive_getattr(self.tied_modules[key], attr_name)
458466
weight_group_list.append((weight, comm['group']))
459467
return weight_group_list
460468

461469
def _synchronize_tied_weights(self):
462470
for key, comm in self.tied_comms.items():
463471
for attr_name in comm['weight_attr']:
464472
dist.broadcast(
465-
getattr(comm['module'], attr_name),
473+
self._recursive_getattr(comm['module'], attr_name),
466474
src=min(comm['ranks']),
467475
group=comm['group'],
468476
)
@@ -475,7 +483,10 @@ def _index_tied_modules(self):
475483

476484
specs = self._layer_specs
477485
tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))
478-
for key in tie_keys:
486+
# Since Python 3.7, "Dictionary order is guaranteed to be insertion order."
487+
# Sort tie_keys here so that orders of self.tied_comms.items() are consistent
488+
# among ranks.
489+
for key in sorted(tie_keys):
479490
# Find the layers that the tied module appears in
480491
tied_layers = []
481492
for idx, layer in enumerate(specs):

0 commit comments

Comments
 (0)