Skip to content

Commit 431a866

Browse files
authored
fix multi Ops.CONTIGUOUS_BACKWARD [pr] (tinygrad#8843)
1 parent 07d3676 commit 431a866

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

test/test_multitensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ def test_backprop_conv(self):
288288
optim.step()
289289
out.numpy()
290290

291+
def test_backprop_conv_wino(self):
292+
with Context(WINO=1): self.test_backprop_conv()
293+
291294
def test_backward_sum(self):
292295
x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0)
293296
w = Tensor([1.,2,3,4], requires_grad=True).shard(devices_2)

tinygrad/engine/multi.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=
158158
(UPat(Ops.FLIP, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), flip_multi),
159159
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
160160
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
161-
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
161+
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
162+
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
162163
])
163164

164165
@track_rewrites(named=True)

0 commit comments

Comments
 (0)