diff --git a/dev/articles/callbacks.html b/dev/articles/callbacks.html index 63820c64..f9f01c68 100644 --- a/dev/articles/callbacks.html +++ b/dev/articles/callbacks.html @@ -225,7 +225,7 @@
torch
Primerinput = torch_randn(2, 3)
input
#> torch_tensor
-#> -0.1934 1.3338 0.2307
-#> -0.3255 0.4996 -0.5817
+#> -0.4498 -1.7058 -1.9809
+#> 0.3088 -1.8150 -0.8680
#> [ CPUFloatType{2,3} ]
A nn_module
is constructed from a
nn_module_generator
. nn_linear
is one of the
@@ -117,8 +117,8 @@
torch
Primeroutput = module_1(input)
output
#> torch_tensor
-#> 0.2293 0.0607 0.7179 1.1731
-#> -0.1982 0.3674 0.0534 0.6567
+#> -0.5222 1.4472 1.0816 -1.5185
+#> -0.6949 0.4341 0.4113 -1.4010
#> [ CPUFloatType{2,4} ][ grad_fn = <AddmmBackward0> ]
A neural network with one (4-unit) hidden layer and two outputs needs the following ingredients
@@ -134,8 +134,8 @@torch
Primeroutput = softmax(output)
output
#> torch_tensor
-#> 0.1706 0.1728 0.6566
-#> 0.1868 0.1952 0.6180
+#> 0.1116 0.3492 0.5392
+#> 0.1333 0.3734 0.4933
#> [ CPUFloatType{2,3} ][ grad_fn = <SoftmaxBackward0> ]
We will now continue with showing how such a neural network can be
represented in mlr3torch
.
Note we only use the $train()
, since torch modules do
not have anything that maps to the state
(it is filled by
@@ -196,8 +196,8 @@
While this object allows to easily perform a forward pass, it does
not inherit from nn_module
, which is useful for various
@@ -245,8 +245,8 @@
graph_module(input)
#> torch_tensor
-#> 0.1706 0.1728 0.6566
-#> 0.1868 0.1952 0.6180
+#> 0.1116 0.3492 0.5392
+#> 0.1333 0.3734 0.4933
#> [ CPUFloatType{2,3} ][ grad_fn = <SoftmaxBackward0> ]
ModelDescriptor
to
small_module(batch$x[[1]])
#> torch_tensor
-#> 2.7076 -0.4141 -5.1192 -0.4539
-#> 2.4751 -0.3695 -4.7644 -0.4728
-#> 2.5174 -0.4087 -4.7471 -0.4259
+#> 0.9057 2.1483 -0.2835 -3.7977
+#> 1.0580 1.9881 -0.3755 -3.5451
+#> 0.8433 2.0029 -0.2918 -3.4974
#> [ CPUFloatType{3,4} ][ grad_fn = <AddmmBackward0> ]The first linear layer that takes “Sepal” input
("linear1"
) creates a 2x4 tensor (batch size 2, 4 units),
@@ -689,14 +689,14 @@
We observe that the po("nn_merge_cat")
concatenates
these, as expected:
The printed output of the data descriptor informs us about:
What happens during materialize(lt[1])
is the
following:
We see that the $graph
has a new pipeop with id
"poly.x"
and the output pointer
points to
poly.x
. Also we see that the shape of the tensor is now
diff --git a/dev/pkgdown.yml b/dev/pkgdown.yml
index d10815f8..f2d6823d 100644
--- a/dev/pkgdown.yml
+++ b/dev/pkgdown.yml
@@ -7,7 +7,7 @@ articles:
articles/internals_pipeop_torch: internals_pipeop_torch.html
articles/lazy_tensor: lazy_tensor.html
articles/pipeop_torch: pipeop_torch.html
-last_built: 2024-08-22T09:58Z
+last_built: 2024-08-22T10:03Z
urls:
reference: https://mlr3torch.mlr-org.com/reference
article: https://mlr3torch.mlr-org.com/articles
diff --git a/dev/reference/DataDescriptor.html b/dev/reference/DataDescriptor.html
index 27541157..eb7805d1 100644
--- a/dev/reference/DataDescriptor.html
+++ b/dev/reference/DataDescriptor.html
@@ -263,14 +263,14 @@
lt1 = as_lazy_tensor(torch_randn(10, 3))
materialize(lt1, rbind = TRUE)
#> torch_tensor
-#> 0.4337 0.5389 0.4592
-#> 0.0992 0.7277 -0.6565
-#> 0.0979 -0.5693 1.1938
-#> -0.3715 -0.5852 0.6129
-#> -0.6930 -1.0994 0.3037
-#> -0.2728 -1.4391 -0.6987
-#> -1.6596 -0.9801 0.5558
-#> 0.3316 0.1100 -0.5444
-#> 0.5616 -0.6240 -0.1895
-#> -0.3808 0.7544 0.0610
+#> 0.6629 -1.2602 -2.5268
+#> -1.1277 -0.3862 -0.9214
+#> 0.5454 1.9272 -0.6538
+#> -0.1761 0.7305 0.2100
+#> -0.7350 -0.5816 -0.7902
+#> -1.0446 -1.3970 1.4858
+#> 1.0944 1.9134 1.0702
+#> 0.3584 0.4934 0.8639
+#> 0.4933 -0.3884 0.2562
+#> -0.0735 -1.2129 -0.9788
#> [ CPUFloatType{10,3} ]
materialize(lt1, rbind = FALSE)
#> [[1]]
#> torch_tensor
-#> 0.4337
-#> 0.5389
-#> 0.4592
+#> 0.6629
+#> -1.2602
+#> -2.5268
#> [ CPUFloatType{3} ]
#>
#> [[2]]
#> torch_tensor
-#> 0.0992
-#> 0.7277
-#> -0.6565
+#> -1.1277
+#> -0.3862
+#> -0.9214
#> [ CPUFloatType{3} ]
#>
#> [[3]]
#> torch_tensor
-#> 0.0979
-#> -0.5693
-#> 1.1938
+#> 0.5454
+#> 1.9272
+#> -0.6538
#> [ CPUFloatType{3} ]
#>
#> [[4]]
#> torch_tensor
-#> -0.3715
-#> -0.5852
-#> 0.6129
+#> -0.1761
+#> 0.7305
+#> 0.2100
#> [ CPUFloatType{3} ]
#>
#> [[5]]
#> torch_tensor
-#> -0.6930
-#> -1.0994
-#> 0.3037
+#> -0.7350
+#> -0.5816
+#> -0.7902
#> [ CPUFloatType{3} ]
#>
#> [[6]]
#> torch_tensor
-#> -0.2728
-#> -1.4391
-#> -0.6987
+#> -1.0446
+#> -1.3970
+#> 1.4858
#> [ CPUFloatType{3} ]
#>
#> [[7]]
#> torch_tensor
-#> -1.6596
-#> -0.9801
-#> 0.5558
+#> 1.0944
+#> 1.9134
+#> 1.0702
#> [ CPUFloatType{3} ]
#>
#> [[8]]
#> torch_tensor
-#> 0.3316
-#> 0.1100
-#> -0.5444
+#> 0.3584
+#> 0.4934
+#> 0.8639
#> [ CPUFloatType{3} ]
#>
#> [[9]]
#> torch_tensor
-#> 0.5616
-#> -0.6240
-#> -0.1895
+#> 0.4933
+#> -0.3884
+#> 0.2562
#> [ CPUFloatType{3} ]
#>
#> [[10]]
#> torch_tensor
-#> -0.3808
-#> 0.7544
-#> 0.0610
+#> 0.01 *
+#> -7.3548
+#> -121.2882
+#> -97.8838
#> [ CPUFloatType{3} ]
#>
lt2 = as_lazy_tensor(torch_randn(10, 4))
@@ -219,184 +220,185 @@ Examplesmaterialize(d, rbind = TRUE)
#> $lt1
#> torch_tensor
-#> 0.4337 0.5389 0.4592
-#> 0.0992 0.7277 -0.6565
-#> 0.0979 -0.5693 1.1938
-#> -0.3715 -0.5852 0.6129
-#> -0.6930 -1.0994 0.3037
-#> -0.2728 -1.4391 -0.6987
-#> -1.6596 -0.9801 0.5558
-#> 0.3316 0.1100 -0.5444
-#> 0.5616 -0.6240 -0.1895
-#> -0.3808 0.7544 0.0610
+#> 0.6629 -1.2602 -2.5268
+#> -1.1277 -0.3862 -0.9214
+#> 0.5454 1.9272 -0.6538
+#> -0.1761 0.7305 0.2100
+#> -0.7350 -0.5816 -0.7902
+#> -1.0446 -1.3970 1.4858
+#> 1.0944 1.9134 1.0702
+#> 0.3584 0.4934 0.8639
+#> 0.4933 -0.3884 0.2562
+#> -0.0735 -1.2129 -0.9788
#> [ CPUFloatType{10,3} ]
#>
#> $lt2
#> torch_tensor
-#> 0.8252 -0.5098 -0.0403 -0.5451
-#> -0.4994 -1.4268 -1.6522 1.2544
-#> -1.0478 -0.9131 0.7225 0.9815
-#> 0.2689 0.3012 -0.7065 2.5550
-#> 0.2321 -0.1517 0.9329 0.2702
-#> 0.2681 0.4025 0.2183 0.8963
-#> 0.3418 1.0520 -0.5711 1.2916
-#> -0.3119 -0.1132 0.1829 -1.9001
-#> 0.2311 1.4513 0.6272 0.4071
-#> 0.8199 -0.9724 0.5549 1.1070
+#> 0.1980 0.1669 2.1620 0.2777
+#> 2.4833 -0.5529 0.3775 -0.3833
+#> 0.7417 -0.6513 -1.5015 -1.0209
+#> -0.9850 0.9002 -1.0860 -1.8951
+#> 0.6284 -0.3454 -0.9807 0.0863
+#> 0.7138 -0.0474 1.3281 -1.6015
+#> 0.9501 -1.2958 -1.0117 1.6362
+#> 0.7318 -0.0975 -0.7521 0.6652
+#> 1.9554 0.6295 -1.4796 -0.3329
+#> -0.2136 -0.8614 1.3766 0.0079
#> [ CPUFloatType{10,4} ]
#>
materialize(d, rbind = FALSE)
#> $lt1
#> $lt1[[1]]
#> torch_tensor
-#> 0.4337
-#> 0.5389
-#> 0.4592
+#> 0.6629
+#> -1.2602
+#> -2.5268
#> [ CPUFloatType{3} ]
#>
#> $lt1[[2]]
#> torch_tensor
-#> 0.0992
-#> 0.7277
-#> -0.6565
+#> -1.1277
+#> -0.3862
+#> -0.9214
#> [ CPUFloatType{3} ]
#>
#> $lt1[[3]]
#> torch_tensor
-#> 0.0979
-#> -0.5693
-#> 1.1938
+#> 0.5454
+#> 1.9272
+#> -0.6538
#> [ CPUFloatType{3} ]
#>
#> $lt1[[4]]
#> torch_tensor
-#> -0.3715
-#> -0.5852
-#> 0.6129
+#> -0.1761
+#> 0.7305
+#> 0.2100
#> [ CPUFloatType{3} ]
#>
#> $lt1[[5]]
#> torch_tensor
-#> -0.6930
-#> -1.0994
-#> 0.3037
+#> -0.7350
+#> -0.5816
+#> -0.7902
#> [ CPUFloatType{3} ]
#>
#> $lt1[[6]]
#> torch_tensor
-#> -0.2728
-#> -1.4391
-#> -0.6987
+#> -1.0446
+#> -1.3970
+#> 1.4858
#> [ CPUFloatType{3} ]
#>
#> $lt1[[7]]
#> torch_tensor
-#> -1.6596
-#> -0.9801
-#> 0.5558
+#> 1.0944
+#> 1.9134
+#> 1.0702
#> [ CPUFloatType{3} ]
#>
#> $lt1[[8]]
#> torch_tensor
-#> 0.3316
-#> 0.1100
-#> -0.5444
+#> 0.3584
+#> 0.4934
+#> 0.8639
#> [ CPUFloatType{3} ]
#>
#> $lt1[[9]]
#> torch_tensor
-#> 0.5616
-#> -0.6240
-#> -0.1895
+#> 0.4933
+#> -0.3884
+#> 0.2562
#> [ CPUFloatType{3} ]
#>
#> $lt1[[10]]
#> torch_tensor
-#> -0.3808
-#> 0.7544
-#> 0.0610
+#> 0.01 *
+#> -7.3548
+#> -121.2882
+#> -97.8838
#> [ CPUFloatType{3} ]
#>
#>
#> $lt2
#> $lt2[[1]]
#> torch_tensor
-#> 0.8252
-#> -0.5098
-#> -0.0403
-#> -0.5451
+#> 0.1980
+#> 0.1669
+#> 2.1620
+#> 0.2777
#> [ CPUFloatType{4} ]
#>
#> $lt2[[2]]
#> torch_tensor
-#> -0.4994
-#> -1.4268
-#> -1.6522
-#> 1.2544
+#> 2.4833
+#> -0.5529
+#> 0.3775
+#> -0.3833
#> [ CPUFloatType{4} ]
#>
#> $lt2[[3]]
#> torch_tensor
-#> -1.0478
-#> -0.9131
-#> 0.7225
-#> 0.9815
+#> 0.7417
+#> -0.6513
+#> -1.5015
+#> -1.0209
#> [ CPUFloatType{4} ]
#>
#> $lt2[[4]]
#> torch_tensor
-#> 0.2689
-#> 0.3012
-#> -0.7065
-#> 2.5550
+#> -0.9850
+#> 0.9002
+#> -1.0860
+#> -1.8951
#> [ CPUFloatType{4} ]
#>
#> $lt2[[5]]
#> torch_tensor
-#> 0.2321
-#> -0.1517
-#> 0.9329
-#> 0.2702
+#> 0.6284
+#> -0.3454
+#> -0.9807
+#> 0.0863
#> [ CPUFloatType{4} ]
#>
#> $lt2[[6]]
#> torch_tensor
-#> 0.2681
-#> 0.4025
-#> 0.2183
-#> 0.8963
+#> 0.7138
+#> -0.0474
+#> 1.3281
+#> -1.6015
#> [ CPUFloatType{4} ]
#>
#> $lt2[[7]]
#> torch_tensor
-#> 0.3418
-#> 1.0520
-#> -0.5711
-#> 1.2916
+#> 0.9501
+#> -1.2958
+#> -1.0117
+#> 1.6362
#> [ CPUFloatType{4} ]
#>
#> $lt2[[8]]
#> torch_tensor
-#> -0.3119
-#> -0.1132
-#> 0.1829
-#> -1.9001
+#> 0.7318
+#> -0.0975
+#> -0.7521
+#> 0.6652
#> [ CPUFloatType{4} ]
#>
#> $lt2[[9]]
#> torch_tensor
-#> 0.2311
-#> 1.4513
-#> 0.6272
-#> 0.4071
+#> 1.9554
+#> 0.6295
+#> -1.4796
+#> -0.3329
#> [ CPUFloatType{4} ]
#>
#> $lt2[[10]]
#> torch_tensor
-#> 0.8199
-#> -0.9724
-#> 0.5549
-#> 1.1070
+#> -0.2136
+#> -0.8614
+#> 1.3766
+#> 0.0079
#> [ CPUFloatType{4} ]
#>
#>
diff --git a/dev/reference/mlr_learners_torch_model.html b/dev/reference/mlr_learners_torch_model.html
index 5dea13ac..7d7d9bdf 100644
--- a/dev/reference/mlr_learners_torch_model.html
+++ b/dev/reference/mlr_learners_torch_model.html
@@ -243,9 +243,9 @@ Exampleslearner$predict(task, ids$test)
#> <PredictionClassif> for 48 observations:
#> row_ids truth response
-#> 1 setosa setosa
+#> 1 setosa versicolor
#> 6 setosa setosa
-#> 12 setosa setosa
+#> 12 setosa versicolor
#> ---
#> 142 virginica versicolor
#> 148 virginica versicolor
diff --git a/dev/reference/mlr_pipeops_torch_ingress_ltnsr.html b/dev/reference/mlr_pipeops_torch_ingress_ltnsr.html
index 2f58fa0a..184d10dd 100644
--- a/dev/reference/mlr_pipeops_torch_ingress_ltnsr.html
+++ b/dev/reference/mlr_pipeops_torch_ingress_ltnsr.html
@@ -271,35 +271,35 @@ Examplesx_batch2
#> torch_tensor
#> (1,1,.,.) =
-#> -0.1406 0.8897 -0.2001 0.3644 -0.4208 -0.0598
-#> -0.3679 -1.8062 -0.7720 -1.5540 0.1804 -0.6076
-#> -0.8970 0.8980 1.3070 0.4297 0.3637 -0.3515
-#> -1.2410 -1.5961 0.3854 -0.0092 -0.7477 0.1467
-#> -1.0128 -0.2173 -0.0256 0.3854 1.1102 0.8426
-#> -0.0015 1.1954 -0.4697 0.5775 0.3426 -0.0188
+#> 0.7933 -0.3588 -0.7738 0.3723 0.7557 0.0776
+#> 0.0358 -0.9190 -1.3731 -0.0963 0.0546 -0.5296
+#> 1.1337 0.6925 0.0384 -0.1584 0.8017 0.0582
+#> -0.4234 -0.3030 -0.8665 -0.3108 -1.0181 -0.6894
+#> -0.6912 -0.4810 -0.6834 1.7490 -0.0982 0.5875
+#> 0.0091 0.8948 0.2014 0.8367 -1.5817 -0.0070
#>
#> (2,1,.,.) =
-#> 0.2568 0.0487 0.2617 0.3904 -0.2544 -1.6331
-#> -1.7578 -0.2814 -0.2986 -0.3370 0.9651 -0.0859
-#> 0.7690 0.2000 -0.4288 -0.5388 -0.0642 -0.1045
-#> -0.6392 -0.3098 -1.1118 -0.6755 -1.6049 -0.7990
-#> 0.2963 -0.7406 -0.3669 0.7130 0.0003 0.2804
-#> 0.8757 0.6154 0.9026 0.7993 0.5094 -0.1241
+#> 0.9005 0.3602 0.3539 0.3428 -0.6728 0.1068
+#> 0.6483 -0.1965 -0.7305 -0.2560 -0.0998 0.3805
+#> 0.0595 0.2962 -0.4040 0.6842 -0.0724 0.4393
+#> 0.4653 0.4450 0.6407 0.1514 -0.7114 -1.7885
+#> -0.0627 0.7263 0.1146 -0.3130 0.3623 -0.5054
+#> 0.0456 0.7691 0.1020 -0.8185 0.3586 -0.8273
#>
#> (1,2,.,.) =
-#> 0.2013 -0.3693 -0.0751 0.4562 -0.1179 -0.3916
-#> -0.5855 -0.7288 0.4462 0.2854 0.7073 0.3175
-#> -0.4044 0.5819 0.5800 0.2930 0.1322 0.1686
-#> -0.1005 0.7866 0.6076 -0.0126 -0.7185 -0.7328
-#> 1.4532 -0.2163 -0.3449 -0.1239 0.4164 -0.2662
-#> 0.7296 -0.5222 -0.6160 0.5174 0.5350 0.0561
+#> 0.1312 -0.6251 0.7048 -0.1977 0.5561 0.5624
+#> 0.0286 -0.2785 -0.2405 0.4713 -0.1579 0.4806
+#> 1.1008 -0.6220 -0.2135 -0.5255 -0.0331 0.6568
+#> 0.0905 0.2452 1.2067 -1.7585 -0.8850 -0.2227
+#> -0.2304 -0.6141 1.6050 1.1661 -1.5937 -0.3271
+#> 0.0214 0.7292 0.6356 0.7637 0.8903 0.5817
#>
#> (2,2,.,.) =
-#> 1.5475 0.3243 0.5735 0.9426 0.5923 -0.5829
-#> 0.7316 -0.1349 -0.1187 -0.6156 -0.1094 -1.1927
-#> -0.3273 0.1350 -0.0156 -1.5596 -0.3006 1.3816
-#> 1.1662 -0.3874 -1.2282 0.0936 0.2554 0.2299
-#> 0.2141 0.3669 0.1565 -0.5532 -0.6292 -0.4646
+#> 1.6573 -0.5304 0.1094 0.8137 0.1025 0.5022
+#> 0.3329 -0.5590 0.7403 0.5722 0.2244 -0.0115
+#> 0.2311 1.3263 0.4767 0.2640 -0.4501 0.5960
+#> 0.1265 -0.4144 -0.9363 1.2215 -0.9120 -0.0889
+#> 0.9326 0.1697 -0.4180 0.0265 1.2333 0.3483
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{2,3,6,6} ]
diff --git a/dev/reference/mlr_pipeops_torch_model_classif.html b/dev/reference/mlr_pipeops_torch_model_classif.html
index 1f6c9e6a..73d4ec4a 100644
--- a/dev/reference/mlr_pipeops_torch_model_classif.html
+++ b/dev/reference/mlr_pipeops_torch_model_classif.html
@@ -271,17 +271,17 @@