-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmnistposter.cc
61 lines (48 loc) · 1.53 KB
/
mnistposter.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "ext/stb/stb_image_write.h"
#include <array>
#include <vector>
#include <fstream>
#include <iostream>
#include "mnistreader.hh"
#include <fenv.h>
#include "misc.hh"
using namespace std;
int main(int argc, char **argv)
{
int filt=-1;
if(argc == 2)
filt= 1 + argv[1][0] - 'a';
feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW );
MNISTReader mn("gzip/emnist-letters-train-images-idx3-ubyte.gz", "gzip/emnist-letters-train-labels-idx1-ubyte.gz");
//MNISTReader mn("gzip/emnist-letters-test-images-idx3-ubyte.gz", "gzip/emnist-letters-test-labels-idx1-ubyte.gz");
cout<<"Have "<<mn.num()<<" images"<<endl;
constexpr int imgrows=1200, imgcols=1900;
vector<uint8_t> out;
out.resize(imgcols*imgrows);
auto pix = [&out, &imgrows, &imgcols](int col, int row) -> uint8_t&
{
return out[col + row*imgcols];
};
int count=0;
Batcher batcher(mn.num());
for(;;) {
auto b = batcher.getBatch(1);
if(b.empty())
break;
int n=b[0];
if(filt >=0 && mn.getLabel(n) != filt)
continue;
Tensor img(28,28);
mn.pushImage(n, img);
int x = 30 * (count % (imgcols/30 - 1)); // this many per row
int y = 30 * (count / (imgcols/30 - 1));
count++;
if(x+30 >= imgcols || y+30 >= imgrows)
break;
for(unsigned int r=0; r < img.getRows(); ++r)
for(unsigned int c=0; c < img.getCols(); ++c)
pix(x+c, y+r) = 255 - img(r,c)*255;
}
stbi_write_png("poster.png", imgcols, imgrows, 1, &out[0], imgcols);
}