@@ -30,22 +30,30 @@ class LlamaBackbone(Backbone):
30
30
constructor.
31
31
32
32
Args:
33
- vocabulary_size ( int): The size of the token vocabulary.
34
- num_layers ( int): The number of transformer layers.
35
- num_query_heads (int): The number of query attention heads for
33
+ vocabulary_size: int. The size of the token vocabulary.
34
+ num_layers: int. The number of transformer layers.
35
+ num_query_heads : int. The number of query attention heads for
36
36
each transformer.
37
- hidden_dim (int): The size of the transformer encoding and pooling
37
+ hidden_dim : int. The size of the transformer encoding and pooling
38
38
layers.
39
- intermediate_dim (int): The output dimension of the first Dense layer in
39
+ intermediate_dim : int. The output dimension of the first Dense layer in
40
40
a three-layer feedforward network for each transformer.
41
- num_key_value_heads (int): The number of key and value attention heads
41
+ num_key_value_heads : int. The number of key and value attention heads
42
42
for each transformer.
43
- rope_max_wavelength (int, optional): The maximum angular wavelength of
43
+ rope_max_wavelength : int. The maximum angular wavelength of
44
44
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
45
- rope_scaling_factor (float, optional): The scaling factor for
46
- calculation of roatary embedding. Defaults to `1.0`.
47
- layer_norm_epsilon (float, optional): Epsilon for the layer
48
- normalization layers in the transformer decoder. Defaults to `1e-6`.
45
+ rope_position_scaling_factor: float. The scaling factor for
46
+ calculation of rotary embedding. Defaults to `1.0`
47
+ rope_frequency_adjustment_factor: float. The scaling factor
48
+ used to scale the inverse frequencies. Defaults to `None`.
49
+ rope_low_freq_factor: float. The low frequency scaling
50
+ factor. Defaults to `None`.
51
+ rope_high_freq_factor: float. Used for Llama3.1+. The high
52
+ frequency scaling factor. Defaults to `None`.
53
+ rope_pretraining_sequence_length: int. Used for Llama3.1+.
54
+ Defaults to `None`.
55
+ layer_norm_epsilon : float. Epsilon for the layer normalization layers
56
+ in the transformer decoder. Defaults to `1e-6`.
49
57
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
50
58
for model computations and weights. Note that some computations,
51
59
such as softmax and layer normalization, will always be done at
@@ -87,7 +95,11 @@ def __init__(
87
95
intermediate_dim ,
88
96
num_key_value_heads ,
89
97
rope_max_wavelength = 10000 ,
90
- rope_scaling_factor = 1.0 ,
98
+ rope_position_scaling_factor = 1.0 ,
99
+ rope_frequency_adjustment_factor = None ,
100
+ rope_low_freq_factor = None ,
101
+ rope_high_freq_factor = None ,
102
+ rope_pretraining_sequence_length = None ,
91
103
layer_norm_epsilon = 1e-6 ,
92
104
dropout = 0 ,
93
105
dtype = None ,
@@ -110,7 +122,15 @@ def __init__(
110
122
num_query_heads = num_query_heads ,
111
123
num_key_value_heads = num_key_value_heads ,
112
124
rope_max_wavelength = rope_max_wavelength ,
113
- rope_scaling_factor = rope_scaling_factor ,
125
+ rope_position_scaling_factor = rope_position_scaling_factor ,
126
+ rope_frequency_adjustment_factor = (
127
+ rope_frequency_adjustment_factor
128
+ ),
129
+ rope_low_freq_factor = rope_low_freq_factor ,
130
+ rope_high_freq_factor = rope_high_freq_factor ,
131
+ rope_pretraining_sequence_length = (
132
+ rope_pretraining_sequence_length
133
+ ),
114
134
layer_norm_epsilon = layer_norm_epsilon ,
115
135
activation = ops .silu ,
116
136
kernel_initializer = _llama_kernel_initializer (stddev = 0.02 ),
@@ -152,9 +172,13 @@ def __init__(
152
172
self .num_query_heads = num_query_heads
153
173
self .hidden_dim = hidden_dim
154
174
self .intermediate_dim = intermediate_dim
155
- self .rope_max_wavelength = rope_max_wavelength
156
175
self .num_key_value_heads = num_key_value_heads
157
- self .rope_scaling_factor = rope_scaling_factor
176
+ self .rope_max_wavelength = rope_max_wavelength
177
+ self .rope_position_scaling_factor = rope_position_scaling_factor
178
+ self .rope_frequency_adjustment_factor = rope_frequency_adjustment_factor
179
+ self .rope_low_freq_factor = rope_low_freq_factor
180
+ self .rope_high_freq_factor = rope_high_freq_factor
181
+ self .rope_pretraining_sequence_length = rope_pretraining_sequence_length
158
182
self .layer_norm_epsilon = layer_norm_epsilon
159
183
self .dropout = dropout
160
184
self .tie_word_embeddings = tie_word_embeddings
@@ -169,7 +193,17 @@ def get_config(self):
169
193
"hidden_dim" : self .hidden_dim ,
170
194
"intermediate_dim" : self .intermediate_dim ,
171
195
"rope_max_wavelength" : self .rope_max_wavelength ,
172
- "rope_scaling_factor" : self .rope_scaling_factor ,
196
+ "rope_position_scaling_factor" : (
197
+ self .rope_position_scaling_factor
198
+ ),
199
+ "rope_frequency_adjustment_factor" : (
200
+ self .rope_frequency_adjustment_factor
201
+ ),
202
+ "rope_low_freq_factor" : self .rope_low_freq_factor ,
203
+ "rope_high_freq_factor" : self .rope_high_freq_factor ,
204
+ "rope_pretraining_sequence_length" : (
205
+ self .rope_pretraining_sequence_length
206
+ ),
173
207
"num_key_value_heads" : self .num_key_value_heads ,
174
208
"layer_norm_epsilon" : self .layer_norm_epsilon ,
175
209
"dropout" : self .dropout ,
0 commit comments