-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy path37learn.cc
101 lines (80 loc) · 2.56 KB
/
37learn.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include "mnistreader.hh"
#include "vizi.hh"
#include <iostream>
#include "ext/sqlitewriter/sqlwriter.hh"
#include <unistd.h>
using namespace std;
float doTest(const MNISTReader& mntest, const Tensor<float>& weights, float bias, SQLiteWriter* sqw=0)
{
unsigned int corrects=0, wrongs=0;
for(unsigned int n = 0 ; n < mntest.num(); ++n) {
int label = mntest.getLabel(n);
if(label != 3 && label != 7)
continue;
Tensor img(28,28);
mntest.pushImage(n, img);
float score = (img.dot(weights).sum()(0,0)) + bias; // the calculation
int predict = score > 0 ? 7 : 3; // the verdict
if(sqw)
sqw->addValue({{"label", label}, {"res", score}, {"verdict", predict}});
if(predict == label) {
corrects++;
}
else {
wrongs++;
}
}
float perc = 100.0*corrects/(corrects+wrongs);
cout << perc << "% correct" << endl;
return perc;
}
int main()
{
MNISTReader mn("gzip/emnist-digits-train-images-idx3-ubyte.gz", "gzip/emnist-digits-train-labels-idx1-ubyte.gz");
MNISTReader mntest("gzip/emnist-digits-test-images-idx3-ubyte.gz", "gzip/emnist-digits-test-labels-idx1-ubyte.gz");
cout << "Have "<<mn.num() << " training images and " << mntest.num() << " validation images." <<endl;
Tensor weights(28,28);
weights.randomize(1.0/sqrt(28*28));
saveTensor(weights, "random-weights.png", 252);
float bias=0;
unlink("37learn.sqlite3");
SQLiteWriter sqw("37learn.sqlite3");
int count=0;
Tensor lr(28,28);
lr.identity(0.01);
for(unsigned int n = 0 ; n < mn.num(); ++n) {
int label = mn.getLabel(n);
if(label != 3 && label != 7)
continue;
if(!(count % 4)) {
if(doTest(mntest, weights, bias) > 98.0)
break;
saveTensor(weights, "weights-"+to_string(count)+".png", 252);
}
Tensor img(28,28);
mn.pushImage(n, img);
float res = (img.dot(weights).sum()(0,0)) + bias; // the calculation
if(count == 25001) {
auto prod = img.dot(weights);
saveTensor(img, "random-image.png", 252, true);
saveTensor(prod, "random-prod.png", 252);
cout<<"res for first image: " << res << '\n';
}
int verdict = res > 0 ? 7 : 3;
if(label == 7) {
if(res < 2.0) {
weights.raw() = weights.raw() + img.raw() * lr.raw();
bias += 0.01;
}
} else {
if(res > -2.0) {
weights.raw() = weights.raw() - img.raw() * lr.raw();
bias -= 0.01;
}
}
++count;
}
saveTensor(weights, "weights-final.png", 252);
doTest(mntest, weights, bias, &sqw);
cout<<"Bias: "<<bias<<endl;
}