|
| 1 | +from functools import lru_cache |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import os |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | +from torch.distributed.device_mesh import init_device_mesh |
| 8 | +from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Replicate, Shard |
| 9 | + |
| 10 | + |
| 11 | +from torch.nn.attention.flex_attention import ( |
| 12 | + _DEFAULT_SPARSE_BLOCK_SIZE, |
| 13 | + create_block_mask, |
| 14 | + flex_attention, |
| 15 | + _mask_mod_signature, |
| 16 | +) |
| 17 | + |
| 18 | +from attn_gym.masks.document_mask import length_to_offsets |
| 19 | +from attn_gym.masks import ( |
| 20 | + causal_mask, |
| 21 | + generate_doc_mask_mod, |
| 22 | +) |
| 23 | +from attn_gym.load_balance import load_balance_algo |
| 24 | + |
| 25 | + |
| 26 | +def get_device_type() -> str: |
| 27 | + return "cuda" |
| 28 | + |
| 29 | + |
| 30 | +@lru_cache |
| 31 | +def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"): |
| 32 | + block_mask = create_block_mask(score_mod, B, H, M, N, device=device) |
| 33 | + return block_mask |
| 34 | + |
| 35 | + |
| 36 | +# TODO: re-write it into a wrapper??? |
| 37 | +def rewrite_mask_mod_for_cp( |
| 38 | + mask_mod: _mask_mod_signature, |
| 39 | + rank: int, |
| 40 | + block_size: int, |
| 41 | + load_balancer_output: torch.Tensor, |
| 42 | +) -> _mask_mod_signature: |
| 43 | + def local_q_idx_to_q_idx(local_q_idx) -> int: |
| 44 | + # calculate local block_idx and block_offset |
| 45 | + local_blk_idx, local_blk_offset = ( |
| 46 | + local_q_idx // block_size, local_q_idx % block_size |
| 47 | + ) |
| 48 | + current_rank_blk_list = load_balancer_output[rank] |
| 49 | + blk_idx = current_rank_blk_list[local_blk_idx] |
| 50 | + return blk_idx * block_size + local_blk_offset |
| 51 | + |
| 52 | + return lambda b, h, q_idx, kv_idx: mask_mod( |
| 53 | + b, h, local_q_idx_to_q_idx(q_idx), kv_idx |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +def run_document_masking(device_mesh, max_seq_len, num_docs): |
| 58 | + # initialize the document lengths |
| 59 | + import random |
| 60 | + |
| 61 | + random.seed(0) |
| 62 | + torch.cuda.manual_seed(0) |
| 63 | + |
| 64 | + def generate_random_lengths(total_length, num_documents): |
| 65 | + # Initialize all lengths to 1 to ensure each document has at least one token |
| 66 | + lengths = [1] * num_documents |
| 67 | + remaining_length = total_length - num_documents |
| 68 | + |
| 69 | + # Randomly distribute the remaining length |
| 70 | + for _ in range(remaining_length): |
| 71 | + index = random.randint(0, num_documents - 1) |
| 72 | + lengths[index] += 1 |
| 73 | + |
| 74 | + return lengths |
| 75 | + |
| 76 | + lengths = generate_random_lengths(max_seq_len, num_docs) |
| 77 | + offsets = length_to_offsets(lengths, torch.device(f'cuda:{torch.cuda.current_device():d}')) # TODO: replace with a device mesh call |
| 78 | + document_causal_mask = generate_doc_mask_mod(causal_mask, offsets) |
| 79 | + test_mask_with_load_balance(device_mesh, mask_mod=document_causal_mask, S=max_seq_len) |
| 80 | + |
| 81 | + |
| 82 | +def test_mask_with_load_balance( |
| 83 | + device_mesh: DeviceMesh, |
| 84 | + mask_mod: Optional[_mask_mod_signature] = None, |
| 85 | + B: int = 16, |
| 86 | + H: int = 16, |
| 87 | + S: int = 8192, |
| 88 | + D: int = 64, |
| 89 | + skip_correctness: bool = False, |
| 90 | + print_mask: bool = True, |
| 91 | + device: str = "cuda", |
| 92 | +): |
| 93 | + data_type = torch.float16 |
| 94 | + |
| 95 | + # create block mask |
| 96 | + block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device) |
| 97 | + block_size = _DEFAULT_SPARSE_BLOCK_SIZE # TODO: get block size from block mask |
| 98 | + |
| 99 | + # input initialization |
| 100 | + qkv = [ |
| 101 | + torch.rand( |
| 102 | + (B, H, S, D), |
| 103 | + device=device_mesh.device_type, |
| 104 | + dtype=data_type, |
| 105 | + requires_grad=True, |
| 106 | + ) |
| 107 | + for _ in range(3) |
| 108 | + ] |
| 109 | + |
| 110 | + # NOTE: this shuffle op can be done in other ways |
| 111 | + def shuffle_tensor_for_load_balancing( |
| 112 | + x: torch.Tensor, shuffle_tensor: torch.Tensor, dim: int |
| 113 | + ) -> torch.Tensor: |
| 114 | + # shuffle the tensor |
| 115 | + num_chunks = shuffle_tensor.numel() |
| 116 | + x_chunk_list = torch.chunk(x, num_chunks, dim=dim) |
| 117 | + assert len(x_chunk_list) == num_chunks |
| 118 | + new_x_chunk_list = [None] * num_chunks |
| 119 | + for blk_idx in range(num_chunks): |
| 120 | + new_x_chunk_list[blk_idx] = x_chunk_list[shuffle_tensor[blk_idx].item()] |
| 121 | + |
| 122 | + return torch.cat(new_x_chunk_list, dim=dim) |
| 123 | + |
| 124 | + def interchange_index_value_2d(tensor: torch.Tensor) -> torch.Tensor: |
| 125 | + """ |
| 126 | + Interchange the index and value in a PyTorch tensor. The input tensor has |
| 127 | + structure: rank -> [block_idx, ...] and the output tensor will be: |
| 128 | + block_idx -> block_idx_in_shuffled_tensor |
| 129 | + """ |
| 130 | + flattened_tensor = tensor.view(-1) |
| 131 | + indices = torch.arange( |
| 132 | + flattened_tensor.numel(), device=flattened_tensor.device |
| 133 | + ) |
| 134 | + revert_tensor = torch.empty_like(flattened_tensor) |
| 135 | + revert_tensor[flattened_tensor] = indices |
| 136 | + |
| 137 | + return revert_tensor |
| 138 | + |
| 139 | + cp_mesh_size = device_mesh.size() |
| 140 | + load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size) |
| 141 | + |
| 142 | + seq_dim = 2 |
| 143 | + # copy QKV |
| 144 | + qkv_copy = [t.detach().clone() for t in qkv] |
| 145 | + # shuffle Q |
| 146 | + qkv_copy[0] = shuffle_tensor_for_load_balancing( |
| 147 | + qkv_copy[0], load_balancer_output.view(-1), dim=seq_dim |
| 148 | + ) |
| 149 | + qkv_dist = [ |
| 150 | + distribute_tensor( |
| 151 | + t.requires_grad_(), device_mesh, [ |
| 152 | + Shard(seq_dim) if i == 0 else Replicate() |
| 153 | + ] |
| 154 | + ) |
| 155 | + for (i, t) in enumerate(qkv) |
| 156 | + ] |
| 157 | + |
| 158 | + q_local, k_full, v_full = (dt.to_local() for dt in qkv_dist) |
| 159 | + |
| 160 | + # rewrite `block_mask` |
| 161 | + mask_mod: _mask_mod_signature = block_mask.mask_mod |
| 162 | + cp_rank = device_mesh.get_local_rank() |
| 163 | + cp_mask_mod = rewrite_mask_mod_for_cp( |
| 164 | + mask_mod, cp_rank, block_size, load_balancer_output |
| 165 | + ) |
| 166 | + cp_block_mask = create_block_mask_cached( |
| 167 | + cp_mask_mod, B=1, H=1, M=S // cp_mesh_size, N=S, device=device |
| 168 | + ) |
| 169 | + |
| 170 | + # Compile the flex_attention function |
| 171 | + compiled_flex_attention = torch.compile(flex_attention, dynamic=False) |
| 172 | + |
| 173 | + # TODO: this doesn't address the return_lse=True case |
| 174 | + cp_out = compiled_flex_attention( |
| 175 | + q_local, |
| 176 | + k_full, |
| 177 | + v_full, |
| 178 | + score_mod=None, |
| 179 | + block_mask=cp_block_mask, |
| 180 | + ) |
| 181 | + assert isinstance(cp_out, torch.Tensor) |
| 182 | + |
| 183 | + # unshard |
| 184 | + cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)]) |
| 185 | + full_cp_out_dist = cp_out_dist.full_tensor() |
| 186 | + # rearrange |
| 187 | + blk_idx_shuffled = interchange_index_value_2d(load_balancer_output) |
| 188 | + full_cp_out_dist = shuffle_tensor_for_load_balancing( |
| 189 | + full_cp_out_dist, blk_idx_shuffled, dim=seq_dim |
| 190 | + ) |
| 191 | + |
| 192 | + # local flex attention |
| 193 | + expect_out = flex_attention(*qkv, block_mask=block_mask) |
| 194 | + torch.testing.assert_close(full_cp_out_dist, expect_out, atol=1e-1, rtol=1e-2) |
| 195 | + |
| 196 | + |
| 197 | +def load_balancing_example(world_size: int, rank: int) -> None: |
| 198 | + device_type = get_device_type() |
| 199 | + device_handle = getattr(torch, device_type, None) |
| 200 | + assert device_handle is not None, f"Unsupported device type: {device_type}" |
| 201 | + num_devices_per_host = device_handle.device_count() |
| 202 | + device_handle.set_device(rank % num_devices_per_host) |
| 203 | + torch._dynamo.config.cache_size_limit = 1000 |
| 204 | + |
| 205 | + # init device mesh |
| 206 | + device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) |
| 207 | + |
| 208 | + run_document_masking(device_mesh, max_seq_len=4096, num_docs=12) |
| 209 | + |
| 210 | + |
| 211 | +if __name__ == "__main__": |
| 212 | + # this script is launched via torchrun which automatically manages ProcessGroup |
| 213 | + rank = int(os.environ["RANK"]) |
| 214 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 215 | + # assert world_size == 4 # our example uses 4 worker ranks |
| 216 | + |
| 217 | + try: |
| 218 | + load_balancing_example(world_size, rank) |
| 219 | + finally: |
| 220 | + dist.barrier() |
| 221 | + dist.destroy_process_group() |
0 commit comments