Skip to content

Commit 8b16c65

Browse files
authored
add compile3 benchmark [pr] (tinygrad#8929)
1 parent 79fb5c6 commit 8b16c65

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

examples/openpilot/compile3.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -98,26 +98,37 @@ def test_vs_compile(run, new_inputs, test_val=None):
9898
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
9999
return val
100100

101-
def test_vs_onnx(new_inputs, test_val, onnx_file):
101+
def test_vs_onnx(new_inputs, test_val, onnx_file, ort=False):
102102
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
103103
onnx_model = onnx.load(onnx_file)
104104

105-
if getenv("ORT"):
105+
timings = []
106+
if ort:
106107
# test with onnxruntime
107108
import onnxruntime as ort
108109
onnx_session = ort.InferenceSession(onnx_file)
109-
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
110+
for _ in range(1 if test_val is not None else 5):
111+
st = time.perf_counter()
112+
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
113+
timings.append(time.perf_counter() - st)
110114
new_torch_out = onnx_output[0]
111-
print("got ort outputs")
112115
else:
113116
# test with torch
114-
from test.models.test_onnx import run_onnx_torch
115-
# NOTE: we have to correct the order here
116-
new_torch_out = run_onnx_torch(onnx_model, {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}).numpy()
117-
print("got torch outputs")
118-
119-
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
120-
print("test vs onnx passed")
117+
import torch
118+
from onnx2torch import convert
119+
inputs = {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}
120+
torch_model = convert(onnx_model).float()
121+
with torch.no_grad():
122+
for _ in range(1 if test_val is not None else 5):
123+
st = time.perf_counter()
124+
torch_out = torch_model(*[torch.tensor(x) for x in inputs.values()])
125+
timings.append(time.perf_counter() - st)
126+
new_torch_out = torch_out.numpy()
127+
128+
if test_val is not None:
129+
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
130+
print("test vs onnx passed")
131+
return timings
121132

122133
if __name__ == "__main__":
123134
onnx_file = fetch(OPENPILOT_MODEL)
@@ -131,4 +142,12 @@ def test_vs_onnx(new_inputs, test_val, onnx_file):
131142
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}
132143

133144
test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
134-
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file)
145+
if getenv("BENCHMARK"):
146+
for be in ["torch", "ort"]:
147+
try:
148+
timings = test_vs_onnx(new_inputs, None, onnx_file, be=="ort")
149+
print(f"timing {be}: {min(timings)*1000:.2f} ms")
150+
except Exception as e:
151+
print(f"{be} fail with {e}")
152+
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file, getenv("ORT"))
153+

0 commit comments

Comments
 (0)