Skip to content

Commit 4e6665b

Browse files
authored
different way to write torch backend (tinygrad#9197)
* different way to write torch backend * both backends * more work * simpler code * more work * test both * imply unwrap/wrap * FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works * ready to start making test_ops work in torch backend * backward pass, TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add works * FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_simple_conv2d works * matmul backward is broken with as_strided
1 parent 041b6d5 commit 4e6665b

File tree

6 files changed

+169
-63
lines changed

6 files changed

+169
-63
lines changed

examples/other_mnist/beautiful_mnist_torch.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def forward(self, x):
2626
return self.lin(torch.flatten(x, 1))
2727

2828
if __name__ == "__main__":
29-
if getenv("TINY_BACKEND"):
29+
if getenv("TINY_BACKEND2"):
30+
import extra.torch_backend.backend2
31+
device = torch.device("cpu")
32+
elif getenv("TINY_BACKEND"):
3033
import extra.torch_backend.backend
3134
device = torch.device("tiny")
3235
else:

extra/torch_backend/backend.py

+74-48
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,35 @@
11
from tinygrad import Tensor, dtypes
2-
from tinygrad.helpers import DEBUG
2+
from tinygrad.helpers import DEBUG, getenv
33
import torch, pathlib
4+
torch.autograd.grad_mode.set_multithreading_enabled(False)
5+
6+
# https://pytorch.org/docs/stable/torch.compiler_ir.html
47

58
# TODO: don't replicate this in cpp
69
torch_to_tiny_dtype = {
710
torch.float32: dtypes.float32,
811
torch.float64: dtypes.float64,
12+
torch.uint8: dtypes.uint8,
13+
torch.int8: dtypes.int8,
914
torch.int32: dtypes.int32,
1015
torch.int64: dtypes.int64,
1116
torch.bool: dtypes.bool,
1217
}
1318

1419
import torch.utils.cpp_extension
1520
mod = torch.utils.cpp_extension.load(name="custom_device_extension", sources=[pathlib.Path(__file__).parent / "wrapped_tensor.cpp"])
16-
wrap, unwrap = mod.wrap, mod.unwrap
21+
def wrap(x:Tensor) -> torch.Tensor: return mod.wrap(x)
22+
def unwrap(x:torch.Tensor) -> Tensor:
23+
assert isinstance(x, torch.Tensor), f"x isn't {type(x)}"
24+
return mod.unwrap(x)
1725
class TinyBackend: pass
1826
torch.utils.rename_privateuse1_backend("tiny")
1927
torch._register_device_module("tiny", TinyBackend)
2028
torch.utils.generate_methods_for_privateuse1_backend()
2129

22-
@torch.library.impl("aten::view", "privateuseone")
23-
def view(x, sz): return mod.wrap(mod.unwrap(x).reshape(sz))
24-
25-
@torch.library.impl("aten::min", "privateuseone")
26-
def min(x): return mod.wrap(mod.unwrap(x).min())
27-
28-
@torch.library.impl("aten::max", "privateuseone")
29-
def max(x): return mod.wrap(mod.unwrap(x).max())
30-
3130
@torch.library.impl("aten::zero_", "privateuseone")
3231
def zero_(x):
33-
tt = mod.unwrap(x)
32+
tt = unwrap(x)
3433
tt.replace(tt.zeros_like())
3534

3635
@torch.library.impl("aten::fill_.Scalar", "privateuseone")
@@ -51,11 +50,14 @@ def as_strided(tensor, size, stride, storage_offset=None):
5150
if size == [] and storage_offset is not None:
5251
# TODO: is this right?
5352
return wrap(unwrap(tensor).flatten()[storage_offset:storage_offset+1].reshape(()))
54-
print(tensor.shape, size, stride, storage_offset)
53+
# broadcast
54+
if len(tensor.shape) == 0: return wrap(unwrap(tensor).reshape((1,)*len(size)).expand(size))
55+
print("******* NOTE: this as_strided is wrong ***********\n", tensor.shape, size, stride, storage_offset)
56+
return wrap(Tensor.zeros(*size))
5557
raise NotImplementedError("fix as_strided")
5658

5759
@torch.library.impl("aten::empty_strided", "privateuseone")
58-
def empty_strided(size, stride, dtype, layout, device, pin_memory):
60+
def empty_strided(size, stride, dtype, layout, device, pin_memory=False):
5961
if DEBUG >= 2: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
6062
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
6163
return wrap(ret)
@@ -68,49 +70,73 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
6870

6971
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
7072
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
71-
print(input, weight, bias)
72-
raise NotImplementedError
73+
#print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
74+
return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None,
75+
groups=groups, stride=stride, dilation=dilation, padding=padding))
76+
#raise NotImplementedError("need convolution")
7377

7478
@torch.library.impl("aten::_copy_from", "privateuseone")
7579
def _copy_from(src, dest):
7680
if str(src.device) == "tiny" and str(dest.device) == "tiny":
7781
unwrap(dest).replace(unwrap(src), allow_shape_mismatch=True)
7882
elif str(src.device) == "tiny" and str(dest.device) == "cpu":
79-
dest[:] = torch.from_numpy(unwrap(src).numpy())
83+
# TODO: is there a better way?
84+
dest.resize_(src.numel()).resize_(src.shape)
85+
dest.copy_(torch.from_numpy(unwrap(src).numpy()))
8086
elif str(src.device) == "cpu" and str(dest.device) == "tiny":
8187
unwrap(dest).assign(Tensor(src.numpy()))
8288
else:
8389
raise NotImplementedError(f"can't copy from {src.device} -> {dest.device}")
8490

85-
@torch.library.impl("aten::exp2.out", "privateuseone")
86-
def exp2_out(x, out): unwrap(out).replace(unwrap(x).exp2(), allow_shape_mismatch=True)
87-
88-
@torch.library.impl("aten::ceil.out", "privateuseone")
89-
def ceil_out(x, out): unwrap(out).replace(unwrap(x).ceil(), allow_shape_mismatch=True)
90-
91-
@torch.library.impl("aten::abs.out", "privateuseone")
92-
def abs_out(x, out): unwrap(out).replace(unwrap(x).abs(), allow_shape_mismatch=True)
93-
94-
@torch.library.impl("aten::bitwise_and.Tensor", "privateuseone")
95-
def bitwise_and_tensor(x, y): return wrap(unwrap(x) & unwrap(y))
96-
97-
@torch.library.impl("aten::add.Tensor", "privateuseone")
98-
def add_tensor(x, y): return wrap(unwrap(x) + unwrap(y))
99-
100-
@torch.library.impl("aten::mul.Tensor", "privateuseone")
101-
def mul_tensor(x, y): return wrap(unwrap(x) * unwrap(y))
102-
103-
@torch.library.impl("aten::div.Tensor", "privateuseone")
104-
def div_tensor(x, y): return wrap(unwrap(x) / unwrap(y))
105-
106-
@torch.library.impl("aten::eq.Tensor", "privateuseone")
107-
def eq_tensor(x, y): return wrap(unwrap(x).eq(unwrap(y)))
108-
109-
@torch.library.impl("aten::ne.Tensor", "privateuseone")
110-
def ne_tensor(x, y): return wrap(unwrap(x).ne(unwrap(y)))
111-
112-
@torch.library.impl("aten::ne.Scalar", "privateuseone")
113-
def ne_scalar(x, y): return wrap(unwrap(x).ne(y))
91+
@torch.library.impl("aten::cat.out", "privateuseone")
92+
def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for x in tensors], dim=dim), allow_shape_mismatch=True)
93+
94+
@torch.library.impl("aten::index.Tensor", "privateuseone")
95+
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
96+
97+
tiny_backend = {
98+
"aten.view": Tensor.reshape,
99+
"aten.add.Tensor": Tensor.add,
100+
"aten.sub.Tensor": Tensor.sub,
101+
"aten.mul.Tensor": Tensor.mul,
102+
"aten.div.Tensor": Tensor.div,
103+
"aten.add_.Tensor": lambda x,y: x.assign(x.add(y)),
104+
"aten.pow.Tensor_Scalar": Tensor.pow,
105+
"aten.bitwise_and.Tensor": Tensor.bitwise_and,
106+
"aten.eq.Tensor": Tensor.eq, "aten.eq.Scalar": Tensor.eq,
107+
"aten.ne.Tensor": Tensor.ne, "aten.ne.Scalar": Tensor.ne,
108+
"aten.gt.Tensor": Tensor.__gt__, "aten.gt.Scalar": Tensor.__gt__,
109+
"aten.lt.Tensor": Tensor.__lt__, "aten.lt.Scalar": Tensor.__lt__,
110+
"aten.exp2": Tensor.exp2,
111+
"aten.min": Tensor.min,
112+
"aten.max": Tensor.max,
113+
"aten.relu": Tensor.relu,
114+
"aten.mean": Tensor.mean,
115+
"aten.neg": Tensor.neg,
116+
"aten.mm": Tensor.matmul,
117+
}
114118

115-
@torch.library.impl("aten::gt.Scalar", "privateuseone")
116-
def gt_scalar(x, y): return wrap(unwrap(x) > y)
119+
# there's earlier things to hook here
120+
#"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True),
121+
#"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True),
122+
#"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True),
123+
#"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True),
124+
125+
def wrap_fxn(k,f):
126+
def nf(*args, **kwargs):
127+
#print(k, len(args), kwargs.keys())
128+
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
129+
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
130+
return wrap(f(*args, **kwargs))
131+
return nf
132+
133+
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
134+
135+
if getenv("TORCH_DEBUG"):
136+
from torch.utils._python_dispatch import TorchDispatchMode
137+
class DispatchLog(TorchDispatchMode):
138+
def __torch_dispatch__(self, func, types, args, kwargs=None):
139+
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
140+
print(f"Dispatch Log: {func}")
141+
return func(*args, **(kwargs or {}))
142+
DispatchLog().__enter__()

extra/torch_backend/backend2.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from tinygrad import Tensor, dtypes
2+
import torch, contextlib
3+
from torch.utils._python_dispatch import TorchDispatchMode
4+
5+
torch_to_tiny_dtype = {
6+
torch.float32: dtypes.float32,
7+
torch.float64: dtypes.float64,
8+
torch.int32: dtypes.int32,
9+
torch.int64: dtypes.int64,
10+
torch.bool: dtypes.bool,
11+
}
12+
13+
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
14+
return TTensor(Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype]))
15+
16+
tiny_backend = {
17+
"aten.empty.memory_format": empty_memory_format,
18+
"aten.view.default": lambda x,sz: TTensor(x.tiny.reshape(sz)),
19+
"aten.abs.default": lambda x: TTensor(x.tiny.abs()),
20+
"aten.eq.Tensor": lambda x,y: TTensor(x.tiny == y.tiny),
21+
"aten.bitwise_and.Tensor": lambda x,y: TTensor(x.tiny & y.tiny),
22+
"aten.ne.Scalar": lambda x,y: TTensor(x.tiny != y),
23+
"aten.mul.Tensor": lambda x,y: TTensor(x.tiny * y.tiny),
24+
"aten.masked_select.default": lambda x,y: TTensor(Tensor(x.tiny.numpy()[y.tiny.numpy()])),
25+
}
26+
27+
class TTensor(torch.Tensor):
28+
tiny: Tensor
29+
context = contextlib.nullcontext
30+
31+
@staticmethod
32+
def __new__(cls, tiny, *args, **kwargs):
33+
out = torch.Tensor._make_wrapper_subclass(cls, tiny.shape)
34+
torch._C._set_throw_on_mutable_data_ptr(out)
35+
out.tiny = tiny
36+
return out
37+
def __repr__(self): return super().__repr__(tensor_contents=f"{self.tiny}")
38+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
39+
print(f"Dispatch Log: {func}(*{[type(x) for x in args]}, **{kwargs.keys()})")
40+
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
41+
new_func = tiny_backend.get(str(func), None)
42+
if new_func is None: raise NotImplementedError(f"add support for {func}")
43+
return new_func(*args, **(kwargs or {}))
44+
45+
class Dispatcher(TorchDispatchMode): __torch_dispatch__ = TTensor.__torch_dispatch__
46+
Dispatcher().__enter__()
47+
48+
if __name__ == "__main__":
49+
a = torch.empty((4,), dtype=torch.int)

extra/torch_backend/test.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,48 @@
22
import unittest
33
import torch
44
import numpy as np
5-
import extra.torch_backend.backend # "tiny" backend is installed
5+
from tinygrad.helpers import getenv
6+
if getenv("TINY_BACKEND2"):
7+
import extra.torch_backend.backend2
8+
device = "cpu"
9+
else:
10+
import extra.torch_backend.backend
11+
device = "tiny"
612

713
class TestTorchBackend(unittest.TestCase):
814
def test_numpy_ones(self):
9-
a = torch.ones(4, device="tiny")
15+
a = torch.ones(4, device=device)
1016
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
1117

1218
def test_numpy_ones(self):
13-
a = torch.ones(4, dtype=torch.int32, device="tiny")
19+
a = torch.ones(4, dtype=torch.int32, device=device)
1420
assert a.dtype == torch.int32
1521
np.testing.assert_equal(a.cpu().numpy(), [1,1,1,1])
1622

1723
def test_plus(self):
18-
a = torch.ones(4, device="tiny")
19-
b = torch.ones(4, device="tiny")
24+
a = torch.ones(4, device=device)
25+
b = torch.ones(4, device=device)
2026
c = a+b
2127
np.testing.assert_equal(c.cpu().numpy(), [2,2,2,2])
2228

29+
def test_exp2(qself):
30+
a = torch.ones(4, device=device)
31+
b = a.exp2()
32+
print(b)
33+
2334
def test_eq(self):
24-
a = torch.ones(4, device="tiny")
25-
b = torch.ones(4, device="tiny")
35+
a = torch.ones(4, device=device)
36+
b = torch.ones(4, device=device)
2637
c = a == b
2738
print(c.cpu().numpy())
2839

2940
def test_isfinite(self):
30-
a = torch.ones(4, device="tiny")
31-
np.testing.assert_equal(torch.isfinite(a), [True, True, True, True])
41+
a = torch.ones(4, device=device)
42+
np.testing.assert_equal(torch.isfinite(a).cpu().numpy(), [True, True, True, True])
3243

3344
# TODO: why
3445
def test_str(self):
35-
a = torch.ones(4, device="tiny")
46+
a = torch.ones(4, device=device)
3647
print(str(a))
3748

3849
if __name__ == "__main__":

extra/torch_backend/wrapped_tensor.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::
1010
}
1111
}
1212

13+
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
14+
// NOTE: no idea what this is
15+
bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; }
16+
};
17+
18+
int register_hook() {
19+
at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface());
20+
return 0;
21+
}
22+
int temp_register_hook = register_hook();
23+
1324
// code from chatgpt
1425
struct GILSafeDeleter {
1526
void operator()(PyObject* ptr) const {
@@ -56,6 +67,8 @@ static caffe2::TypeMeta dtypeFromName(const std::string &dtype_name) {
5667
} else if (dtype_name == "int") { return caffe2::TypeMeta::Make<int32_t>();
5768
} else if (dtype_name == "long") { return caffe2::TypeMeta::Make<int64_t>();
5869
} else if (dtype_name == "bool") { return caffe2::TypeMeta::Make<bool>();
70+
} else if (dtype_name == "char") { return caffe2::TypeMeta::Make<char>();
71+
} else if (dtype_name == "unsigned char") { return caffe2::TypeMeta::Make<unsigned char>();
5972
}
6073
throw std::runtime_error("Unsupported dtype: " + dtype_name);
6174
}

test/test_ops.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from tinygrad.tensor import _to_np_dtype
99
from tinygrad.device import is_dtype_supported
1010

11+
if getenv("TINY_BACKEND"):
12+
import extra.torch_backend.backend # noqa: F401 # pylint: disable=unused-import
13+
torch.set_default_device("tiny")
14+
1115
if CI:
1216
warnings.filterwarnings("ignore", message="Non-empty compiler output encountered")
1317

@@ -46,8 +50,8 @@ def compare(s, tinygrad_output, torch_output, atol, rtol):
4650
if DEBUG >= 6:
4751
np.set_printoptions(linewidth=200, suppress=True)
4852
print(ret.numpy())
49-
print(out.detach().numpy())
50-
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
53+
print(out.detach().cpu().numpy())
54+
compare("forward pass", ret.numpy(), out.detach().cpu().numpy(), atol=atol, rtol=rtol)
5155

5256
torch_fbp, tinygrad_fbp = np.nan, np.nan
5357
if not forward_only and not FORWARD_ONLY:
@@ -65,7 +69,7 @@ def compare(s, tinygrad_output, torch_output, atol, rtol):
6569
tinygrad_fbp = time.monotonic() - st
6670

6771
for i, (t, tt_grad) in enumerate(zip(ts, tst_grads)):
68-
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol)
72+
compare(f"backward pass tensor {i}", tt_grad.numpy(), t.grad.detach().cpu().numpy(), atol=grad_atol, rtol=grad_rtol)
6973

7074
"""
7175
(ret+1).square().mean().backward()
@@ -90,7 +94,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
9094
for i in range(len(ts)):
9195
# NOTE: torch default int64 for python ints input
9296
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
93-
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
97+
tst = [Tensor(x.detach().cpu().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
9498
return ts, tst
9599

96100
class TestOps(unittest.TestCase):

0 commit comments

Comments
 (0)