Skip to content

[cp] set up load balancing testbed #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: gh/XilunWu/1/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions attn_gym/load_balance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from attn_gym.load_balance.load_balancer import load_balance_algo

__all__ = ["load_balance_algo"]
48 changes: 48 additions & 0 deletions attn_gym/load_balance/load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod

import torch


__all__ = ["load_balance_algo"]


def load_balance_algo(S: int, size: int, block_size: int) -> torch.Tensor:
return HeadTail.gen_load_balance_plan(S, size, block_size)


class LoadAlgorithm(ABC):
@classmethod
@abstractmethod
def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor:
pass


class Noop(LoadAlgorithm):
@classmethod
def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor:
total_num_blk = S // block_size
assert S % (size * block_size) == 0
local_num_blk = total_num_blk // size
return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk)


class HeadTail(LoadAlgorithm):
@classmethod
def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor:
total_num_blk = S // block_size
assert S % (size * 2 * block_size) == 0
local_num_blk_pair = total_num_blk // (size * 2)
plan_tensor = torch.arange(total_num_blk, device="cuda").view(
-1, local_num_blk_pair
)
return torch.stack(
(
plan_tensor[:size],
plan_tensor[size:].flip(dims=(0,)),
),
dim=1,
).view(size, -1)


if __name__ == "__main__":
print(HeadTail.gen_load_balance_plan(32, 4, 1))
221 changes: 221 additions & 0 deletions examples/distributed_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
from functools import lru_cache
from typing import Optional

import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Replicate, Shard


from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE,
create_block_mask,
flex_attention,
_mask_mod_signature,
)

from attn_gym.masks.document_mask import length_to_offsets
from attn_gym.masks import (
causal_mask,
generate_doc_mask_mod,
)
from attn_gym.load_balance import load_balance_algo


def get_device_type() -> str:
return "cuda"


@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
return block_mask


# TODO: re-write it into a wrapper???
def rewrite_mask_mod_for_cp(
mask_mod: _mask_mod_signature,
rank: int,
block_size: int,
load_balancer_output: torch.Tensor,
) -> _mask_mod_signature:
def local_q_idx_to_q_idx(local_q_idx) -> int:
# calculate local block_idx and block_offset
local_blk_idx, local_blk_offset = (
local_q_idx // block_size, local_q_idx % block_size
)
current_rank_blk_list = load_balancer_output[rank]
blk_idx = current_rank_blk_list[local_blk_idx]
return blk_idx * block_size + local_blk_offset

return lambda b, h, q_idx, kv_idx: mask_mod(
b, h, local_q_idx_to_q_idx(q_idx), kv_idx
)


def run_document_masking(device_mesh, max_seq_len, num_docs):
# initialize the document lengths
import random

random.seed(0)
torch.cuda.manual_seed(0)

def generate_random_lengths(total_length, num_documents):
# Initialize all lengths to 1 to ensure each document has at least one token
lengths = [1] * num_documents
remaining_length = total_length - num_documents

# Randomly distribute the remaining length
for _ in range(remaining_length):
index = random.randint(0, num_documents - 1)
lengths[index] += 1

return lengths

lengths = generate_random_lengths(max_seq_len, num_docs)
offsets = length_to_offsets(lengths, torch.device(f'cuda:{torch.cuda.current_device():d}')) # TODO: replace with a device mesh call
document_causal_mask = generate_doc_mask_mod(causal_mask, offsets)
test_mask_with_load_balance(device_mesh, mask_mod=document_causal_mask, S=max_seq_len)


def test_mask_with_load_balance(
device_mesh: DeviceMesh,
mask_mod: Optional[_mask_mod_signature] = None,
B: int = 16,
H: int = 16,
S: int = 8192,
D: int = 64,
skip_correctness: bool = False,
print_mask: bool = True,
device: str = "cuda",
):
data_type = torch.float16

# create block mask
block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device)
block_size = _DEFAULT_SPARSE_BLOCK_SIZE # TODO: get block size from block mask

