Skip to content

Commit dc5e758

Browse files
committed
single input digit prediction using neural network
1 parent ddc8171 commit dc5e758

File tree

3 files changed

+118
-8
lines changed

3 files changed

+118
-8
lines changed

Unit03/Neural_Networks.ipynb

+98-7
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@
265265
"cell_type": "markdown",
266266
"metadata": {},
267267
"source": [
268-
"### **3) Backpropagation**"
268+
"## **3) Backpropagation**"
269269
]
270270
},
271271
{
@@ -527,7 +527,7 @@
527527
"cell_type": "markdown",
528528
"metadata": {},
529529
"source": [
530-
"### 4) Improvements or Optimization (Gradient Descent)"
530+
"## **4) Improvements or Optimization (Gradient Descent)**"
531531
]
532532
},
533533
{
@@ -571,12 +571,14 @@
571571
"source": [
572572
"<hr>\n",
573573
"\n",
574-
"## Lets build a Neural Network"
574+
"## Lets build a Neural Network\n",
575+
"\n",
576+
"A simple neural network with one hidden layer."
575577
]
576578
},
577579
{
578580
"cell_type": "code",
579-
"execution_count": 1,
581+
"execution_count": 2,
580582
"metadata": {},
581583
"outputs": [],
582584
"source": [
@@ -746,7 +748,7 @@
746748
},
747749
{
748750
"cell_type": "code",
749-
"execution_count": 2,
751+
"execution_count": 3,
750752
"metadata": {},
751753
"outputs": [],
752754
"source": [
@@ -833,14 +835,14 @@
833835
},
834836
{
835837
"cell_type": "code",
836-
"execution_count": 14,
838+
"execution_count": 4,
837839
"metadata": {},
838840
"outputs": [
839841
{
840842
"name": "stdout",
841843
"output_type": "stream",
842844
"text": [
843-
"Accuracy: 94.17%\n"
845+
"Accuracy: 93.06%\n"
844846
]
845847
},
846848
{
@@ -909,6 +911,95 @@
909911
"print(f\"Accuracy: {accuracy:.2%}\")"
910912
]
911913
},
914+
{
915+
"cell_type": "markdown",
916+
"metadata": {},
917+
"source": [
918+
"### Predict Single Digit"
919+
]
920+
},
921+
{
922+
"cell_type": "code",
923+
"execution_count": 71,
924+
"metadata": {},
925+
"outputs": [
926+
{
927+
"name": "stdout",
928+
"output_type": "stream",
929+
"text": [
930+
"Input Single Digit 8x8 image vector : \n",
931+
"+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+\n",
932+
"| Pix1 | Pix2 | Pix3 | Pix4 | Pix5 | Pix6 | Pix7 | Pix8 |\n",
933+
"+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+\n",
934+
"| 0.00 | 0.00 | 0.12 | 0.81 | 0.50 | 0.00 | 0.00 | 0.00 |\n",
935+
"| 0.00 | 0.00 | 0.38 | 1.00 | 1.00 | 0.38 | 0.00 | 0.00 |\n",
936+
"| 0.00 | 0.00 | 0.31 | 0.94 | 0.81 | 0.69 | 0.00 | 0.00 |\n",
937+
"| 0.00 | 0.00 | 0.00 | 0.44 | 1.00 | 0.94 | 0.00 | 0.00 |\n",
938+
"| 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.88 | 0.21 | 0.00 |\n",
939+
"| 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.44 | 0.69 | 0.00 |\n",
940+
"| 0.00 | 0.00 | 0.00 | 0.19 | 0.25 | 0.25 | 1.00 | 0.15 |\n",
941+
"| 0.00 | 0.00 | 0.12 | 0.94 | 0.81 | 0.88 | 0.81 | 0.12 |\n",
942+
"+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+::::::::::::+\n",
943+
"Predicted Value : 9\n",
944+
"Actual Digit Image with Label :\n"
945+
]
946+
},
947+
{
948+
"data": {
949+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAFsAAABbCAYAAAAcNvmZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAADiUlEQVR4nO2dMUszWRiFT8yKBpkESxVRNGA6xU4RsRArwS6ghajY2cTCwi5g5x/wB8RWsB1sFARBERRbC2UwlpqZgIVOZqtvYWE/713m7tnd+c4D0yQnycvDy8zcuTO5uSRJEggKXf92Ab8Skk1EsolINhHJJiLZRCSbyG82oU6ng2azCc/zkMvl/uma/nckSYIoijA4OIiurm/6N7EgCIIEgDbDFgTBtx6tOtvzPABAEAQoFos2H/kpz8/Pxsza2poxs7y8bMzs7+9b1ZSWMAwxPDz8h6efYSX7x66jWCymlm0qCADy+bwx09PTY8ykrfXvYtrF6gBJRLKJSDYRySYi2UQkm4jVqZ9Lzs/PjZn7+3snmdHRUWNmY2PDmHGFOpuIZBORbCKSTUSyiUg2EckmItlE6IOa/v5+Y6ZUKhkzNgOW09NTY0aDmowi2UQkm4hkE5FsIpJNRLKJSDYR+qBmZWXFmKnX68bM7u6uMfP09GRREQ91NhHJJiLZRCSbiGQTkWwikk1EsonkbP4oIAxDlEoltFot+t38aXA1mzM1NfXt+7Z+1NlEJJuIZBORbCKSTUSyiUg2EckmQp+pYVKr1YwZm1khm4GPDepsIpJNRLKJSDYRySYi2UQkm4hkE/lPDmru7u6Mmff3d2PGNMMCuLmNLYoi43cA6mwqkk1Esomklh1FEWq1GkZGRlAoFDA7O4ubmxsXtWWO1LK3t7dxdnaGRqOBh4cHLC0tYXFxES8vLy7qyxSpZH98fODk5ASHh4eYn59HuVxGvV5HuVzG0dGRqxozQyrZX19fiOMYvb29f3q9UCjg8vIyVWFZJJVsz/MwMzODg4MDNJtNxHGM4+NjXF1d4fX11VWNmSH1oKbRaGBrawtDQ0PI5/OYnp7G6uoqbm9v/zJv85zLwsKCMWNza5nNwGdyctKYcUXqA+T4+DguLi7QbrcRBAGur6/x+fmJsbExF/VlCmfn2X19fRgYGMDb2xt837d6KuxXI/VuxPd9JEmCiYkJPD4+Ym9vD5VKBZubmy7qyxSpO7vVamFnZweVSgXr6+uYm5uD7/vo7u52UV+mSN3Z1WoV1WrVRS2ZR9dGiEg2EavdyI8nQcIwTP2DNhfabZaojOPYmOl0Ok6+x1Rzu90GYFG31qnhrVNj9QCTVmD6nsRyBSYr2cINOkASkWwikk1EsolINhHJJiLZRH4HEPMVOAcglpkAAAAASUVORK5CYII=",
950+
"text/plain": [
951+
"<Figure size 600x600 with 1 Axes>"
952+
]
953+
},
954+
"metadata": {},
955+
"output_type": "display_data"
956+
}
957+
],
958+
"source": [
959+
"import numpy as np\n",
960+
"from dataclasses import dataclass\n",
961+
"import pandas as pd\n",
962+
"from tabulate import tabulate\n",
963+
"from prettytable import PrettyTable\n",
964+
"\n",
965+
"def show_input_vector(vector):\n",
966+
" reshaped_X = vector.reshape(8, 8)\n",
967+
" df = pd.DataFrame(reshaped_X)\n",
968+
" fixed_width_df = df.map(lambda x: f\"{x:5.2f}\")\n",
969+
"\n",
970+
" t = PrettyTable(\n",
971+
" ['Pix1', 'Pix2', 'Pix3', 'Pix4', 'Pix5', 'Pix6', 'Pix7', 'Pix8'], \n",
972+
" align='c', \n",
973+
" horizontal_char=':',\n",
974+
" max_width=10, \n",
975+
" min_width=10\n",
976+
" ) \n",
977+
" t.add_rows(fixed_width_df.values.tolist())\n",
978+
" print(t)\n",
979+
"\n",
980+
"@dataclass\n",
981+
"class SingleDigit:\n",
982+
" images: np.array\n",
983+
" target: np.array\n",
984+
"\n",
985+
"test_idx = 31\n",
986+
"X_single_input = X[test_idx]\n",
987+
"y_single_label = y[test_idx]\n",
988+
"singleDigit = SingleDigit(\n",
989+
" # Shape of digits.images is (samples_n, 8, 8) where 8x8 are pixels\n",
990+
" images = np.array([digits.images[test_idx]]),\n",
991+
" target = np.array([digits.target[test_idx]])\n",
992+
")\n",
993+
" \n",
994+
"predict = np.argmax(nn.forward(X_single_input), axis=1)\n",
995+
"\n",
996+
"print(f\"Input Single Digit 8x8 image vector : \")\n",
997+
"show_input_vector(X_single_input)\n",
998+
"print(f\"Predicted Value : \", predict[0])\n",
999+
"print(f\"Actual Digit Image with Label :\")\n",
1000+
"show_digits(singleDigit, 1)"
1001+
]
1002+
},
9121003
{
9131004
"cell_type": "markdown",
9141005
"metadata": {},

poetry.lock

+18-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ plotly = "^5.23.0"
3232
plotly-express = "^0.4.1"
3333
spacy = "^3.7.6"
3434
nltk = "^3.9.1"
35+
tabulate = "^0.9.0"
36+
prettytable = "^3.11.0"
3537

3638
[tool.poetry.group.dev.dependencies]
3739
notebook = "^7.2.1"

0 commit comments

Comments
 (0)