This repository has been archived by the owner on Jul 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmnist_tflite_model_test.py
64 lines (54 loc) · 2.11 KB
/
mnist_tflite_model_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
TF_LITE_MODEL = './mnist.tflite'
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from pprint import pprint
# load MNIST test dataset
(_, _), (x_test, y_test) = mnist.load_data()
print('test image shape:', x_test.shape)
print('test label shape:', y_test.shape)
print('')
# load TF Lite model and inspect input/output shape
print('Loading', TF_LITE_MODEL, '...\n')
interpreter = tf.lite.Interpreter(model_path=TF_LITE_MODEL)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print('input details:')
pprint(input_details)
print('')
print('output details:')
pprint(output_details)
print('')
# resize the input/output shape to fit the test dataset (so we can do batch prediction)
interpreter.resize_tensor_input(input_details[0]['index'], x_test.shape)
interpreter.resize_tensor_input(output_details[0]['index'], y_test.shape)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print('new input shape:', input_details[0]['shape'])
print('new output shape:', output_details[0]['shape'])
print('')
# make prediction
print('Predicting...')
interpreter.set_tensor(input_details[0]['index'], x_test)
interpreter.invoke()
predicted = interpreter.get_tensor(output_details[0]['index']).argmax(axis=1)
# inspect metrics
from sklearn.metrics import accuracy_score, mean_squared_error
print('Prediction accuracy:', accuracy_score(y_test, predicted).round(4))
print('Prediction MSE:', mean_squared_error(y_test, predicted).round(4))
print('')
# compare prediction to real labels
from sklearn.metrics import classification_report
print(classification_report(y_test, predicted))
# draw first 40 test digits and their predicted labels
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 3))
for i in range(40):
ax = fig.add_subplot(4, 10, i + 1)
ax.set_axis_off()
ax.set_title(f'{predicted[i]}')
plt.imshow(x_test[i], cmap='gray')
plt.tight_layout()
plt.savefig('./mnist-model-test.jpg')
plt.show()