-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmisc.hh
108 lines (95 loc) · 2.04 KB
/
misc.hh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#pragma once
#include <iostream>
#include <random>
#include <algorithm>
#include <deque>
#include <vector>
#include <chrono>
#include <mutex>
#include <memory>
#include <atomic>
#include <optional>
struct HyperParameters
{
float lr;
float momentum;
int batchMult;
unsigned int getBatchSize()
{
return 8*batchMult;
}
};
struct TrainingProgress
{
int batchno=0;
float lastTook=0;
std::vector<float> losses;
std::vector<float> corrects;
std::atomic<unsigned int> trained=0;
};
extern struct TrainingProgress g_progress;
extern std::shared_ptr<HyperParameters> g_hyper;
int graphicsThread();
class Batcher
{
public:
explicit Batcher(int n, std::optional<std::mt19937> rng=std::optional<std::mt19937>())
{
for(int i=0; i < n ; ++i)
d_store.push_back(i);
randomize(rng);
}
explicit Batcher(const std::vector<int>& in)
{
for(const auto& i : in)
d_store.push_back(i);
randomize();
}
auto getBatch(int n)
{
std::deque<int> ret;
for(int i = 0 ; !d_store.empty() && i < n; ++i) {
ret.push_back(d_store.front());
d_store.pop_front();
}
return ret;
}
auto getBatchLocked(int n)
{
std::deque<int> ret;
std::lock_guard<std::mutex> l(d_mut);
for(int i = 0 ; !d_store.empty() && i < n; ++i) {
ret.push_back(d_store.front());
d_store.pop_front();
}
return ret;
}
private:
std::deque<int> d_store;
std::mutex d_mut;
void randomize(std::optional<std::mt19937> rnd = std::optional<std::mt19937>())
{
if(rnd) {
std::shuffle(d_store.begin(), d_store.end(), *rnd);
}
else {
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(d_store.begin(), d_store.end(), g);
}
}
};
struct DTime
{
void start()
{
d_start = std::chrono::steady_clock::now();
}
uint32_t lapUsec()
{
auto usec = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::steady_clock::now()- d_start).count();
start();
return usec;
}
std::chrono::time_point<std::chrono::steady_clock> d_start;
};