From da43c271c8b430e3c2ebba654858d1bb0e5875bf Mon Sep 17 00:00:00 2001 From: William Fedus Date: Mon, 12 Apr 2021 20:28:35 -0700 Subject: [PATCH] Option to use mtf.Print to log which tokens are sent to which experts when run on CPU. PiperOrigin-RevId: 368137313 --- mesh_tensorflow/transformer/moe.py | 72 ++++++++++++++++++++-- mesh_tensorflow/transformer/transformer.py | 5 +- mesh_tensorflow/transformer/utils.py | 19 ++++-- 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/mesh_tensorflow/transformer/moe.py b/mesh_tensorflow/transformer/moe.py index fda505ee..5feaeeb2 100644 --- a/mesh_tensorflow/transformer/moe.py +++ b/mesh_tensorflow/transformer/moe.py @@ -65,7 +65,8 @@ def __init__(self, word_embed_mode=None, use_second_place_expert_prob=None, use_second_place_expert_prob_temp=None, - top_n_num_experts_per_token=3): + top_n_num_experts_per_token=3, + token_logging=False): self._hparams = HParams( moe_gating=moe_gating, moe_num_experts=num_experts, @@ -97,6 +98,7 @@ def __init__(self, use_second_place_expert_prob_temp), moe_top_n_num_experts_per_token=top_n_num_experts_per_token) self._activation = activation + self.token_logging = token_logging def call(self, context, x, losses=None): """Call the layer.""" @@ -116,7 +118,13 @@ def call(self, context, x, losses=None): output_dim = self._hparams.moe_output_dim else: output_dim = context.model.model_dim - y, loss = transformer_moe_layer_v1( + if self.token_logging: + tokens = _detokenize(context.inputs, context.model.vocabulary) + x = mtf.Print(x, [tokens], "tokens:", summarize=1000) + extras = _windows(context.inputs, context.length_dim) + else: + extras = None + y, loss, extras = transformer_moe_layer_v1( x, output_dim, self._hparams, @@ -127,7 +135,16 @@ def call(self, context, x, losses=None): nonpadding=context.nonpadding, activation=self._activation, num_microbatches=context.num_microbatches, - token_embeddings=context.input_embeddings) + token_embeddings=context.input_embeddings, + extras=extras) + + if extras: + extras = _detokenize(extras, context.model.vocabulary) + experts_dim = mtf.Dimension("experts", self._hparams.moe_num_experts) + extras = mtf.unstack(extras, experts_dim) + for i, t in enumerate(extras): + y = mtf.Print(y, [t], "EXPERT %s:" % i, summarize=1000) + if context.losses is not None: context.losses.append(loss) if not has_length_dim: @@ -139,6 +156,23 @@ def call(self, context, x, losses=None): return y +@gin.configurable +def _windows(ids, length_dim, window_start=0, window_end=0): + to_stack = [] + for offset in range(window_start, window_end + 1): + to_stack.append(mtf.shift(ids, -offset, length_dim, wrap=False)) + return mtf.stack(to_stack, "window", axis=ids.shape.ndims) + + +def _detokenize(ids, vocabulary): + return mtf.slicewise( + vocabulary.decode_tf, + [ids], + output_shape=mtf.Shape(ids.shape.dims[:-1]), + output_dtype=tf.string, + splittable_dims=ids.shape.dims[:-1]) + + class MoE2D(transformer.TransformerLayer): """Mixture of Experts Layer.""" @@ -202,7 +236,7 @@ def call(self, context, x, losses=None): def transformer_moe_layer_v1( inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu, - num_microbatches=None, token_embeddings=None): + num_microbatches=None, token_embeddings=None, extras=None): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 @@ -281,6 +315,7 @@ def transformer_moe_layer_v1( [batch_dim(s), length_dim, input_dim]. These are the word embeddings for that correspond to the inputs. These can optionally be used to make routing decisions. + extras: a tensor to dispatch (for debugging purposes) Returns: outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim] @@ -344,6 +379,10 @@ def transformer_moe_layer_v1( # over which those groups are split. batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) + + if extras: + extras_dims = extras.shape.dims[len(batch_and_length_dims):] + # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim @@ -381,6 +420,11 @@ def transformer_moe_layer_v1( token_embeddings = mtf.cast( mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype) + if extras: + extras = mtf.reshape( + extras, + [outer_batch_dim, num_groups_dim, group_size_dim] + extras_dims) + # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train @@ -503,6 +547,17 @@ def transformer_moe_layer_v1( input_dim ])) + if extras: + extras = mtf.einsum([extras, mtf.cast(dispatch_tensor, extras.dtype)], + mtf.Shape([ + outer_batch_dim, experts_dim_unsplit, + num_groups_dim, expert_capacity_dim] + extras_dims)) + extras = mtf.reshape( + extras, + mtf.Shape([ + outer_batch_dim, experts_dim, batch_dim_unsplit, + expert_capacity_dim] + extras_dims)) + # Now feed the expert inputs through the experts. h = mtf.layers.dense_product( expert_inputs, @@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name): k = _compute_output(k_h, layer_name="k_wo") outputs.append(q) outputs.append(k) - return outputs, loss * hparams.moe_loss_coef + return outputs, loss * hparams.moe_loss_coef, None else: output = _compute_output(h, layer_name="wo") - return output, loss * hparams.moe_loss_coef + loss *= hparams.moe_loss_coef + + if extras: + return output, loss, extras + else: + return output, loss, None def transformer_moe_layer_v2( diff --git a/mesh_tensorflow/transformer/transformer.py b/mesh_tensorflow/transformer/transformer.py index f8c20d9a..428e1679 100644 --- a/mesh_tensorflow/transformer/transformer.py +++ b/mesh_tensorflow/transformer/transformer.py @@ -722,7 +722,8 @@ def __init__(self, input_full_attention=False, loss_on_targets_only=False, loss_denominator=None, - token_dropout_rate=0.0): + token_dropout_rate=0.0, + vocabulary=None): """Create a Unitransformer. Args: @@ -767,6 +768,7 @@ def __init__(self, same denominator as was used for the pretraining. This complication might be avoided by always using loss_denominator = 1.0. token_dropout_rate: an optional floating point value + vocabulary: an optional vocabularies.Vocabulary """ self.layer_stack = layer_stack self.model_dim = mtf.Dimension("d_model", d_model) @@ -807,6 +809,7 @@ def __init__(self, raise ValueError( "input_full_attention only makes sense with autoregressive") self.token_dropout_rate = token_dropout_rate + self.vocabulary = vocabulary @property def fully_autoregressive(self): diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index e64ca8a2..e6e8d936 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -172,7 +172,9 @@ def build_model(model_type="bitransformer", input_vocab_size=gin.REQUIRED, output_vocab_size=gin.REQUIRED, layout_rules=None, - mesh_shape=None): + mesh_shape=None, + input_vocabulary=None, + target_vocabulary=None): """Build a transformer model. Currently, four types of models are supported: @@ -214,15 +216,21 @@ def build_model(model_type="bitransformer", output_vocab_size: an integer layout_rules: optional, input to mtf.convert_to_layout_rules mesh_shape: optional, an input to mtf.convert_to_shape() + input_vocabulary: optional, a vocubalaries.Vocabulary + target_vocabulary: optional, a vocubalaries.Vocabulary + Returns: a Unitransformer or Bitransformer """ if model_type == "bitransformer": - return transformer.make_bitransformer( + ret = transformer.make_bitransformer( input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, mesh_shape=mesh_shape, layout=layout_rules) + ret.encoder.vocabulary = input_vocabulary + ret.decoder.vocabulary = target_vocabulary + return ret elif model_type == "bi_student_teacher": return transformer.make_bi_student_teacher( input_vocab_size=input_vocab_size, @@ -236,7 +244,8 @@ def build_model(model_type="bitransformer", input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, mesh_shape=mesh_shape, - layout=layout_rules) + layout=layout_rules, + vocabulary=input_vocabulary) else: raise ValueError("unknown model_type") @@ -2067,7 +2076,9 @@ def get_estimator(model_type, vocabulary, mesh_shape, input_vocab_size=inputs_vocabulary(vocabulary).vocab_size, output_vocab_size=targets_vocabulary(vocabulary).vocab_size, layout_rules=layout_rules, - mesh_shape=mesh_shape) + mesh_shape=mesh_shape, + input_vocabulary=inputs_vocabulary(vocabulary), + target_vocabulary=targets_vocabulary(vocabulary)) model_fn = tpu_estimator_model_fn( model_type=model_type,