Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 6ba1aca

Browse files
committed
add scale norm
1 parent e22cc2b commit 6ba1aca

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

mesh_tensorflow/transformer/transformer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,31 @@ def sublayer_rms_norm_subsampled(x, layer_stack, context, percentage=100.,
547547
mtf.square(var_activations), reduced_dim=var_dim)
548548
return x * mtf.rsqrt(variance + epsilon) * scale
549549

550+
@gin.configurable
551+
def sublayer_scale_norm(x, layer_stack, context, epsilon=1e-6, name="scale_norm"):
552+
"""Scale normalization.
553+
554+
Args:
555+
x: an input mtf.Tensor
556+
layer_stack: a LayerStack
557+
context: a Context
558+
epsilon: a float
559+
name: a string
560+
Returns:
561+
a mtf.Tensor
562+
"""
563+
del layer_stack
564+
model_dim = context.model.model_dim
565+
with tf.variable_scope(name):
566+
scale = mtf.get_variable(
567+
context.mesh,
568+
"scale",
569+
context.model.ensemble_dims,
570+
initializer=tf.ones_initializer(),
571+
dtype=context.variable_dtype)
572+
variance = mtf.reduce_mean(mtf.square(x), reduced_dim=model_dim)
573+
return x * mtf.rsqrt(variance + epsilon) * scale
574+
550575

551576
@gin.configurable
552577
def sublayer_residual(x, layer_stack, context):

0 commit comments

Comments
 (0)