|
| 1 | +################################# |
| 2 | +Speed up models by compiling them |
| 3 | +################################# |
| 4 | + |
| 5 | +Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs. |
| 6 | +This guide shows you how to apply ``torch.compile`` correctly in your code. |
| 7 | + |
| 8 | +.. note:: |
| 9 | + |
| 10 | + This requires PyTorch >= 2.0. |
| 11 | + |
| 12 | + |
| 13 | +---- |
| 14 | + |
| 15 | + |
| 16 | +******************************************* |
| 17 | +Apply torch.compile to your LightningModule |
| 18 | +******************************************* |
| 19 | + |
| 20 | +Compiling a LightningModule is as simple as adding one line of code, calling :func:`torch.compile`: |
| 21 | + |
| 22 | +.. code-block:: python |
| 23 | +
|
| 24 | + import torch |
| 25 | + import lightning as L |
| 26 | +
|
| 27 | + # Define the model |
| 28 | + model = MyLightningModule() |
| 29 | +
|
| 30 | + # Compile the model |
| 31 | + model = torch.compile(model) |
| 32 | +
|
| 33 | + # Run with the Trainer |
| 34 | + trainer = L.Trainer() |
| 35 | + trainer.fit(model) |
| 36 | +
|
| 37 | +
|
| 38 | +.. important:: |
| 39 | + |
| 40 | + You should compile the model **before** calling ``trainer.fit()`` as shown above for an optimal integration with features in Trainer. |
| 41 | + |
| 42 | +The newly added call to ``torch.compile()`` by itself doesn't do much. It just wraps the model in a "compiled model". |
| 43 | +The actual optimization will start when calling the ``forward()`` method for the first time: |
| 44 | + |
| 45 | +.. code-block:: python |
| 46 | +
|
| 47 | + # 1st execution compiles the model (slow) |
| 48 | + output = model(input) |
| 49 | +
|
| 50 | + # All future executions will be fast (for inputs of the same size) |
| 51 | + output = model(input) |
| 52 | + output = model(input) |
| 53 | + ... |
| 54 | +
|
| 55 | +**When you pass the LightningModule to the Trainer, it will automatically also compile the ``*_step()`` methods.** |
| 56 | + |
| 57 | +When measuring the speed of a compiled model and comparing it to a regular model, it is important to |
| 58 | +always exclude the first call to ``forward()``/``*_step()`` from your measurements, since it includes the compilation time. |
| 59 | + |
| 60 | + |
| 61 | +.. collapse:: Full example with benchmark |
| 62 | + |
| 63 | + Below is an example that measures the speedup you get when compiling the InceptionV3 from TorchVision. |
| 64 | + |
| 65 | + .. code-block:: python |
| 66 | +
|
| 67 | + import statistics |
| 68 | + import torch |
| 69 | + import torchvision.models as models |
| 70 | + import lightning as L |
| 71 | + from torch.utils.data import DataLoader |
| 72 | +
|
| 73 | +
|
| 74 | + class MyLightningModule(L.LightningModule): |
| 75 | + def __init__(self): |
| 76 | + super().__init__() |
| 77 | + self.model = models.inception_v3() |
| 78 | +
|
| 79 | + def training_step(self, batch): |
| 80 | + return self.model(batch).logits.sum() |
| 81 | +
|
| 82 | + def train_dataloader(self): |
| 83 | + return DataLoader([torch.randn(3, 512, 512) for _ in range(256)], batch_size=16) |
| 84 | +
|
| 85 | + def configure_optimizers(self): |
| 86 | + return torch.optim.SGD(self.parameters(), lr=0.01) |
| 87 | +
|
| 88 | +
|
| 89 | + class Benchmark(L.Callback): |
| 90 | + """A callback that measures the median execution time between the start and end of a batch.""" |
| 91 | + def __init__(self): |
| 92 | + self.start = torch.cuda.Event(enable_timing=True) |
| 93 | + self.end = torch.cuda.Event(enable_timing=True) |
| 94 | + self.times = [] |
| 95 | +
|
| 96 | + def median_time(self): |
| 97 | + return statistics.median(self.times) |
| 98 | +
|
| 99 | + def on_train_batch_start(self, trainer, *args, **kwargs): |
| 100 | + self.start.record() |
| 101 | +
|
| 102 | + def on_train_batch_end(self, trainer, *args, **kwargs): |
| 103 | + # Exclude the first iteration to let the model warm up |
| 104 | + if trainer.global_step > 1: |
| 105 | + self.end.record() |
| 106 | + torch.cuda.synchronize() |
| 107 | + self.times.append(self.start.elapsed_time(self.end) / 1000) |
| 108 | +
|
| 109 | +
|
| 110 | + model = MyLightningModule() |
| 111 | +
|
| 112 | + # Compile! |
| 113 | + compiled_model = torch.compile(model) |
| 114 | +
|
| 115 | + # Measure the median iteration time with uncompiled model |
| 116 | + benchmark = Benchmark() |
| 117 | + trainer = L.Trainer(accelerator="cuda", devices=1, max_steps=10, callbacks=[benchmark]) |
| 118 | + trainer.fit(model) |
| 119 | + eager_time = benchmark.median_time() |
| 120 | +
|
| 121 | + # Measure the median iteration time with compiled model |
| 122 | + benchmark = Benchmark() |
| 123 | + trainer = L.Trainer(accelerator="cuda", devices=1, max_steps=10, callbacks=[benchmark]) |
| 124 | + trainer.fit(compiled_model) |
| 125 | + compile_time = benchmark.median_time() |
| 126 | +
|
| 127 | + # Compare the speedup for the compiled execution |
| 128 | + speedup = eager_time / compile_time |
| 129 | + print(f"Eager median time: {eager_time:.4f} seconds") |
| 130 | + print(f"Compile median time: {compile_time:.4f} seconds") |
| 131 | + print(f"Speedup: {speedup:.1f}x") |
| 132 | +
|
| 133 | +
|
| 134 | + On an NVIDIA A100 SXM4 40GB with PyTorch 2.2.0, CUDA 12.1, we get the following speedup: |
| 135 | + |
| 136 | + .. code-block:: text |
| 137 | +
|
| 138 | + Eager median time: 0.0863 seconds |
| 139 | + Compile median time: 0.0709 seconds |
| 140 | + Speedup: 1.2x |
| 141 | +
|
| 142 | +
|
| 143 | +---- |
| 144 | + |
| 145 | + |
| 146 | +****************** |
| 147 | +Avoid graph breaks |
| 148 | +****************** |
| 149 | + |
| 150 | +When ``torch.compile`` looks at the code in your model's ``forward()`` or ``*_step()`` method, it will try to compile as much of the code as possible. |
| 151 | +If there are regions in the code that it doesn't understand, it will introduce a so-called "graph break" that essentially splits the code in optimized and unoptimized parts. |
| 152 | +Graph breaks aren't a deal breaker, since the optimized parts should still run faster. |
| 153 | +But if you want to get the most out of ``torch.compile``, you might want to invest rewriting the problematic section of the code that produce the breaks. |
| 154 | + |
| 155 | +You can check whether your model produces graph breaks by calling ``torch.compile`` with ``fullgraph=True``: |
| 156 | + |
| 157 | +.. code-block:: python |
| 158 | +
|
| 159 | + # Force an error if there is a graph break in the model |
| 160 | + model = torch.compile(model, fullgraph=True) |
| 161 | +
|
| 162 | +Be aware that the error messages produced here are often quite cryptic, so you will likely have to do some `troubleshooting <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html>`_ to fully optimize your model. |
| 163 | + |
| 164 | + |
| 165 | +---- |
| 166 | + |
| 167 | + |
| 168 | +******************* |
| 169 | +Avoid recompilation |
| 170 | +******************* |
| 171 | + |
| 172 | +As mentioned before, the compilation of the model happens the first time you call ``forward()`` or the first time the Trainer calls the ``*_step()`` methods. |
| 173 | +At this point, PyTorch will inspect the input tensor(s) and optimize the compiled code for the particular shape, data type and other properties the input has. |
| 174 | +If the shape of the input remains the same across all calls, PyTorch will reuse the compiled code it generated and you will get the best speedup. |
| 175 | +However, if these properties change across subsequent calls to ``forward()``/``*_step()``, PyTorch will be forced to recompile the model for the new shapes, and this will significantly slow down your training if it happens on every iteration. |
| 176 | + |
| 177 | +**When your training suddenly becomes slow, it's probably because PyTorch is recompiling the model!** |
| 178 | +Here are some common scenarios when this can happen: |
| 179 | + |
| 180 | +- You are using dataset with different inputs or shapes for validation than for training, causing a recompilation whenever the Trainer switches between training and validation. |
| 181 | +- Your dataset size is not divisible by the batch size, and the dataloader has ``drop_last=False`` (the default). |
| 182 | + The last batch in your training loop will be smaller and trigger a recompilation. |
| 183 | + |
| 184 | +Ideally, you should try to make the input shape(s) to ``forward()`` static. |
| 185 | +However, when this is not possible, you can request PyTorch to compile the code by taking into account possible changes to the input shapes. |
| 186 | + |
| 187 | +.. code-block:: python |
| 188 | +
|
| 189 | + # On PyTorch < 2.2 |
| 190 | + model = torch.compile(model, dynamic=True) |
| 191 | +
|
| 192 | +A model compiled with ``dynamic=True`` will typically be slower than a model compiled with static shapes, but it will avoid the extreme cost of recompilation every iteration. |
| 193 | +On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically and you should no longer need to set this. |
| 194 | + |
| 195 | + |
| 196 | +---- |
| 197 | + |
| 198 | + |
| 199 | +*********************************** |
| 200 | +Experiment with compilation options |
| 201 | +*********************************** |
| 202 | + |
| 203 | +There are optional settings that, depending on your model, can give additional speedups. |
| 204 | + |
| 205 | +**CUDA Graphs:** By enabling CUDA Graphs, CUDA will record all computations in a graph and replay it every time forward and backward is called. |
| 206 | +The requirement is that your model must be static, i.e., the input shape must not change and your model must execute the same operations every time. |
| 207 | +Enabling CUDA Graphs often results in a significant speedup, but sometimes also increases the memory usage of your model. |
| 208 | + |
| 209 | +.. code-block:: python |
| 210 | +
|
| 211 | + # Enable CUDA Graphs |
| 212 | + compiled_model = torch.compile(model, mode="reduce-overhead") |
| 213 | +
|
| 214 | + # This does the same |
| 215 | + compiled_model = torch.compile(model, options={"triton.cudagraphs": True}) |
| 216 | +
|
| 217 | +| |
| 218 | +
|
| 219 | +**Shape padding:** The specific shape/size of the tensors involved in the computation of your model (input, activations, weights, gradients, etc.) can have an impact on the performance. |
| 220 | +With shape padding enabled, ``torch.compile`` can extend the tensors by padding to a size that gives a better memory alignment. |
| 221 | +Naturally, the tradoff here is that it will consume a bit more memory. |
| 222 | + |
| 223 | +.. code-block:: python |
| 224 | +
|
| 225 | + # Default is False |
| 226 | + compiled_model = torch.compile(model, options={"shape_padding": True}) |
| 227 | +
|
| 228 | +
|
| 229 | +You can find a full list of compile options in the `PyTorch documentation <https://pytorch.org/docs/stable/generated/torch.compile.html>`_. |
| 230 | + |
| 231 | + |
| 232 | +---- |
| 233 | + |
| 234 | + |
| 235 | +************************************** |
| 236 | +A note about torch.compile in practice |
| 237 | +************************************** |
| 238 | + |
| 239 | +In practice, you will find that ``torch.compile`` may not work well at first or may be counter-productive to performance. |
| 240 | +Compilation may fail with cryptic error messages that are hard to debug, luckily the PyTorch team is responsive and it's likely that messaging will improve in time. |
| 241 | +It is not uncommon that ``torch.compile`` will produce a significantly *slower* model or one with higher memory usage. You'll need to invest time in this phase if the model is not among the ones that have a happy path. |
| 242 | +As a note, the compilation phase itself will take some time, taking up to several minutes. |
| 243 | +For these reasons, we recommend that you don't invest too much time trying to apply ``torch.compile`` during development, and rather evaluate its effectiveness toward the end when you are about to launch long-running, expensive experiments. |
| 244 | +Always compare the speed and memory usage of the compiled model against the original model! |
| 245 | + |
| 246 | + |
| 247 | +---- |
| 248 | + |
| 249 | + |
| 250 | +*********** |
| 251 | +Limitations |
| 252 | +*********** |
| 253 | + |
| 254 | +There are a few limitations you should be aware of when using ``torch.compile`` in conjunction with the Trainer: |
| 255 | + |
| 256 | +* ``torch.compile`` currently does not get reapplied over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment. |
| 257 | + This limitation will be lifted in the future. |
| 258 | + |
| 259 | +* In some cases, using ``self.log()`` in your LightningModule will cause compilation errors. |
| 260 | + Until addressed, you can work around these issues by applying ``torch.compile`` to the submodule(s) of your LightningModule rather than to the entire LightningModule at once. |
| 261 | + |
| 262 | + .. code-block:: python |
| 263 | +
|
| 264 | + import lightning as L |
| 265 | +
|
| 266 | + class MyLightningModule(L.LightningModule): |
| 267 | + def __init__(self): |
| 268 | + super().__init__() |
| 269 | + self.model = MySubModule() |
| 270 | + self.model = torch.compile(self.model) |
| 271 | + ... |
| 272 | +
|
| 273 | +| |
0 commit comments