diff --git a/notebooks/heart_training.jl b/notebooks/heart_training.jl index f2638aa..382e831 100644 --- a/notebooks/heart_training.jl +++ b/notebooks/heart_training.jl @@ -17,6 +17,9 @@ end # ╔═╡ 8d4a6d5a-c437-43bb-a3db-ab961b218c2e using PlutoUI: TableOfContents, Slider, bind +# ╔═╡ 83b95cee-90ed-4522-b9a8-79c082fce02e +using Random: default_rng, seed! + # ╔═╡ 7353b7ce-8b33-4602-aed7-2aa24864aca5 using HTTP: download @@ -35,9 +38,6 @@ using Glob: glob # ╔═╡ 8e2f2c6d-127d-42a6-9906-970c09a22e61 using CairoMakie: Figure, Axis, heatmap! -# ╔═╡ 83b95cee-90ed-4522-b9a8-79c082fce02e -using Random: default_rng, seed! - # ╔═╡ a3f44d7c-efa3-41d0-9509-b099ab7f09d4 using Lux @@ -60,6 +60,12 @@ using LuxCUDA # ╔═╡ c8d6553a-90df-4aeb-aa6d-a213e16fab48 TableOfContents() +# ╔═╡ af50e5f3-1a1c-47e5-a461-ffbee0329309 +begin + rng = default_rng() + seed!(rng, 0) +end + # ╔═╡ cdfd2412-897d-4642-bb69-f8031c418446 function download_dataset(heart_url, target_directory) if isempty(readdir(target_directory)) @@ -117,6 +123,11 @@ begin end +# ╔═╡ af798f6b-7549-4253-b02b-2ed20dc1125b +md""" +# Randomness +""" + # ╔═╡ f0e64ba5-5e11-4ddb-91d3-2a34c60dc6bf md""" # Data Preparation @@ -177,15 +188,19 @@ function preprocess_image_label_pair(pair, target_size) end end -# ╔═╡ 313b950b-9725-48ef-a1af-aa720a6eb28e -begin +# ╔═╡ ac2ed012-2b64-42b2-b97c-2a5352af9ec8 +if LuxCUDA.functional() target_size = (128, 128, 64) - transformed_data = mapobs( - x -> preprocess_image_label_pair(x, target_size), - data - ) +else + target_size = (64, 64, 64) end +# ╔═╡ c5539898-6b0c-4172-ba6c-9bfe2819c9fb +transformed_data = mapobs( + x -> preprocess_image_label_pair(x, target_size), + data +) + # ╔═╡ 03bab55a-6e5e-4b9f-b56a-7e9f993576eb md""" ## Dataloaders @@ -271,23 +286,12 @@ md""" # Model """ -# ╔═╡ af798f6b-7549-4253-b02b-2ed20dc1125b -md""" -## Randomness -""" - -# ╔═╡ af50e5f3-1a1c-47e5-a461-ffbee0329309 -begin - rng = default_rng() - seed!(rng, 0) -end - # ╔═╡ b3fc9578-6b40-4afc-bb58-c772a61a60a5 md""" ## Helper functions """ -# ╔═╡ b14a75c2-0b86-4364-b3dd-699446e2be62 +# ╔═╡ 3e938872-a390-40ba-8b00-b132f988e2d3 function create_unet_layers( kernel_size, de_kernel_size, channel_list; downsample = true) @@ -321,14 +325,14 @@ function create_unet_layers( return (conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample) end -# ╔═╡ ca45d18e-ad39-49cf-9810-bad96c1e3c31 +# ╔═╡ 1d65b1d1-82de-40ca-aaba-9eee23883cf3 md""" -## Unet module +## Contracting Block """ -# ╔═╡ 48c94969-16bc-4e69-bd3a-71e65b584dc8 +# ╔═╡ 40762509-b26e-47f5-8b49-e7100fdeb72a begin - struct UNetModule <: Lux.AbstractExplicitContainerLayer{ + struct ContractBlock <: Lux.AbstractExplicitContainerLayer{ (:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample) } conv1::Conv @@ -341,7 +345,7 @@ begin sample::Chain end - function UNetModule( + function ContractBlock( kernel_size, de_kernel_size, channel_list; downsample = true ) @@ -352,10 +356,10 @@ begin downsample = downsample ) - UNetModule(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample) + ContractBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample) end - function (m::UNetModule)(x, ps, st::NamedTuple) + function (m::ContractBlock)(x, ps, st::NamedTuple) res, st_bridge_conv = m.bridge_conv(x, ps.bridge_conv, st.bridge_conv) x, st_conv1 = m.conv1(x, ps.conv1, st.conv1) x, st_bn1 = m.bn1(x, ps.bn1, st.bn1) @@ -374,14 +378,14 @@ begin end end -# ╔═╡ a58cbfd8-8210-4cd1-b761-ffe6ea8417e9 +# ╔═╡ 91e05c6c-e9b3-4a72-84a5-2ce4b1359b1a md""" -## Deconv module +## Expanding Block """ -# ╔═╡ 2fb54a8d-7972-4657-bd0b-c95b897c73a1 +# ╔═╡ 70614cac-2e06-48a9-9cf6-9078bc7436bc begin - struct DeConvModule <: Lux.AbstractExplicitContainerLayer{ + struct ExpandBlock <: Lux.AbstractExplicitContainerLayer{ (:conv1, :conv2, :bn1, :bn2, :bridge_conv, :sample) } conv1::Conv @@ -394,7 +398,7 @@ begin sample::Chain end - function DeConvModule( + function ExpandBlock( kernel_size, de_kernel_size, channel_list; downsample = false) @@ -403,10 +407,10 @@ begin downsample = downsample ) - DeConvModule(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample) + ExpandBlock(conv1, conv2, relu1, relu2, bn1, bn2, bridge_conv, sample) end - function (m::DeConvModule)(x, ps, st::NamedTuple) + function (m::ExpandBlock)(x, ps, st::NamedTuple) x, x1 = x[1], x[2] x = cat(x, x1; dims=4) @@ -429,29 +433,29 @@ begin end end -# ╔═╡ cbcdb3a6-eaff-4126-b5b5-c07c907e74b4 +# ╔═╡ 36885de0-aa0e-4037-929f-44e074fb17f5 md""" -## FCN +## U-Net """ -# ╔═╡ ca2e52c6-63d2-4126-aa62-69a48fa36dbc +# ╔═╡ af56e2f7-2ab8-4ff2-8295-038b3a565cbc begin - struct FCN <: Lux.AbstractExplicitContainerLayer{ + struct UNet <: Lux.AbstractExplicitContainerLayer{ (:conv1, :conv2, :conv3, :conv4, :conv5, :de_conv1, :de_conv2, :de_conv3, :de_conv4, :last_conv) } conv1::Chain conv2::Chain - conv3::UNetModule - conv4::UNetModule - conv5::UNetModule - de_conv1::UNetModule - de_conv2::DeConvModule - de_conv3::DeConvModule - de_conv4::DeConvModule + conv3::ContractBlock + conv4::ContractBlock + conv5::ContractBlock + de_conv1::ContractBlock + de_conv2::ExpandBlock + de_conv3::ExpandBlock + de_conv4::ExpandBlock last_conv::Conv end - function FCN(channel) + function UNet(channel) conv1 = Chain( Conv((5, 5, 5), 1 => channel, stride=1, pad=2), BatchNorm(channel), @@ -462,44 +466,44 @@ begin BatchNorm(2 * channel), relu ) - conv3 = UNetModule(5, 2, [2 * channel, 2 * channel, 4 * channel]) - conv4 = UNetModule(5, 2, [4 * channel, 4 * channel, 8 * channel]) - conv5 = UNetModule(5, 2, [8 * channel, 8 * channel, 16 * channel]) + conv3 = ContractBlock(5, 2, [2 * channel, 2 * channel, 4 * channel]) + conv4 = ContractBlock(5, 2, [4 * channel, 4 * channel, 8 * channel]) + conv5 = ContractBlock(5, 2, [8 * channel, 8 * channel, 16 * channel]) - de_conv1 = UNetModule( + de_conv1 = ContractBlock( 5, 2, [16 * channel, 32 * channel, 16 * channel]; downsample = false ) - de_conv2 = DeConvModule( + de_conv2 = ExpandBlock( 5, 2, [32 * channel, 8 * channel, 8 * channel]; downsample = false ) - de_conv3 = DeConvModule( + de_conv3 = ExpandBlock( 5, 2, [16 * channel, 4 * channel, 4 * channel]; downsample = false ) - de_conv4 = DeConvModule( + de_conv4 = ExpandBlock( 5, 2, [8 * channel, 2 * channel, channel]; downsample = false ) last_conv = Conv((1, 1, 1), 2 * channel => 2, stride=1, pad=0) - FCN(conv1, conv2, conv3, conv4, conv5, de_conv1, de_conv2, de_conv3, de_conv4, last_conv) + UNet(conv1, conv2, conv3, conv4, conv5, de_conv1, de_conv2, de_conv3, de_conv4, last_conv) end - function (m::FCN)(x, ps, st::NamedTuple) + function (m::UNet)(x, ps, st::NamedTuple) # Convolutional layers x, st_conv1 = m.conv1(x, ps.conv1, st.conv1) x_1 = x # Store for skip connection x, st_conv2 = m.conv2(x, ps.conv2, st.conv2) - # Downscaling UNet modules + # Downscaling Blocks x, x_2, st_conv3 = m.conv3(x, ps.conv3, st.conv3) x, x_3, st_conv4 = m.conv4(x, ps.conv4, st.conv4) x, x_4, st_conv5 = m.conv5(x, ps.conv5, st.conv5) - # Upscaling DeConv modules + # Upscaling Blocks x, _, st_de_conv1 = m.de_conv1(x, ps.de_conv1, st.de_conv1) x, st_de_conv2 = m.de_conv2((x, x_4), ps.de_conv2, st.de_conv2) x, st_de_conv3 = m.de_conv3((x, x_3), ps.de_conv3, st.de_conv3) @@ -558,19 +562,27 @@ function compute_loss(x, y, model, ps, st, epoch) # Compute loss loss = 0.0 - for b in axes(y, 5) + # for b in axes(y, 5) + # _y_pred = y_pred_binary[:, :, :, b] + # _y = y_binary[:, :, :, b] + + # local _y_dtm, _y_pred_dtm + # ignore_derivatives() do + # _y_dtm = transform(boolean_indicator(_y)) + # _y_pred_dtm = transform(boolean_indicator(_y_pred)) + # end + + # hd = hausdorff_loss(_y_pred, _y, _y_pred_dtm, _y_dtm) + # dsc = dice_loss(_y_pred, _y) + # loss += alpha * dsc + beta * hd + # end + + for b in axes(y, 5) _y_pred = y_pred_binary[:, :, :, b] _y = y_binary[:, :, :, b] - - local _y_dtm, _y_pred_dtm - ignore_derivatives() do - _y_dtm = transform(boolean_indicator(_y)) - _y_pred_dtm = transform(boolean_indicator(_y_pred)) - end - hd = hausdorff_loss(_y_pred, _y, _y_pred_dtm, _y_dtm) dsc = dice_loss(_y_pred, _y) - loss += alpha * dsc + beta * hd + loss += dsc end return loss / size(y, 5), y_pred_binary, st end @@ -584,7 +596,7 @@ md""" dev = gpu_device() # ╔═╡ bbdaf5c5-9faa-4b61-afab-c0242b8ca034 -model = FCN(4) +model = UNet(4) # ╔═╡ 6ec3e34b-1c57-4cfb-a50d-ee786c2e4559 begin @@ -637,7 +649,7 @@ function train_model(model, ps, st, train_loader, num_epochs, dev) end # ╔═╡ a2e88851-227a-4719-8828-6064f9d3ef81 -num_epochs = 3 +num_epochs = 20 # ╔═╡ 5cae73af-471c-4068-b9ff-5bc03dd0472d train_model(model, ps, st, train_loader, num_epochs, dev) @@ -3033,6 +3045,9 @@ version = "3.5.0+0" # ╔═╡ Cell order: # ╠═8d4a6d5a-c437-43bb-a3db-ab961b218c2e # ╠═c8d6553a-90df-4aeb-aa6d-a213e16fab48 +# ╟─af798f6b-7549-4253-b02b-2ed20dc1125b +# ╠═83b95cee-90ed-4522-b9a8-79c082fce02e +# ╠═af50e5f3-1a1c-47e5-a461-ffbee0329309 # ╟─f0e64ba5-5e11-4ddb-91d3-2a34c60dc6bf # ╠═7353b7ce-8b33-4602-aed7-2aa24864aca5 # ╠═de5efc37-db19-440e-9487-9a7bea84996d @@ -3051,7 +3066,9 @@ version = "3.5.0+0" # ╠═18b31959-9cdf-41d9-a389-7c18febf7b07 # ╠═72827ad5-4820-4545-8099-1033d962970e # ╠═8ad7b2bb-1672-473a-a7b5-bf505733f7a3 -# ╠═313b950b-9725-48ef-a1af-aa720a6eb28e +# ╠═317c1571-d232-4cab-ac10-9fc3b7ad33b0 +# ╠═ac2ed012-2b64-42b2-b97c-2a5352af9ec8 +# ╠═c5539898-6b0c-4172-ba6c-9bfe2819c9fb # ╟─03bab55a-6e5e-4b9f-b56a-7e9f993576eb # ╠═d40f19dc-f06e-44ef-b82b-9763ff1f1189 # ╠═4d75f114-225f-45e2-a683-e82ff137d909 @@ -3066,18 +3083,15 @@ version = "3.5.0+0" # ╟─6e2bfcfb-77e3-4532-a14d-10f4b91f2f54 # ╟─bae79c05-034a-4c39-801a-01229b618e94 # ╟─1494df6e-f407-42c4-8404-1f4871a2f817 -# ╠═83b95cee-90ed-4522-b9a8-79c082fce02e # ╠═a3f44d7c-efa3-41d0-9509-b099ab7f09d4 -# ╟─af798f6b-7549-4253-b02b-2ed20dc1125b -# ╠═af50e5f3-1a1c-47e5-a461-ffbee0329309 # ╟─b3fc9578-6b40-4afc-bb58-c772a61a60a5 -# ╠═b14a75c2-0b86-4364-b3dd-699446e2be62 -# ╟─ca45d18e-ad39-49cf-9810-bad96c1e3c31 -# ╠═48c94969-16bc-4e69-bd3a-71e65b584dc8 -# ╟─a58cbfd8-8210-4cd1-b761-ffe6ea8417e9 -# ╠═2fb54a8d-7972-4657-bd0b-c95b897c73a1 -# ╟─cbcdb3a6-eaff-4126-b5b5-c07c907e74b4 -# ╠═ca2e52c6-63d2-4126-aa62-69a48fa36dbc +# ╠═3e938872-a390-40ba-8b00-b132f988e2d3 +# ╟─1d65b1d1-82de-40ca-aaba-9eee23883cf3 +# ╠═40762509-b26e-47f5-8b49-e7100fdeb72a +# ╟─91e05c6c-e9b3-4a72-84a5-2ce4b1359b1a +# ╠═70614cac-2e06-48a9-9cf6-9078bc7436bc +# ╟─36885de0-aa0e-4037-929f-44e074fb17f5 +# ╠═af56e2f7-2ab8-4ff2-8295-038b3a565cbc # ╟─df2dd9a7-045c-44a5-a62c-8d9f2541dc14 # ╠═a6669580-de24-4111-a7cb-26d3e727a12e # ╠═dfc9377a-7cc1-43ba-bb43-683d24e67d79 @@ -3090,7 +3104,6 @@ version = "3.5.0+0" # ╟─a25bdfe6-b24d-446b-926f-6e0727d647a2 # ╠═8598dfca-8929-4ec3-9eb5-09c240c3fdba # ╟─45949f7f-4e4a-4857-af43-ff013dbdd137 -# ╠═317c1571-d232-4cab-ac10-9fc3b7ad33b0 # ╠═402ba194-350e-4ff3-832b-6651be1d9ce7 # ╠═bbdaf5c5-9faa-4b61-afab-c0242b8ca034 # ╠═6ec3e34b-1c57-4cfb-a50d-ee786c2e4559