Skip to content

Commit 020bc85

Browse files
committed
[cp] set up load balancing testbed
ghstack-source-id: f28b448bab842304813ff7464ca61e71494ec552 Pull Request resolved: #120
1 parent af82ef0 commit 020bc85

File tree

3 files changed

+272
-0
lines changed

3 files changed

+272
-0
lines changed

attn_gym/load_balance/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from attn_gym.load_balance.load_balancer import load_balance_algo
2+
3+
__all__ = ["load_balance_algo"]
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from abc import ABC, abstractmethod
2+
3+
import torch
4+
5+
6+
__all__ = ["load_balance_algo"]
7+
8+
9+
def load_balance_algo(S: int, size: int, block_size: int) -> torch.Tensor:
10+
return HeadTail.gen_load_balance_plan(S, size, block_size)
11+
12+
13+
class LoadAlgorithm(ABC):
14+
@classmethod
15+
@abstractmethod
16+
def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor:
17+
pass
18+
19+
20+
class Noop(LoadAlgorithm):
21+
@classmethod
22+
def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor:
23+
total_num_blk = S // block_size
24+
assert S % (size * block_size) == 0
25+
local_num_blk = total_num_blk // size
26+
return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk)
27+
28+
29+
class HeadTail(LoadAlgorithm):
30+
@classmethod
31+
def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor:
32+
total_num_blk = S // block_size
33+
assert S % (size * 2 * block_size) == 0
34+
local_num_blk_pair = total_num_blk // (size * 2)
35+
plan_tensor = torch.arange(total_num_blk, device="cuda").view(
36+
-1, local_num_blk_pair
37+
)
38+
return torch.stack(
39+
(
40+
plan_tensor[:size],
41+
plan_tensor[size:].flip(dims=(0,)),
42+
),
43+
dim=1,
44+
).view(size, -1)
45+
46+
47+
if __name__ == "__main__":
48+
print(HeadTail.gen_load_balance_plan(32, 4, 1))

examples/distributed_benchmark.py

