|
| 1 | +<div align="center"> |
| 2 | + |
| 3 | +# DeepCompile: Unlocking Compiler Optimization for Distributed Training |
| 4 | + |
| 5 | +</div> |
| 6 | + |
| 7 | +# Introduction |
| 8 | + |
| 9 | +<div align="center"> |
| 10 | + |
| 11 | +<img src="media/perf_summary.png" width="1000"> |
| 12 | + |
| 13 | +</div> |
| 14 | + |
| 15 | +Distributed training has become essential for scaling today’s massive deep learning models. While deep learning compilers like PyTorch compiler dramatically improved single-GPU training performance through optimizations like kernel fusion and operator scheduling, they fall short when it comes to distributed workloads. |
| 16 | +Existing distributed training frameworks such as DeepSpeed and FSDP have made large-scale model training feasible through advanced parallelization strategies. While powerful, their optimizations are implemented at the PyTorch framework level, which limits the ability to apply compiler-style techniques like dependency analysis or operator scheduling. |
| 17 | + |
| 18 | +DeepCompile addresses this gap by enabling compiler-level optimizations for distributed training. It takes a standard single-GPU model implementation and transforms it into an optimized multi-GPU training graph without requiring changes to the model code. Unlike existing approaches, DeepCompile automatically applies parameter sharding, communication scheduling, and memory-aware execution at the compiler IR level, enabling global analysis and optimization that are difficult to express in traditional frameworks. Furthermore, during training, DeepCompile employs profile-guided optimization techniques to dynamically tune these parallelization strategies and improve training performance. |
| 19 | + |
| 20 | +Our evaluation demonstrates that DeepCompile improves training performance over ZeRO-3 baselines, achieving up to 1.5x speedup when sufficient GPU resources are available, and up to 7x speedup in GPU-constrained settings that require offloading. DeepCompile is available in DeepSpeed versions >= [0.16.6](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.16.6). |
| 21 | + |
| 22 | +# Design Overview |
| 23 | + |
| 24 | +DeepCompile extends the capabilities of deep learning compilers to support distributed training. It starts from a standard single-GPU model implementation, such as those available on the Hugging Face model hub, and automatically transforms it by inserting necessary distributed training operations such as parameter sharding and communication primitives. Users are not required to embed any distributed logic into the model code. |
| 25 | + |
| 26 | +The process begins by compiling the model into an intermediate representation (IR), which forms a computation graph. DeepCompile then applies a sequence of *optimization passes*, each responsible for a specific transformation of the computation graph or a targeted performance improvement, to incrementally introduce distributed behavior and optimize the graph. These include operations such as all-gather for sharded parameters or offloading of optimizer states, all while preserving the original computation semantics (Fig. 1). |
| 27 | + |
| 28 | +<div align="center"> |
| 29 | + |
| 30 | +<img src="media/workflow.png" width="400"> |
| 31 | + |
| 32 | +*Figure 1: Workflow of compilation and optimization with DeepCompile.* |
| 33 | + |
| 34 | +</div> |
| 35 | + |
| 36 | +At its core, DeepCompile builds on two key capabilities: |
| 37 | + |
| 38 | +- **Automatic parallelization**: DeepCompile allows optimization passes to rewrite the single-GPU computation graph into a distributed multi-GPU version, incorporating strategies such as ZeRO, FSDP, and more. This eliminates the need for manual implementation of distributed training logic, drastically reducing engineering effort. |
| 39 | +- **Profile-guided performance tuning**: At runtime, DeepCompile collects profiling data such as operator-level memory usage and execution latency. It uses this information to dynamically schedule computation and communication operators. This enables effects such as an improved overlap between communication and computation, and an avoidance of memory bottlenecks. Fine-grained tuning through these optimization passes often leads to better performance than even manually engineered implementations. |
| 40 | + |
| 41 | +Figure 2 illustrates the optimization cycle employed by DeepCompile. After the initial computation graph is generated by the compiler, DeepCompile profiles its behavior by measuring operator execution time, communication overhead, and memory usage throughout the forward and backward passes. |
| 42 | + |
| 43 | +<div align="center"> |
| 44 | + |
| 45 | +<img src="media/opt_loop.png" width="600"> |
| 46 | + |
| 47 | +*Figure 2. Optimization cycle.* |
| 48 | + |
| 49 | +</div> |
| 50 | + |
| 51 | +Based on the collected profiling data, DeepCompile applies a sequence of optimization passes. These passes modify the computation graph by inserting, removing, or reordering operators to improve overall efficiency. The modified graph is then re-profiled, and this cycle of profiling and optimization is repeated. |
| 52 | + |
| 53 | +Once a stable set of optimizations has been applied, the graph is deployed for the remaining training iterations. During execution, memory usage and other runtime characteristics may change. In such cases, DeepCompile can resume the profiling and optimization cycle according to the predefined schedule of passes, allowing the graph to adapt and maintain high performance. |
| 54 | + |
| 55 | +# Optimizations |
| 56 | + |
| 57 | +DeepCompile is designed as a general compiler framework for applying and optimizing a wide range of parallelization strategies. In the following, we describe several optimizations that have been implemented as optimization passes within DeepCompile. |
| 58 | + |
| 59 | +## ZeRO3 |
| 60 | + |
| 61 | +As an initial step, we have used DeepCompile to implement and enhance ZeRO-3-style optimizations at the compiler level. ZeRO-3 partitions model parameters, gradients, and optimizer states across devices, reducing memory usage and enabling large-scale training. |
| 62 | + |
| 63 | +In conventional ZeRO-3 implementations, operations such as all-gather, reduce-scatter, and buffer release are typically inserted using Python hooks at runtime. DeepCompile replaces this approach by injecting these operations directly into the computation graph during compilation. This allows the compiler to determine their placement precisely, guided by both the static structure of the graph and runtime profiling information. |
| 64 | + |
| 65 | +One of the key optimizations is **proactive prefetching**, which launches all-gather operations earlier in the computation based on memory usage profiling. This reordering increases the overlap between communication and computation thereby improving throughput, while avoiding OOMs. In addition, small communication operations are often fused to reduce launch latency and improve efficiency. |
| 66 | + |
| 67 | +Another optimization is **selective unsharding**, which keeps certain parameters in an unsharded form during the forward and backward passes when memory conditions permit. This reduces the frequency of all-gather operations and avoids redundant communication, particularly in scenarios where gradient accumulation is enabled. |
| 68 | + |
| 69 | +## Offloading |
| 70 | + |
| 71 | +DeepCompile also supports **adaptive offloading**, which offloads optimizer states to reduce GPU memory pressure. Unlike approaches that offload all the optimizer states, adaptive offloading identifies only the portions that exceed the memory limit—such as momentum and variance used by the Adam optimizer—and schedules data transfers to overlap with computation. This selective and asynchronous strategy minimizes overhead and enables efficient training even in memory-constrained environments. |
| 72 | + |
| 73 | +## ZeRO1 |
| 74 | + |
| 75 | +ZeRO-1 differs from ZeRO-3 in that it shards only the optimizer states across devices, while keeping parameters and gradients fully replicated. This approach reduces memory usage with minimal changes to computation flow, making it a lightweight alternative for certain training scenarios. |
| 76 | +DeepCompile implements ZeRO-1-style optimization by inserting reduce-scatter operations directly into the computation graph. By avoiding Python-level hooks, this graph-level integration reduces overhead and improves execution efficiency. |
| 77 | + |
| 78 | +# Performance Improvements |
| 79 | + |
| 80 | +## ZeRO-3 |
| 81 | + |
| 82 | +We evaluated DeepCompile on Llama-3-70B and Mixtral 8x7B using parameter sharding on top of Hugging Face model implementations. |
| 83 | +Figure 3 shows training throughput (TFLOPs/GPU) across different gradient accumulation steps, using 32 H100 GPUs with a sequence length of 1024. |
| 84 | +We compare DeepCompile against two DeepSpeed ZeRO-3 baselines: (i) an eager-mode version without compiler support (labelled ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (labelled ZeRO3+Compile). For DeepCompile, we enabled both proactive prefetching and selective unsharding to demonstrate the combined effect of these optimization passes. |
| 85 | + |
| 86 | +<div align="center"> <img src="media/perf_zero3.png" width="800"> |
| 87 | + |
| 88 | +*Figure 3. Achieved throughputs for ZeRO3 training of Llama-3 70B and Mixtral 8x7B models.* |
| 89 | + |
| 90 | +</div> |
| 91 | +Across both models, DeepCompile consistently delivers higher throughput. The benefit becomes more pronounced at higher accumulation steps, where the reduced frequency of parameter updates makes selective unsharding more effective. DeepCompile with proactive prefetching and selective unsharding achieves up to 1.28× speedup over ZeRO-3 on Llama-3-70B and 1.54× on Mixtral 8x7B. |
| 92 | + |
| 93 | +Meanwhile, enabling the PyTorch compiler with ZeRO-3, i.e., ZeRO3+Compile introduces minor overheads in some settings. This is because ZeRO-3 includes many conditional branches for runtime features such as prefetching. When the compiler encounters branches that cannot be statically resolved, it splits the computation into multiple graph segments. These fragmented segments can reduce optimization opportunities and introduce additional overheads during execution. |
| 94 | + |
| 95 | +## Offloading |
| 96 | + |
| 97 | +Training models as large as Llama-3 70B with ZeRO-3 typically requires 32 GPUs with 80GB of memory. |
| 98 | +DeepSpeed addresses this challenge by offering offloading capabilities, which transfer optimizer states and optionally model parameters to CPU memory to reduce GPU memory usage. DeepCompile also supports offloading through a dedicated optimization pass, but with a few key differences in design. |
| 99 | + |
| 100 | +Unlike the traditional approach of offloading both optimizer computation and memory, DeepCompile offloads only optimizer memory (e.g., momentum, variance, and master weights of Adam optimizer) while the optimizer computation remains on GPU. DeepCompile profiles memory usage during both forward and backward passes to identify when offloading is necessary, and transfers only the required data. This fine-grained approach avoids unnecessary overhead and helps maintain high computational throughput. |
| 101 | +Furthermore, DeepCompile overlaps data transfers with computation whenever possible, dynamically adjusting the timing based on observed memory usage patterns. This asynchronous behavior is a crucial aspect of DeepCompile’s offloading strategy, allowing it to reduce GPU memory pressure without stalling execution. |
| 102 | + |
| 103 | +We evaluated DeepCompile's offloading using Llama-3 70B on 16xH100-80GB (half the required GPU counts) and present the results in Figure 4. |
| 104 | + |
| 105 | +<div align="center"> |
| 106 | + |
| 107 | +<img src="media/perf_offload.png" width="400"> |
| 108 | + |
| 109 | +*Figure 4. Achieved throughput of optimizer offloading for Llama-3 70B on 16x80GB GPUs* |
| 110 | + |
| 111 | +</div> |
| 112 | + |
| 113 | +We compare against two ZeRO-3 offloading baselines: (i) an eager-mode version without compiler support (ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO3+Compile). As shown by the results, DeepCompile significantly improves offloading efficiency and provides up to 7× speedup over ZeRO3+Eager. In contrast, we see that ZeRO3+Compile achieves similar performance as ZeRO3+Eager. |
| 114 | + |
| 115 | + |
| 116 | +## ZeRO-1 |
| 117 | + |
| 118 | +We also evaluated DeepCompile with ZeRO-1 using the Llama-3-8B model. We compare DeepCompile against two ZeRO-1 baselines: (i) an eager-mode version without compiler support (ZeRO1+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO1+Compile). In our experiment with 8 GPUs and a batch size of 2, DeepCompile achieved consistent throughput improvements across different sequence lengths, as shown in Figure 5. |
| 119 | + |
| 120 | +<div align="center"> |
| 121 | + |
| 122 | +<img src="media/perf_zero1.png" width="800"> |
| 123 | + |
| 124 | +*Figure 5. Achieved throughput of ZeRO-1 training of Llama-3 8B* |
| 125 | + |
| 126 | +</div> |
| 127 | + |
| 128 | +The most significant speedup was observed with batch size 1 and sequence length 512, where DeepCompile outperformed ZeRO1+Eager by up to 1.9×, and ZeRO1+Compile by up to 2.5×. |
| 129 | + |
| 130 | +While compiler-based approaches can be effective for large batch sizes and long sequences by replacing suboptimal operations with more efficient kernels, they may also introduce overheads in ZeRO-1-style training in the form of *graph breaks* around the communication operations. These overheads become more pronounced with smaller batch sizes and sequence lengths, thus hurting performance compared to the non-compiled execution. In contrast, DeepCompile inserts communication operators directly into the computation graph during compilation, avoiding graph fragmentation and minimizing associated overhead. This makes DeepCompile more robust to small-scale workloads, while still benefiting from compiler-level optimizations. |
| 131 | + |
| 132 | +## Additional Results and Analysis |
| 133 | + |
| 134 | +Please refer to our [arXiv paper](https://arxiv.org/abs/2504.09983) for additional results, such as detailed comparisons across different batch sizes, sequence lengths, and memory usage. |
| 135 | + |
| 136 | +# Looking Ahead |
| 137 | + |
| 138 | +DeepCompile brings the power of compiler-based optimizations to distributed deep learning. By transforming computation graphs and applying profile-guided optimization passes, it enables more efficient training without requiring changes to model code. |
| 139 | + |
| 140 | +This release is just the beginning. We’re actively working on expanding the set of optimization passes and improving integration with a broader range of distributed training strategies. Future directions include automated parallelization (sequence/tensor parallelisms), smarter memory management, and dynamic adaptation to runtime behavior. |
| 141 | + |
| 142 | +We invite the community to try DeepCompile, explore its capabilities, and contribute to its evolution. Let’s build the next generation of scalable deep learning together. |
| 143 | + |
| 144 | +# Contributors |
| 145 | + |
| 146 | +This project is the result of a close collaboration between Microsoft and the University of Virginia. The contributors are: Masahiro Tanaka, Du Li, and Umesh Chand, Olatunji Ruwase (Microsoft); and Ali Zafar and Haiying Shen (University of Virginia). |
| 147 | + |
| 148 | +# Appendix |
| 149 | + |
| 150 | +## Examples and Benchmarks |
| 151 | + |
| 152 | +Our DeepSpeedExamples repository provides [example code](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/deepcompile) to enable DeepCompile. |
| 153 | + |
| 154 | +## Optimization Passes |
| 155 | + |
| 156 | +The following optimization passes are currently available in DeepCompile: |
| 157 | + |
| 158 | +- All-gather & reduce-scatter insertion (ZeRO3) |
| 159 | +- Proactive prefetching (ZeRO3) |
| 160 | +- Selective unsharding (ZeRO3) |
| 161 | +- Reduce-scatter insertion (ZeRO1) |
| 162 | +- Adaptive offloading |
| 163 | + |
| 164 | +We used the following combinations of passes in the experiments presented above: |
| 165 | + |
| 166 | +- Improved communication scheduling for ZeRO-3: All-gather & reduce-scatter → Proactive prefetching → Selective unsharding |
| 167 | +- Offloading optimizer states for ZeRO3: Adding all-gather & reduce-scatter → Adaptive offloading |
| 168 | +- Reduced overhead and improved overlap for ZeRO-1: Adding reduce-scatter |
0 commit comments