Skip to content

Commit 2ceaa05

Browse files
committed
Small updates
1 parent 220278c commit 2ceaa05

File tree

6 files changed

+21
-9
lines changed

6 files changed

+21
-9
lines changed

docs/src/lecture_11/data/mnist.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Flux
22
using Flux: onecold
33
using MLDatasets
44

5-
include(joinpath(dirname(@__FILE__), "utilities.jl"))
5+
include(joinpath(dirname(@__FILE__), ("utilities.jl")))
66

77
T = Float32
88
dataset = MLDatasets.MNIST
@@ -19,7 +19,7 @@ model = Chain(
1919
softmax,
2020
)
2121

22-
file_name = joinpath("data", "mnist_sigmoid.jld2")
22+
file_name = evaldir("mnist_sigmoid.jld2")
2323
train_or_load!(file_name, model, X_train, y_train)
2424

2525
ii1 = findall(onecold(y_train, 0:9) .== 1)[1:5]

docs/src/lecture_11/data/mnist.jld2

1.47 KB
Binary file not shown.

docs/src/lecture_11/data/mnist_gpu.jl

+16-6
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
using MLDatasets
22
using Flux
33

4-
include(joinpath(dirname(@__FILE__), "utilities.jl"))
4+
include(joinpath(dirname(@__FILE__), ("utilities.jl")))
55

66
dataset = MLDatasets.MNIST
77
T = Float32
88
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)
99

10+
# model = Chain(
11+
# Conv((2, 2), 1 => 16, sigmoid),
12+
# MaxPool((2, 2)),
13+
# Conv((2, 2), 16 => 8, sigmoid),
14+
# MaxPool((2, 2)),
15+
# Flux.flatten,
16+
# Dense(288, size(y_train, 1)),
17+
# softmax,
18+
# ) |> gpu
19+
1020
model = Chain(
11-
Conv((2, 2), 1 => 16, sigmoid),
21+
Conv((2, 2), 1 => 16, relu),
1222
MaxPool((2, 2)),
13-
Conv((2, 2), 16 => 8, sigmoid),
23+
Conv((2, 2), 16 => 8, relu),
1424
MaxPool((2, 2)),
1525
Flux.flatten,
1626
Dense(288, size(y_train, 1)),
1727
softmax,
18-
) |> gpu
28+
)
1929

20-
file_name = joinpath("data", "mnist_sigmoid.jld2")
21-
train_model!(model, X_train, y_train; file_name=file_name, n_epochs=100)
30+
file_name = evaldir("mnist.jld2")
31+
train_or_load!(file_name, model, X_train, y_train; n_epochs=100, force=true)
2232

2333
accuracy(model, X_test, y_test)
1.47 KB
Binary file not shown.

docs/src/lecture_11/data/utilities.jl

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ else
1414
gpu(x) = x
1515
end
1616

17+
evaldir(args...) = joinpath(dirname(@__FILE__), args...)
18+
1719
accuracy(model, x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
1820

1921
function reshape_data(X::AbstractArray{T,3}, y::AbstractVector) where {T}

docs/src/lecture_11/nn.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ The accuracy is over 93%, which is not bad for training for one epoch only. Let
405405
println("Test accuracy = " * string(accuracy(X_test, y_test))) # hide
406406
```
407407

408-
The externally trained model has an accuracy of more than 98% (it has the same architecture as the one defined above, but it was trained for 50 epochs.). Even though there are perfect models (with accuracy 100%) on MNIST, we are happy with this result. We will perform further analysis of the network in the exercises.
408+
The externally trained model has an accuracy of almost 98% (it has the same architecture as the one defined above, but it was trained for 100 epochs.). Even though there are perfect models (with accuracy 100%) on MNIST, we are happy with this result. We will perform further analysis of the network in the exercises.
409409

410410
```@setup nn
411411
using Plots

0 commit comments

Comments
 (0)