Skip to content

add ds inject policies #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ jobs:

- name: Install MII
run: |
pip install git+https://github.com/microsoft/DeepSpeed.git
pip install .[dev,local]
pip install .[dev,local] git+https://github.com/microsoft/deepspeed.git@staging-mii-update

- name: Unit tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:

- name: Install MII
run: |
pip install .[dev]
pip install .[dev] git+https://github.com/microsoft/deepspeed.git@staging-mii-update

- name: Formatting checks
run: |
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/nv-torch-latest-v100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ jobs:
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Install MII
run: |
pip install git+https://github.com/microsoft/DeepSpeed.git
pip install git+https://github.com/huggingface/transformers.git
pip install -U accelerate
pip install .[dev,local]
pip install .[dev,local] git+https://github.com/microsoft/deepspeed.git@staging-mii-update
ds_report
- name: Unit tests
run: |
Expand Down
20 changes: 20 additions & 0 deletions mii/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from .bert import HFBertLayerPolicy
from .gpt_neo import HFGPTNEOLayerPolicy
from .gpt_neox import GPTNEOXLayerPolicy
from .gptj import HFGPTJLayerPolicy
from .megatron import MegatronLayerPolicy
from .gpt2 import HFGPT2LayerPolicy
from .bloom import BLOOMLayerPolicy

replace_policies = [
HFBertLayerPolicy,
HFGPTNEOLayerPolicy,
GPTNEOXLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
BLOOMLayerPolicy
]
69 changes: 69 additions & 0 deletions mii/policies/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from torch.nn.parameter import Parameter
from deepspeed.module_inject.base_policy import InjectBasePolicy


class HFBertLayerPolicy(InjectBasePolicy):
_orig_layer_class = None

def __init__(self, client_module, inference=False, preln=False):
super().__init__(inference)
self.client_module = client_module
self.preln = preln
if HFBertLayerPolicy._orig_layer_class is None:
try:
import transformers
HFBertLayerPolicy._orig_layer_class = [
transformers.models.bert.modeling_bert.BertLayer,
transformers.models.roberta.modeling_roberta.RobertaLayer
]
except:
HFBertLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.attention.self.query.weight.shape[1], \
self.client_module.attention.self.num_attention_heads

def attention(self):
qw = self.client_module.attention.self.query.weight
qb = self.client_module.attention.self.query.bias
kw = self.client_module.attention.self.key.weight
kb = self.client_module.attention.self.key.bias
vw = self.client_module.attention.self.value.weight
vb = self.client_module.attention.self.value.bias

qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False)

return self.linear_layer, \
qkvw, \
qkvb, \
self.client_module.attention.output.dense.weight, \
self.client_module.attention.output.dense.bias, \
self.scale_attention, \
self.is_megatron_v2

def mlp(self):
if self.preln:
intermediate_ff = self.client_module.intermediate.dense_act
else:
intermediate_ff = self.client_module.intermediate.dense

return self.linear_layer, intermediate_ff.weight, intermediate_ff.bias, \
self.client_module.output.dense.weight, \
self.client_module.output.dense.bias

def layerNorm(self):
if self.preln:
attention_layernorm = self.client_module.PostAttentionLayerNorm
transformer_layernorm = self.client_module.PreAttentionLayerNorm
else:
attention_layernorm = self.client_module.attention.output.LayerNorm
transformer_layernorm = self.client_module.output.LayerNorm
return attention_layernorm.weight, \
attention_layernorm.bias, \
transformer_layernorm.weight, \
transformer_layernorm.bias
46 changes: 46 additions & 0 deletions mii/policies/bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.module_inject.base_policy import InjectBasePolicy


class BLOOMLayerPolicy(InjectBasePolicy):
_orig_layer_class = None

def __init__(self, client_module, inference=True):
super().__init__(inference, linear_layer=True)
self.client_module = client_module
try:
import transformers
BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
global supported_models
supported_models.update(
{transformers.models.bloom.modeling_bloom.BloomModel})
except:
BLOOMLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.self_attention.hidden_size, \
self.client_module.self_attention.num_heads

def attention(self):
return self.linear_layer, \
self.client_module.self_attention.query_key_value.weight, \
self.client_module.self_attention.query_key_value.bias, \
self.client_module.self_attention.dense.weight, \
self.client_module.self_attention.dense.bias, \
self.scale_attention, \
self.is_megatron_v2

def mlp(self):
return self.linear_layer, \
self.client_module.mlp.dense_h_to_4h.weight, \
self.client_module.mlp.dense_h_to_4h.bias, \
self.client_module.mlp.dense_4h_to_h.weight, \
self.client_module.mlp.dense_4h_to_h.bias

def layerNorm(self):
return self.client_module.post_attention_layernorm.weight, \
self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
44 changes: 44 additions & 0 deletions mii/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
from deepspeed.module_inject.base_policy import InjectBasePolicy


class HFGPT2LayerPolicy(InjectBasePolicy):
_orig_layer_class = None