# input initialization
qkv = [
torch.rand(
(B, H, S, D),
device=device_mesh.device_type,
dtype=data_type,
requires_grad=True,
)
for _ in range(3)
]

# NOTE: this shuffle op can be done in other ways
def shuffle_tensor_for_load_balancing(
x: torch.Tensor, shuffle_tensor: torch.Tensor, dim: int
) -> torch.Tensor:
# shuffle the tensor
num_chunks = shuffle_tensor.numel()
x_chunk_list = torch.chunk(x, num_chunks, dim=dim)
assert len(x_chunk_list) == num_chunks
new_x_chunk_list = [None] * num_chunks
for blk_idx in range(num_chunks):
new_x_chunk_list[blk_idx] = x_chunk_list[shuffle_tensor[blk_idx].item()]

return torch.cat(new_x_chunk_list, dim=dim)

def interchange_index_value_2d(tensor: torch.Tensor) -> torch.Tensor:
"""
Interchange the index and value in a PyTorch tensor. The input tensor has
structure: rank -> [block_idx, ...] and the output tensor will be:
block_idx -> block_idx_in_shuffled_tensor
"""
flattened_tensor = tensor.view(-1)
indices = torch.arange(
flattened_tensor.numel(), device=flattened_tensor.device
)
revert_tensor = torch.empty_like(flattened_tensor)
revert_tensor[flattened_tensor] = indices

return revert_tensor

cp_mesh_size = device_mesh.size()
load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size)

seq_dim = 2
# copy QKV
qkv_copy = [t.detach().clone() for t in qkv]
# shuffle Q
qkv_copy[0] = shuffle_tensor_for_load_balancing(
qkv_copy[0], load_balancer_output.view(-1), dim=seq_dim
)
qkv_dist = [
distribute_tensor(
t.requires_grad_(), device_mesh, [
Shard(seq_dim) if i == 0 else Replicate()
]
)
for (i, t) in enumerate(qkv)
]

q_local, k_full, v_full = (dt.to_local() for dt in qkv_dist)

# rewrite `block_mask`
mask_mod: _mask_mod_signature = block_mask.mask_mod
cp_rank = device_mesh.get_local_rank()
cp_mask_mod = rewrite_mask_mod_for_cp(
mask_mod, cp_rank, block_size, load_balancer_output
)
cp_block_mask = create_block_mask_cached(
cp_mask_mod, B=1, H=1, M=S // cp_mesh_size, N=S, device=device
)

# Compile the flex_attention function
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)

# TODO: this doesn't address the return_lse=True case
cp_out = compiled_flex_attention(
q_local,
k_full,
v_full,
score_mod=None,
block_mask=cp_block_mask,
)
assert isinstance(cp_out, torch.Tensor)

# unshard
cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)])
full_cp_out_dist = cp_out_dist.full_tensor()
# rearrange
blk_idx_shuffled = interchange_index_value_2d(load_balancer_output)
full_cp_out_dist = shuffle_tensor_for_load_balancing(
full_cp_out_dist, blk_idx_shuffled, dim=seq_dim
)

# local flex attention
expect_out = flex_attention(*qkv, block_mask=block_mask)
torch.testing.assert_close(full_cp_out_dist, expect_out, atol=1e-1, rtol=1e-2)


def load_balancing_example(world_size: int, rank: int) -> None:
device_type = get_device_type()
device_handle = getattr(torch, device_type, None)
assert device_handle is not None, f"Unsupported device type: {device_type}"
num_devices_per_host = device_handle.device_count()
device_handle.set_device(rank % num_devices_per_host)
torch._dynamo.config.cache_size_limit = 1000

# init device mesh
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,))

run_document_masking(device_mesh, max_seq_len=4096, num_docs=12)


if __name__ == "__main__":
# this script is launched via torchrun which automatically manages ProcessGroup
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
# assert world_size == 4 # our example uses 4 worker ranks

try:
load_balancing_example(world_size, rank)
finally:
dist.barrier()
dist.destroy_process_group()
Loading