Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit da119c8

Browse files
author
Mesh TensorFlow Team
committed
Merge pull request #295 from lucidrains:master
PiperOrigin-RevId: 360963493
2 parents 878832b + 6ba1aca commit da119c8

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

mesh_tensorflow/transformer/transformer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,36 @@ def sublayer_rms_norm_subsampled(x, layer_stack, context, percentage=100.,
548548
return x * mtf.rsqrt(variance + epsilon) * scale
549549

550550

551+
@gin.configurable
552+
def sublayer_scale_norm(x,
553+
layer_stack,
554+
context,
555+
epsilon=1e-6,
556+
name="scale_norm"):
557+
"""Scale normalization.
558+
559+
Args:
560+
x: an input mtf.Tensor
561+
layer_stack: a LayerStack
562+
context: a Context
563+
epsilon: a float
564+
name: a string
565+
Returns:
566+
a mtf.Tensor
567+
"""
568+
del layer_stack
569+
model_dim = context.model.model_dim
570+
with tf.variable_scope(name):
571+
scale = mtf.get_variable(
572+
context.mesh,
573+
"scale",
574+
context.model.ensemble_dims,
575+
initializer=tf.ones_initializer(),
576+
dtype=context.variable_dtype)
577+
variance = mtf.reduce_mean(mtf.square(x), reduced_dim=model_dim)
578+
return x * mtf.rsqrt(variance + epsilon) * scale
579+
580+
551581
@gin.configurable
552582
def sublayer_residual(x, layer_stack, context):
553583
del layer_stack

0 commit comments

Comments
 (0)