Skip to content

Commit 220278c

Browse files
authored
Merge pull request #24 from JuliaTeachingCTU/macha/refactor-lectures-6-7
Macha/fixes-neural-networks
2 parents 9392514 + c1b102b commit 220278c

12 files changed

+174
-172
lines changed

docs/Project.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
55
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
6-
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
76
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
87
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
98
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
@@ -13,9 +12,11 @@ GLPK = "60bf3e95-4087-53dc-ae20-288a0d20c6a6"
1312
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
1413
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
1514
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
15+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1616
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
19+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1920
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2021
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
2122
Query = "1a8c2f83-1ff3-5112-b086-8aa67b057ba1"
@@ -29,8 +30,7 @@ StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
2930
BSON = "0.3"
3031
BenchmarkTools = "1.5"
3132
CSV = "0.10"
32-
DataFrames = "1.6"
33-
DifferentialEquations = "7.14"
33+
DataFrames = "1.7"
3434
Distributions = "0.25"
3535
Documenter = "1.7"
3636
Flux = "0.14"
@@ -39,8 +39,10 @@ GLPK = "1.2"
3939
GR = "0.73"
4040
HypothesisTests = "0.11"
4141
Ipopt = "1.6"
42+
JLD2 = "0.5"
4243
JuMP = "1.23"
4344
MLDatasets = "0.7"
45+
MLUtils = "0.4"
4446
Plots = "1.40"
4547
ProgressMeter = "1.10"
4648
Query = "1.0"

docs/src/lecture_11/data/mnist.bson

-20.2 KB
Binary file not shown.

docs/src/lecture_11/data/mnist.jl

+27-81
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,54 @@
1-
using BSON
21
using Flux
3-
using Flux: onehotbatch, onecold
2+
using Flux: onecold
43
using MLDatasets
54

6-
Core.eval(Main, :(using Flux)) # hide
7-
8-
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
9-
s = size(X)
10-
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
11-
end
12-
13-
function train_or_load!(file_name, m, X, y; force=false, kwargs...)
14-
15-
!isdir(dirname(file_name)) && mkpath(dirname(file_name))
16-
17-
if force || !isfile(file_name)
18-
train_model!(m, X, y; file_name=file_name, kwargs...)
19-
else
20-
m_loaded = BSON.load(file_name)[:m]
21-
Flux.loadparams!(m, params(m_loaded))
22-
end
23-
end
24-
25-
function load_data(dataset; onehot=false, T=Float32)
26-
classes = 0:9
27-
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
28-
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
29-
y_train = T.(y_train)
30-
y_test = T.(y_test)
31-
32-
if onehot
33-
y_train = onehotbatch(y_train[:], classes)
34-
y_test = onehotbatch(y_test[:], classes)
35-
end
36-
37-
return X_train, y_train, X_test, y_test
38-
end
39-
40-
using Plots
41-
42-
plot_image(x::AbstractArray{T, 2}) where T = plot(Gray.(x'), axis=nothing)
43-
44-
function plot_image(x::AbstractArray{T, 4}) where T
45-
@assert size(x,4) == 1
46-
plot_image(x[:,:,:,1])
47-
end
48-
49-
function plot_image(x::AbstractArray{T, 3}) where T
50-
@assert size(x,3) == 1
51-
plot_image(x[:,:,1])
52-
end
53-
5+
include(joinpath(dirname(@__FILE__), "utilities.jl"))
546

557
T = Float32
568
dataset = MLDatasets.MNIST
579

5810
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)
5911

12+
model = Chain(
13+
Conv((2, 2), 1 => 16, sigmoid),
14+
MaxPool((2, 2)),
15+
Conv((2, 2), 16 => 8, sigmoid),
16+
MaxPool((2, 2)),
17+
Flux.flatten,
18+
Dense(288, size(y_train, 1)),
19+
softmax,
20+
)
6021

61-
62-
63-
64-
m = Chain(
65-
Conv((2,2), 1=>16, sigmoid),
66-
MaxPool((2,2)),
67-
Conv((2,2), 16=>8, sigmoid),
68-
MaxPool((2,2)),
69-
flatten,
70-
Dense(288, size(y_train,1)), softmax)
71-
72-
file_name = joinpath("data", "mnist_sigmoid.bson")
73-
train_or_load!(file_name, m, X_train, y_train)
74-
75-
76-
22+
file_name = joinpath("data", "mnist_sigmoid.jld2")
23+
train_or_load!(file_name, model, X_train, y_train)
7724

7825
ii1 = findall(onecold(y_train, 0:9) .== 1)[1:5]
7926
ii2 = findall(onecold(y_train, 0:9) .== 9)[1:5]
8027

81-
8228
for qwe = 0:9
8329
ii0 = findall(onecold(y_train, 0:9) .== qwe)[1:5]
8430

85-
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii0]
86-
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0]
87-
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0]
31+
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii0]
32+
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0]
33+
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0]
8834

