Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
Dale-Black committed Dec 27, 2023
1 parent 6dc6def commit acf5572
Showing 1 changed file with 94 additions and 81 deletions.
175 changes: 94 additions & 81 deletions notebooks/heart_training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ end
# ╔═╡ 8d4a6d5a-c437-43bb-a3db-ab961b218c2e
using PlutoUI: TableOfContents, Slider, bind

# ╔═╡ 83b95cee-90ed-4522-b9a8-79c082fce02e
using Random: default_rng, seed!

# ╔═╡ 7353b7ce-8b33-4602-aed7-2aa24864aca5
using HTTP: download

Expand All @@ -35,9 +38,6 @@ using Glob: glob
# ╔═╡ 8e2f2c6d-127d-42a6-9906-970c09a22e61
using CairoMakie: Figure, Axis, heatmap!

# ╔═╡ 83b95cee-90ed-4522-b9a8-79c082fce02e
using Random: default_rng, seed!

# ╔═╡ a3f44d7c-efa3-41d0-9509-b099ab7f09d4
using Lux

Expand All @@ -60,6 +60,12 @@ using LuxCUDA
# ╔═╡ c8d6553a-90df-4aeb-aa6d-a213e16fab48
TableOfContents()

# ╔═╡ af50e5f3-1a1c-47e5-a461-ffbee0329309
begin
rng = default_rng()
seed!(rng, 0)
end

# ╔═╡ cdfd2412-897d-4642-bb69-f8031c418446
function download_dataset(heart_url, target_directory)
if isempty(readdir(target_directory))
Expand Down Expand Up @@ -117,6 +123,11 @@ begin

end

# ╔═╡ af798f6b-7549-4253-b02b-2ed20dc1125b
md"""
# Randomness
"""

# ╔═╡ f0e64ba5-5e11-4ddb-91d3-2a34c60dc6bf
md"""
# Data Preparation
Expand Down Expand Up @@ -177,15 +188,19 @@ function preprocess_image_label_pair(pair, target_size)
end
end

# ╔═╡ 313b950b-9725-48ef-a1af-aa720a6eb28e
begin
# ╔═╡ ac2ed012-2b64-42b2-b97c-2a5352af9ec8
if LuxCUDA.functional()
target_size = (128, 128, 64)
transformed_data = mapobs(
x -> preprocess_image_label_pair(x, target_size),
data
)
else
target_size = (64, 64, 64)
end

# ╔═╡ c5539898-6b0c-4172-ba6c-9bfe2819c9fb
transformed_data = mapobs(
x -> preprocess_image_label_pair(x, target_size),
data
)

# ╔═╡ 03bab55a-6e5e-4b9f-b56a-7e9f993576eb
md"""
## Dataloaders
Expand Down Expand Up @@ -271,23 +286,12 @@ md"""
# Model
"""

# ╔═╡ af798f6b-7549-4253-b02b-2ed20dc1125b
md"""
## Randomness
"""

# ╔═╡ af50e5f3-1a1c-47e5-a461-ffbee0329309
begin
rng = default_rng()
seed!(rng, 0)
end

# ╔═╡ b3fc9578-6b40-4afc-bb58-c772a61a60a5
md"""
## Helper functions
"""

