-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain-adders.js
71 lines (60 loc) · 1.86 KB
/
train-adders.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
var brain = require("brain.js");
var NN = brain.NeuralNetwork;
var fs = require("fs");
var args = require("yargs").argv;
// get arguments from console.
var n_bits = args.bits || args.b;
var error = args.error || args.e;
var learningRate = args.learningRate || args.l;
var multiplier = args.multiplier || args.m;
var logPeriod = args.logPeriod || args.lp;
// if we have n_bits argument then make the path for the data file
if (n_bits) {
var dataPath = "./Data/" + n_bits + "bit_add_in_out.json";
}
// get data from the path or fallback to 4bit data
var data = require(dataPath || "./Data/4bit_add_in_out.json");
// make the network with n_bits or fallback to 16 for 4 bits
var network = new NN({
hiddenLayers: [(multiplier || 4) * (parseInt(n_bits) || 16)]
});
// preprocessing the data
var Data = [];
for (var i = 0; i < data.inputs.length; i++) {
Data.push({ input: data.inputs[i], output: data.outputs[i] });
}
console.log("data preprocessing done");
// trianing options
var trainingOptions = {
errorThresh: error || 0.01, // error threshold to reach
iterations: 100000000, // maximum training iterations
log: true, // log progress periodically
logPeriod: logPeriod || 100, // number of iterations between logging
learningRate: learningRate || 0.01, // learning rate
activation: x => (x >= 0.5 ? 1 : 0)
};
console.log(trainingOptions);
// train the network
network.train(Data, trainingOptions);
// convert the network into a json obj and save it.
var json = network.toJSON();
// add metadata about the network into the compute
json.networkData = {
n_bits: n_bits,
error: error,
learningRate: learningRate,
trainingOptions: trainingOptions
};
// write the JSON compute into a new file
fs.writeFile(
"./precomputed-net-" +
new Date().toISOString() +
"-" +
n_bits +
"Bit" +
".json",
JSON.stringify(json),
err => {
console.log(err);
}
);