@@ -98,26 +98,37 @@ def test_vs_compile(run, new_inputs, test_val=None):
98
98
np .testing .assert_raises (AssertionError , np .testing .assert_array_equal , val , changed_val )
99
99
return val
100
100
101
- def test_vs_onnx (new_inputs , test_val , onnx_file ):
101
+ def test_vs_onnx (new_inputs , test_val , onnx_file , ort = False ):
102
102
new_inputs_numpy = {k :v .numpy () for k ,v in new_inputs .items ()}
103
103
onnx_model = onnx .load (onnx_file )
104
104
105
- if getenv ("ORT" ):
105
+ timings = []
106
+ if ort :
106
107
# test with onnxruntime
107
108
import onnxruntime as ort
108
109
onnx_session = ort .InferenceSession (onnx_file )
109
- onnx_output = onnx_session .run ([onnx_model .graph .output [0 ].name ], {k :v .astype (np .float16 ) for k ,v in new_inputs_numpy .items ()})
110
+ for _ in range (1 if test_val is not None else 5 ):
111
+ st = time .perf_counter ()
112
+ onnx_output = onnx_session .run ([onnx_model .graph .output [0 ].name ], {k :v .astype (np .float16 ) for k ,v in new_inputs_numpy .items ()})
113
+ timings .append (time .perf_counter () - st )
110
114
new_torch_out = onnx_output [0 ]
111
- print ("got ort outputs" )
112
115
else :
113
116
# test with torch
114
- from test .models .test_onnx import run_onnx_torch
115
- # NOTE: we have to correct the order here
116
- new_torch_out = run_onnx_torch (onnx_model , {k .name :new_inputs_numpy [k .name ] for k in onnx_model .graph .input }).numpy ()
117
- print ("got torch outputs" )
118
-
119
- np .testing .assert_allclose (new_torch_out .reshape (test_val .shape ), test_val , atol = 1e-4 , rtol = 1e-2 )
120
- print ("test vs onnx passed" )
117
+ import torch
118
+ from onnx2torch import convert
119
+ inputs = {k .name :new_inputs_numpy [k .name ] for k in onnx_model .graph .input }
120
+ torch_model = convert (onnx_model ).float ()
121
+ with torch .no_grad ():
122
+ for _ in range (1 if test_val is not None else 5 ):
123
+ st = time .perf_counter ()
124
+ torch_out = torch_model (* [torch .tensor (x ) for x in inputs .values ()])
125
+ timings .append (time .perf_counter () - st )
126
+ new_torch_out = torch_out .numpy ()
127
+
128
+ if test_val is not None :
129
+ np .testing .assert_allclose (new_torch_out .reshape (test_val .shape ), test_val , atol = 1e-4 , rtol = 1e-2 )
130
+ print ("test vs onnx passed" )
131
+ return timings
121
132
122
133
if __name__ == "__main__" :
123
134
onnx_file = fetch (OPENPILOT_MODEL )
@@ -131,4 +142,12 @@ def test_vs_onnx(new_inputs, test_val, onnx_file):
131
142
sorted (zip (pickle_loaded .captured .expected_names , pickle_loaded .captured .expected_st_vars_dtype_device ))}
132
143
133
144
test_val = test_vs_compile (pickle_loaded , new_inputs , test_val )
134
- if not getenv ("FLOAT16" ): test_vs_onnx (new_inputs , test_val , onnx_file )
145
+ if getenv ("BENCHMARK" ):
146
+ for be in ["torch" , "ort" ]:
147
+ try :
148
+ timings = test_vs_onnx (new_inputs , None , onnx_file , be == "ort" )
149
+ print (f"timing { be } : { min (timings )* 1000 :.2f} ms" )
150
+ except Exception as e :
151
+ print (f"{ be } fail with { e } " )
152
+ if not getenv ("FLOAT16" ): test_vs_onnx (new_inputs , test_val , onnx_file , getenv ("ORT" ))
153
+
0 commit comments