-
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathget_started.html
455 lines (427 loc) · 36.9 KB
/
get_started.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
<!DOCTYPE html>
<!-- Generated by pkgdown: do not edit by hand --><html lang="en">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<title>Get Started • mlr3torch</title>
<!-- favicons --><link rel="icon" type="image/png" sizes="16x16" href="../favicon-16x16.png">
<link rel="icon" type="image/png" sizes="32x32" href="../favicon-32x32.png">
<link rel="apple-touch-icon" type="image/png" sizes="180x180" href="../apple-touch-icon.png">
<link rel="apple-touch-icon" type="image/png" sizes="120x120" href="../apple-touch-icon-120x120.png">
<link rel="apple-touch-icon" type="image/png" sizes="76x76" href="../apple-touch-icon-76x76.png">
<link rel="apple-touch-icon" type="image/png" sizes="60x60" href="../apple-touch-icon-60x60.png">
<script src="../lightswitch.js"></script><script src="../deps/jquery-3.6.0/jquery-3.6.0.min.js"></script><meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<link href="../deps/bootstrap-5.3.1/bootstrap.min.css" rel="stylesheet">
<script src="../deps/bootstrap-5.3.1/bootstrap.bundle.min.js"></script><link href="../deps/Roboto-0.4.9/font.css" rel="stylesheet">
<link href="../deps/JetBrains_Mono-0.4.9/font.css" rel="stylesheet">
<link href="../deps/Roboto_Slab-0.4.9/font.css" rel="stylesheet">
<link href="../deps/font-awesome-6.5.2/css/all.min.css" rel="stylesheet">
<link href="../deps/font-awesome-6.5.2/css/v4-shims.min.css" rel="stylesheet">
<script src="../deps/headroom-0.11.0/headroom.min.js"></script><script src="../deps/headroom-0.11.0/jQuery.headroom.min.js"></script><script src="../deps/bootstrap-toc-1.0.1/bootstrap-toc.min.js"></script><script src="../deps/clipboard.js-2.0.11/clipboard.min.js"></script><script src="../deps/search-1.0.0/autocomplete.jquery.min.js"></script><script src="../deps/search-1.0.0/fuse.min.js"></script><script src="../deps/search-1.0.0/mark.min.js"></script><script src="../deps/MathJax-3.2.2/tex-chtml.min.js"></script><!-- pkgdown --><script src="../pkgdown.js"></script><meta property="og:title" content="Get Started">
<meta name="robots" content="noindex">
</head>
<body>
<a href="#main" class="visually-hidden-focusable">Skip to contents</a>
<nav class="navbar navbar-expand-lg fixed-top " aria-label="Site navigation"><div class="container">
<a class="navbar-brand me-2" href="../index.html">mlr3torch</a>
<small class="nav-text text-default me-auto" data-bs-toggle="tooltip" data-bs-placement="bottom" title="In-development version">0.1.2-9000</small>
<button class="navbar-toggler" type="button" data-bs-toggle="collapse" data-bs-target="#navbar" aria-controls="navbar" aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div id="navbar" class="collapse navbar-collapse ms-3">
<ul class="navbar-nav me-auto">
<li class="nav-item"><a class="nav-link" href="../reference/index.html"><span class="fa fa-file-alt"></span> Reference</a></li>
<li class="nav-item"><a class="nav-link" href="../news/index.html">Changelog</a></li>
<li class="active nav-item dropdown">
<button class="nav-link dropdown-toggle" type="button" id="dropdown-articles" data-bs-toggle="dropdown" aria-expanded="false" aria-haspopup="true">Articles</button>
<ul class="dropdown-menu" aria-labelledby="dropdown-articles">
<li><a class="dropdown-item" href="../articles/callbacks.html">Custom Callbacks</a></li>
<li><a class="dropdown-item" href="../articles/get_started.html">Get Started</a></li>
<li><a class="dropdown-item" href="../articles/internals_pipeop_torch.html">Internals</a></li>
<li><a class="dropdown-item" href="../articles/lazy_tensor.html">Lazy Tensor</a></li>
<li><a class="dropdown-item" href="../articles/pipeop_torch.html">Defining an Architecture</a></li>
</ul>
</li>
<li class="nav-item"><a class="external-link nav-link" href="https://mlr3book.mlr-org.com"><span class="fa fa-link"></span> mlr3book</a></li>
</ul>
<ul class="navbar-nav">
<li class="nav-item"><form class="form-inline" role="search">
<input class="form-control" type="search" name="search-input" id="search-input" autocomplete="off" aria-label="Search site" placeholder="Search for" data-search-index="../search.json">
</form></li>
<li class="nav-item"><a class="external-link nav-link" href="https://github.com/mlr-org/mlr3torch"><span class="fa fa-github"></span></a></li>
<li class="nav-item"><a class="external-link nav-link" href="https://lmmisld-lmu-stats-slds.srv.mwn.de/mlr_invite/"><span class="fa fa-comments"></span></a></li>
<li class="nav-item"><a class="external-link nav-link" href="https://stackoverflow.com/questions/tagged/mlr3"><span class="fa fab fa-stack-overflow"></span></a></li>
<li class="nav-item"><a class="external-link nav-link" href="https://mlr-org.com/"><span class="fa fa-rss"></span></a></li>
<li class="nav-item dropdown">
<button class="nav-link dropdown-toggle" type="button" id="dropdown-lightswitch" data-bs-toggle="dropdown" aria-expanded="false" aria-haspopup="true" aria-label="Light switch"><span class="fa fa-sun"></span></button>
<ul class="dropdown-menu dropdown-menu-end" aria-labelledby="dropdown-lightswitch">
<li><button class="dropdown-item" data-bs-theme-value="light"><span class="fa fa-sun"></span> Light</button></li>
<li><button class="dropdown-item" data-bs-theme-value="dark"><span class="fa fa-moon"></span> Dark</button></li>
<li><button class="dropdown-item" data-bs-theme-value="auto"><span class="fa fa-adjust"></span> Auto</button></li>
</ul>
</li>
</ul>
</div>
</div>
</nav><div class="container template-article">
<div class="row">
<main id="main" class="col-md-9"><div class="page-header">
<img src="../logo.svg" class="logo" alt=""><h1>Get Started</h1>
<small class="dont-index">Source: <a href="https://github.com/mlr-org/mlr3torch/blob/main/vignettes/articles/get_started.Rmd" class="external-link"><code>vignettes/articles/get_started.Rmd</code></a></small>
<div class="d-none name"><code>get_started.Rmd</code></div>
</div>
<div class="section level2">
<h2 id="quickstart">Quickstart<a class="anchor" aria-label="anchor" href="#quickstart"></a>
</h2>
<p>In this vignette we will show how to get started with
<code>mlr3torch</code> by training a simple neural network on a tabular
regression problem. We assume that you are familiar with the
<code>mlr3</code> framework, see e.g. the <a href="https://mlr3book.mlr-org.com/" class="external-link">mlr3 book</a>. As a first example,
we will train a simple multi-layer perceptron (MLP) on the well-known
“mtcars” task, where the goal is to predict the miles per galleon
(‘mpg’) of cars. This architecture comes as a predfined learner with
<code>mlr3torch</code>, but you can also easily create new network
architectures, see the <em>Neural Networks as Graphs</em> vignette for a
detailed introduoduion. We first set a seed for reproducibility, load
the library and construct the task.</p>
<div class="sourceCode" id="cb1"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="fu"><a href="https://rdrr.io/r/base/Random.html" class="external-link">set.seed</a></span><span class="op">(</span><span class="fl">314</span><span class="op">)</span></span>
<span><span class="kw"><a href="https://rdrr.io/r/base/library.html" class="external-link">library</a></span><span class="op">(</span><span class="va"><a href="https://mlr3torch.mlr-org.com/">mlr3torch</a></span><span class="op">)</span></span>
<span><span class="va">task</span> <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">tsk</a></span><span class="op">(</span><span class="st">"mtcars"</span><span class="op">)</span></span>
<span><span class="va">task</span><span class="op">$</span><span class="fu">head</span><span class="op">(</span><span class="op">)</span></span>
<span><span class="co">#> mpg am carb cyl disp drat gear hp qsec vs wt</span></span>
<span><span class="co">#> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num> <num></span></span>
<span><span class="co">#> 1: 21.0 1 4 6 160 3.90 4 110 16.46 0 2.620</span></span>
<span><span class="co">#> 2: 21.0 1 4 6 160 3.90 4 110 17.02 0 2.875</span></span>
<span><span class="co">#> 3: 22.8 1 1 4 108 3.85 4 93 18.61 1 2.320</span></span>
<span><span class="co">#> 4: 21.4 0 1 6 258 3.08 3 110 19.44 1 3.215</span></span>
<span><span class="co">#> 5: 18.7 0 2 8 360 3.15 3 175 17.02 0 3.440</span></span>
<span><span class="co">#> 6: 18.1 0 1 6 225 2.76 3 105 20.22 1 3.460</span></span></code></pre></div>
<p>Learners in <code>mlr3torch</code> work very similary to other
<code>mlr3</code> learners. Below, we construct a simple multi layer
perceptron for regression. We do this as usual by calling
<code><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">lrn()</a></code> and configuring the parameters: We use two hidden
layers with 50 neurons, For training, we set the batch size to 32, the
number of training epochs to 30 and the device to <code>"cpu"</code>.
For a complete description of the available parameters see
<code><a href="../reference/mlr_learners.mlp.html">?mlr3torch::LearnerTorchMLP</a></code>.</p>
<div class="sourceCode" id="cb2"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlp</span> <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">lrn</a></span><span class="op">(</span><span class="st">"regr.mlp"</span>,</span>
<span> <span class="co"># architecture parameters</span></span>
<span> neurons <span class="op">=</span> <span class="fu"><a href="https://rdrr.io/r/base/c.html" class="external-link">c</a></span><span class="op">(</span><span class="fl">50</span>, <span class="fl">50</span><span class="op">)</span>,</span>
<span> <span class="co"># training arguments</span></span>
<span> batch_size <span class="op">=</span> <span class="fl">32</span>, epochs <span class="op">=</span> <span class="fl">30</span>, device <span class="op">=</span> <span class="st">"cpu"</span></span>
<span><span class="op">)</span></span>
<span><span class="va">mlp</span></span>
<span><span class="co">#> <LearnerTorchMLP[regr]:regr.mlp>: My Little Powny</span></span>
<span><span class="co">#> * Model: -</span></span>
<span><span class="co">#> * Parameters: epochs=30, device=cpu, num_threads=1,</span></span>
<span><span class="co">#> num_interop_threads=1, seed=random, eval_freq=1,</span></span>
<span><span class="co">#> measures_train=<list>, measures_valid=<list>, patience=0,</span></span>
<span><span class="co">#> min_delta=0, batch_size=32, neurons=50,50, p=0.5,</span></span>
<span><span class="co">#> activation=<nn_relu>, activation_args=<list></span></span>
<span><span class="co">#> * Validate: NULL</span></span>
<span><span class="co">#> * Packages: mlr3, mlr3torch, torch</span></span>
<span><span class="co">#> * Predict Types: [response]</span></span>
<span><span class="co">#> * Feature Types: integer, numeric, lazy_tensor</span></span>
<span><span class="co">#> * Properties: internal_tuning, marshal, validation</span></span>
<span><span class="co">#> * Optimizer: adam</span></span>
<span><span class="co">#> * Loss: mse</span></span>
<span><span class="co">#> * Callbacks: -</span></span></code></pre></div>
<p>We can use this learner for training and prediction just like any
other regression learner. Below, we split the observations into a
training and test set, train the learner on the training set and create
predictions for the test set. Finally, we compute the mean squared error
of the predictions.</p>
<div class="sourceCode" id="cb3"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="co"># Split the obersevations into training and test set</span></span>
<span><span class="va">splits</span> <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/partition.html" class="external-link">partition</a></span><span class="op">(</span><span class="va">task</span><span class="op">)</span></span>
<span><span class="co"># Train the learner on the train set</span></span>
<span><span class="va">mlp</span><span class="op">$</span><span class="fu">train</span><span class="op">(</span><span class="va">task</span>, row_ids <span class="op">=</span> <span class="va">splits</span><span class="op">$</span><span class="va">train</span><span class="op">)</span></span>
<span><span class="co"># Predict the test set</span></span>
<span><span class="va">prediction</span> <span class="op">=</span> <span class="va">mlp</span><span class="op">$</span><span class="fu">predict</span><span class="op">(</span><span class="va">task</span>, row_ids <span class="op">=</span> <span class="va">splits</span><span class="op">$</span><span class="va">test</span><span class="op">)</span></span>
<span><span class="co"># Compute the mse</span></span>
<span><span class="va">prediction</span><span class="op">$</span><span class="fu">score</span><span class="op">(</span><span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">msr</a></span><span class="op">(</span><span class="st">"regr.mse"</span><span class="op">)</span><span class="op">)</span></span>
<span><span class="co">#> regr.mse </span></span>
<span><span class="co">#> 283.838</span></span></code></pre></div>
</div>
<div class="section level2">
<h2 id="configuring-a-learner">Configuring a Learner<a class="anchor" aria-label="anchor" href="#configuring-a-learner"></a>
</h2>
<p>Although torch learners are quite like other <code>mlr3</code>
learners, there are some differences. One is that all
<code>LearnerTorch</code> classes have <em>construction arguments</em>,
i.e. torch learners are more modular than other learners. While learners
are free to implement their own construction arguments, there are some
that are common to all torch learners, namely the <code>loss</code>,
<code>optimizer</code> and <code>callbacks</code>. Each of these object
can have their own parameters that are included in the
<code>LearnerTorch</code>’s parameter set.</p>
<p>In the previous example, we did not specify any of these explicitly
and used the default values, which was the Adam optimizer, MSE as the
loss and no callbacks. We will now show how to configure these three
aspects of a learner through the <code><a href="../reference/TorchOptimizer.html">mlr3torch::TorchOptimizer</a></code>,
<code><a href="../reference/TorchLoss.html">mlr3torch::TorchLoss</a></code>, and
<code><a href="../reference/TorchCallback.html">mlr3torch::TorchCallback</a></code> classes.</p>
<div class="section level3">
<h3 id="loss">Loss<a class="anchor" aria-label="anchor" href="#loss"></a>
</h3>
<p>The loss function, also known as the objective function or cost
function, measures the discrepancy between the predicted output and the
true output. It quantifies how well the model is performing during
training. The R package <code>torch</code>, which underpins the
<code>mlr3torch</code> framework, already provides a number of
predefined loss functions such as the Mean Squared Error
(<code>nn_mse_loss</code>), the Mean Absolute Error
(<code>nn_l1_loss</code>), or the cross entropy loss
(<code>nn_cross_entropy_loss</code>). In <code>mlr3torch</code>, we
represent loss functions using the <code><a href="../reference/TorchLoss.html">mlr3torch::TorchLoss</a></code>
class. It provides a thin wrapper around the torch loss functions and
annotates them with meta information, most importantly a
<code><a href="https://paradox.mlr-org.com/reference/ParamSet.html" class="external-link">paradox::ParamSet</a></code> that allows to configure the loss
function. Such an object can be constructed using
<code>t_loss(<key>)</code>. Below, we construct the L1 loss
function, which is also known as Mean Absolute Error (MAE). The printed
output below informs us about the wrapped loss function
<code>(nn_l1_loss</code>), the configured parameters, the packages it
depends on and for which task types it can be used.</p>
<div class="sourceCode" id="cb4"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">l1</span> <span class="op">=</span> <span class="fu"><a href="../reference/t_loss.html">t_loss</a></span><span class="op">(</span><span class="st">"l1"</span><span class="op">)</span></span>
<span><span class="va">l1</span></span>
<span><span class="co">#> <TorchLoss:l1> Absolute Error</span></span>
<span><span class="co">#> * Generator: nn_l1_loss</span></span>
<span><span class="co">#> * Parameters: list()</span></span>
<span><span class="co">#> * Packages: torch,mlr3torch</span></span>
<span><span class="co">#> * Task Types: regr</span></span></code></pre></div>
<p>Its <code>ParamSet</code> contains only one parameter, namely
<code>reduction</code>, which specifies how the loss is reduced over the
batch.</p>
<div class="sourceCode" id="cb5"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="co"># the paradox::ParamSet of the loss</span></span>
<span><span class="va">l1</span><span class="op">$</span><span class="va">param_set</span></span>
<span><span class="co">#> <ParamSet(1)></span></span>
<span><span class="co">#> id class lower upper nlevels default value</span></span>
<span><span class="co">#> <char> <char> <num> <num> <num> <list> <list></span></span>
<span><span class="co">#> 1: reduction ParamFct NA NA 2 mean [NULL]</span></span></code></pre></div>
<p>The wrapped loss module generator is accessible through the slot
<code>$generator</code>.</p>
<div class="sourceCode" id="cb6"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">l1</span><span class="op">$</span><span class="va">generator</span></span>
<span><span class="co">#> <nn_l1_loss> object generator</span></span>
<span><span class="co">#> Inherits from: <inherit></span></span>
<span><span class="co">#> Public:</span></span>
<span><span class="co">#> .classes: nn_l1_loss nn_loss nn_module</span></span>
<span><span class="co">#> initialize: function (reduction = "mean") </span></span>
<span><span class="co">#> forward: function (input, target) </span></span>
<span><span class="co">#> clone: function (deep = FALSE, ..., replace_values = TRUE) </span></span>
<span><span class="co">#> Private:</span></span>
<span><span class="co">#> .__clone_r6__: function (deep = FALSE) </span></span>
<span><span class="co">#> Parent env: <environment: 0x56101df4bf40></span></span>
<span><span class="co">#> Locked objects: FALSE</span></span>
<span><span class="co">#> Locked class: FALSE</span></span>
<span><span class="co">#> Portable: TRUE</span></span></code></pre></div>
<p>We can pass the <code>TorchLoss</code> as the argument
<code>loss</code> during initialization of the learner. The parameters
of the loss are added to the learner’s <code>ParamSet</code>, prefixed
with <code>"loss."</code>.</p>
<div class="sourceCode" id="cb7"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlp_l1</span> <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">lrn</a></span><span class="op">(</span><span class="st">"regr.mlp"</span>, loss <span class="op">=</span> <span class="va">l1</span><span class="op">)</span></span>
<span><span class="va">mlp_l1</span><span class="op">$</span><span class="va">param_set</span><span class="op">$</span><span class="va">values</span><span class="op">$</span><span class="va">loss.reduction</span></span>
<span><span class="co">#> NULL</span></span></code></pre></div>
<p>All predefined loss functions are stored in the
<code>mlr3torch_losses</code> dictionary, from which they can be
retrieved using <code>t_loss(<key>)</code>.</p>
<div class="sourceCode" id="cb8"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlr3torch_losses</span></span>
<span><span class="co">#> <DictionaryMlr3torchLosses> with 3 stored values</span></span>
<span><span class="co">#> Keys: cross_entropy, l1, mse</span></span></code></pre></div>
</div>
<div class="section level3">
<h3 id="optimizer">Optimizer<a class="anchor" aria-label="anchor" href="#optimizer"></a>
</h3>
<p>The optimizer determines how the model’s weights are updated based on
the calculated loss. It adjusts the parameters of the model to minimize
the loss function, optimizing the model’s performance. Optimizers work
analogous to loss functions, i.e. <code>mlr3torch</code> provides a thin
wrapper – the <code>TorchOptimizer</code> class – around the optimizers
such as Adam (<code>optim_adam</code>) or SGD (<code>optim_sgd</code>).
<code>TorchLoss</code> objects can be constructed using
<code>t_opt(<key>)</code>. For optimizers, the associated
<code>ParamSet</code> is more interesting as we see below:</p>
<div class="sourceCode" id="cb9"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">sgd</span> <span class="op">=</span> <span class="fu"><a href="../reference/t_opt.html">t_opt</a></span><span class="op">(</span><span class="st">"sgd"</span><span class="op">)</span></span>
<span><span class="va">sgd</span></span>
<span><span class="co">#> <TorchOptimizer:sgd> Stochastic Gradient Descent</span></span>
<span><span class="co">#> * Generator: optim_sgd</span></span>
<span><span class="co">#> * Parameters: list()</span></span>
<span><span class="co">#> * Packages: torch,mlr3torch</span></span>
<span></span>
<span><span class="va">sgd</span><span class="op">$</span><span class="va">param_set</span></span>
<span><span class="co">#> <ParamSet(5)></span></span>
<span><span class="co">#> id class lower upper nlevels default value</span></span>
<span><span class="co">#> <char> <char> <num> <num> <num> <list> <list></span></span>
<span><span class="co">#> 1: lr ParamDbl 0 Inf Inf <NoDefault[0]> [NULL]</span></span>
<span><span class="co">#> 2: momentum ParamDbl 0 1 Inf 0 [NULL]</span></span>
<span><span class="co">#> 3: dampening ParamDbl 0 1 Inf 0 [NULL]</span></span>
<span><span class="co">#> 4: weight_decay ParamDbl 0 1 Inf 0 [NULL]</span></span>
<span><span class="co">#> 5: nesterov ParamLgl NA NA 2 FALSE [NULL]</span></span></code></pre></div>
<p>The wrapped torch optimizer can be accessed through the slot
<code>generator</code>.</p>
<p>Parameters of <code>TorchOptimizer</code> (but also
<code>TorchLoss</code> and <code>TorchCallback</code>) can be set in the
usual <code>mlr3</code> way, i.e. either during construction, or
afterwards using the <code>$set_values()</code> method of the parameter
set.</p>
<div class="sourceCode" id="cb10"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">sgd</span><span class="op">$</span><span class="va">param_set</span><span class="op">$</span><span class="fu">set_values</span><span class="op">(</span></span>
<span> lr <span class="op">=</span> <span class="fl">0.5</span>, <span class="co"># increase learning rate</span></span>
<span> nesterov <span class="op">=</span> <span class="cn">FALSE</span> <span class="co"># no nesterov momentum</span></span>
<span><span class="op">)</span></span></code></pre></div>
<p>Below we see that the optimizer’s parameters are added to the
learner’s <code>ParamSet</code> (prefixed with <code>"opt."</code>) and
that the values are set to the values we specified.</p>
<div class="sourceCode" id="cb11"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlp_sgd</span> <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">lrn</a></span><span class="op">(</span><span class="st">"regr.mlp"</span>, optimizer <span class="op">=</span> <span class="va">sgd</span><span class="op">)</span></span>
<span><span class="fu"><a href="https://rdrr.io/pkg/data.table/man/as.data.table.html" class="external-link">as.data.table</a></span><span class="op">(</span><span class="va">mlp_sgd</span><span class="op">$</span><span class="va">param_set</span><span class="op">)</span><span class="op">[</span></span>
<span> <span class="fu"><a href="https://rdrr.io/r/base/startsWith.html" class="external-link">startsWith</a></span><span class="op">(</span><span class="va">id</span>, <span class="st">"opt."</span><span class="op">)</span><span class="op">]</span><span class="op">[[</span><span class="fl">1L</span><span class="op">]</span><span class="op">]</span></span>
<span><span class="co">#> [1] "opt.lr" "opt.momentum" "opt.dampening" "opt.weight_decay"</span></span>
<span><span class="co">#> [5] "opt.nesterov"</span></span>
<span><span class="va">mlp_sgd</span><span class="op">$</span><span class="va">param_set</span><span class="op">$</span><span class="va">values</span><span class="op">[</span><span class="fu"><a href="https://rdrr.io/r/base/c.html" class="external-link">c</a></span><span class="op">(</span><span class="st">"opt.lr"</span>, <span class="st">"opt.nesterov"</span><span class="op">)</span><span class="op">]</span></span>
<span><span class="co">#> $opt.lr</span></span>
<span><span class="co">#> [1] 0.5</span></span>
<span><span class="co">#> </span></span>
<span><span class="co">#> $opt.nesterov</span></span>
<span><span class="co">#> [1] FALSE</span></span></code></pre></div>
<p>By exposing the optimizer’s parameters, they can be conveniently
tuned using <a href="https://github.com/mlr-org/mlr3tuning" class="external-link"><code>mlr3tuning</code></a>.</p>
<p>All available optimizers are stored in the
<code>mlr3torch_optimizers</code> dictionary.</p>
<div class="sourceCode" id="cb12"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlr3torch_optimizers</span></span>
<span><span class="co">#> <DictionaryMlr3torchOptimizers> with 7 stored values</span></span>
<span><span class="co">#> Keys: adadelta, adagrad, adam, asgd, rmsprop, rprop, sgd</span></span></code></pre></div>
</div>
<div class="section level3">
<h3 id="callbacks">Callbacks<a class="anchor" aria-label="anchor" href="#callbacks"></a>
</h3>
<p>The third important configuration option are callbacks which allow to
customize the training process. This allows saving model checkpoints,
logging metrics, or implementing custom functionality for specific
training scenarios. For a tutorial on how to implement a custom
callback, see the <em>Custom Callbacks</em> vignette. Here, we will only
show how to use predefined callbacks. Below, we retrieve the
<code>"history"</code> callback using <code><a href="../reference/t_clbk.html">t_clbk()</a></code>, which has
no parameters and merely saves the training and validation history in
the learner so it can be accessed afterwards.</p>
<div class="sourceCode" id="cb13"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">history</span> <span class="op">=</span> <span class="fu"><a href="../reference/t_clbk.html">t_clbk</a></span><span class="op">(</span><span class="st">"history"</span><span class="op">)</span></span>
<span><span class="va">history</span></span>
<span><span class="co">#> <TorchCallback:history> History</span></span>
<span><span class="co">#> * Generator: CallbackSetHistory</span></span>
<span><span class="co">#> * Parameters: list()</span></span>
<span><span class="co">#> * Packages: mlr3torch,torch</span></span></code></pre></div>
<p>If we wanted to learn about what the callback does, we can access the
help page of the wrapped object using the <code>$help()</code> method.
Note that this is also possible for the loss and optimizer.</p>
<div class="sourceCode" id="cb14"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">history</span><span class="op">$</span><span class="fu">help</span><span class="op">(</span><span class="op">)</span></span></code></pre></div>
<p>All predefined callbacks are stored in the
<code>mlr3torch_callbacks</code> dictionary.</p>
<div class="sourceCode" id="cb15"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlr3torch_callbacks</span></span>
<span><span class="co">#> <DictionaryMlr3torchCallbacks> with 4 stored values</span></span>
<span><span class="co">#> Keys: checkpoint, history, progress, tb</span></span></code></pre></div>
</div>
<div class="section level3">
<h3 id="putting-it-together">Putting it Together<a class="anchor" aria-label="anchor" href="#putting-it-together"></a>
</h3>
<p>We now define our customized MLP learner using the loss, optimizer
and callback we have just covered. To keep track of the performance, we
use 30% of the training data for validation and evaluate it using the
MAE <code>Measure</code>. Note that the <code>mearures_valid</code> and
<code>measures_train</code> parameters of <code>LearnerTorch</code> take
common <code><a href="https://mlr3.mlr-org.com/reference/Measure.html" class="external-link">mlr3::Measure</a></code>s, whereas the loss function must be a
<code>TorchLoss</code>.</p>
<div class="sourceCode" id="cb16"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlp_custom</span> <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">lrn</a></span><span class="op">(</span><span class="st">"regr.mlp"</span>,</span>
<span> <span class="co"># construction arguments</span></span>
<span> optimizer <span class="op">=</span> <span class="va">sgd</span>, loss <span class="op">=</span> <span class="va">l1</span>, callbacks <span class="op">=</span> <span class="va">history</span>,</span>
<span> <span class="co"># scores to keep track of</span></span>
<span> measures_valid <span class="op">=</span> <span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">msr</a></span><span class="op">(</span><span class="st">"regr.mae"</span><span class="op">)</span>,</span>
<span> <span class="co"># other parameters are left as-is:</span></span>
<span> <span class="co"># architecture</span></span>
<span> neurons <span class="op">=</span> <span class="fu"><a href="https://rdrr.io/r/base/c.html" class="external-link">c</a></span><span class="op">(</span><span class="fl">50</span>, <span class="fl">50</span><span class="op">)</span>,</span>
<span> <span class="co"># training arguments</span></span>
<span> batch_size <span class="op">=</span> <span class="fl">32</span>, epochs <span class="op">=</span> <span class="fl">30</span>, device <span class="op">=</span> <span class="st">"cpu"</span>,</span>
<span> <span class="co"># validation proportion</span></span>
<span> validate <span class="op">=</span> <span class="fl">0.3</span></span>
<span><span class="op">)</span></span>
<span></span>
<span><span class="va">mlp_custom</span></span>
<span><span class="co">#> <LearnerTorchMLP[regr]:regr.mlp>: My Little Powny</span></span>
<span><span class="co">#> * Model: -</span></span>
<span><span class="co">#> * Parameters: epochs=30, device=cpu, num_threads=1,</span></span>
<span><span class="co">#> num_interop_threads=1, seed=random, eval_freq=1,</span></span>
<span><span class="co">#> measures_train=<list>, measures_valid=<MeasureRegrSimple>,</span></span>
<span><span class="co">#> patience=0, min_delta=0, batch_size=32, neurons=50,50, p=0.5,</span></span>
<span><span class="co">#> activation=<nn_relu>, activation_args=<list>, opt.lr=0.5,</span></span>
<span><span class="co">#> opt.nesterov=FALSE</span></span>
<span><span class="co">#> * Validate: 0.3</span></span>
<span><span class="co">#> * Packages: mlr3, mlr3torch, torch</span></span>
<span><span class="co">#> * Predict Types: [response]</span></span>
<span><span class="co">#> * Feature Types: integer, numeric, lazy_tensor</span></span>
<span><span class="co">#> * Properties: internal_tuning, marshal, validation</span></span>
<span><span class="co">#> * Optimizer: sgd</span></span>
<span><span class="co">#> * Loss: l1</span></span>
<span><span class="co">#> * Callbacks: history</span></span></code></pre></div>
<p>We now train the learner on the “mtcars” task again and use the same
train-test split as before.</p>
<div class="sourceCode" id="cb17"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">mlp_custom</span><span class="op">$</span><span class="fu">train</span><span class="op">(</span><span class="va">task</span>, row_ids <span class="op">=</span> <span class="va">splits</span><span class="op">$</span><span class="va">train</span><span class="op">)</span></span>
<span><span class="va">prediction_custom</span> <span class="op">=</span> <span class="va">mlp_custom</span><span class="op">$</span><span class="fu">predict</span><span class="op">(</span><span class="va">task</span>, row_ids <span class="op">=</span> <span class="va">splits</span><span class="op">$</span><span class="va">test</span><span class="op">)</span></span></code></pre></div>
<p>Below we make predictions on the unseen test data and compare the
scores. Because we directly optimized the L1 (aka MAE) loss and tweaked
the learning rate, our configured <code>mlp_custom</code> learner has a
lower MAE than the default <code>mlp</code> learner.</p>
<div class="sourceCode" id="cb18"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="va">prediction_custom</span><span class="op">$</span><span class="fu">score</span><span class="op">(</span><span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">msr</a></span><span class="op">(</span><span class="st">"regr.mae"</span><span class="op">)</span><span class="op">)</span></span>
<span><span class="co">#> regr.mae </span></span>
<span><span class="co">#> 7.122375</span></span>
<span><span class="va">prediction</span><span class="op">$</span><span class="fu">score</span><span class="op">(</span><span class="fu"><a href="https://mlr3.mlr-org.com/reference/mlr_sugar.html" class="external-link">msr</a></span><span class="op">(</span><span class="st">"regr.mae"</span><span class="op">)</span><span class="op">)</span></span>
<span><span class="co">#> regr.mae </span></span>
<span><span class="co">#> 15.27983</span></span></code></pre></div>
<p>Because we configured the learner to use the history callback, we can
find the validation history in its <code>$model</code> slot:</p>
<div class="sourceCode" id="cb19"><pre class="downlit sourceCode r">
<code class="sourceCode R"><span><span class="fu"><a href="https://rdrr.io/r/utils/head.html" class="external-link">head</a></span><span class="op">(</span><span class="va">mlp_custom</span><span class="op">$</span><span class="va">model</span><span class="op">$</span><span class="va">callbacks</span><span class="op">$</span><span class="va">history</span><span class="op">$</span><span class="va">valid</span><span class="op">)</span></span>
<span><span class="co">#> epoch regr.mae</span></span>
<span><span class="co">#> <num> <num></span></span>
<span><span class="co">#> 1: 1 1.777395e+04</span></span>
<span><span class="co">#> 2: 2 3.955504e+08</span></span>
<span><span class="co">#> 3: 3 1.143863e+04</span></span>
<span><span class="co">#> 4: 4 1.792927e+01</span></span>
<span><span class="co">#> 5: 5 1.759594e+01</span></span>
<span><span class="co">#> 6: 6 1.726260e+01</span></span></code></pre></div>
<p>The plot below shows it for the epochs 6 to 30.</p>
<p><img src="get_started_files/figure-html/unnamed-chunk-21-1.png" width="50%" style="display: block; margin: auto;"></p>
<p>Other important information that is stored in the
<code>Learner</code>’s model is the <code>$network</code>, which is the
underlying <code>nn_module</code>. For a full description of the model,
see <code><a href="../reference/mlr_learners_torch.html">?LearnerTorch</a></code>.</p>
</div>
</div>
</main><aside class="col-md-3"><nav id="toc" aria-label="Table of contents"><h2>On this page</h2>
</nav></aside>
</div>
<footer><div class="pkgdown-footer-left">
<p>Developed by Sebastian Fischer, Martin Binder.</p>
</div>
<div class="pkgdown-footer-right">
<p>Site built with <a href="https://pkgdown.r-lib.org/" class="external-link">pkgdown</a> 2.1.1.</p>
</div>
</footer>
</div>
</body>
</html>