def __init__(self, client_module, inference=True):
# HuggingFace GPT2 uses convolutional layer instead of linear layer
super().__init__(inference, linear_layer=False)
self.client_module = client_module
try:
import transformers
HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block
except:
HFGPT2LayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.attn.embed_dim, \
self.client_module.attn.num_heads

def attention(self):
return self.linear_layer, \
self.client_module.attn.c_attn.weight, \
self.client_module.attn.c_attn.bias, \
self.client_module.attn.c_proj.weight, \
self.client_module.attn.c_proj.bias, \
self.scale_attention, \
self.is_megatron_v2

def mlp(self):
return self.linear_layer, \
self.client_module.mlp.c_fc.weight, \
self.client_module.mlp.c_fc.bias, \
self.client_module.mlp.c_proj.weight, \
self.client_module.mlp.c_proj.bias

def layerNorm(self):
return self.client_module.ln_2.weight, \
self.client_module.ln_2.bias, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
51 changes: 51 additions & 0 deletions mii/policies/gpt_neo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from torch.nn.parameter import Parameter
from deepspeed.module_inject.base_policy import InjectBasePolicy


class HFGPTNEOLayerPolicy(InjectBasePolicy):
_orig_layer_class = None

def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=False)
self.client_module = client_module
try:
import transformers
HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock
except:
HFGPTNEOLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.attn.attention.q_proj.weight.shape[1], \
self.client_module.attn.attention.num_heads

def attention(self):
qw = self.client_module.attn.attention.q_proj.weight
kw = self.client_module.attn.attention.k_proj.weight
vw = self.client_module.attn.attention.v_proj.weight

qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)

return self.linear_layer, \
qkvw, \
None, \
self.client_module.attn.attention.out_proj.weight, \
self.client_module.attn.attention.out_proj.bias, \
self.scale_attention, \
self.is_megatron_v2

def mlp(self):
return self.linear_layer, \
self.client_module.mlp.c_fc.weight, \
self.client_module.mlp.c_fc.bias, \
self.client_module.mlp.c_proj.weight, \
self.client_module.mlp.c_proj.bias

def layerNorm(self):
return self.client_module.ln_2.weight, \
self.client_module.ln_2.bias, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
60 changes: 60 additions & 0 deletions mii/policies/gpt_neox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from packaging import version as pkg_version
from deepspeed.module_inject.base_policy import InjectBasePolicy


class GPTNEOXLayerPolicy(InjectBasePolicy):
_orig_layer_class = None
version = 0

def __init__(self, client_module, inference=True, megatron_v2=True):
super().__init__(inference, megatron_v2=megatron_v2)
self.client_module = client_module
if GPTNEOXLayerPolicy._orig_layer_class is None:
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
GPTNEOXLayerPolicy._orig_layer_class = None
else:
try:
from transformers import GPTNeoXLayer
GPTNEOXLayerPolicy._orig_layer_class = GPTNeoXLayer
except ImportError:
GPTNEOXLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
if GPTNEOXLayerPolicy.version == 0:
attention = self.client_module.attention
else:
attention = self.client_module.self_attention

return self.client_module.attention.query_key_value.weight.shape[1], \
self.client_module.attention.num_attention_heads

def attention(self):
if GPTNEOXLayerPolicy.version == 0:
attention = self.client_module.attention
else:
attention = self.client_module.self_attention

return self.linear_layer, \
attention.query_key_value.weight, \
attention.query_key_value.bias, \
attention.dense.weight, \
attention.dense.bias, \
self.scale_attention, \
self.is_megatron_v2

def mlp(self):
return self.linear_layer, \
self.client_module.mlp.dense_h_to_4h.weight, \
self.client_module.mlp.dense_h_to_4h.bias, \
self.client_module.mlp.dense_4h_to_h.weight, \
self.client_module.mlp.dense_4h_to_h.bias

def layerNorm(self):
return self.client_module.post_attention_layernorm.weight, \
self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
51 changes: 51 additions & 0 deletions mii/policies/gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import torch
from torch.nn.parameter import Parameter
from deepspeed.module_inject.base_policy import InjectBasePolicy


class HFGPTJLayerPolicy(InjectBasePolicy):
_orig_layer_class = None

def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
try:
import transformers
HFGPTJLayerPolicy._orig_layer_class = transformers.models.gptj.modeling_gptj.GPTJBlock
except:
HFGPTJLayerPolicy._orig_layer_class = None

def get_hidden_heads(self):
return self.client_module.attn.q_proj.weight.shape[1], \
self.client_module.attn.num_attention_heads

def attention(self):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight

qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)

return self.linear_layer, \
qkvw, \
None, \
self.client_module.attn.out_proj.weight, \
None, \
self.scale_attention, \
self.is_megatron_v2

def mlp(self):
return self.linear_layer, \
self.client_module.mlp.fc_in.weight, \
self.client_module.mlp.fc_in.bias, \
self.client_module.mlp.fc_out.weight, \
self.client_module.mlp.fc_out.bias

def layerNorm(self):
return None, \
None, \
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
Loading