-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathexample_09.py
executable file
·61 lines (53 loc) · 1.73 KB
/
example_09.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
import numpy as np
import onnxruntime as rt
""" 1. GEMM
sess = rt.InferenceSession("test.onnx")
a = sess.get_inputs()[0].name
b = sess.get_inputs()[1].name
c = sess.get_inputs()[2].name
a_input = np.ones((3, 4)).astype(np.float32)
b_input = np.ones((4, 5)).astype(np.float32)
c_input = np.ones((3, 5)).astype(np.float32)
pred_onnx = sess.run(None, {a: a_input, b : b_input, c: c_input})[0]
flag01 = np.array_equal(pred_onnx,
np.asarray([
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.]]))
"""
""" 2. Split
sess = rt.InferenceSession("test.onnx")
a = sess.get_inputs()[0].name
a_input = np.ones(sess.get_inputs()[0].shape).astype(np.float32)
pred_onnx = sess.run(None, {a: a_input})
print(pred_onnx[0])
print(pred_onnx[1])
"""
""" 3. MaxPool
sess = rt.InferenceSession("test.onnx")
a = sess.get_inputs()[0].name
a_input = np.ones(sess.get_inputs()[0].shape).astype(np.float32)
pred_onnx = sess.run(None, {a: a_input})
print(pred_onnx[0])
print(pred_onnx[1])
"""
""" 4. Concat
sess = rt.InferenceSession("test.onnx")
a = sess.get_inputs()[0].name
b = sess.get_inputs()[1].name
a_input = np.ones(sess.get_inputs()[0].shape).astype(np.float32)
b_input = np.ones(sess.get_inputs()[1].shape).astype(np.float32)
pred_onnx = sess.run(None, {a: a_input, b: b_input})
print(pred_onnx[0])
"""
""" 5. Sum """
sess = rt.InferenceSession("test.onnx")
a = sess.get_inputs()[0].name
b = sess.get_inputs()[1].name
c = sess.get_inputs()[2].name
a_input = np.ones(sess.get_inputs()[0].shape).astype(np.float32)
b_input = np.ones(sess.get_inputs()[1].shape).astype(np.float32)
c_input = np.ones(sess.get_inputs()[2].shape).astype(np.float32)
pred_onnx = sess.run(None, {a: a_input, b: b_input, c: c_input})
print(pred_onnx[0])