Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 44928c0

Browse files
afrozenatorMesh TensorFlow Team
authored andcommitted
[Mesh-TF] Add is_training as an arg to mtf.dropout
PiperOrigin-RevId: 361088273
1 parent da119c8 commit 44928c0

File tree

13 files changed

+86
-79
lines changed

13 files changed

+86
-79
lines changed

mesh_tensorflow/bert/bert.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(self,
239239
self.embedding_output += short_position_table
240240
self.embedding_output = self.normalize(self.embedding_output)
241241
self.embedding_output = mtf.dropout(
242-
self.embedding_output,
242+
self.embedding_output, is_training,
243243
keep_prob=1.0 - self.config.layer_output_dropout_prob)
244244

245245
with tf.variable_scope("encoder"):
@@ -283,7 +283,8 @@ def __init__(self,
283283
else:
284284
raise ValueError("unknown layer type " + layer_type)
285285
x = mtf.dropout(
286-
x, keep_prob=1.0 - self.config.layer_output_dropout_prob)
286+
x, is_training,
287+
keep_prob=1.0 - self.config.layer_output_dropout_prob)
287288
layer_output = prev_layer_output + x
288289
if self.config.residual_structure == "original":
289290
layer_output = self.normalize(layer_output)
@@ -363,6 +364,7 @@ def self_attention(self, x, attention_bias):
363364
# seem a bit unusual, but is taken from the original Transformer paper.
364365
attention_probs = mtf.dropout(
365366
attention_probs,
367+
is_training=(self.config.attention_probs_dropout_prob == 0.0),
366368
keep_prob=1.0 - self.config.attention_probs_dropout_prob)
367369

368370
output = mtf.einsum([attention_probs, values],

mesh_tensorflow/bert/run_classifier.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,8 @@ def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
625625
initializer=tf.zeros_initializer())
626626

627627
with tf.variable_scope("loss"):
628-
if is_training:
629-
# I.e., 0.1 dropout
630-
output_layer = mtf.dropout(output_layer, keep_prob=0.9)
628+
# I.e., 0.1 dropout
629+
output_layer = mtf.dropout(output_layer, is_training, keep_prob=0.9)
631630
logits = mtf.einsum([output_layer, output_weights],
632631
reduced_dims=[hidden_dim])
633632
logits = logits + output_bias

mesh_tensorflow/experimental/unet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,15 +428,14 @@ def conv_with_spatial_partition(
428428
bn_update_ops = []
429429

430430
x = mtf.leaky_relu(x, 0.1)
431-
432-
if is_training:
433-
x = mtf.dropout(x, keep_p)
431+
x = mtf.dropout(x, is_training, keep_p)
434432

435433
return x, bn_update_ops
436434

437435

438436
def deconv_with_spatial_partition(
439437
x, sampled_2d_slices, image_nx_dim, image_ny_dim, n_filters, keep_p,
438+
is_training,
440439
odim_name, variable_dtype, name):
441440
"""Deconvolution with spatial partition."""
442441
if sampled_2d_slices:
@@ -456,7 +455,7 @@ def deconv_with_spatial_partition(
456455
name=name,
457456
)
458457

459-
x = mtf.dropout(x, keep_p)
458+
x = mtf.dropout(x, is_training, keep_p)
460459

461460
return x
462461

@@ -570,6 +569,7 @@ def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
570569
x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
571570
FLAGS.n_base_filters * (2**depth),
572571
FLAGS.dropout_keep_p,
572+
is_training,
573573
'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
574574
variable_dtype, 'deconv_{}_0'.format(depth))
575575
x = mtf.concat(

mesh_tensorflow/layers.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def my_fn(x):
11311131

11321132
def dense_relu_dense(x,
11331133
hidden_channels,
1134+
is_training,
11341135
dropout=0.0,
11351136
dropout_broadcast_dims=None,
11361137
master_dtype=tf.float32,
@@ -1142,6 +1143,7 @@ def dense_relu_dense(x,
11421143
Args:
11431144
x: a mtf.Tensor
11441145
hidden_channels: a mtf.Dimension - channels in the hidden layer
1146+
is_training: a boolean, set to true while training
11451147
dropout: an optional float
11461148
dropout_broadcast_dims: an optional list of mtf.Dimension
11471149
master_dtype: a tf.dtype
@@ -1156,9 +1158,8 @@ def dense_relu_dense(x,
11561158
h = dense(x, hidden_channels,
11571159
use_bias=False, activation=mtf.relu,
11581160
master_dtype=master_dtype, slice_dtype=slice_dtype, name="wi")
1159-
if dropout != 0.0:
1160-
h = mtf.dropout(h, 1.0 - dropout,
1161-
noise_shape=h.shape - dropout_broadcast_dims)
1161+
h = mtf.dropout(h, is_training, 1.0 - dropout,
1162+
noise_shape=h.shape - dropout_broadcast_dims)
11621163
return dense(h, io_channels, use_bias=False, activation=None,
11631164
master_dtype=master_dtype, slice_dtype=slice_dtype,
11641165
name="wo")
@@ -1187,6 +1188,7 @@ def local_self_attention_spatial_blocks(
11871188
query_antecedent,
11881189
kv_channels,
11891190
heads,
1191+
is_training,
11901192
memory_w_dim=None,
11911193
mask_right=False,
11921194
master_dtype=tf.float32,
@@ -1205,6 +1207,7 @@ def local_self_attention_spatial_blocks(
12051207
must have the same size as query_length, but a different name.
12061208
kv_channels: a mtf.Dimension (the size of the key and value vectors)
12071209
heads: a mtf.Dimension (the number of heads)
1210+
is_training: a bool, is true if training, else false.
12081211
memory_w_dim: mtf Dimension, for the memory width block.
12091212
mask_right: bool, flag specifying whether we mask out attention to the right
12101213
for the decoder.
@@ -1255,7 +1258,7 @@ def local_self_attention_spatial_blocks(
12551258
mask = attention_bias_local_block(
12561259
query_antecedent.mesh, w_dim, memory_w_dim)
12571260

1258-
output = dot_product_attention(q, k, v, mask=mask)
1261+
output = dot_product_attention(q, k, v, mask=mask, is_training=is_training)
12591262

12601263
return mtf.einsum(
12611264
[output, wo], mtf.Shape([batch, num_w_blocks, w_dim, io_channels]))
@@ -1264,6 +1267,7 @@ def local_self_attention_spatial_blocks(
12641267
def masked_local_attention_1d(x,
12651268
kv_channels,
12661269
heads,
1270+
is_training,
12671271
window_size=128,
12681272
master_dtype=tf.float32,
12691273
slice_dtype=tf.float32,
@@ -1280,6 +1284,7 @@ def masked_local_attention_1d(x,
12801284
x: a mtf.Tensor with shape batch_dims + [length, io_channels]
12811285
kv_channels: a mtf.Dimension (the size of the key and value vectors)
12821286
heads: a mtf.Dimension (the number of heads)
1287+
is_training: a bool, is True if training else False.
12831288
window_size: an integer
12841289
master_dtype: a tf.dtype (deprecated - use params arg)
12851290
slice_dtype: a tf.dtype (deprecated - use params arg)
@@ -1351,7 +1356,7 @@ def masked_local_attention_1d(x,
13511356
# Note: The first window_size-1 positions can see back into pre-time
13521357
# where all the keys and values are zero. We could mask this out, but we
13531358
# don't.
1354-
o = dot_product_attention(q, k, v, mask=mask)
1359+
o = dot_product_attention(q, k, v, mask=mask, is_training=is_training)
13551360
o = mtf.reshape(o, batch_dims + [heads, length, kv_channels])
13561361
return mtf.einsum([o, wo], mtf.Shape(batch_dims + [length, io_channels]))
13571362

@@ -1408,7 +1413,7 @@ def masked_local_attention_1d_incremental(x,
14081413
mtf.mod(step_num, window_length.size))
14091414
k = mtf.where(current_position, k, prev_k, output_shape=prev_k.shape)
14101415
v = mtf.where(current_position, v, prev_v, output_shape=prev_v.shape)
1411-
o = dot_product_attention(q, k, v, mask=None)
1416+
o = dot_product_attention(q, k, v, mask=None, is_training=False)
14121417
y = mtf.einsum([o, wo], x.shape)
14131418
return y, k, v
14141419

@@ -1441,6 +1446,7 @@ def local_2d_halo_exchange(k, v, num_h_blocks, h_dim,
14411446
def local_2d_self_attention_spatial_blocks(query_antecedent,
14421447
kv_channels,
14431448
heads,
1449+
is_training,
14441450
memory_h_dim=None,
14451451
memory_w_dim=None,
14461452
mask_right=False,
@@ -1460,6 +1466,7 @@ def local_2d_self_attention_spatial_blocks(query_antecedent,
14601466
query_length, but a different name.
14611467
kv_channels: a mtf.Dimension (the size of the key and value vectors)
14621468
heads: a mtf.Dimension (the number of heads)
1469+
is_training: a bool, is True while training else False.
14631470
memory_h_dim: mtf Dimension, for the memory height block.
14641471
memory_w_dim: mtf Dimension, for the memory width block.
14651472
mask_right: bool, flag specifying whether we mask out attention to the right
@@ -1515,7 +1522,7 @@ def local_2d_self_attention_spatial_blocks(query_antecedent,
15151522
mask = attention_bias_local_2d_block(query_antecedent.mesh, h_dim, w_dim,
15161523
memory_h_dim, memory_w_dim)
15171524

1518-
output = dot_product_attention(q, k, v, mask=mask)
1525+
output = dot_product_attention(q, k, v, mask=mask, is_training=is_training)
15191526

15201527
return mtf.einsum(
15211528
[output, wo],
@@ -1592,6 +1599,7 @@ def dot_product_attention(q,
15921599
k,
15931600
v,
15941601
mask,
1602+
is_training,
15951603
dropout=0.0,
15961604
dropout_broadcast_dims=None,
15971605
extra_logit=None):
@@ -1605,6 +1613,7 @@ def dot_product_attention(q,
16051613
v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
16061614
match with q.
16071615
mask: mask Tensor (see attention_mask())
1616+
is_training: a boolean, set to true while training
16081617
dropout: a float.
16091618
dropout_broadcast_dims: an optional list of mtf.Dimension
16101619
extra_logit: an optional scalar or tensor
@@ -1618,10 +1627,9 @@ def dot_product_attention(q,
16181627
if mask is not None:
16191628
logits += mask
16201629
weights = mtf.softmax(logits, length_kv, extra_logit=extra_logit)
1621-
if dropout != 0.0:
1622-
weights = mtf.dropout(
1623-
weights, 1.0 - dropout,
1624-
noise_shape=weights.shape - dropout_broadcast_dims)
1630+
weights = mtf.dropout(
1631+
weights, is_training, 1.0 - dropout,
1632+
noise_shape=weights.shape - dropout_broadcast_dims)
16251633
depth_v = v.shape.dims[-1]
16261634
outputs_shape = mtf.Shape(q.shape.dims[:-1] + [depth_v])
16271635
outputs = mtf.einsum([weights, v], outputs_shape)
@@ -1633,6 +1641,7 @@ def multihead_attention(query_antecedent,
16331641
mask,
16341642
kv_channels,
16351643
heads,
1644+
is_training,
16361645
dropout=0.0,
16371646
dropout_broadcast_dims=None,
16381647
master_dtype=tf.float32,
@@ -1653,6 +1662,7 @@ def multihead_attention(query_antecedent,
16531662
mask: mask Tensor (see attention_mask())
16541663
kv_channels: a mtf.Dimension (the size of the key and value vectors)
16551664
heads: a mtf.Dimension (the number of heads)
1665+
is_training: a bool, is True while training, false otherwise.
16561666
dropout: a floating point value
16571667
dropout_broadcast_dims: an optional list of mtf.Dimension
16581668
master_dtype: a tf.dtype
@@ -1692,7 +1702,7 @@ def multihead_attention(query_antecedent,
16921702
[memory_antecedent, wv],
16931703
mtf.Shape(batch_dims + [heads, memory_length, kv_channels]))
16941704
o = dot_product_attention(
1695-
q, k, v, mask, dropout, dropout_broadcast_dims)
1705+
q, k, v, mask, is_training, dropout, dropout_broadcast_dims)
16961706
return mtf.einsum(
16971707
[o, wo], mtf.Shape(batch_dims + [query_length, io_channels]))
16981708

@@ -1756,7 +1766,7 @@ def multihead_self_attention_incremental(query_antecedent,
17561766
mtf.greater(mtf.range(
17571767
query_antecedent.mesh, memory_length, dtype=tf.int32), step_num),
17581768
q.dtype) * -1e9
1759-
o = dot_product_attention(q, k, v, mask)
1769+
o = dot_product_attention(q, k, v, mask, is_training=False)
17601770
y = mtf.einsum([o, wo], query_antecedent.shape)
17611771
return y, k, v
17621772

@@ -1792,7 +1802,7 @@ def multihead_encdec_attention_incremental(query_antecedent,
17921802
q = mtf.einsum(
17931803
[query_antecedent, wq],
17941804
mtf.Shape(query_dims + [heads, kv_channels]))
1795-
o = dot_product_attention(q, k, v, mask)
1805+
o = dot_product_attention(q, k, v, mask, is_training=False)
17961806
return mtf.einsum([o, wo], query_antecedent.shape)
17971807

17981808

@@ -1931,6 +1941,7 @@ def multihead_self_attention_memory_compressed(x,
19311941
compression_factor,
19321942
kv_channels,
19331943
heads,
1944+
is_training,
19341945
dropout=0.0,
19351946
dropout_broadcast_dims=None,
19361947
master_dtype=tf.float32,
@@ -1948,6 +1959,7 @@ def multihead_self_attention_memory_compressed(x,
19481959
compression_factor: an integer
19491960
kv_channels: a mtf.Dimension (the size of the key and value vectors)
19501961
heads: a mtf.Dimension (the number of heads)
1962+
is_training: a boolean, set to true while training
19511963
dropout: a floating point value
19521964
dropout_broadcast_dims: an optional list of mtf.Dimension
19531965
master_dtype: a tf.dtype
@@ -1989,7 +2001,8 @@ def multihead_self_attention_memory_compressed(x,
19892001
else:
19902002
mask = None
19912003
o = dot_product_attention(
1992-
q, k, v, mask, dropout, dropout_broadcast_dims, extra_logit=0.0)
2004+
q, k, v, mask, is_training, dropout, dropout_broadcast_dims,
2005+
extra_logit=0.0)
19932006
return mtf.einsum(
19942007
[o, wo], mtf.Shape(batch_dims + [length, io_channels]))
19952008

mesh_tensorflow/layers_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ def testDenseReluDense(self):
194194
mtf_inputs = mtf.import_tf_tensor(
195195
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
196196
mtf_outputs = mtf.layers.dense_relu_dense(mtf_inputs,
197-
hidden_channels=hidden_dim)
197+
hidden_channels=hidden_dim,
198+
is_training=False)
198199
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
199200
shape=[], layout={}, devices=[""])
200201
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
@@ -232,6 +233,7 @@ def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels,
232233
mtf_query,
233234
kv_channels=kv_channels_dim,
234235
heads=heads_dim,
236+
is_training=False,
235237
window_size=window_size)
236238
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
237239
shape=[], layout={}, devices=[""])
@@ -280,7 +282,8 @@ def testDotProductAttention(
280282
mtf_query,
281283
mtf_key,
282284
mtf_value,
283-
mask=None)
285+
mask=None,
286+
is_training=False)
284287
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
285288
shape=[], layout={}, devices=[""])
286289
lowering = mtf.Lowering(graph, {mesh: mesh_impl})
@@ -320,7 +323,8 @@ def testMultiheadAttention(self, kv_channels, heads):
320323
memory_antecedent=None,
321324
mask=None,
322325
kv_channels=kv_channels_dim,
323-
heads=heads_dim)
326+
heads=heads_dim,
327+
is_training=False)
324328
mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
325329
shape=[], layout={}, devices=[""])
326330
lowering = mtf.Lowering(graph, {mesh: mesh_impl})

mesh_tensorflow/ops.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5829,7 +5829,8 @@ def random_normal(mesh, shape, **kwargs):
58295829
return RandomOperation(mesh, shape, tf.random.normal, **kwargs).outputs[0]
58305830

58315831

5832-
def dropout(x, keep_prob=None, rate=None, noise_shape=None, name=None):
5832+
def dropout(x, is_training, keep_prob=None, rate=None, noise_shape=None,
5833+
name=None):
58335834
"""Randomly set some elements to 0 and scale up the rest.
58345835
58355836
Dropout rate should be specified in exactly one of two ways:
@@ -5842,6 +5843,8 @@ def dropout(x, keep_prob=None, rate=None, noise_shape=None, name=None):
58425843
58435844
Args:
58445845
x: a Tensor
5846+
is_training: a boolean, set to true while training, if false dropout becomes
5847+
an identity function.
58455848
keep_prob: a float between 0.0 and 1.0
58465849
rate: a float between 0.0 and 1.0
58475850
noise_shape: an optional Shape (a subset of x.shape)
@@ -5858,7 +5861,7 @@ def dropout(x, keep_prob=None, rate=None, noise_shape=None, name=None):
58585861
if noise_shape is None:
58595862
noise_shape = x.shape
58605863
with tf.variable_scope(name, default_name="dropout"):
5861-
if keep_prob == 1.0:
5864+
if keep_prob == 1.0 or not is_training:
58625865
return x
58635866
noise = cast(less(random_uniform(
58645867
x.mesh, noise_shape,

mesh_tensorflow/transformer/attention.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,9 @@ def attention(q,
7979
logits += mtf.cast(bias, logits.dtype)
8080
weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
8181
weights = mtf.cast(weights, v.dtype)
82-
if dropout_rate != 0.0:
83-
weights = mtf.dropout(
84-
weights, 1.0 - dropout_rate,
85-
noise_shape=weights.shape - dropout_broadcast_dims)
82+
weights = mtf.dropout(
83+
weights, context.train, 1.0 - dropout_rate,
84+
noise_shape=weights.shape - dropout_broadcast_dims)
8685
outputs_shape = q.shape - key_dim + value_dim
8786
outputs = mtf.einsum([weights, v], outputs_shape)
8887
outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim)
@@ -150,10 +149,9 @@ def hybrid_attention(q,
150149
lower_log_weights, memory_length_dim, extra_logit=extra_logit)
151150

152151
weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights
153-
if dropout_rate != 0.0:
154-
weights = mtf.dropout(
155-
weights, 1.0 - dropout_rate,
156-
noise_shape=weights.shape - dropout_broadcast_dims)
152+
weights = mtf.dropout(
153+
weights, context.train, 1.0 - dropout_rate,
154+
noise_shape=weights.shape - dropout_broadcast_dims)
157155
outputs_shape = q.shape - key_dim + value_dim
158156
outputs = mtf.einsum([weights, v], outputs_shape)
159157
return outputs
@@ -328,10 +326,9 @@ def synthetic_attention(q,
328326
logits += bias
329327

330328
weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit)
331-
if dropout_rate != 0.0:
332-
weights = mtf.dropout(
333-
weights, 1.0 - dropout_rate,
334-
noise_shape=weights.shape - dropout_broadcast_dims)
329+
weights = mtf.dropout(
330+
weights, context.train, 1.0 - dropout_rate,
331+
noise_shape=weights.shape - dropout_broadcast_dims)
335332

336333
if synthesize and "plus" not in synthesize_mode:
337334
if synthesize_mode == "dense_minus":

0 commit comments

Comments
 (0)