Skip to content

Commit 4b37184

Browse files
committed
Adjust LayerNormANE bias to match torch.nn.LayerNorm equation and pin torch version
1 parent 2050f58 commit 4b37184

File tree

4 files changed

+23
-3
lines changed

4 files changed

+23
-3
lines changed

ane_transformers/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.1"
1+
__version__ = "0.1.2"

ane_transformers/huggingface/distilbert.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,26 @@
2222
"coremltools does not support dict outputs. Please set return_dict=False"
2323

2424

25+
# Note: torch.nn.LayerNorm and ane_transformers.reference.layer_norm.LayerNormANE
26+
# apply scale and bias terms in opposite orders. In order to accurately restore a
27+
# state_dict trained using the former into the the latter, we adjust the bias term
28+
def correct_for_bias_scale_order_inversion(state_dict, prefix, local_metadata,
29+
strict, missing_keys,
30+
unexpected_keys, error_msgs):
31+
state_dict[prefix +
32+
'bias'] = state_dict[prefix + 'bias'] / state_dict[prefix +
33+
'weight']
34+
return state_dict
35+
36+
37+
class LayerNormANE(LayerNormANE):
38+
39+
def __init__(self, *args, **kwargs):
40+
super().__init__(*args, **kwargs)
41+
self._register_load_state_dict_pre_hook(
42+
correct_for_bias_scale_order_inversion)
43+
44+
2545
class Embeddings(modeling_distilbert.Embeddings):
2646
""" Embeddings module optimized for Apple Neural Engine
2747
"""

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch>=1.10.0
1+
torch>=1.10.0,<=1.11.0
22
transformers>=4.18.0
33
coremltools>=5.2.0
44
yapf

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
long_description_content_type='text/markdown',
1515
author='Apple Inc.',
1616
install_requires=[
17-
"torch>=1.10.0",
17+
"torch>=1.10.0,<=1.11.0",
1818
"coremltools>=5.2.0",
1919
"transformers>=4.18.0",
2020
"protobuf>=3.1.0,<=3.20.1",

0 commit comments

Comments
 (0)