@@ -323,7 +323,7 @@ def test_fold_conv_batchnorm(self):
323
323
324
324
def test_fold_conv_batchnorm_optim (self ):
325
325
# this is too high
326
- for optim , cnt in [(nn .optim .Adam , 18 ), (nn .optim .SGD , 15 )]:
326
+ for optim , cnt in [(nn .optim .Adam , 18 ), (nn .optim .SGD , 11 )]:
327
327
with self .subTest (optim = optim .__name__ ):
328
328
with Tensor .train ():
329
329
img = Tensor .ones (1 ,3 ,4 ,4 )
@@ -1070,7 +1070,7 @@ def test_sgd_conv_fuse(self):
1070
1070
opt = nn .optim .SGD (nn .state .get_parameters (c1 ))
1071
1071
opt .zero_grad ()
1072
1072
c1 (img ).relu ().sum ().backward ()
1073
- check_schedule (opt .schedule_step (), 5 )
1073
+ check_schedule (opt .schedule_step (), 3 )
1074
1074
1075
1075
def test_sgd_2convs_fuse (self ):
1076
1076
with Tensor .train ():
@@ -1081,7 +1081,7 @@ def test_sgd_2convs_fuse(self):
1081
1081
opt = nn .optim .SGD (nn .state .get_parameters ([c1 , c2 ]))
1082
1082
opt .zero_grad ()
1083
1083
c2 (c1 (img ).relu ()).relu ().sum ().backward ()
1084
- check_schedule (opt .schedule_step (), 9 )
1084
+ check_schedule (opt .schedule_step (), 7 )
1085
1085
1086
1086
def test_fold_2convs_sgd_nesterov_momentum_wd (self ):
1087
1087
with Tensor .train ():
@@ -1092,7 +1092,7 @@ def test_fold_2convs_sgd_nesterov_momentum_wd(self):
1092
1092
opt = nn .optim .SGD (nn .state .get_parameters ([c1 , c2 ]), nesterov = True , momentum = 0.9 , weight_decay = 0.1 )
1093
1093
opt .zero_grad ()
1094
1094
c2 (c1 (img ).relu ()).relu ().sum ().backward ()
1095
- check_schedule (opt .schedule_step (), 11 )
1095
+ check_schedule (opt .schedule_step (), 9 )
1096
1096
1097
1097
def test_sgd_4convs_fuse (self ):
1098
1098
with Tensor .train ():
@@ -1105,7 +1105,7 @@ def test_sgd_4convs_fuse(self):
1105
1105
opt = nn .optim .SGD (nn .state .get_parameters ([c1 , c2 , c3 , c4 ]))
1106
1106
opt .zero_grad ()
1107
1107
c4 (c3 (c2 (c1 (img ).relu ()).relu ()).relu ()).relu ().sum ().backward ()
1108
- check_schedule (opt .schedule_step (), 21 )
1108
+ check_schedule (opt .schedule_step (), 17 )
1109
1109
1110
1110
def test_sgd_4convs_fuse_conv_bw (self ):
1111
1111
with Tensor .train ():
@@ -1118,7 +1118,7 @@ def test_sgd_4convs_fuse_conv_bw(self):
1118
1118
opt = nn .optim .SGD (nn .state .get_parameters ([c1 , c2 , c3 , c4 ]))
1119
1119
opt .zero_grad ()
1120
1120
c4 (c3 (c2 (c1 (img ).relu ()).relu ()).relu ()).relu ().sum ().backward ()
1121
- with Context (FUSE_CONV_BW = 1 ): check_schedule (opt .schedule_step (), 18 )
1121
+ with Context (FUSE_CONV_BW = 1 ): check_schedule (opt .schedule_step (), 14 )
1122
1122
1123
1123
@unittest .skipUnless (is_dtype_supported (dtypes .half ), "need half" )
1124
1124
def test_prefer_half_buffer (self ):
0 commit comments