@@ -1131,6 +1131,7 @@ def my_fn(x):
1131
1131
1132
1132
def dense_relu_dense (x ,
1133
1133
hidden_channels ,
1134
+ is_training ,
1134
1135
dropout = 0.0 ,
1135
1136
dropout_broadcast_dims = None ,
1136
1137
master_dtype = tf .float32 ,
@@ -1142,6 +1143,7 @@ def dense_relu_dense(x,
1142
1143
Args:
1143
1144
x: a mtf.Tensor
1144
1145
hidden_channels: a mtf.Dimension - channels in the hidden layer
1146
+ is_training: a boolean, set to true while training
1145
1147
dropout: an optional float
1146
1148
dropout_broadcast_dims: an optional list of mtf.Dimension
1147
1149
master_dtype: a tf.dtype
@@ -1156,9 +1158,8 @@ def dense_relu_dense(x,
1156
1158
h = dense (x , hidden_channels ,
1157
1159
use_bias = False , activation = mtf .relu ,
1158
1160
master_dtype = master_dtype , slice_dtype = slice_dtype , name = "wi" )
1159
- if dropout != 0.0 :
1160
- h = mtf .dropout (h , 1.0 - dropout ,
1161
- noise_shape = h .shape - dropout_broadcast_dims )
1161
+ h = mtf .dropout (h , is_training , 1.0 - dropout ,
1162
+ noise_shape = h .shape - dropout_broadcast_dims )
1162
1163
return dense (h , io_channels , use_bias = False , activation = None ,
1163
1164
master_dtype = master_dtype , slice_dtype = slice_dtype ,
1164
1165
name = "wo" )
@@ -1187,6 +1188,7 @@ def local_self_attention_spatial_blocks(
1187
1188
query_antecedent ,
1188
1189
kv_channels ,
1189
1190
heads ,
1191
+ is_training ,
1190
1192
memory_w_dim = None ,
1191
1193
mask_right = False ,
1192
1194
master_dtype = tf .float32 ,
@@ -1205,6 +1207,7 @@ def local_self_attention_spatial_blocks(
1205
1207
must have the same size as query_length, but a different name.
1206
1208
kv_channels: a mtf.Dimension (the size of the key and value vectors)
1207
1209
heads: a mtf.Dimension (the number of heads)
1210
+ is_training: a bool, is true if training, else false.
1208
1211
memory_w_dim: mtf Dimension, for the memory width block.
1209
1212
mask_right: bool, flag specifying whether we mask out attention to the right
1210
1213
for the decoder.
@@ -1255,7 +1258,7 @@ def local_self_attention_spatial_blocks(
1255
1258
mask = attention_bias_local_block (
1256
1259
query_antecedent .mesh , w_dim , memory_w_dim )
1257
1260
1258
- output = dot_product_attention (q , k , v , mask = mask )
1261
+ output = dot_product_attention (q , k , v , mask = mask , is_training = is_training )
1259
1262
1260
1263
return mtf .einsum (
1261
1264
[output , wo ], mtf .Shape ([batch , num_w_blocks , w_dim , io_channels ]))
@@ -1264,6 +1267,7 @@ def local_self_attention_spatial_blocks(
1264
1267
def masked_local_attention_1d (x ,
1265
1268
kv_channels ,
1266
1269
heads ,
1270
+ is_training ,
1267
1271
window_size = 128 ,
1268
1272
master_dtype = tf .float32 ,
1269
1273
slice_dtype = tf .float32 ,
@@ -1280,6 +1284,7 @@ def masked_local_attention_1d(x,
1280
1284
x: a mtf.Tensor with shape batch_dims + [length, io_channels]
1281
1285
kv_channels: a mtf.Dimension (the size of the key and value vectors)
1282
1286
heads: a mtf.Dimension (the number of heads)
1287
+ is_training: a bool, is True if training else False.
1283
1288
window_size: an integer
1284
1289
master_dtype: a tf.dtype (deprecated - use params arg)
1285
1290
slice_dtype: a tf.dtype (deprecated - use params arg)
@@ -1351,7 +1356,7 @@ def masked_local_attention_1d(x,
1351
1356
# Note: The first window_size-1 positions can see back into pre-time
1352
1357
# where all the keys and values are zero. We could mask this out, but we
1353
1358
# don't.
1354
- o = dot_product_attention (q , k , v , mask = mask )
1359
+ o = dot_product_attention (q , k , v , mask = mask , is_training = is_training )
1355
1360
o = mtf .reshape (o , batch_dims + [heads , length , kv_channels ])
1356
1361
return mtf .einsum ([o , wo ], mtf .Shape (batch_dims + [length , io_channels ]))
1357
1362
@@ -1408,7 +1413,7 @@ def masked_local_attention_1d_incremental(x,
1408
1413
mtf .mod (step_num , window_length .size ))
1409
1414
k = mtf .where (current_position , k , prev_k , output_shape = prev_k .shape )
1410
1415
v = mtf .where (current_position , v , prev_v , output_shape = prev_v .shape )
1411
- o = dot_product_attention (q , k , v , mask = None )
1416
+ o = dot_product_attention (q , k , v , mask = None , is_training = False )
1412
1417
y = mtf .einsum ([o , wo ], x .shape )
1413
1418
return y , k , v
1414
1419
@@ -1441,6 +1446,7 @@ def local_2d_halo_exchange(k, v, num_h_blocks, h_dim,
1441
1446
def local_2d_self_attention_spatial_blocks (query_antecedent ,
1442
1447
kv_channels ,
1443
1448
heads ,
1449
+ is_training ,
1444
1450
memory_h_dim = None ,
1445
1451
memory_w_dim = None ,
1446
1452
mask_right = False ,
@@ -1460,6 +1466,7 @@ def local_2d_self_attention_spatial_blocks(query_antecedent,
1460
1466
query_length, but a different name.
1461
1467
kv_channels: a mtf.Dimension (the size of the key and value vectors)
1462
1468
heads: a mtf.Dimension (the number of heads)
1469
+ is_training: a bool, is True while training else False.
1463
1470
memory_h_dim: mtf Dimension, for the memory height block.
1464
1471
memory_w_dim: mtf Dimension, for the memory width block.
1465
1472
mask_right: bool, flag specifying whether we mask out attention to the right
@@ -1515,7 +1522,7 @@ def local_2d_self_attention_spatial_blocks(query_antecedent,
1515
1522
mask = attention_bias_local_2d_block (query_antecedent .mesh , h_dim , w_dim ,
1516
1523
memory_h_dim , memory_w_dim )
1517
1524
1518
- output = dot_product_attention (q , k , v , mask = mask )
1525
+ output = dot_product_attention (q , k , v , mask = mask , is_training = is_training )
1519
1526
1520
1527
return mtf .einsum (
1521
1528
[output , wo ],
@@ -1592,6 +1599,7 @@ def dot_product_attention(q,
1592
1599
k ,
1593
1600
v ,
1594
1601
mask ,
1602
+ is_training ,
1595
1603
dropout = 0.0 ,
1596
1604
dropout_broadcast_dims = None ,
1597
1605
extra_logit = None ):
@@ -1605,6 +1613,7 @@ def dot_product_attention(q,
1605
1613
v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
1606
1614
match with q.
1607
1615
mask: mask Tensor (see attention_mask())
1616
+ is_training: a boolean, set to true while training
1608
1617
dropout: a float.
1609
1618
dropout_broadcast_dims: an optional list of mtf.Dimension
1610
1619
extra_logit: an optional scalar or tensor
@@ -1618,10 +1627,9 @@ def dot_product_attention(q,
1618
1627
if mask is not None :
1619
1628
logits += mask
1620
1629
weights = mtf .softmax (logits , length_kv , extra_logit = extra_logit )
1621
- if dropout != 0.0 :
1622
- weights = mtf .dropout (
1623
- weights , 1.0 - dropout ,
1624
- noise_shape = weights .shape - dropout_broadcast_dims )
1630
+ weights = mtf .dropout (
1631
+ weights , is_training , 1.0 - dropout ,
1632
+ noise_shape = weights .shape - dropout_broadcast_dims )
1625
1633
depth_v = v .shape .dims [- 1 ]
1626
1634
outputs_shape = mtf .Shape (q .shape .dims [:- 1 ] + [depth_v ])
1627
1635
outputs = mtf .einsum ([weights , v ], outputs_shape )
@@ -1633,6 +1641,7 @@ def multihead_attention(query_antecedent,
1633
1641
mask ,
1634
1642
kv_channels ,
1635
1643
heads ,
1644
+ is_training ,
1636
1645
dropout = 0.0 ,
1637
1646
dropout_broadcast_dims = None ,
1638
1647
master_dtype = tf .float32 ,
@@ -1653,6 +1662,7 @@ def multihead_attention(query_antecedent,
1653
1662
mask: mask Tensor (see attention_mask())
1654
1663
kv_channels: a mtf.Dimension (the size of the key and value vectors)
1655
1664
heads: a mtf.Dimension (the number of heads)
1665
+ is_training: a bool, is True while training, false otherwise.
1656
1666
dropout: a floating point value
1657
1667
dropout_broadcast_dims: an optional list of mtf.Dimension
1658
1668
master_dtype: a tf.dtype
@@ -1692,7 +1702,7 @@ def multihead_attention(query_antecedent,
1692
1702
[memory_antecedent , wv ],
1693
1703
mtf .Shape (batch_dims + [heads , memory_length , kv_channels ]))
1694
1704
o = dot_product_attention (
1695
- q , k , v , mask , dropout , dropout_broadcast_dims )
1705
+ q , k , v , mask , is_training , dropout , dropout_broadcast_dims )
1696
1706
return mtf .einsum (
1697
1707
[o , wo ], mtf .Shape (batch_dims + [query_length , io_channels ]))
1698
1708
@@ -1756,7 +1766,7 @@ def multihead_self_attention_incremental(query_antecedent,
1756
1766
mtf .greater (mtf .range (
1757
1767
query_antecedent .mesh , memory_length , dtype = tf .int32 ), step_num ),
1758
1768
q .dtype ) * - 1e9
1759
- o = dot_product_attention (q , k , v , mask )
1769
+ o = dot_product_attention (q , k , v , mask , is_training = False )
1760
1770
y = mtf .einsum ([o , wo ], query_antecedent .shape )
1761
1771
return y , k , v
1762
1772
@@ -1792,7 +1802,7 @@ def multihead_encdec_attention_incremental(query_antecedent,
1792
1802
q = mtf .einsum (
1793
1803
[query_antecedent , wq ],
1794
1804
mtf .Shape (query_dims + [heads , kv_channels ]))
1795
- o = dot_product_attention (q , k , v , mask )
1805
+ o = dot_product_attention (q , k , v , mask , is_training = False )
1796
1806
return mtf .einsum ([o , wo ], query_antecedent .shape )
1797
1807
1798
1808
@@ -1931,6 +1941,7 @@ def multihead_self_attention_memory_compressed(x,
1931
1941
compression_factor ,
1932
1942
kv_channels ,
1933
1943
heads ,
1944
+ is_training ,
1934
1945
dropout = 0.0 ,
1935
1946
dropout_broadcast_dims = None ,
1936
1947
master_dtype = tf .float32 ,
@@ -1948,6 +1959,7 @@ def multihead_self_attention_memory_compressed(x,
1948
1959
compression_factor: an integer
1949
1960
kv_channels: a mtf.Dimension (the size of the key and value vectors)
1950
1961
heads: a mtf.Dimension (the number of heads)
1962
+ is_training: a boolean, set to true while training
1951
1963
dropout: a floating point value
1952
1964
dropout_broadcast_dims: an optional list of mtf.Dimension
1953
1965
master_dtype: a tf.dtype
@@ -1989,7 +2001,8 @@ def multihead_self_attention_memory_compressed(x,
1989
2001
else :
1990
2002
mask = None
1991
2003
o = dot_product_attention (
1992
- q , k , v , mask , dropout , dropout_broadcast_dims , extra_logit = 0.0 )
2004
+ q , k , v , mask , is_training , dropout , dropout_broadcast_dims ,
2005
+ extra_logit = 0.0 )
1993
2006
return mtf .einsum (
1994
2007
[o , wo ], mtf .Shape (batch_dims + [length , io_channels ]))
1995
2008
0 commit comments