Skip to content

Commit 342b684

Browse files
authored
Llama3.1 (#2132)
1 parent a5337f5 commit 342b684

File tree

8 files changed

+388
-29
lines changed

8 files changed

+388
-29
lines changed

keras_hub/src/models/falcon/falcon_backbone.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class FalconBackbone(Backbone):
2929
layer_norm_epsilon: float. Epsilon for the layer normalization layers in
3030
the transformer decoder.
3131
attention_dropout_rate: float. Dropout probability for the attention.
32-
feedforward_dropout_rate: flaot. Dropout probability for the
32+
feedforward_dropout_rate: float. Dropout probability for the
3333
feedforward.
3434
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
3535
for model computations and weights. Note that some computations,

keras_hub/src/models/llama/llama_attention.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import keras
44
from keras import ops
55

6-
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
6+
from keras_hub.src.models.llama.llama_rotary_embedding import (
7+
LlamaRotaryEmbedding,
8+
)
79
from keras_hub.src.utils.keras_utils import clone_initializer
810
from keras_hub.src.utils.keras_utils import fused_attention_op_available
911

@@ -16,7 +18,11 @@ def __init__(
1618
num_query_heads,
1719
num_key_value_heads,
1820
rope_max_wavelength=10000,
19-
rope_scaling_factor=1.0,
21+
rope_position_scaling_factor=1.0,
22+
rope_frequency_adjustment_factor=None,
23+
rope_low_freq_factor=None,
24+
rope_high_freq_factor=None,
25+
rope_pretraining_sequence_length=None,
2026
kernel_initializer="glorot_uniform",
2127
dropout=0,
2228
**kwargs,
@@ -28,13 +34,16 @@ def __init__(
2834

2935
self.num_key_value_groups = num_query_heads // num_key_value_heads
3036
self.rope_max_wavelength = rope_max_wavelength
37+
self.rope_position_scaling_factor = rope_position_scaling_factor
38+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
39+
self.rope_low_freq_factor = rope_low_freq_factor
40+
self.rope_high_freq_factor = rope_high_freq_factor
41+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
3142

3243
self.kernel_initializer = keras.initializers.get(
3344
clone_initializer(kernel_initializer)
3445
)
3546

36-
self.rope_scaling_factor = rope_scaling_factor
37-
3847
def build(self, inputs_shape):
3948
# Einsum variables:
4049
# b = batch size
@@ -103,9 +112,13 @@ def build(self, inputs_shape):
103112
)
104113
self._output_dense.build((None, None, self.num_query_heads, head_dim))
105114

106-
self.rotary_embedding_layer = RotaryEmbedding(
115+
self.rotary_embedding_layer = LlamaRotaryEmbedding(
107116
max_wavelength=self.rope_max_wavelength,
108-
scaling_factor=self.rope_scaling_factor,
117+
position_scaling_factor=self.rope_position_scaling_factor,
118+
frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
119+
low_freq_factor=self.rope_low_freq_factor,
120+
high_freq_factor=self.rope_high_freq_factor,
121+
pretraining_sequence_length=self.rope_pretraining_sequence_length,
109122
dtype=self.dtype_policy,
110123
)
111124

@@ -224,6 +237,11 @@ def get_config(self):
224237
"num_key_value_heads": self.num_key_value_heads,
225238
"rope_max_wavelength": self.rope_max_wavelength,
226239
"rope_scaling_factor": self.rope_scaling_factor,
240+
"rope_low_freq_factor": self.rope_low_freq_factor,
241+
"rope_high_freq_factor": self.rope_high_freq_factor,
242+
"rope_pretraining_sequence_length": (
243+
self.rope_pretraining_sequence_length
244+
),
227245
"kernel_initializer": keras.initializers.serialize(
228246
self.kernel_initializer
229247
),

keras_hub/src/models/llama/llama_backbone.py

+50-16
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
3030
constructor.
3131
3232
Args:
33-
vocabulary_size (int): The size of the token vocabulary.
34-
num_layers (int): The number of transformer layers.
35-
num_query_heads (int): The number of query attention heads for
33+
vocabulary_size: int. The size of the token vocabulary.
34+
num_layers: int. The number of transformer layers.
35+
num_query_heads : int. The number of query attention heads for
3636
each transformer.
37-
hidden_dim (int): The size of the transformer encoding and pooling
37+
hidden_dim : int. The size of the transformer encoding and pooling
3838
layers.
39-
intermediate_dim (int): The output dimension of the first Dense layer in
39+
intermediate_dim : int. The output dimension of the first Dense layer in
4040
a three-layer feedforward network for each transformer.
41-
num_key_value_heads (int): The number of key and value attention heads
41+
num_key_value_heads : int. The number of key and value attention heads
4242
for each transformer.
43-
rope_max_wavelength (int, optional): The maximum angular wavelength of
43+
rope_max_wavelength : int. The maximum angular wavelength of
4444
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
45-
rope_scaling_factor (float, optional): The scaling factor for
46-
calculation of roatary embedding. Defaults to `1.0`.
47-
layer_norm_epsilon (float, optional): Epsilon for the layer
48-
normalization layers in the transformer decoder. Defaults to `1e-6`.
45+
rope_position_scaling_factor: float. The scaling factor for
46+
calculation of rotary embedding. Defaults to `1.0`
47+
rope_frequency_adjustment_factor: float. The scaling factor
48+
used to scale the inverse frequencies. Defaults to `None`.
49+
rope_low_freq_factor: float. The low frequency scaling
50+
factor. Defaults to `None`.
51+
rope_high_freq_factor: float. Used for Llama3.1+. The high
52+
frequency scaling factor. Defaults to `None`.
53+
rope_pretraining_sequence_length: int. Used for Llama3.1+.
54+
Defaults to `None`.
55+
layer_norm_epsilon : float. Epsilon for the layer normalization layers
56+
in the transformer decoder. Defaults to `1e-6`.
4957
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
5058
for model computations and weights. Note that some computations,
5159
such as softmax and layer normalization, will always be done at
@@ -87,7 +95,11 @@ def __init__(
8795
intermediate_dim,
8896
num_key_value_heads,
8997
rope_max_wavelength=10000,
90-
rope_scaling_factor=1.0,
98+
rope_position_scaling_factor=1.0,
99+
rope_frequency_adjustment_factor=None,
100+
rope_low_freq_factor=None,
101+
rope_high_freq_factor=None,
102+
rope_pretraining_sequence_length=None,
91103
layer_norm_epsilon=1e-6,
92104
dropout=0,
93105
dtype=None,
@@ -110,7 +122,15 @@ def __init__(
110122
num_query_heads=num_query_heads,
111123
num_key_value_heads=num_key_value_heads,
112124
rope_max_wavelength=rope_max_wavelength,
113-
rope_scaling_factor=rope_scaling_factor,
125+
rope_position_scaling_factor=rope_position_scaling_factor,
126+
rope_frequency_adjustment_factor=(
127+
rope_frequency_adjustment_factor
128+
),
129+
rope_low_freq_factor=rope_low_freq_factor,
130+
rope_high_freq_factor=rope_high_freq_factor,
131+
rope_pretraining_sequence_length=(
132+
rope_pretraining_sequence_length
133+
),
114134
layer_norm_epsilon=layer_norm_epsilon,
115135
activation=ops.silu,
116136
kernel_initializer=_llama_kernel_initializer(stddev=0.02),
@@ -152,9 +172,13 @@ def __init__(
152172
self.num_query_heads = num_query_heads
153173
self.hidden_dim = hidden_dim
154174
self.intermediate_dim = intermediate_dim
155-
self.rope_max_wavelength = rope_max_wavelength
156175
self.num_key_value_heads = num_key_value_heads
157-
self.rope_scaling_factor = rope_scaling_factor
176+
self.rope_max_wavelength = rope_max_wavelength
177+
self.rope_position_scaling_factor = rope_position_scaling_factor
178+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
179+
self.rope_low_freq_factor = rope_low_freq_factor
180+
self.rope_high_freq_factor = rope_high_freq_factor
181+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
158182
self.layer_norm_epsilon = layer_norm_epsilon
159183
self.dropout = dropout
160184
self.tie_word_embeddings = tie_word_embeddings
@@ -169,7 +193,17 @@ def get_config(self):
169193
"hidden_dim": self.hidden_dim,
170194
"intermediate_dim": self.intermediate_dim,
171195
"rope_max_wavelength": self.rope_max_wavelength,
172-
"rope_scaling_factor": self.rope_scaling_factor,
196+
"rope_position_scaling_factor": (
197+
self.rope_position_scaling_factor
198+
),
199+
"rope_frequency_adjustment_factor": (
200+
self.rope_frequency_adjustment_factor
201+
),
202+
"rope_low_freq_factor": self.rope_low_freq_factor,
203+
"rope_high_freq_factor": self.rope_high_freq_factor,
204+
"rope_pretraining_sequence_length": (
205+
self.rope_pretraining_sequence_length
206+
),
173207
"num_key_value_heads": self.num_key_value_heads,
174208
"layer_norm_epsilon": self.layer_norm_epsilon,
175209
"dropout": self.dropout,

keras_hub/src/models/llama/llama_decoder.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ def __init__(
2121
num_query_heads,
2222
num_key_value_heads,
2323
rope_max_wavelength=10000,
24-
rope_scaling_factor=1.0,
24+
rope_position_scaling_factor=1.0,
25+
rope_frequency_adjustment_factor=None,
26+
rope_low_freq_factor=None,
27+
rope_high_freq_factor=None,
28+
rope_pretraining_sequence_length=None,
2529
activation="silu",
2630
layer_norm_epsilon=1e-5,
2731
kernel_initializer="glorot_uniform",
@@ -34,7 +38,11 @@ def __init__(
3438
self.num_key_value_heads = num_key_value_heads
3539

3640
self.rope_max_wavelength = rope_max_wavelength
37-
self.rope_scaling_factor = rope_scaling_factor
41+
self.rope_position_scaling_factor = rope_position_scaling_factor
42+
self.rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
43+
self.rope_low_freq_factor = rope_low_freq_factor
44+
self.rope_high_freq_factor = rope_high_freq_factor
45+
self.rope_pretraining_sequence_length = rope_pretraining_sequence_length
3846

3947
self.dropout = dropout
4048

@@ -53,7 +61,11 @@ def build(self, decoder_sequence_shape):
5361
num_query_heads=self.num_query_heads,
5462
num_key_value_heads=self.num_key_value_heads,
5563
rope_max_wavelength=self.rope_max_wavelength,
56-
rope_scaling_factor=self.rope_scaling_factor,
64+
rope_position_scaling_factor=self.rope_position_scaling_factor,
65+
rope_frequency_adjustment_factor=self.rope_frequency_adjustment_factor,
66+
rope_low_freq_factor=self.rope_low_freq_factor,
67+
rope_high_freq_factor=self.rope_high_freq_factor,
68+
rope_pretraining_sequence_length=self.rope_pretraining_sequence_length,
5769
kernel_initializer=clone_initializer(self.kernel_initializer),
5870
dropout=self.dropout,
5971
dtype=self.dtype_policy,
@@ -221,6 +233,11 @@ def get_config(self):
221233
"num_query_heads": self.num_query_heads,
222234
"rope_max_wavelength": self.rope_max_wavelength,
223235
"rope_scaling_factor": self.rope_scaling_factor,
236+
"rope_low_freq_factor": self.rope_low_freq_factor,
237+
"rope_high_freq_factor": self.rope_high_freq_factor,
238+
"rope_pretraining_sequence_length": (
239+
self.rope_pretraining_sequence_length
240+
),
224241
"num_key_value_heads": self.num_key_value_heads,
225242
"activation": keras.activations.serialize(self.activation),
226243
"layer_norm_epsilon": self.layer_norm_epsilon,

0 commit comments

Comments
 (0)