89-
p = plot(p0..., p1..., p2...; layout=(3,5))
35+
p = plot(p0..., p1..., p2...; layout=(3, 5))
9036
display(p)
9137
end
9238

93-
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii1]
94-
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1]
95-
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1]
39+
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii1]
40+
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1]
41+
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1]
9642

97-
plot(p0..., p1..., p2...; layout=(3,5))
43+
plot(p0..., p1..., p2...; layout=(3, 5))
9844

9945

100-
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii2]
101-
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2]
102-
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2]
46+
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii2]
47+
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2]
48+
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2]
10349

104-
plot(p0..., p1..., p2...; layout=(3,5))
50+
plot(p0..., p1..., p2...; layout=(3, 5))
10551

106-
for i in 1:length(m)
107-
println(size(m[1:i](X_train[:,:,:,1:1])))
52+
for i in 1:length(model)
53+
println(size(model[1:i](X_train[:, :, :, 1:1])))
10854
end

docs/src/lecture_11/data/mnist.jld2

25.8 KB
Binary file not shown.

docs/src/lecture_11/data/mnist_gpu.jl

+13-59
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,23 @@
11
using MLDatasets
22
using Flux
3-
using BSON
4-
using Random
5-
using Statistics
6-
using Base.Iterators: partition
7-
using Flux: crossentropy, onehotbatch, onecold
83

9-
10-
accuracy(x, y) = mean(onecold(cpu(m(x))) .== onecold(cpu(y)))
11-
12-
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
13-
s = size(X)
14-
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
15-
end
16-
17-
function train_model!(m, X, y;
18-
opt=ADAM(0.001),
19-
batch_size=128,
20-
n_epochs=10,
21-
file_name="")
22-
23-
loss(x, y) = crossentropy(m(x), y)
24-
25-
batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds
26-
return (gpu(X[:, :, :, inds]), gpu(y[:, inds]))
27-
end
28-
29-
for i in 1:n_epochs
30-
println("Iteration " * string(i))
31-
Flux.train!(loss, params(m), batches_train, opt)
32-
end
33-
34-
!isempty(file_name) && BSON.bson(file_name, m=m|>cpu)
35-
36-
return
37-
end
38-
39-
function load_data(dataset; onehot=false, T=Float32)
40-
classes = 0:9
41-
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
42-
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
43-
y_train = T.(y_train)
44-
y_test = T.(y_test)
45-
46-
if onehot
47-
y_train = onehotbatch(y_train[:], classes)
48-
y_test = onehotbatch(y_test[:], classes)
49-
end
50-
51-
return X_train, y_train, X_test, y_test
52-
end
4+
include(joinpath(dirname(@__FILE__), "utilities.jl"))
535

546
dataset = MLDatasets.MNIST
557
T = Float32
568
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)
579

