27
27
WEIGHTED_KERNEL_MULTIPLIER ,
28
28
)
29
29
from torchrec .distributed .planner .types import (
30
+ CollectiveType ,
31
+ GeneralizedCommsBandwidth ,
30
32
ParameterConstraints ,
31
33
Perf ,
32
34
PlannerError ,
@@ -231,8 +233,7 @@ def estimate(
231
233
hbm_mem_bw = self ._topology .hbm_mem_bw ,
232
234
ddr_mem_bw = self ._topology .ddr_mem_bw ,
233
235
hbm_to_ddr_mem_bw = self ._topology .hbm_to_ddr_mem_bw ,
234
- intra_host_bw = self ._topology .intra_host_bw ,
235
- inter_host_bw = self ._topology .inter_host_bw ,
236
+ comms_bandwidths = self ._topology .comms_bandwidths ,
236
237
bwd_compute_multiplier = self ._topology .bwd_compute_multiplier ,
237
238
weighted_feature_bwd_compute_multiplier = self ._topology .weighted_feature_bwd_compute_multiplier ,
238
239
is_pooled = sharding_option .is_pooled ,
@@ -269,8 +270,7 @@ def perf_func_emb_wall_time(
269
270
hbm_mem_bw : float ,
270
271
ddr_mem_bw : float ,
271
272
hbm_to_ddr_mem_bw : float ,
272
- intra_host_bw : float ,
273
- inter_host_bw : float ,
273
+ comms_bandwidths : GeneralizedCommsBandwidth ,
274
274
bwd_compute_multiplier : float ,
275
275
weighted_feature_bwd_compute_multiplier : float ,
276
276
is_pooled : bool ,
@@ -359,8 +359,7 @@ def perf_func_emb_wall_time(
359
359
num_poolings = num_poolings ,
360
360
hbm_to_ddr_mem_bw = hbm_to_ddr_mem_bw ,
361
361
device_bw = device_bw ,
362
- inter_host_bw = inter_host_bw ,
363
- intra_host_bw = intra_host_bw ,
362
+ comms_bandwidths = comms_bandwidths ,
364
363
bwd_compute_multiplier = bwd_compute_multiplier ,
365
364
weighted_feature_bwd_compute_multiplier = weighted_feature_bwd_compute_multiplier ,
366
365
is_pooled = is_pooled ,
@@ -385,8 +384,7 @@ def perf_func_emb_wall_time(
385
384
num_poolings = num_poolings ,
386
385
hbm_to_ddr_mem_bw = hbm_to_ddr_mem_bw ,
387
386
device_bw = device_bw ,
388
- inter_host_bw = inter_host_bw ,
389
- intra_host_bw = intra_host_bw ,
387
+ comms_bandwidths = comms_bandwidths ,
390
388
bwd_compute_multiplier = bwd_compute_multiplier ,
391
389
weighted_feature_bwd_compute_multiplier = weighted_feature_bwd_compute_multiplier ,
392
390
is_pooled = is_pooled ,
@@ -414,8 +412,7 @@ def perf_func_emb_wall_time(
414
412
num_poolings = num_poolings ,
415
413
hbm_to_ddr_mem_bw = hbm_to_ddr_mem_bw ,
416
414
device_bw = device_bw ,
417
- inter_host_bw = inter_host_bw ,
418
- intra_host_bw = intra_host_bw ,
415
+ comms_bandwidths = comms_bandwidths ,
419
416
bwd_compute_multiplier = bwd_compute_multiplier ,
420
417
weighted_feature_bwd_compute_multiplier = weighted_feature_bwd_compute_multiplier ,
421
418
is_pooled = is_pooled ,
@@ -435,7 +432,7 @@ def perf_func_emb_wall_time(
435
432
output_data_type_size = output_data_type_size ,
436
433
num_poolings = num_poolings ,
437
434
device_bw = device_bw ,
438
- inter_host_bw = inter_host_bw ,
435
+ comms_bandwidths = comms_bandwidths ,
439
436
bwd_compute_multiplier = bwd_compute_multiplier ,
440
437
weighted_feature_bwd_compute_multiplier = weighted_feature_bwd_compute_multiplier ,
441
438
is_pooled = is_pooled ,
@@ -477,15 +474,15 @@ def _get_tw_sharding_perf(
477
474
num_poolings : List [float ],
478
475
hbm_to_ddr_mem_bw : float ,
479
476
device_bw : float ,
480
- inter_host_bw : float ,
481
- intra_host_bw : float ,
477
+ comms_bandwidths : GeneralizedCommsBandwidth ,
482
478
bwd_compute_multiplier : float ,
483
479
weighted_feature_bwd_compute_multiplier : float ,
484
480
is_pooled : bool ,
485
481
is_weighted : bool = False ,
486
482
is_inference : bool = False ,
487
483
expected_cache_fetches : float = 0 ,
488
484
) -> Perf :
485
+
489
486
batch_inputs = sum (
490
487
[x * y * z for x , y , z in zip (input_lengths , num_poolings , batch_sizes )]
491
488
)
@@ -518,8 +515,11 @@ def _get_tw_sharding_perf(
518
515
block_usage_penalty = HALF_BLOCK_PENALTY
519
516
else : # emb_dim >= 32
520
517
block_usage_penalty = QUARTER_BLOCK_PENALTY
521
-
522
- comms_bw = inter_host_bw if world_size > local_world_size else intra_host_bw
518
+ comms_bw = comms_bandwidths .get_bw (
519
+ world_size = world_size ,
520
+ local_world_size = local_world_size ,
521
+ collective_type = CollectiveType .ALL_TO_ALL ,
522
+ )
523
523
fwd_comms = fwd_output_write_size / comms_bw
524
524
525
525
fwd_compute = (
@@ -576,8 +576,7 @@ def _get_rw_sharding_perf(
576
576
num_poolings : List [float ],
577
577
hbm_to_ddr_mem_bw : float ,
578
578
device_bw : float ,
579
- inter_host_bw : float ,
580
- intra_host_bw : float ,
579
+ comms_bandwidths : GeneralizedCommsBandwidth ,
581
580
bwd_compute_multiplier : float ,
582
581
weighted_feature_bwd_compute_multiplier : float ,
583
582
is_pooled : bool ,
@@ -615,8 +614,12 @@ def _get_rw_sharding_perf(
615
614
if is_pooled
616
615
else batch_outputs * world_size * emb_dim * bwd_a2a_comm_data_type_size
617
616
)
617
+ comms_bw = comms_bandwidths .get_bw (
618
+ world_size = world_size ,
619
+ local_world_size = local_world_size ,
620
+ collective_type = CollectiveType .REDUCE_SCATTER ,
621
+ )
618
622
619
- comms_bw = inter_host_bw if world_size > local_world_size else intra_host_bw
620
623
fwd_comms = fwd_output_write_size / comms_bw
621
624
622
625
fwd_compute = (
@@ -628,7 +631,11 @@ def _get_rw_sharding_perf(
628
631
return Perf (
629
632
fwd_compute = fwd_compute , fwd_comms = fwd_comms , bwd_compute = 0 , bwd_comms = 0
630
633
)
631
-
634
+ comms_bw = comms_bandwidths .get_bw (
635
+ world_size = world_size ,
636
+ local_world_size = local_world_size ,
637
+ collective_type = CollectiveType .ALL_GATHER ,
638
+ )
632
639
bwd_comms = bwd_output_write_size / comms_bw
633
640
634
641
bwd_batched_copy = bwd_output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw
@@ -675,8 +682,7 @@ def _get_twrw_sharding_perf(
675
682
num_poolings : List [float ],
676
683
hbm_to_ddr_mem_bw : float ,
677
684
device_bw : float ,
678
- inter_host_bw : float ,
679
- intra_host_bw : float ,
685
+ comms_bandwidths : GeneralizedCommsBandwidth ,
680
686
bwd_compute_multiplier : float ,
681
687
weighted_feature_bwd_compute_multiplier : float ,
682
688
is_pooled : bool ,
@@ -709,28 +715,43 @@ def _get_twrw_sharding_perf(
709
715
bwd_output_write_size = (
710
716
batch_outputs * world_size * emb_dim * bwd_sr_comm_data_type_size
711
717
)
718
+ comms_bw = comms_bandwidths .get_bw (
719
+ world_size = local_world_size ,
720
+ local_world_size = local_world_size ,
721
+ collective_type = CollectiveType .REDUCE_SCATTER ,
722
+ )
712
723
713
724
# intra host comm
714
- fwd_comms = fwd_output_write_size / intra_host_bw
725
+ fwd_comms = fwd_output_write_size / comms_bw
715
726
716
727
# inter host comm
717
728
if world_size > local_world_size :
718
729
inter_host_fwd_output_write_size = (
719
730
batch_outputs
720
731
* (
721
732
world_size / local_world_size
722
- ) # this is the size of the procress group.
733
+ ) # this is the size of the procees group.
723
734
* emb_dim
724
735
* fwd_a2a_comm_data_type_size
725
736
)
726
- fwd_comms += inter_host_fwd_output_write_size / inter_host_bw
737
+ comms_bw = comms_bandwidths .get_bw (
738
+ world_size = int (world_size / local_world_size ),
739
+ local_world_size = 1 ,
740
+ collective_type = CollectiveType .ALL_TO_ALL ,
741
+ )
742
+ fwd_comms += inter_host_fwd_output_write_size / comms_bw
727
743
728
744
fwd_compute = (
729
745
input_read_size + embedding_lookup_size + fwd_output_write_size
730
746
) / device_bw
731
747
732
748
# intra host comm (i.e. all gather)
733
- bwd_comms = bwd_output_write_size / intra_host_bw
749
+ comms_bw = comms_bandwidths .get_bw (
750
+ world_size = local_world_size ,
751
+ local_world_size = local_world_size ,
752
+ collective_type = CollectiveType .ALL_GATHER ,
753
+ )
754
+ bwd_comms = bwd_output_write_size / comms_bw
734
755
735
756
# inter host comm (i.e. all to all)
736
757
if world_size > local_world_size :
@@ -742,7 +763,12 @@ def _get_twrw_sharding_perf(
742
763
* emb_dim
743
764
* bwd_a2a_comm_data_type_size
744
765
)
745
- bwd_comms += inter_host_bwd_output_write_size / inter_host_bw
766
+ comms_bw = comms_bandwidths .get_bw (
767
+ world_size = int (world_size / local_world_size ),
768
+ local_world_size = 1 ,
769
+ collective_type = CollectiveType .ALL_TO_ALL ,
770
+ )
771
+ bwd_comms += inter_host_bwd_output_write_size / comms_bw
746
772
747
773
bwd_grad_indice_weights_kernel = (
748
774
fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0
@@ -784,7 +810,7 @@ def _get_dp_sharding_perf(
784
810
output_data_type_size : float ,
785
811
num_poolings : List [float ],
786
812
device_bw : float ,
787
- inter_host_bw : float ,
813
+ comms_bandwidths : GeneralizedCommsBandwidth ,
788
814
bwd_compute_multiplier : float ,
789
815
weighted_feature_bwd_compute_multiplier : float ,
790
816
is_pooled : bool ,
@@ -815,12 +841,12 @@ def _get_dp_sharding_perf(
815
841
num_nodes = min (world_size / local_world_size , 2 )
816
842
817
843
# all-reduce data transfer: https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf
818
- all_reduce = (
819
- table_size
820
- * (2 * num_nodes - 1 )
821
- / num_nodes
822
- / (inter_host_bw * local_world_size ) # 1 NIC per GPU
844
+ comms_bw = comms_bandwidths .get_bw (
845
+ world_size = world_size ,
846
+ local_world_size = local_world_size ,
847
+ collective_type = CollectiveType .ALL_REDUCE ,
823
848
)
849
+ all_reduce = table_size * (2 * num_nodes - 1 ) / num_nodes / comms_bw
824
850
# inter host communication constraint
825
851
if world_size > 2 * local_world_size :
826
852
all_reduce *= 2
0 commit comments