# ╔═╡ b14a75c2-0b86-4364-b3dd-699446e2be62
# ╔═╡ 3e938872-a390-40ba-8b00-b132f988e2d3
function create_unet_layers(
kernel_size, de_kernel_size, channel_list;
downsample = true)
Expand Down Expand Up @@ -321,14 +325,14 @@ function create_unet_layers(
return (conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end

# ╔═╡ ca45d18e-ad39-49cf-9810-bad96c1e3c31
# ╔═╡ 1d65b1d1-82de-40ca-aaba-9eee23883cf3
md"""
## Unet module
## Contracting Block
"""

# ╔═╡ 48c94969-16bc-4e69-bd3a-71e65b584dc8
# ╔═╡ 40762509-b26e-47f5-8b49-e7100fdeb72a
begin
struct UNetModule <: Lux.AbstractExplicitContainerLayer{
struct ContractBlock <: Lux.AbstractExplicitContainerLayer{
(:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample)
}
conv1::Conv
Expand All @@ -341,7 +345,7 @@ begin
sample::Chain
end

function UNetModule(
function ContractBlock(
kernel_size, de_kernel_size, channel_list;
downsample = true
)
Expand All @@ -352,10 +356,10 @@ begin
downsample = downsample
)

UNetModule(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
ContractBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end

function (m::UNetModule)(x, ps, st::NamedTuple)
function (m::ContractBlock)(x, ps, st::NamedTuple)
res, st_bridge_conv = m.bridge_conv(x, ps.bridge_conv, st.bridge_conv)
x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
x, st_bn1 = m.bn1(x, ps.bn1, st.bn1)
Expand All @@ -374,14 +378,14 @@ begin
end
end

# ╔═╡ a58cbfd8-8210-4cd1-b761-ffe6ea8417e9
# ╔═╡ 91e05c6c-e9b3-4a72-84a5-2ce4b1359b1a
md"""
## Deconv module
## Expanding Block
"""

# ╔═╡ 2fb54a8d-7972-4657-bd0b-c95b897c73a1
# ╔═╡ 70614cac-2e06-48a9-9cf6-9078bc7436bc
begin
struct DeConvModule <: Lux.AbstractExplicitContainerLayer{
struct ExpandBlock <: Lux.AbstractExplicitContainerLayer{
(:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample)
}
conv1::Conv
Expand All @@ -394,7 +398,7 @@ begin
sample::Chain
end

function DeConvModule(
function ExpandBlock(
kernel_size, de_kernel_size, channel_list;
downsample = false)

Expand All @@ -403,10 +407,10 @@ begin
downsample = downsample
)

DeConvModule(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
ExpandBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample)
end

function (m::DeConvModule)(x, ps, st::NamedTuple)
function (m::ExpandBlock)(x, ps, st::NamedTuple)
x, x1 = x[1], x[2]
x = cat(x, x1; dims=4)

Expand All @@ -429,29 +433,29 @@ begin
end
end

# ╔═╡ cbcdb3a6-eaff-4126-b5b5-c07c907e74b4
# ╔═╡ 36885de0-aa0e-4037-929f-44e074fb17f5
md"""
## FCN
## U-Net
"""

# ╔═╡ ca2e52c6-63d2-4126-aa62-69a48fa36dbc
# ╔═╡ af56e2f7-2ab8-4ff2-8295-038b3a565cbc
begin
struct FCN <: Lux.AbstractExplicitContainerLayer{
struct UNet <: Lux.AbstractExplicitContainerLayer{
(:conv1, :conv2, :conv3, :conv4, :conv5, :de_conv1, :de_conv2, :de_conv3, :de_conv4, :last_conv)
}
conv1::Chain
conv2::Chain
conv3::UNetModule
conv4::UNetModule
conv5::UNetModule
de_conv1::UNetModule
de_conv2::DeConvModule
de_conv3::DeConvModule
de_conv4::DeConvModule
conv3::ContractBlock
conv4::ContractBlock
conv5::ContractBlock
de_conv1::ContractBlock
de_conv2::ExpandBlock
de_conv3::ExpandBlock
de_conv4::ExpandBlock
last_conv::Conv
end

function FCN(channel)
function UNet(channel)
conv1 = Chain(
Conv((5, 5, 5), 1 => channel, stride=1, pad=2),
BatchNorm(channel),
Expand All @@ -462,44 +466,44 @@ begin
BatchNorm(2 * channel),
relu
)
conv3 = UNetModule(5, 2, [2 * channel, 2 * channel, 4 * channel])
conv4 = UNetModule(5, 2, [4 * channel, 4 * channel, 8 * channel])
conv5 = UNetModule(5, 2, [8 * channel, 8 * channel, 16 * channel])
conv3 = ContractBlock(5, 2, [2 * channel, 2 * channel, 4 * channel])
conv4 = ContractBlock(5, 2, [4 * channel, 4 * channel, 8 * channel])
conv5 = ContractBlock(5, 2, [8 * channel, 8 * channel, 16 * channel])

de_conv1 = UNetModule(
de_conv1 = ContractBlock(
5, 2, [16 * channel, 32 * channel, 16 * channel];
downsample = false
)
de_conv2 = DeConvModule(
de_conv2 = ExpandBlock(
5, 2, [32 * channel, 8 * channel, 8 * channel];
downsample = false
)
de_conv3 = DeConvModule(
de_conv3 = ExpandBlock(
5, 2, [16 * channel, 4 * channel, 4 * channel];
downsample = false
)
de_conv4 = DeConvModule(
de_conv4 = ExpandBlock(
5, 2, [8 * channel, 2 * channel, channel];
downsample = false
)

last_conv = Conv((1, 1, 1), 2 * channel => 2, stride=1, pad=0)

FCN(conv1, conv2, conv3, conv4, conv5, de_conv1, de_conv2, de_conv3, de_conv4, last_conv)
UNet(conv1, conv2, conv3, conv4, conv5, de_conv1, de_conv2, de_conv3, de_conv4, last_conv)
end

function (m::FCN)(x, ps, st::NamedTuple)
function (m::UNet)(x, ps, st::NamedTuple)
# Convolutional layers
x, st_conv1 = m.conv1(x, ps.conv1, st.conv1)
x_1 = x # Store for skip connection
x, st_conv2 = m.conv2(x, ps.conv2, st.conv2)

# Downscaling UNet modules
# Downscaling Blocks
x, x_2, st_conv3 = m.conv3(x, ps.conv3, st.conv3)
x, x_3, st_conv4 = m.conv4(x, ps.conv4, st.conv4)
x, x_4, st_conv5 = m.conv5(x, ps.conv5, st.conv5)

# Upscaling DeConv modules
# Upscaling Blocks
x, _, st_de_conv1 = m.de_conv1(x, ps.de_conv1, st.de_conv1)
x, st_de_conv2 = m.de_conv2((x, x_4), ps.de_conv2, st.de_conv2)
x, st_de_conv3 = m.de_conv3((x, x_3), ps.de_conv3, st.de_conv3)
Expand Down Expand Up @@ -558,19 +562,27 @@ function compute_loss(x, y, model, ps, st, epoch)

# Compute loss
loss = 0.0
for b in axes(y, 5)
# for b in axes(y, 5)
# _y_pred = y_pred_binary[:, :, :, b]
# _y = y_binary[:, :, :, b]

# local _y_dtm, _y_pred_dtm
# ignore_derivatives() do
# _y_dtm = transform(boolean_indicator(_y))
# _y_pred_dtm = transform(boolean_indicator(_y_pred))
# end

# hd = hausdorff_loss(_y_pred, _y, _y_pred_dtm, _y_dtm)
# dsc = dice_loss(_y_pred, _y)
# loss += alpha * dsc + beta * hd
# end

for b in axes(y, 5)
_y_pred = y_pred_binary[:, :, :, b]
_y = y_binary[:, :, :, b]

local _y_dtm, _y_pred_dtm
ignore_derivatives() do
_y_dtm = transform(boolean_indicator(_y))
_y_pred_dtm = transform(boolean_indicator(_y_pred))
end

hd = hausdorff_loss(_y_pred, _y, _y_pred_dtm, _y_dtm)
dsc = dice_loss(_y_pred, _y)
loss += alpha * dsc + beta * hd
loss += dsc
end
return loss / size(y, 5), y_pred_binary, st
end
Expand All @@ -584,7 +596,7 @@ md"""
dev = gpu_device()

# ╔═╡ bbdaf5c5-9faa-4b61-afab-c0242b8ca034
model = FCN(4)
model = UNet(4)

# ╔═╡ 6ec3e34b-1c57-4cfb-a50d-ee786c2e4559
begin
Expand Down Expand Up @@ -637,7 +649,7 @@ function train_model(model, ps, st, train_loader, num_epochs, dev)
end

# ╔═╡ a2e88851-227a-4719-8828-6064f9d3ef81
num_epochs = 3
num_epochs = 20

# ╔═╡ 5cae73af-471c-4068-b9ff-5bc03dd0472d
train_model(model, ps, st, train_loader, num_epochs, dev)
Expand Down Expand Up @@ -3033,6 +3045,9 @@ version = "3.5.0+0"
# ╔═╡ Cell order:
# ╠═8d4a6d5a-c437-43bb-a3db-ab961b218c2e
# ╠═c8d6553a-90df-4aeb-aa6d-a213e16fab48
# ╟─af798f6b-7549-4253-b02b-2ed20dc1125b
# ╠═83b95cee-90ed-4522-b9a8-79c082fce02e
# ╠═af50e5f3-1a1c-47e5-a461-ffbee0329309
# ╟─f0e64ba5-5e11-4ddb-91d3-2a34c60dc6bf
# ╠═7353b7ce-8b33-4602-aed7-2aa24864aca5
# ╠═de5efc37-db19-440e-9487-9a7bea84996d
Expand All @@ -3051,7 +3066,9 @@ version = "3.5.0+0"
# ╠═18b31959-9cdf-41d9-a389-7c18febf7b07
# ╠═72827ad5-4820-4545-8099-1033d962970e
# ╠═8ad7b2bb-1672-473a-a7b5-bf505733f7a3
# ╠═313b950b-9725-48ef-a1af-aa720a6eb28e
# ╠═317c1571-d232-4cab-ac10-9fc3b7ad33b0
# ╠═ac2ed012-2b64-42b2-b97c-2a5352af9ec8
# ╠═c5539898-6b0c-4172-ba6c-9bfe2819c9fb
# ╟─03bab55a-6e5e-4b9f-b56a-7e9f993576eb
# ╠═d40f19dc-f06e-44ef-b82b-9763ff1f1189
# ╠═4d75f114-225f-45e2-a683-e82ff137d909
Expand All @@ -3066,18 +3083,15 @@ version = "3.5.0+0"
# ╟─6e2bfcfb-77e3-4532-a14d-10f4b91f2f54
# ╟─bae79c05-034a-4c39-801a-01229b618e94
# ╟─1494df6e-f407-42c4-8404-1f4871a2f817
# ╠═83b95cee-90ed-4522-b9a8-79c082fce02e
# ╠═a3f44d7c-efa3-41d0-9509-b099ab7f09d4
# ╟─af798f6b-7549-4253-b02b-2ed20dc1125b
# ╠═af50e5f3-1a1c-47e5-a461-ffbee0329309
# ╟─b3fc9578-6b40-4afc-bb58-c772a61a60a5
# ╠═b14a75c2-0b86-4364-b3dd-699446e2be62
# ╟─ca45d18e-ad39-49cf-9810-bad96c1e3c31
# ╠═48c94969-16bc-4e69-bd3a-71e65b584dc8
# ╟─a58cbfd8-8210-4cd1-b761-ffe6ea8417e9
# ╠═2fb54a8d-7972-4657-bd0b-c95b897c73a1
# ╟─cbcdb3a6-eaff-4126-b5b5-c07c907e74b4
# ╠═ca2e52c6-63d2-4126-aa62-69a48fa36dbc
# ╠═3e938872-a390-40ba-8b00-b132f988e2d3
# ╟─1d65b1d1-82de-40ca-aaba-9eee23883cf3
# ╠═40762509-b26e-47f5-8b49-e7100fdeb72a
# ╟─91e05c6c-e9b3-4a72-84a5-2ce4b1359b1a
# ╠═70614cac-2e06-48a9-9cf6-9078bc7436bc
# ╟─36885de0-aa0e-4037-929f-44e074fb17f5
# ╠═af56e2f7-2ab8-4ff2-8295-038b3a565cbc
# ╟─df2dd9a7-045c-44a5-a62c-8d9f2541dc14
# ╠═a6669580-de24-4111-a7cb-26d3e727a12e
# ╠═dfc9377a-7cc1-43ba-bb43-683d24e67d79
Expand All @@ -3090,7 +3104,6 @@ version = "3.5.0+0"
# ╟─a25bdfe6-b24d-446b-926f-6e0727d647a2
# ╠═8598dfca-8929-4ec3-9eb5-09c240c3fdba
# ╟─45949f7f-4e4a-4857-af43-ff013dbdd137
# ╠═317c1571-d232-4cab-ac10-9fc3b7ad33b0
# ╠═402ba194-350e-4ff3-832b-6651be1d9ce7
# ╠═bbdaf5c5-9faa-4b61-afab-c0242b8ca034
# ╠═6ec3e34b-1c57-4cfb-a50d-ee786c2e4559
Expand Down

0 comments on commit acf5572

Please sign in to comment.