You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-03-23-DynamicBatching.md
+13-9Lines changed: 13 additions & 9 deletions
Original file line number
Diff line number
Diff line change
@@ -12,19 +12,18 @@ categories: [Code]
12
12
13
13
To maximise GPU memory when training large models, we want to pack tokens such that sequence padding is minimised and GPU memory is maximised.
14
14
15
-
We have several options, starting with the default.
15
+
1.`torch.utils.data.dataloader` is an python iterable over a PyTorch dataset
16
+
2.`torch.utils.data.dataset` implements `__getitem()__`, which maps keys to data samples.
17
+
3.`torch.utils.data.sampler` specifies the sequences of keys used in data loading.
16
18
17
-
#### **Default Approach**
19
+
By default, the `DataLoader` will collate individual fetched samples into batches using the arguments `batch_size`, `drop_last`, `batch_sampler`, and `collate_fn`. An alternatively, if `batch_size` is None, we can construct a `BatchSampler` which yields a list of keys at a time.
18
20
19
-
The most default thing to do is to pad every sequence to the maximimum context window, and return a fixed batch size.
However, this is incredibly wasteful. Imagine a batch size of 2, where we have a sequence X1 of length 10 and sequence X2 of length 1000 in the same batch. Sequence X1 will be padded for 990 token positions, which is nearly 50\% wasted GPU memory.
24
+
We have several options, starting with the default.
25
+
26
+
The most default thing to do is to pad every sequence to the maximimum context window, and return a fixed batch size. However, this is incredibly wasteful. Imagine a batch size of 2, where we have a sequence X1 of length 10 and sequence X2 of length 1000 in the same batch. Sequence X1 will be padded for 990 token positions, which is nearly 50% wasted GPU memory.
28
27
29
28
<br>
30
29
@@ -132,3 +131,8 @@ class MyTrainer(Trainer):
132
131
{% endhighlight %}
133
132
134
133
Then we can easily do `trainer = MyTrainer(..); trainer.train()`. Because we used BatchSampler, the `batch_size` argument given to trainer should be empty or there will be an error thrown regarding a conflict in `batch_size` number.
134
+
135
+
<br>
136
+
#### **References**
137
+
138
+
[PyTorch Data Utils Reference](https://pytorch.org/docs/stable/data.html)
0 commit comments