-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdigit_guesser.py
148 lines (123 loc) · 4.69 KB
/
digit_guesser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
Written by Jordan Otsuji
digit_guesser.py prompts the user to draw a digit on the canvas, which the model will attempt to classify
"""
import pygame
import tensorflow as tf
import numpy as np
WHITE = 0xFFFFFF
BLACK = 0x000000
WINDOW_WIDTH = WINDOW_HEIGHT = 600
MODEL_PATH = "digit_recognition_128_128_10.model"
model = tf.keras.models.load_model(MODEL_PATH)
class Tile:
def __init__(self, x, y, width, height):
self.x = x
self.y = y
self.width = width
self.height = height
self.color = WHITE;
def draw(self, canvas):
"""
Draw this individual tile as a rectangle using its x, y, width, and height to determine rect bounds
"""
pygame.draw.rect(canvas, self.color, (self.x, self.y, self.x + self.width, self.y + self.height))
class Canvas:
def __init__(self, rows, columns, width, height):
self.rows = rows
self.columns = columns
self.len = rows * columns
self.width = width
self.height = height
self.tiles = []
self.initTiles()
def draw(self, canvas):
"""
Call the draw function of every tile within the canvas
"""
for tileRow in self.tiles:
for tile in tileRow:
tile.draw(canvas)
def getTile(self, clickPosition):
"""
Returns the tile that was clicked on by the user based on the click's location on the canvas
"""
try:
x = clickPosition[0]
y = clickPosition[1]
# integer division by tile width and height to get clicked tile's row and col
col = int(x) // self.tiles[0][0].width
row = int(y) // self.tiles[0][0].height
return self.tiles[row][col]
except:
pass
def initTiles(self):
"""
Initializes the tiles in this canvas and assigns them evenly spaced x and y coordinates
"""
tile_width = self.width // self.columns
tile_height = self.height // self.rows
for row in range(self.rows):
self.tiles.append([])
for column in range(self.columns):
# initialize each tile with x and y values as tile width and height * row and column for automatic spacing
self.tiles[row].append(Tile(tile_width * column, tile_height * row, tile_width, tile_height))
def clear(self):
"""
Clears the current canvas
"""
for i in range(self.rows):
for j in range(self.columns):
self.tiles[i][j].color = WHITE
def convert_to_feature(self):
"""
Converts the current canvas data to an array for the model to use for prediction
"""
current_tiles = self.tiles
feature = [[] for i in range(len(current_tiles))]
# Build feature matrix one tile at a time based on color
for i in range(len(current_tiles)):
for j in range(len(current_tiles[i])):
if(current_tiles[i][j].color == WHITE):
feature[i].append(0)
else:
feature[i].append(1)
# TF requires another surrounding [] for correct dimensions
# results in (, 28, 28) array
tf_compatable_feature = []
tf_compatable_feature.append(feature)
return tf_compatable_feature
def main():
"""
Main loop detecting and responding to events
"""
while True:
for event in pygame.event.get():
if(event.type == pygame.KEYDOWN):
if(event.key == pygame.K_RETURN):
# if enter is pressed, print the model's guess
probabilities = model.predict(canvas.convert_to_feature())
print(f'Predicted Probabilities: \n\t{probabilities}')
prediction = np.argmax(probabilities[0])
print(f"Model Prediction: {prediction}")
elif(event.key == pygame.K_c or event.key == pygame.K_r):
canvas.clear()
if(pygame.mouse.get_pressed()[0]):
# if left click, color the tile black
pos = pygame.mouse.get_pos()
currentTile = canvas.getTile(pos)
if(currentTile):
currentTile.color = BLACK;
if(event.type == pygame.QUIT):
pygame.quit()
quit(0)
# redraw the canvas and update
canvas.draw(window)
pygame.display.update()
canvas = Canvas(28, 28, WINDOW_WIDTH, WINDOW_HEIGHT)
# Pygame Window creation
pygame.init()
window = pygame.display.set_mode((WINDOW_WIDTH, WINDOW_HEIGHT))
pygame.display.set_caption("Handwritten Digit Classification AI")
# main loop
main()