Skip to content

Commit b1d43ca

Browse files
gregmacnamarafacebook-github-bot
authored andcommitted
Enable More General Comms Bandwidth Estimates (#2905)
Summary: Optimizing sharding plans relies on good perf estimates which in turn relies on good bandwidth estimates. Bandwidth of a collective communication depends on many factors including the world size, type of collective (all to all, reduce scatter, ...), and hardware. In this change, we introduce an abstract class GeneralizedCommsBandwidths that implements a get_bw function. Users can subclass this and define the get_bw appropriately for the hardware or network setup being used. If passed to the Topology, it will be called from perf estimators and will enable setting correct configurations of bandwidth depending on hardware, collective, world_size. We prefer to define this as a class and function as opposed to a dictionary or other similar data structure because we want to enable users to specify bandwidth as precisely as is useful to them. For example, specifying a different bandwidth for every different worldsize rather than only <local_world_size or >=local_world_size. For Topology objects, the comms_bandwidths argument is optional so we do not change existing behavior and perf estimates do not change. If not passed, we create a GeneralizedCommsBandwidth object with the passed values for inter and intra node comms bw and that replicates the existing behavior of perf estimators. We re-define the signature of perf estimators downstream to ensure consistency and clarity for users so that there is a consistent comms bandwidth. Reviewed By: iamzainhuda Differential Revision: D73229318
1 parent a28ac22 commit b1d43ca

File tree

3 files changed

+486
-40
lines changed

3 files changed

+486
-40
lines changed

torchrec/distributed/planner/shard_estimators.py

+58-32
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
WEIGHTED_KERNEL_MULTIPLIER,
2828
)
2929
from torchrec.distributed.planner.types import (
30+
CollectiveType,
31+
GeneralizedCommsBandwidth,
3032
ParameterConstraints,
3133
Perf,
3234
PlannerError,
@@ -231,8 +233,7 @@ def estimate(
231233
hbm_mem_bw=self._topology.hbm_mem_bw,
232234
ddr_mem_bw=self._topology.ddr_mem_bw,
233235
hbm_to_ddr_mem_bw=self._topology.hbm_to_ddr_mem_bw,
234-
intra_host_bw=self._topology.intra_host_bw,
235-
inter_host_bw=self._topology.inter_host_bw,
236+
comms_bandwidths=self._topology.comms_bandwidths,
236237
bwd_compute_multiplier=self._topology.bwd_compute_multiplier,
237238
weighted_feature_bwd_compute_multiplier=self._topology.weighted_feature_bwd_compute_multiplier,
238239
is_pooled=sharding_option.is_pooled,
@@ -269,8 +270,7 @@ def perf_func_emb_wall_time(
269270
hbm_mem_bw: float,
270271
ddr_mem_bw: float,
271272
hbm_to_ddr_mem_bw: float,
272-
intra_host_bw: float,
273-
inter_host_bw: float,
273+
comms_bandwidths: GeneralizedCommsBandwidth,
274274
bwd_compute_multiplier: float,
275275
weighted_feature_bwd_compute_multiplier: float,
276276
is_pooled: bool,
@@ -359,8 +359,7 @@ def perf_func_emb_wall_time(
359359
num_poolings=num_poolings,
360360
hbm_to_ddr_mem_bw=hbm_to_ddr_mem_bw,
361361
device_bw=device_bw,
362-
inter_host_bw=inter_host_bw,
363-
intra_host_bw=intra_host_bw,
362+
comms_bandwidths=comms_bandwidths,
364363
bwd_compute_multiplier=bwd_compute_multiplier,
365364
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
366365
is_pooled=is_pooled,
@@ -385,8 +384,7 @@ def perf_func_emb_wall_time(
385384
num_poolings=num_poolings,
386385
hbm_to_ddr_mem_bw=hbm_to_ddr_mem_bw,
387386
device_bw=device_bw,
388-
inter_host_bw=inter_host_bw,
389-
intra_host_bw=intra_host_bw,
387+
comms_bandwidths=comms_bandwidths,
390388
bwd_compute_multiplier=bwd_compute_multiplier,
391389
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
392390
is_pooled=is_pooled,
@@ -414,8 +412,7 @@ def perf_func_emb_wall_time(
414412
num_poolings=num_poolings,
415413
hbm_to_ddr_mem_bw=hbm_to_ddr_mem_bw,
416414
device_bw=device_bw,
417-
inter_host_bw=inter_host_bw,
418-
intra_host_bw=intra_host_bw,
415+
comms_bandwidths=comms_bandwidths,
419416
bwd_compute_multiplier=bwd_compute_multiplier,
420417
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
421418
is_pooled=is_pooled,
@@ -435,7 +432,7 @@ def perf_func_emb_wall_time(
435432
output_data_type_size=output_data_type_size,
436433
num_poolings=num_poolings,
437434
device_bw=device_bw,
438-
inter_host_bw=inter_host_bw,
435+
comms_bandwidths=comms_bandwidths,
439436
bwd_compute_multiplier=bwd_compute_multiplier,
440437
weighted_feature_bwd_compute_multiplier=weighted_feature_bwd_compute_multiplier,
441438
is_pooled=is_pooled,
@@ -477,15 +474,15 @@ def _get_tw_sharding_perf(
477474
num_poolings: List[float],
478475
hbm_to_ddr_mem_bw: float,
479476
device_bw: float,
480-
inter_host_bw: float,
481-
intra_host_bw: float,
477+
comms_bandwidths: GeneralizedCommsBandwidth,
482478
bwd_compute_multiplier: float,
483479
weighted_feature_bwd_compute_multiplier: float,
484480
is_pooled: bool,
485481
is_weighted: bool = False,
486482
is_inference: bool = False,
487483
expected_cache_fetches: float = 0,
488484
) -> Perf:
485+
489486
batch_inputs = sum(
490487
[x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]
491488
)
@@ -518,8 +515,11 @@ def _get_tw_sharding_perf(
518515
block_usage_penalty = HALF_BLOCK_PENALTY
519516
else: # emb_dim >= 32
520517
block_usage_penalty = QUARTER_BLOCK_PENALTY
521-
522-
comms_bw = inter_host_bw if world_size > local_world_size else intra_host_bw
518+
comms_bw = comms_bandwidths.get_bw(
519+
world_size=world_size,
520+
local_world_size=local_world_size,
521+
collective_type=CollectiveType.ALL_TO_ALL,
522+
)
523523
fwd_comms = fwd_output_write_size / comms_bw
524524

525525
fwd_compute = (
@@ -576,8 +576,7 @@ def _get_rw_sharding_perf(
576576
num_poolings: List[float],
577577
hbm_to_ddr_mem_bw: float,
578578
device_bw: float,
579-
inter_host_bw: float,
580-
intra_host_bw: float,
579+
comms_bandwidths: GeneralizedCommsBandwidth,
581580
bwd_compute_multiplier: float,
582581
weighted_feature_bwd_compute_multiplier: float,
583582
is_pooled: bool,
@@ -615,8 +614,12 @@ def _get_rw_sharding_perf(
615614
if is_pooled
616615
else batch_outputs * world_size * emb_dim * bwd_a2a_comm_data_type_size
617616
)
617+
comms_bw = comms_bandwidths.get_bw(
618+
world_size=world_size,
619+
local_world_size=local_world_size,
620+
collective_type=CollectiveType.REDUCE_SCATTER,
621+
)
618622

619-
comms_bw = inter_host_bw if world_size > local_world_size else intra_host_bw
620623
fwd_comms = fwd_output_write_size / comms_bw
621624

622625
fwd_compute = (
@@ -628,7 +631,11 @@ def _get_rw_sharding_perf(
628631
return Perf(
629632
fwd_compute=fwd_compute, fwd_comms=fwd_comms, bwd_compute=0, bwd_comms=0
630633
)
631-
634+
comms_bw = comms_bandwidths.get_bw(
635+
world_size=world_size,
636+
local_world_size=local_world_size,
637+
collective_type=CollectiveType.ALL_GATHER,
638+
)
632639
bwd_comms = bwd_output_write_size / comms_bw
633640

634641
bwd_batched_copy = bwd_output_write_size * BATCHED_COPY_PERF_FACTOR / device_bw
@@ -675,8 +682,7 @@ def _get_twrw_sharding_perf(
675682
num_poolings: List[float],
676683
hbm_to_ddr_mem_bw: float,
677684
device_bw: float,
678-
inter_host_bw: float,
679-
intra_host_bw: float,
685+
comms_bandwidths: GeneralizedCommsBandwidth,
680686
bwd_compute_multiplier: float,
681687
weighted_feature_bwd_compute_multiplier: float,
682688
is_pooled: bool,
@@ -709,28 +715,43 @@ def _get_twrw_sharding_perf(
709715
bwd_output_write_size = (
710716
batch_outputs * world_size * emb_dim * bwd_sr_comm_data_type_size
711717
)
718+
comms_bw = comms_bandwidths.get_bw(
719+
world_size=local_world_size,
720+
local_world_size=local_world_size,
721+
collective_type=CollectiveType.REDUCE_SCATTER,
722+
)
712723

713724
# intra host comm
714-
fwd_comms = fwd_output_write_size / intra_host_bw
725+
fwd_comms = fwd_output_write_size / comms_bw
715726

716727
# inter host comm
717728
if world_size > local_world_size:
718729
inter_host_fwd_output_write_size = (
719730
batch_outputs
720731
* (
721732
world_size / local_world_size
722-
) # this is the size of the procress group.
733+
) # this is the size of the procees group.
723734
* emb_dim
724735
* fwd_a2a_comm_data_type_size
725736
)
726-
fwd_comms += inter_host_fwd_output_write_size / inter_host_bw
737+
comms_bw = comms_bandwidths.get_bw(
738+
world_size=int(world_size / local_world_size),
739+
local_world_size=1,
740+
collective_type=CollectiveType.ALL_TO_ALL,
741+
)
742+
fwd_comms += inter_host_fwd_output_write_size / comms_bw
727743

728744
fwd_compute = (
729745
input_read_size + embedding_lookup_size + fwd_output_write_size
730746
) / device_bw
731747

732748
# intra host comm (i.e. all gather)
733-
bwd_comms = bwd_output_write_size / intra_host_bw
749+
comms_bw = comms_bandwidths.get_bw(
750+
world_size=local_world_size,
751+
local_world_size=local_world_size,
752+
collective_type=CollectiveType.ALL_GATHER,
753+
)
754+
bwd_comms = bwd_output_write_size / comms_bw
734755

735756
# inter host comm (i.e. all to all)
736757
if world_size > local_world_size:
@@ -742,7 +763,12 @@ def _get_twrw_sharding_perf(
742763
* emb_dim
743764
* bwd_a2a_comm_data_type_size
744765
)
745-
bwd_comms += inter_host_bwd_output_write_size / inter_host_bw
766+
comms_bw = comms_bandwidths.get_bw(
767+
world_size=int(world_size / local_world_size),
768+
local_world_size=1,
769+
collective_type=CollectiveType.ALL_TO_ALL,
770+
)
771+
bwd_comms += inter_host_bwd_output_write_size / comms_bw
746772

747773
bwd_grad_indice_weights_kernel = (
748774
fwd_compute * WEIGHTED_KERNEL_MULTIPLIER if is_weighted else 0
@@ -784,7 +810,7 @@ def _get_dp_sharding_perf(
784810
output_data_type_size: float,
785811
num_poolings: List[float],
786812
device_bw: float,
787-
inter_host_bw: float,
813+
comms_bandwidths: GeneralizedCommsBandwidth,
788814
bwd_compute_multiplier: float,
789815
weighted_feature_bwd_compute_multiplier: float,
790816
is_pooled: bool,
@@ -815,12 +841,12 @@ def _get_dp_sharding_perf(
815841
num_nodes = min(world_size / local_world_size, 2)
816842

817843
# all-reduce data transfer: https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf
818-
all_reduce = (
819-
table_size
820-
* (2 * num_nodes - 1)
821-
/ num_nodes
822-
/ (inter_host_bw * local_world_size) # 1 NIC per GPU
844+
comms_bw = comms_bandwidths.get_bw(
845+
world_size=world_size,
846+
local_world_size=local_world_size,
847+
collective_type=CollectiveType.ALL_REDUCE,
823848
)
849+
all_reduce = table_size * (2 * num_nodes - 1) / num_nodes / comms_bw
824850
# inter host communication constraint
825851
if world_size > 2 * local_world_size:
826852
all_reduce *= 2

0 commit comments

Comments
 (0)