Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Option to use mtf.Print to log which tokens are sent to which experts when run on CPU. #329

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 66 additions & 6 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(self,
word_embed_mode=None,
use_second_place_expert_prob=None,
use_second_place_expert_prob_temp=None,
top_n_num_experts_per_token=3):
top_n_num_experts_per_token=3,
token_logging=False):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(self,
use_second_place_expert_prob_temp),
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
self._activation = activation
self.token_logging = token_logging

def call(self, context, x, losses=None):
"""Call the layer."""
Expand All @@ -116,7 +118,13 @@ def call(self, context, x, losses=None):
output_dim = self._hparams.moe_output_dim
else:
output_dim = context.model.model_dim
y, loss = transformer_moe_layer_v1(
if self.token_logging:
tokens = _detokenize(context.inputs, context.model.vocabulary)
x = mtf.Print(x, [tokens], "tokens:", summarize=1000)
extras = _windows(context.inputs, context.length_dim)
else:
extras = None
y, loss, extras = transformer_moe_layer_v1(
x,
output_dim,
self._hparams,
Expand All @@ -127,7 +135,16 @@ def call(self, context, x, losses=None):
nonpadding=context.nonpadding,
activation=self._activation,
num_microbatches=context.num_microbatches,
token_embeddings=context.input_embeddings)
token_embeddings=context.input_embeddings,
extras=extras)

if extras:
extras = _detokenize(extras, context.model.vocabulary)
experts_dim = mtf.Dimension("experts", self._hparams.moe_num_experts)
extras = mtf.unstack(extras, experts_dim)
for i, t in enumerate(extras):
y = mtf.Print(y, [t], "EXPERT %s:" % i, summarize=1000)

if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
Expand All @@ -139,6 +156,23 @@ def call(self, context, x, losses=None):
return y


@gin.configurable
def _windows(ids, length_dim, window_start=0, window_end=0):
to_stack = []
for offset in range(window_start, window_end + 1):
to_stack.append(mtf.shift(ids, -offset, length_dim, wrap=False))
return mtf.stack(to_stack, "window", axis=ids.shape.ndims)


def _detokenize(ids, vocabulary):
return mtf.slicewise(
vocabulary.decode_tf,
[ids],
output_shape=mtf.Shape(ids.shape.dims[:-1]),
output_dtype=tf.string,
splittable_dims=ids.shape.dims[:-1])


class MoE2D(transformer.TransformerLayer):
"""Mixture of Experts Layer."""

Expand Down Expand Up @@ -202,7 +236,7 @@ def call(self, context, x, losses=None):
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None):
num_microbatches=None, token_embeddings=None, extras=None):
"""Local mixture of experts that works well on TPU.

Adapted from the paper https://arxiv.org/abs/1701.06538
Expand Down Expand Up @@ -281,6 +315,7 @@ def transformer_moe_layer_v1(
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
extras: a tensor to dispatch (for debugging purposes)

Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
Expand Down Expand Up @@ -344,6 +379,10 @@ def transformer_moe_layer_v1(
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])

if extras:
extras_dims = extras.shape.dims[len(batch_and_length_dims):]

# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
Expand Down Expand Up @@ -381,6 +420,11 @@ def transformer_moe_layer_v1(
token_embeddings = mtf.cast(
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

if extras:
extras = mtf.reshape(
extras,
[outer_batch_dim, num_groups_dim, group_size_dim] + extras_dims)

# Each sequence sends expert_capacity positions to each expert.
if train:
capacity_factor = hparams.moe_capacity_factor_train
Expand Down Expand Up @@ -503,6 +547,17 @@ def transformer_moe_layer_v1(
input_dim
]))

if extras:
extras = mtf.einsum([extras, mtf.cast(dispatch_tensor, extras.dtype)],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim] + extras_dims))
extras = mtf.reshape(
extras,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit,
expert_capacity_dim] + extras_dims))

# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
Expand Down Expand Up @@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name):
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
return outputs, loss * hparams.moe_loss_coef, None
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef
loss *= hparams.moe_loss_coef

if extras:
return output, loss, extras
else:
return output, loss, None


def transformer_moe_layer_v2(
Expand Down
5 changes: 4 additions & 1 deletion mesh_tensorflow/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ def __init__(self,
input_full_attention=False,
loss_on_targets_only=False,
loss_denominator=None,
token_dropout_rate=0.0):
token_dropout_rate=0.0,
vocabulary=None):
"""Create a Unitransformer.

Args:
Expand Down Expand Up @@ -767,6 +768,7 @@ def __init__(self,
same denominator as was used for the pretraining. This complication
might be avoided by always using loss_denominator = 1.0.
token_dropout_rate: an optional floating point value
vocabulary: an optional vocabularies.Vocabulary
"""
self.layer_stack = layer_stack
self.model_dim = mtf.Dimension("d_model", d_model)
Expand Down Expand Up @@ -807,6 +809,7 @@ def __init__(self,
raise ValueError(
"input_full_attention only makes sense with autoregressive")
self.token_dropout_rate = token_dropout_rate
self.vocabulary = vocabulary

@property
def fully_autoregressive(self):
Expand Down
19 changes: 15 additions & 4 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def build_model(model_type="bitransformer",
input_vocab_size=gin.REQUIRED,
output_vocab_size=gin.REQUIRED,
layout_rules=None,
mesh_shape=None):
mesh_shape=None,
input_vocabulary=None,
target_vocabulary=None):
"""Build a transformer model.

Currently, four types of models are supported:
Expand Down Expand Up @@ -214,15 +216,21 @@ def build_model(model_type="bitransformer",
output_vocab_size: an integer
layout_rules: optional, input to mtf.convert_to_layout_rules
mesh_shape: optional, an input to mtf.convert_to_shape()
input_vocabulary: optional, a vocubalaries.Vocabulary
target_vocabulary: optional, a vocubalaries.Vocabulary

Returns:
a Unitransformer or Bitransformer
"""
if model_type == "bitransformer":
return transformer.make_bitransformer(
ret = transformer.make_bitransformer(
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
ret.encoder.vocabulary = input_vocabulary
ret.decoder.vocabulary = target_vocabulary
return ret
elif model_type == "bi_student_teacher":
return transformer.make_bi_student_teacher(
input_vocab_size=input_vocab_size,
Expand All @@ -236,7 +244,8 @@ def build_model(model_type="bitransformer",
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
layout=layout_rules,
vocabulary=input_vocabulary)
else:
raise ValueError("unknown model_type")

Expand Down Expand Up @@ -2067,7 +2076,9 @@ def get_estimator(model_type, vocabulary, mesh_shape,
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
layout_rules=layout_rules,
mesh_shape=mesh_shape)
mesh_shape=mesh_shape,
input_vocabulary=inputs_vocabulary(vocabulary),
target_vocabulary=targets_vocabulary(vocabulary))

model_fn = tpu_estimator_model_fn(
model_type=model_type,
Expand Down