+221
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
from functools import lru_cache
2+
from typing import Optional
3+
4+
import os
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed.device_mesh import init_device_mesh
8+
from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Replicate, Shard
9+
10+
11+
from torch.nn.attention.flex_attention import (
12+
_DEFAULT_SPARSE_BLOCK_SIZE,
13+
create_block_mask,
14+
flex_attention,
15+
_mask_mod_signature,
16+
)
17+
18+
from attn_gym.masks.document_mask import length_to_offsets
19+
from attn_gym.masks import (
20+
causal_mask,
21+
generate_doc_mask_mod,
22+
)
23+
from attn_gym.load_balance import load_balance_algo
24+
25+
26+
def get_device_type() -> str:
27+
return "cuda"
28+
29+
30+
@lru_cache
31+
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
32+
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
33+
return block_mask
34+
35+
36+
# TODO: re-write it into a wrapper???
37+
def rewrite_mask_mod_for_cp(
38+
mask_mod: _mask_mod_signature,
39+
rank: int,
40+
block_size: int,
41+
load_balancer_output: torch.Tensor,
42+
) -> _mask_mod_signature:
43+
def local_q_idx_to_q_idx(local_q_idx) -> int:
44+
# calculate local block_idx and block_offset
45+
local_blk_idx, local_blk_offset = (
46+
local_q_idx // block_size, local_q_idx % block_size
47+
)
48+
current_rank_blk_list = load_balancer_output[rank]
49+
blk_idx = current_rank_blk_list[local_blk_idx]
50+
return blk_idx * block_size + local_blk_offset
51+
52+
return lambda b, h, q_idx, kv_idx: mask_mod(
53+
b, h, local_q_idx_to_q_idx(q_idx), kv_idx
54+
)
55+
56+
57+
def run_document_masking(device_mesh, max_seq_len, num_docs):
58+
# initialize the document lengths
59+
import random
60+
61+
random.seed(0)
62+
torch.cuda.manual_seed(0)
63+
64+
def generate_random_lengths(total_length, num_documents):
65+
# Initialize all lengths to 1 to ensure each document has at least one token
66+
lengths = [1] * num_documents
67+
remaining_length = total_length - num_documents
68+
69+
# Randomly distribute the remaining length
70+
for _ in range(remaining_length):
71+
index = random.randint(0, num_documents - 1)
72+
lengths[index] += 1
73+
74+
return lengths
75+
76+
lengths = generate_random_lengths(max_seq_len, num_docs)
77+
offsets = length_to_offsets(lengths, torch.device(f'cuda:{torch.cuda.current_device():d}')) # TODO: replace with a device mesh call
78+
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
79+
test_mask_with_load_balance(device_mesh, mask_mod=document_causal_mask, S=max_seq_len)
80+
81+
82+
def test_mask_with_load_balance(
83+
device_mesh: DeviceMesh,
84+
mask_mod: Optional[_mask_mod_signature] = None,
85+
B: int = 16,
86+
H: int = 16,
87+
S: int = 8192,
88+
D: int = 64,
89+
skip_correctness: bool = False,
90+
print_mask: bool = True,
91+
device: str = "cuda",
92+
):
93+
data_type = torch.float16
94+
95+
# create block mask
96+
block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device)
97+
block_size = _DEFAULT_SPARSE_BLOCK_SIZE # TODO: get block size from block mask
98+
99+
# input initialization
100+
qkv = [
101+
torch.rand(
102+
(B, H, S, D),
103+
device=device_mesh.device_type,
104+
dtype=data_type,
105+
requires_grad=True,
106+
)
107+
for _ in range(3)
108+
]
109+
110+
# NOTE: this shuffle op can be done in other ways
111+
def shuffle_tensor_for_load_balancing(
112+
x: torch.Tensor, shuffle_tensor: torch.Tensor, dim: int
113+
) -> torch.Tensor:
114+
# shuffle the tensor
115+
num_chunks = shuffle_tensor.numel()
116+
x_chunk_list = torch.chunk(x, num_chunks, dim=dim)
117+
assert len(x_chunk_list) == num_chunks
118+
new_x_chunk_list = [None] * num_chunks
119+
for blk_idx in range(num_chunks):
120+
new_x_chunk_list[blk_idx] = x_chunk_list[shuffle_tensor[blk_idx].item()]
121+
122+
return torch.cat(new_x_chunk_list, dim=dim)
123+
124+
def interchange_index_value_2d(tensor: torch.Tensor) -> torch.Tensor:
125+
"""
126+
Interchange the index and value in a PyTorch tensor. The input tensor has
127+
structure: rank -> [block_idx, ...] and the output tensor will be:
128+
block_idx -> block_idx_in_shuffled_tensor
129+
"""
130+
flattened_tensor = tensor.view(-1)
131+
indices = torch.arange(
132+
flattened_tensor.numel(), device=flattened_tensor.device
133+
)
134+
revert_tensor = torch.empty_like(flattened_tensor)
135+
revert_tensor[flattened_tensor] = indices
136+
137+
return revert_tensor
138+
139+
cp_mesh_size = device_mesh.size()
140+
load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size)
141+
142+
seq_dim = 2
143+
# copy QKV
144+
qkv_copy = [t.detach().clone() for t in qkv]
145+
# shuffle Q
146+
qkv_copy[0] = shuffle_tensor_for_load_balancing(
147+
qkv_copy[0], load_balancer_output.view(-1), dim=seq_dim
148+
)
149+
qkv_dist = [
150+
distribute_tensor(
151+
t.requires_grad_(), device_mesh, [
152+
Shard(seq_dim) if i == 0 else Replicate()
153+
]
154+
)
155+
for (i, t) in enumerate(qkv)
156+
]
157+
158+
q_local, k_full, v_full = (dt.to_local() for dt in qkv_dist)
159+
160+
# rewrite `block_mask`
161+
mask_mod: _mask_mod_signature = block_mask.mask_mod
162+
cp_rank = device_mesh.get_local_rank()
163+
cp_mask_mod = rewrite_mask_mod_for_cp(
164+
mask_mod, cp_rank, block_size, load_balancer_output
165+
)
166+
cp_block_mask = create_block_mask_cached(
167+
cp_mask_mod, B=1, H=1, M=S // cp_mesh_size, N=S, device=device
168+
)
169+
170+
# Compile the flex_attention function
171+
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
172+
173+
# TODO: this doesn't address the return_lse=True case
174+
cp_out = compiled_flex_attention(
175+
q_local,
176+
k_full,
177+
v_full,
178+
score_mod=None,
179+
block_mask=cp_block_mask,
180+
)
181+
assert isinstance(cp_out, torch.Tensor)
182+
183+
# unshard
184+
cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)])
185+
full_cp_out_dist = cp_out_dist.full_tensor()
186+
# rearrange
187+
blk_idx_shuffled = interchange_index_value_2d(load_balancer_output)
188+
full_cp_out_dist = shuffle_tensor_for_load_balancing(
189+
full_cp_out_dist, blk_idx_shuffled, dim=seq_dim
190+
)
191+
192+
# local flex attention
193+
expect_out = flex_attention(*qkv, block_mask=block_mask)
194+
torch.testing.assert_close(full_cp_out_dist, expect_out, atol=1e-1, rtol=1e-2)
195+
196+
197+
def load_balancing_example(world_size: int, rank: int) -> None:
198+
device_type = get_device_type()
199+
device_handle = getattr(torch, device_type, None)
200+
assert device_handle is not None, f"Unsupported device type: {device_type}"
201+
num_devices_per_host = device_handle.device_count()
202+
device_handle.set_device(rank % num_devices_per_host)
203+
torch._dynamo.config.cache_size_limit = 1000
204+
205+
# init device mesh
206+
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))
207+
208+
run_document_masking(device_mesh, max_seq_len=4096, num_docs=12)
209+
210+
211+
if __name__ == "__main__":
212+
# this script is launched via torchrun which automatically manages ProcessGroup
213+
rank = int(os.environ["RANK"])
214+
world_size = int(os.environ["WORLD_SIZE"])
215+
# assert world_size == 4 # our example uses 4 worker ranks
216+
217+
try:
218+
load_balancing_example(world_size, rank)
219+
finally:
220+
dist.barrier()
221+
dist.destroy_process_group()

0 commit comments

Comments
 (0)