Skip to content

Commit 96bff0b

Browse files
authored
contiguous is no longer needed in SGD [pr] (tinygrad#8760)
* contiguous is no longer needed in SGD [pr] * add allow condition
1 parent efc7971 commit 96bff0b

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

Diff for: test/test_schedule.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def test_fold_conv_batchnorm(self):
323323

324324
def test_fold_conv_batchnorm_optim(self):
325325
# this is too high
326-
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 15)]:
326+
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]:
327327
with self.subTest(optim=optim.__name__):
328328
with Tensor.train():
329329
img = Tensor.ones(1,3,4,4)
@@ -1070,7 +1070,7 @@ def test_sgd_conv_fuse(self):
10701070
opt = nn.optim.SGD(nn.state.get_parameters(c1))
10711071
opt.zero_grad()
10721072
c1(img).relu().sum().backward()
1073-
check_schedule(opt.schedule_step(), 5)
1073+
check_schedule(opt.schedule_step(), 3)
10741074

10751075
def test_sgd_2convs_fuse(self):
10761076
with Tensor.train():
@@ -1081,7 +1081,7 @@ def test_sgd_2convs_fuse(self):
10811081
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
10821082
opt.zero_grad()
10831083
c2(c1(img).relu()).relu().sum().backward()
1084-
check_schedule(opt.schedule_step(), 9)
1084+
check_schedule(opt.schedule_step(), 7)
10851085

10861086
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
10871087
with Tensor.train():
@@ -1092,7 +1092,7 @@ def test_fold_2convs_sgd_nesterov_momentum_wd(self):
10921092
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
10931093
opt.zero_grad()
10941094
c2(c1(img).relu()).relu().sum().backward()
1095-
check_schedule(opt.schedule_step(), 11)
1095+
check_schedule(opt.schedule_step(), 9)
10961096

10971097
def test_sgd_4convs_fuse(self):
10981098
with Tensor.train():
@@ -1105,7 +1105,7 @@ def test_sgd_4convs_fuse(self):
11051105
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
11061106
opt.zero_grad()
11071107
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
1108-
check_schedule(opt.schedule_step(), 21)
1108+
check_schedule(opt.schedule_step(), 17)
11091109

11101110
def test_sgd_4convs_fuse_conv_bw(self):
11111111
with Tensor.train():
@@ -1118,7 +1118,7 @@ def test_sgd_4convs_fuse_conv_bw(self):
11181118
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
11191119
opt.zero_grad()
11201120
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
1121-
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 18)
1121+
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14)
11221122

11231123
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
11241124
def test_prefer_half_buffer(self):

Diff for: tinygrad/engine/schedule.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,13 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
222222
assign_preloads[x.buf_uop] = None
223223
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
224224
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
225+
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
226+
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
225227
# if it has a single view and it's equal when you shrink a contig, it's fine
226-
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
227-
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
228-
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
228+
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
229+
# otherwise, it's not fine
230+
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
231+
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
229232
# capture process replay
230233
if CAPTURE_PROCESS_REPLAY:
231234
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))

Diff for: tinygrad/nn/optim.py

-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ def __init__(self, params:list[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-
7777

7878
def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
7979
for i, (t, g) in enumerate(zip(self.params, grads)):
80-
# contiguous is needed since the grads can allegedly form a "diamond"
81-
# TODO: fix this in lazy.py
82-
g = g.contiguous()
8380
if self.tcoef != 0:
8481
r1 = t.detach().square().sum().sqrt()
8582
r2 = g.square().sum().sqrt()

0 commit comments

Comments
 (0)