58-
m = Chain(
59-
Conv((2,2), 1=>16, sigmoid),
60-
MaxPool((2,2)),
61-
Conv((2,2), 16=>8, sigmoid),
62-
MaxPool((2,2)),
63-
flatten,
64-
Dense(288, size(y_train,1)), softmax) |> gpu
10+
model = Chain(
11+
Conv((2, 2), 1 => 16, sigmoid),
12+
MaxPool((2, 2)),
13+
Conv((2, 2), 16 => 8, sigmoid),
14+
MaxPool((2, 2)),
15+
Flux.flatten,
16+
Dense(288, size(y_train, 1)),
17+
softmax,
18+
) |> gpu
6519

66-
file_name = joinpath("data", "mnist_sigmoid.bson")
67-
train_model!(m, X_train, y_train; file_name=file_name, n_epochs=100)
20+
file_name = joinpath("data", "mnist_sigmoid.jld2")
21+
train_model!(model, X_train, y_train; file_name=file_name, n_epochs=100)
6822

69-
accuracy(X_test |> gpu, y_test |> gpu)
23+
accuracy(model, X_test, y_test)
-20.1 KB
Binary file not shown.
25.8 KB
Binary file not shown.

docs/src/lecture_11/data/utilities.jl

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using MLDatasets
2+
using Flux
3+
using JLD2
4+
using Random
5+
using Statistics
6+
using Base.Iterators: partition
7+
using Flux: crossentropy, onehotbatch, onecold
8+
using Plots
9+
using Pkg
10+
11+
if haskey(Pkg.project().dependencies, "CUDA")
12+
using CUDA
13+
else
14+
gpu(x) = x
15+
end
16+
17+
accuracy(model, x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
18+
19+
function reshape_data(X::AbstractArray{T,3}, y::AbstractVector) where {T}
20+
s = size(X)
21+
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
22+
end
23+
24+
function load_data(dataset; onehot=false, T=Float32)
25+
classes = 0:9
26+
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
27+
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
28+
y_train = T.(y_train)
29+
y_test = T.(y_test)
30+
31+
if onehot
32+
y_train = onehotbatch(y_train[:], classes)
33+
y_test = onehotbatch(y_test[:], classes)
34+
end
35+
36+
return X_train, y_train, X_test, y_test
37+
end
38+
39+
function train_model!(
40+
model,
41+
X,
42+
y;
43+
opt=Adam(0.001),
44+
batch_size=128,
45+
n_epochs=10,
46+
file_name="",
47+
)
48+
49+
loss(x, y) = crossentropy(model(x), y)
50+
51+
batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds
52+
return (gpu(X[:, :, :, inds]), gpu(y[:, inds]))
53+
end
54+
55+
for epoch in 1:n_epochs
56+
@show epoch
57+
Flux.train!(loss, Flux.params(model), batches_train, opt)
58+
end
59+
60+
!isempty(file_name) && jldsave(file_name; model_state=Flux.state(model) |> cpu)
61+
62+
return
63+
end
64+
65+
function train_or_load!(file_name, model, args...; force=false, kwargs...)
66+
67+
!isdir(dirname(file_name)) && mkpath(dirname(file_name))
68+
69+
if force || !isfile(file_name)
70+
train_model!(model, args...; file_name=file_name, kwargs...)
71+
else
72+
model_state = JLD2.load(file_name, "model_state")
73+
Flux.loadmodel!(model, model_state)
74+
end
75+
end
76+
77+
plot_image(x::AbstractArray{T,2}) where {T} = plot(Gray.(x'), axis=nothing)
78+
79+
function plot_image(x::AbstractArray{T,4}) where {T}
80+
@assert size(x, 4) == 1
81+
plot_image(x[:, :, :, 1])
82+
end
83+
84+
function plot_image(x::AbstractArray{T,3}) where {T}
85+
@assert size(x, 3) == 1
86+
plot_image(x[:, :, 1])
87+
end

0 commit comments

Comments
 (0)