Skip to content

[Bugfix] Pad hidden_states to avoid cross-ring AllGatherV #963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ApsarasX
Copy link
Collaborator

What this PR does / why we need it?

On the 910B2C hardware platform, each machine is equipped with 16 NPU cards. When using Tensor Parallelism (TP)=16, if the num_tokens (input sequence length) is not divisible by tp_size, it triggers cross-ring AllGatherV communication operations, which may result in unexpected errors(see figure below).

9251917D-B00A-4EAF-A1E4-5B92DA746795

Does this PR introduce any user-facing change?

No

How was this patch tested?

Signed-off-by: ApsarasX <apsarax@outlook.com>
@ApsarasX ApsarasX force-pushed the community-fix-tp16-error branch from f03a0ae to bb6ea4e Compare May 26, 2025 16:18
@ApsarasX ApsarasX added the ready read for review label May 26, 2025
Signed-off-by: ApsarasX <apsarax@outlook.com>
@ApsarasX ApsarasX force-pushed the community-fix-tp16-error branch from c87fac1 to cc7ec02 Compare May 27, 2025 02:00
@ApsarasX ApsarasX requested review from ganyi1996ppo, wangxiyuan, Yikun and MengqingCao and removed request for ganyi1996ppo May 27, 2025 03:55
@ganyi1996ppo
Copy link
Collaborator

Have you tried torch.distributed.all_gather and cat the data collected from all rank?

@ApsarasX
Copy link
Collaborator Author

Have you tried torch.distributed.all_gather and cat the data collected from all rank?

import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp

def verify_tp_o_method(rank, world_size: int, hidden_statses: list[torch.Tensor]):
    torch.npu.set_device(rank)

    dist.init_process_group(
        backend='hccl',
        init_method='tcp://127.0.0.1:55223',
        world_size=world_size,
        rank=rank
    )

    hidden_statses = hidden_statses.to(f"npu:{rank}")

    chunk_hidden_states = torch.tensor_split(hidden_statses, world_size, dim=0)

    router_hidden_states = chunk_hidden_states[rank]

    print(f"[rank={rank}] router_hidden_states.shape = {router_hidden_states.shape}")

    dist.all_gather(list(chunk_hidden_states), router_hidden_states)

    dist.barrier()

    dist.destroy_process_group()

def main():
    world_size = 16
    inter_dim = 16369
    # inter_dim = 16384
    output_dim = 7168
    hidden_statses = torch.randn((inter_dim, output_dim), dtype=torch.half)

    mp.spawn(verify_tp_o_method, args=(world_size, hidden_statses,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

The above code will cause the following error:

截屏2025-05-28 17 13 24

@ganyi1996ppo
Copy link
Collaborator

import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp

def verify_tp_o_method(rank, world_size: int, hidden_statses: list[torch.Tensor]):
    torch.npu.set_device(rank)

    dist.init_process_group(
        backend='hccl',
        init_method='tcp://127.0.0.1:55223',
        world_size=world_size,
        rank=rank
    )

    hidden_statses = hidden_statses.to(f"npu:{rank}")

    chunk_hidden_states = torch.tensor_split(hidden_statses, world_size, dim=0)

    router_hidden_states = chunk_hidden_states[rank]

    print(f"[rank={rank}] router_hidden_states.shape = {router_hidden_states.shape}")

    dist.all_gather(list(chunk_hidden_states), router_hidden_states)

    dist.barrier()

    dist.destroy_process_group()

def main():
    world_size = 16
    inter_dim = 16369
    # inter_dim = 16384
    output_dim = 7168
    hidden_statses = torch.randn((inter_dim, output_dim), dtype=torch.half)

    mp.spawn(verify_tp_o_method, args=(world_size, hidden_statses,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

The above code will cause the following error:

截屏2025-05-28 17 13 24

Strange, this works fine on my machine.....

@ApsarasX
Copy link
Collaborator Author

ApsarasX commented Jun 3, 2025

@ganyi1996ppo What progress about this PR?

@github-actions github-actions bot added merge-conflicts and removed ready read for review labels Jun 4, 2025
Copy link

github-actions bot commented Jun 4, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@jianzs
Copy link
Collaborator

jianzs commented Jun 14, 2025

import torch
import torch_npu
import torch.distributed as dist
import torch.multiprocessing as mp

def verify_tp_o_method(rank, world_size: int, hidden_statses: list[torch.Tensor]):
    torch.npu.set_device(rank)

    dist.init_process_group(
        backend='hccl',
        init_method='tcp://127.0.0.1:55223',
        world_size=world_size,
        rank=rank
    )

    hidden_statses = hidden_statses.to(f"npu:{rank}")

    chunk_hidden_states = torch.tensor_split(hidden_statses, world_size, dim=0)

    router_hidden_states = chunk_hidden_states[rank]

    print(f"[rank={rank}] router_hidden_states.shape = {router_hidden_states.shape}")

    dist.all_gather(list(chunk_hidden_states), router_hidden_states)

    dist.barrier()

    dist.destroy_process_group()

def main():
    world_size = 16
    inter_dim = 16369
    # inter_dim = 16384
    output_dim = 7168
    hidden_statses = torch.randn((inter_dim, output_dim), dtype=torch.half)

    mp.spawn(verify_tp_o_method, args=(world_size, hidden_statses,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

The above code will cause the following error:
截屏2025-05-28 17 13 24

Strange, this works fine on my machine.....

Is it because your device have only 8 cards instead of 16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants