Skip to content

Commit 97bc723

Browse files
authored
torch backend works for ResNet-18 (tinygrad#9200)
* torch backend progress, a few more functions * resnet works * pillow * tv
1 parent f92820d commit 97bc723

File tree

4 files changed

+101
-20
lines changed

4 files changed

+101
-20
lines changed

.github/workflows/test.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,17 @@ jobs:
153153
- name: Setup Environment
154154
uses: ./.github/actions/setup-tinygrad
155155
with:
156-
key: torch-backend
156+
key: torch-backend-pillow-torchvision
157157
deps: testing_minimal
158+
pydeps: "pillow torchvision"
158159
- name: Install ninja
159160
run: |
160161
sudo apt update || true
161162
sudo apt install -y --no-install-recommends ninja-build
162163
- name: Test one op
163164
run: PYTHONPATH=. FORWARD_ONLY=1 TINY_BACKEND=1 python3 test/test_ops.py TestOps.test_add
165+
- name: Test ResNet-18
166+
run: PYTHONPATH=. python3 extra/torch_backend/example.py
164167
- name: Test Ops with TINY_BACKEND (expect failure)
165168
run: PYTHONPATH=. TINY_BACKEND=1 pytest test/test_ops.py || true
166169
- name: Test beautiful_mnist in torch with TINY_BACKEND (expect failure)

extra/torch_backend/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data

extra/torch_backend/backend.py

+77-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from tinygrad import Tensor, dtypes
2-
from tinygrad.helpers import DEBUG, getenv
2+
from tinygrad.helpers import DEBUG, getenv, prod
3+
TORCH_DEBUG = getenv("TORCH_DEBUG")
34
import torch, pathlib
45
torch.autograd.grad_mode.set_multithreading_enabled(False)
56

@@ -46,31 +47,44 @@ def masked_select(self, mask):
4647
return wrap(Tensor(self.cpu().numpy()[mask.cpu().numpy()]))
4748

4849
@torch.library.impl("aten::as_strided", "privateuseone")
49-
def as_strided(tensor, size, stride, storage_offset=None):
50-
if size == [] and storage_offset is not None:
51-
# TODO: is this right?
52-
return wrap(unwrap(tensor).flatten()[storage_offset:storage_offset+1].reshape(()))
53-
# broadcast
54-
if len(tensor.shape) == 0: return wrap(unwrap(tensor).reshape((1,)*len(size)).expand(size))
55-
print("******* NOTE: this as_strided is wrong ***********\n", tensor.shape, size, stride, storage_offset)
56-
return wrap(Tensor.zeros(*size))
50+
def as_strided(tensor:torch.Tensor, size, stride, storage_offset=None):
51+
#return tensor.cpu().as_strided(size, stride).tiny()
52+
if TORCH_DEBUG >= 1: print("** NOTE: this as_strided is wrong", tensor.shape, size, stride, storage_offset)
53+
54+
if tuple(x for x in tensor.shape if x != 1) == tuple(x for x in size if x != 1):
55+
# this is squeeze/unsqueeze
56+
return tensor.reshape(size)
57+
58+
# TODO: how do i know this is permute?
59+
if tensor.shape == (1000, 512) and size == [512, 1000] and stride == [0, 1]:
60+
return wrap(unwrap(tensor).permute(1,0))
61+
62+
#print(tensor.cpu().numpy())
5763
raise NotImplementedError("fix as_strided")
5864

5965
@torch.library.impl("aten::empty_strided", "privateuseone")
60-
def empty_strided(size, stride, dtype, layout, device, pin_memory=False):
61-
if DEBUG >= 2: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
66+
def empty_strided(size, stride, dtype, layout=None, device=None, pin_memory=False):
67+
if TORCH_DEBUG: print(f"empty_strided {size=} {stride=} {dtype=} {layout=} {device=} {pin_memory=}")
6268
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
6369
return wrap(ret)
6470

6571
@torch.library.impl("aten::empty.memory_format", "privateuseone")
6672
def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=False, memory_format=None):
67-
if DEBUG >= 2: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
68-
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype])
73+
if TORCH_DEBUG: print(f"empty.memory_format {size=} {dtype=} {layout=} {device=} {pin_memory=} {memory_format=}")
74+
ret = Tensor.empty(*size, dtype=torch_to_tiny_dtype[dtype or torch.get_default_dtype()])
6975
return wrap(ret)
7076

77+
@torch.library.impl("aten::max_pool2d_with_indices", "privateuseone")
78+
def max_pool2d_with_indices(self:Tensor, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False):
79+
# TODO: support return_indices in tinygrad
80+
ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode)
81+
# TODO: this is wrong
82+
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
83+
7184
@torch.library.impl("aten::convolution_overrideable", "privateuseone")
7285
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
73-
#print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
86+
if TORCH_DEBUG >= 1:
87+
print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
7488
return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None,
7589
groups=groups, stride=stride, dilation=dilation, padding=padding))
7690
#raise NotImplementedError("need convolution")
@@ -94,45 +108,89 @@ def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for
94108
@torch.library.impl("aten::index.Tensor", "privateuseone")
95109
def index_tensor(x, y): return wrap(unwrap(x)[y[0].tolist()])
96110

111+
# register some decompositions
112+
from torch._decomp import get_decompositions
113+
aten = torch.ops.aten
114+
decomps = [
115+
aten.native_batch_norm,
116+
aten.addmm,
117+
# NOTE: many of these don't work or cause infinite loops
118+
#aten.var_mean,
119+
#aten.var,
120+
#aten.rsqrt,
121+
#aten.max_pool2d_with_indices,
122+
]
123+
for k,v in get_decompositions(decomps).items():
124+
key = str(k._schema).split("(")[0]
125+
if TORCH_DEBUG >= 2: print("register decomp for", k)
126+
torch.library.impl(key, "privateuseone")(v)
127+
97128
tiny_backend = {
98129
"aten.view": Tensor.reshape,
99130
"aten.add.Tensor": Tensor.add,
100131
"aten.sub.Tensor": Tensor.sub,
101132
"aten.mul.Tensor": Tensor.mul,
102133
"aten.div.Tensor": Tensor.div,
103-
"aten.add_.Tensor": lambda x,y: x.assign(x.add(y)),
134+
"aten.add_.Tensor": lambda x,y,alpha=1: x.assign(x.add(y)*alpha),
104135
"aten.pow.Tensor_Scalar": Tensor.pow,
105136
"aten.bitwise_and.Tensor": Tensor.bitwise_and,
106137
"aten.eq.Tensor": Tensor.eq, "aten.eq.Scalar": Tensor.eq,
107138
"aten.ne.Tensor": Tensor.ne, "aten.ne.Scalar": Tensor.ne,
108139
"aten.gt.Tensor": Tensor.__gt__, "aten.gt.Scalar": Tensor.__gt__,
109140
"aten.lt.Tensor": Tensor.__lt__, "aten.lt.Scalar": Tensor.__lt__,
141+
"aten.le.Tensor": Tensor.__le__, "aten.le.Scalar": Tensor.__le__,
142+
"aten.abs": Tensor.abs,
143+
"aten.exp": Tensor.exp,
110144
"aten.exp2": Tensor.exp2,
111145
"aten.min": Tensor.min,
112146
"aten.max": Tensor.max,
113147
"aten.relu": Tensor.relu,
148+
"aten.relu_": lambda x: x.assign(x.relu()),
114149
"aten.mean": Tensor.mean,
150+
"aten.mean.dim": Tensor.mean,
115151
"aten.neg": Tensor.neg,
152+
"aten.reciprocal": Tensor.reciprocal,
153+
"aten.sqrt": Tensor.sqrt,
154+
"aten.rsqrt": Tensor.rsqrt,
116155
"aten.mm": Tensor.matmul,
156+
"aten.var.correction": Tensor.var,
157+
# TODO: support var_mean in tinygrad
158+
"aten.var_mean.correction": lambda self, dims, keepdim=False, correction=1: (self.var(dims, keepdim, correction), self.mean(dims, keepdim)),
159+
# NOTE: axis=[] in torch means all, change tinygrad?
160+
"aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None:
161+
out.replace(Tensor.sum(self, axis if len(axis) else None, keepdim), allow_shape_mismatch=True),
162+
"aten.argmax": Tensor.argmax,
163+
"aten.scatter.value": Tensor.scatter,
164+
"aten.gather": Tensor.gather,
165+
"aten.where.self": Tensor.where,
166+
"aten._log_softmax": lambda self,dim,half_to_float: self.softmax(dim),
167+
"aten.random_": lambda self:
168+
self.assign(Tensor.randint(*self.shape, low=dtypes.min(self.dtype), high=dtypes.max(self.dtype), device=self.device, dtype=self.dtype)),
169+
"aten.uniform_": lambda self, low=0, high=1: self.assign(Tensor.uniform(*self.shape, low=low, high=high)),
170+
"aten.normal_": lambda self, low=0, high=1: self.assign(Tensor.normal(*self.shape, low=low, high=high)),
117171
}
118172

119-
# there's earlier things to hook here
173+
# NOTE: there's earlier things to hook these, so the .out form isn't needed
120174
#"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True),
121175
#"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True),
122176
#"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True),
123177
#"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True),
124178

125179
def wrap_fxn(k,f):
126180
def nf(*args, **kwargs):
127-
#print(k, len(args), kwargs.keys())
181+
if TORCH_DEBUG: print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args],
182+
{k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()})
128183
args = [unwrap(x) if isinstance(x, torch.Tensor) else x for x in args]
129184
kwargs = {k:unwrap(v) if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}
130-
return wrap(f(*args, **kwargs))
185+
out = f(*args, **kwargs)
186+
if isinstance(out, Tensor): return wrap(out)
187+
elif isinstance(out, tuple): return tuple(wrap(x) for x in out)
188+
else: raise RuntimeError(f"unknown output type {type(out)}")
131189
return nf
132190

133191
for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::"), "privateuseone")(wrap_fxn(k,v))
134192

135-
if getenv("TORCH_DEBUG"):
193+
if TORCH_DEBUG:
136194
from torch.utils._python_dispatch import TorchDispatchMode
137195
class DispatchLog(TorchDispatchMode):
138196
def __torch_dispatch__(self, func, types, args, kwargs=None):

extra/torch_backend/example.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from PIL import Image
2+
import torch, torchvision, pathlib
3+
import torchvision.transforms as transforms
4+
import extra.torch_backend.backend
5+
device = "tiny"
6+
torch.set_default_device(device)
7+
8+
if __name__ == "__main__":
9+
img = Image.open(pathlib.Path(__file__).parent.parent.parent / "test/models/efficientnet/Chicken.jpg").convert('RGB')
10+
transform = transforms.Compose([
11+
transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
12+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
13+
])
14+
img = transform(img).unsqueeze(0).to(device)
15+
16+
model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT).eval()
17+
out = model(img).detach().cpu().numpy()
18+
print("output:", out.shape, out.argmax())
19+
assert out.argmax() == 7 # cock

0 commit comments

Comments
 (0)