-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtree_test.lua
83 lines (59 loc) · 2.43 KB
/
tree_test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
-- extend package.path with path of this .lua file:
local filepath = debug.getinfo(1).source:match("@(.*)$")
local dir = string.gsub(filepath, '/[^/]+$', '') .. "/"
package.path = dir .. "/?.lua;" .. package.path
local math = require("math")
local ffi = require("ffi")
local bitop = require("bit")
require("luarocks.loader")
local array = require("ljarray.array")
local helpers = require("ljarray.helpers")
local tree = require("tree")
local forest = require("forest")
local X_train = array.create({100,2}, array.float32)
local y_train = array.create({X_train.shape[0]},array.int32)
X_train:bind(0,0,50):assign(1)
X_train:bind(0,50,X_train.shape[0]):assign(100)
y_train:bind(0,0,50):assign(1)
y_train:bind(0,50,X_train.shape[0]):assign(2)
local t = tree.create({f_subsample = 1.0, n_classes = 2})
t:learn(X_train,y_train)
local X_test = X_train
local prediction = t:predict(X_test)
for i = 0, X_test.shape[0]-1 do
assert(prediction:get(i) == y_train:get(i))
end
print("FINISHED TEST1")
local X_train = array.rand({100,100}, array.float32)
X_train:add(17)
local y_train = array.randint(0,2,{X_train.shape[0]},array.int32):add(1)
print("LEARNING TEST2")
local t = tree.create({n_classes = 2, m_try = X_train.shape[1]})
t:learn(X_train,y_train)
local X_test = X_train
local prediction = t:predict(X_test)
local correct = 0
for i = 0, X_test.shape[0]-1 do
if prediction:get(i) == y_train:get(i) then
correct = correct + 1
end
-- assert(prediction:get(i) == y_train:get(i), "prediction failed for index " .. i .. " prediction = " .. prediction:get(i) .. ", GT = "..y_train:get(i))
end
print("TOTAL CORRECCT COUNT", correct, correct / X_test.shape[0])
print("BEGIN BENCHMARKING")
math.randomseed(os.time())
local X_train = array.rand({1000000,10}, array.float32)
local y_train = array.randint(0,2,{X_train.shape[0]}):add(1)
local t = forest.create({n_trees = 3, n_classes = 2, m_try = 3})
helpers.benchmark(function() t:learn(X_train,y_train) end, 1, "training RF")
local X_test = X_train
helpers.benchmark(function() t:predict(X_test) end, 1, "predicting RF")
local prediction = t:predict(X_test)
local correct = 0
for i = 0, X_test.shape[0]-1 do
if prediction:get(i) == y_train:get(i) then
correct = correct + 1
end
-- assert(prediction:get(i) == y_train:get(i), "prediction failed for index " .. i .. " prediction = " .. prediction:get(i) .. ", GT = "..y_train:get(i))
end
print("TOTAL CORRECCT COUNT", correct, correct / X_test.shape[0])