Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU Model Loading Issues - Issue & Code Fix #36

Open
pftq opened this issue Mar 17, 2025 · 4 comments
Open

Multi-GPU Model Loading Issues - Issue & Code Fix #36

pftq opened this issue Mar 17, 2025 · 4 comments

Comments

@pftq
Copy link
Contributor

pftq commented Mar 17, 2025

It takes 10-20 min to load up torch, checkpoints, etc each when using 2 GPUs. The time grows with more GPUs. It otherwise only takes a couple minutes if it were 1 GPU. I suspect it's because of contention issues where all GPUs are trying to access the model files at the same time.

Example log
Look at lines for "Loading torch model" (10 minutes) and "Loading text encoder model" (20 minutes)

2025-03-16 18:33:46.289 | INFO     | hyvideo.inference:from_pretrained:170 - Got text-to-video model root path: ckpts
2025-03-16 18:33:46.289 | INFO     | hyvideo.inference:from_pretrained:170 - Got text-to-video model root path: ckpts
DEBUG 03-16 18:33:46 [parallel_state.py:200] world_size=2 rank=0 local_rank=-1 distributed_init_method=env:// backend=nccl
DEBUG 03-16 18:33:46 [parallel_state.py:200] world_size=2 rank=1 local_rank=-1 distributed_init_method=env:// backend=nccl
2025-03-16 18:33:46.352 | INFO     | hyvideo.inference:from_pretrained:192 - Building model...
2025-03-16 18:33:46.352 | INFO     | hyvideo.inference:from_pretrained:192 - Building model...
2025-03-16 18:33:47.023 | INFO     | hyvideo.inference:load_state_dict:334 - Loading torch model ckpts/hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt...
2025-03-16 18:33:47.024 | INFO     | hyvideo.inference:load_state_dict:334 - Loading torch model ckpts/hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt...
2025-03-16 18:42:56.412 | INFO     | hyvideo.vae:load_vae:29 - Loading 3D VAE model (884-16c-hy) from: ./ckpts/hunyuan-video-i2v-720p/vae
2025-03-16 18:42:56.453 | INFO     | hyvideo.vae:load_vae:29 - Loading 3D VAE model (884-16c-hy) from: ./ckpts/hunyuan-video-i2v-720p/vae
2025-03-16 18:46:43.852 | INFO     | hyvideo.vae:load_vae:55 - VAE to dtype: torch.float16
2025-03-16 18:46:43.962 | INFO     | hyvideo.vae:load_vae:55 - VAE to dtype: torch.float16
2025-03-16 18:46:43.973 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (llm-i2v) from: ./ckpts/text_encoder_i2v
2025-03-16 18:46:44.075 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (llm-i2v) from: ./ckpts/text_encoder_i2v
Loading checkpoint shards:  25%|███████████████████████████████████████▊                                                                                                                       | 1/4 [08:59<26:57, 539.10s/it]Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [20:05<00:00, 301.49s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [20:06<00:00, 301.52s/it]
2025-03-16 19:08:08.187 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 19:08:11.073 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (llm-i2v) from: ./ckpts/text_encoder_i2v
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-03-16 19:08:11.618 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (clipL) from: ./ckpts/text_encoder_2
2025-03-16 19:08:15.042 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 19:08:17.000 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 19:08:17.035 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (clipL) from: ./ckpts/text_encoder_2
2025-03-16 19:08:17.248 | INFO     | hyvideo.inference:predict:596 - Input (height, width, video_length) = (720, 720, 129)
2025-03-16 19:08:19.307 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (llm-i2v) from: ./ckpts/text_encoder_i2v
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-03-16 19:08:19.639 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (clipL) from: ./ckpts/text_encoder_2
2025-03-16 19:08:19.996 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 19:08:20.028 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (clipL) from: ./ckpts/text_encoder_2
2025-03-16 19:08:20.149 | INFO     | hyvideo.inference:predict:596 - Input (height, width, video_length) = (720, 720, 129)
@pftq
Copy link
Contributor Author

pftq commented Mar 17, 2025

I changed inference.py > from_pretrained to load models sequentially (one GPU at a time), and that seemed to fix the problem. The log now shows the total load time being under a couple minutes instead of up to 20 min per model.

I include the code update below and also created a pull request if you guys want to incorporate the fix into the main repo:
#37

Log:

2025-03-16 20:53:16.782 | INFO     | hyvideo.inference:from_pretrained:281 - Building model...
[rank1]:[W316 20:53:16.059829872 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
2025-03-16 20:53:17.267 | INFO     | hyvideo.inference:load_state_dict:517 - Loading torch model ckpts/hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt...
2025-03-16 20:54:57.389 | INFO     | hyvideo.vae:load_vae:29 - Loading 3D VAE model (884-16c-hy) from: ./ckpts/hunyuan-video-i2v-720p/vae
2025-03-16 20:55:02.900 | INFO     | hyvideo.vae:load_vae:55 - VAE to dtype: torch.float16
2025-03-16 20:55:03.037 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (llm-i2v) from: ./ckpts/text_encoder_i2v
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:51<00:00, 12.83s/it]
2025-03-16 20:56:12.104 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 20:56:14.920 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (llm-i2v) from: ./ckpts/text_encoder_i2v
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-03-16 20:56:15.392 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (clipL) from: ./ckpts/text_encoder_2
2025-03-16 20:56:19.649 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 20:56:19.685 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (clipL) from: ./ckpts/text_encoder_2
2025-03-16 20:56:19.911 | INFO     | hyvideo.inference:from_pretrained:373 - Rank 0: Starting broadcast synchronization
[rank0]:[W316 20:56:19.188660486 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
2025-03-16 20:56:21.094 | INFO     | hyvideo.inference:from_pretrained:436 - Rank 0: Broadcasting model parameters
2025-03-16 20:56:21.144 | INFO     | hyvideo.inference:from_pretrained:440 - Rank 0: Broadcasting VAE parameters
2025-03-16 20:56:21.220 | INFO     | hyvideo.vae:load_vae:29 - Loading 3D VAE model (884-16c-hy) from: ./ckpts/hunyuan-video-i2v-720p/vae
2025-03-16 20:56:24.206 | INFO     | hyvideo.vae:load_vae:55 - VAE to dtype: torch.float16
2025-03-16 20:56:24.321 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (llm-i2v) from: ./ckpts/text_encoder_i2v
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:17<00:00,  4.27s/it]
2025-03-16 20:56:55.901 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 20:56:58.288 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (llm-i2v) from: ./ckpts/text_encoder_i2v
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
2025-03-16 20:56:58.600 | INFO     | hyvideo.text_encoder:load_text_encoder:35 - Loading text encoder model (clipL) from: ./ckpts/text_encoder_2
2025-03-16 20:56:58.897 | INFO     | hyvideo.text_encoder:load_text_encoder:61 - Text encoder to dtype: torch.float16
2025-03-16 20:56:58.926 | INFO     | hyvideo.text_encoder:load_tokenizer:75 - Loading tokenizer (clipL) from: ./ckpts/text_encoder_2
2025-03-16 20:56:59.018 | INFO     | hyvideo.inference:from_pretrained:436 - Rank 1: Broadcasting model parameters
2025-03-16 20:56:59.024 | INFO     | hyvideo.inference:from_pretrained:444 - Rank 0: Broadcasting vae_kwargs
2025-03-16 20:56:59.041 | INFO     | hyvideo.inference:from_pretrained:440 - Rank 1: Broadcasting VAE parameters
2025-03-16 20:56:59.046 | INFO     | hyvideo.inference:from_pretrained:444 - Rank 1: Broadcasting vae_kwargs
2025-03-16 20:56:59.109 | INFO     | hyvideo.inference:from_pretrained:448 - Rank 0: Broadcasting text_encoder parameters
2025-03-16 20:56:59.134 | INFO     | hyvideo.inference:from_pretrained:452 - Rank 0: Broadcasting text_encoder_2 parameters
2025-03-16 20:56:59.160 | INFO     | hyvideo.inference:predict:802 - Input (height, width, video_length) = (720, 720, 129)
2025-03-16 20:56:59.182 | INFO     | hyvideo.inference:from_pretrained:448 - Rank 1: Broadcasting text_encoder parameters
2025-03-16 20:56:59.217 | INFO     | hyvideo.inference:from_pretrained:452 - Rank 1: Broadcasting text_encoder_2 parameters
2025-03-16 20:56:59.238 | INFO     | hyvideo.inference:predict:802 - Input (height, width, video_length) = (720, 720, 129)

Here's the new from_pretrained code if you want to incorporate it:

    # 20250316 pftq: Fixed multi-GPU loading times going up to 20 min due to loading contention by loading models only to one GPU and braodcasting to the rest.
    @classmethod
    def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
        """
        Initialize the Inference pipeline.
    
        Args:
            pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
            args (argparse.Namespace): The arguments for the pipeline.
            device (int): The device for inference. Default is None.
        """
        logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
        
        # ========================================================================
        # Initialize Distributed Environment
        # ========================================================================
        # 20250316 pftq: Modified to extract rank and world_size early for sequential loading
        if args.ulysses_degree > 1 or args.ring_degree > 1:
            assert xfuser is not None, "Ulysses Attention and Ring Attention requires xfuser package."
            assert args.use_cpu_offload is False, "Cannot enable use_cpu_offload in the distributed environment."
            # 20250316 pftq: Set local rank and device explicitly for NCCL
            local_rank = int(os.environ['LOCAL_RANK'])
            device = torch.device(f"cuda:{local_rank}")
            torch.cuda.set_device(local_rank)  # 20250316 pftq: Set CUDA device explicitly
            dist.init_process_group("nccl")  # 20250316 pftq: Removed device_id, rely on set_device
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            assert world_size == args.ring_degree * args.ulysses_degree, \
                "number of GPUs should be equal to ring_degree * ulysses_degree."
            init_distributed_environment(rank=rank, world_size=world_size)
            initialize_model_parallel(
                sequence_parallel_degree=world_size,
                ring_degree=args.ring_degree,
                ulysses_degree=args.ulysses_degree,
            )
        else:
            rank = 0  # 20250316 pftq: Default rank for single GPU
            world_size = 1  # 20250316 pftq: Default world_size for single GPU
            if device is None:
                device = "cuda" if torch.cuda.is_available() else "cpu"
    
        parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
        torch.set_grad_enabled(False)
    
        # ========================================================================
        # Build main model, VAE, and text encoder sequentially on rank 0
        # ========================================================================
        # 20250316 pftq: Load models only on rank 0, then broadcast
        if rank == 0:
            logger.info("Building model...")
            factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
            if args.i2v_mode and args.i2v_condition_type == "latent_concat":
                in_channels = args.latent_channels * 2 + 1
                image_embed_interleave = 2
            elif args.i2v_mode and args.i2v_condition_type == "token_replace":
                in_channels = args.latent_channels
                image_embed_interleave = 4
            else:
                in_channels = args.latent_channels
                image_embed_interleave = 1
            out_channels = args.latent_channels
    
            if args.embedded_cfg_scale:
                factor_kwargs["guidance_embed"] = True
    
            model = load_model(
                args,
                in_channels=in_channels,
                out_channels=out_channels,
                factor_kwargs=factor_kwargs,
            )
    
            if args.use_fp8:
                convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision])
            model = model.to(device)
            model = Inference.load_state_dict(args, model, pretrained_model_path)
            model.eval()
    
            # VAE
            vae, _, s_ratio, t_ratio = load_vae(
                args.vae,
                args.vae_precision,
                logger=logger,
                device=device if not args.use_cpu_offload else "cpu",
            )
            vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
    
            # Text encoder
            if args.i2v_mode:
                args.text_encoder = "llm-i2v"
                args.tokenizer = "llm-i2v"
                args.prompt_template = "dit-llm-encode-i2v"
                args.prompt_template_video = "dit-llm-encode-video-i2v"
    
            if args.prompt_template_video is not None:
                crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
            elif args.prompt_template is not None:
                crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
            else:
                crop_start = 0
            max_length = args.text_len + crop_start
    
            prompt_template = PROMPT_TEMPLATE[args.prompt_template] if args.prompt_template is not None else None
            prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
    
            text_encoder = TextEncoder(
                text_encoder_type=args.text_encoder,
                max_length=max_length,
                text_encoder_precision=args.text_encoder_precision,
                tokenizer_type=args.tokenizer,
                i2v_mode=args.i2v_mode,
                prompt_template=prompt_template,
                prompt_template_video=prompt_template_video,
                hidden_state_skip_layer=args.hidden_state_skip_layer,
                apply_final_norm=args.apply_final_norm,
                reproduce=args.reproduce,
                logger=logger,
                device=device if not args.use_cpu_offload else "cpu",
                image_embed_interleave=image_embed_interleave
            )
            text_encoder_2 = None
            if args.text_encoder_2 is not None:
                text_encoder_2 = TextEncoder(
                    text_encoder_type=args.text_encoder_2,
                    max_length=args.text_len_2,
                    text_encoder_precision=args.text_encoder_precision_2,
                    tokenizer_type=args.tokenizer_2,
                    reproduce=args.reproduce,
                    logger=logger,
                    device=device if not args.use_cpu_offload else "cpu",
                )
        else:
            # 20250316 pftq: Initialize as None on non-zero ranks
            model = None
            vae = None
            vae_kwargs = None
            text_encoder = None
            text_encoder_2 = None
    
        # 20250316 pftq: Broadcast models to all ranks
        if world_size > 1:
            logger.info(f"Rank {rank}: Starting broadcast synchronization")
            dist.barrier()  # Ensure rank 0 finishes loading before broadcasting
            if rank != 0:
                # Reconstruct model skeleton on non-zero ranks
                factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
                if args.i2v_mode and args.i2v_condition_type == "latent_concat":
                    in_channels = args.latent_channels * 2 + 1
                    image_embed_interleave = 2
                elif args.i2v_mode and args.i2v_condition_type == "token_replace":
                    in_channels = args.latent_channels
                    image_embed_interleave = 4
                else:
                    in_channels = args.latent_channels
                    image_embed_interleave = 1
                out_channels = args.latent_channels
                if args.embedded_cfg_scale:
                    factor_kwargs["guidance_embed"] = True
                model = load_model(args, in_channels=in_channels, out_channels=out_channels, factor_kwargs=factor_kwargs).to(device)
                vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device=device if not args.use_cpu_offload else "cpu")
                vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
                vae = vae.to(device)
                if args.i2v_mode:
                    args.text_encoder = "llm-i2v"
                    args.tokenizer = "llm-i2v"
                    args.prompt_template = "dit-llm-encode-i2v"
                    args.prompt_template_video = "dit-llm-encode-video-i2v"
                if args.prompt_template_video is not None:
                    crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
                elif args.prompt_template is not None:
                    crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
                else:
                    crop_start = 0
                max_length = args.text_len + crop_start
                prompt_template = PROMPT_TEMPLATE[args.prompt_template] if args.prompt_template is not None else None
                prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None
                text_encoder = TextEncoder(
                    text_encoder_type=args.text_encoder,
                    max_length=max_length,
                    text_encoder_precision=args.text_encoder_precision,
                    tokenizer_type=args.tokenizer,
                    i2v_mode=args.i2v_mode,
                    prompt_template=prompt_template,
                    prompt_template_video=prompt_template_video,
                    hidden_state_skip_layer=args.hidden_state_skip_layer,
                    apply_final_norm=args.apply_final_norm,
                    reproduce=args.reproduce,
                    logger=logger,
                    device=device if not args.use_cpu_offload else "cpu",
                    image_embed_interleave=image_embed_interleave
                ).to(device)
                text_encoder_2 = None
                if args.text_encoder_2 is not None:
                    text_encoder_2 = TextEncoder(
                        text_encoder_type=args.text_encoder_2,
                        max_length=args.text_len_2,
                        text_encoder_precision=args.text_encoder_precision_2,
                        tokenizer_type=args.tokenizer_2,
                        reproduce=args.reproduce,
                        logger=logger,
                        device=device if not args.use_cpu_offload else "cpu",
                    ).to(device)
    
            # Broadcast model parameters with logging
            logger.info(f"Rank {rank}: Broadcasting model parameters")
            for param in model.parameters():
                dist.broadcast(param.data, src=0)
            model.eval()
            logger.info(f"Rank {rank}: Broadcasting VAE parameters")
            for param in vae.parameters():
                dist.broadcast(param.data, src=0)
            # 20250316 pftq: Use broadcast_object_list for vae_kwargs
            logger.info(f"Rank {rank}: Broadcasting vae_kwargs")
            vae_kwargs_list = [vae_kwargs] if rank == 0 else [None]
            dist.broadcast_object_list(vae_kwargs_list, src=0)
            vae_kwargs = vae_kwargs_list[0]
            logger.info(f"Rank {rank}: Broadcasting text_encoder parameters")
            for param in text_encoder.parameters():
                dist.broadcast(param.data, src=0)
            if text_encoder_2 is not None:
                logger.info(f"Rank {rank}: Broadcasting text_encoder_2 parameters")
                for param in text_encoder_2.parameters():
                    dist.broadcast(param.data, src=0)
    
        return cls(
            args=args,
            vae=vae,
            vae_kwargs=vae_kwargs,
            text_encoder=text_encoder,
            text_encoder_2=text_encoder_2,
            model=model,
            use_cpu_offload=args.use_cpu_offload,
            device=device,
            logger=logger,
            parallel_args=parallel_args
        )

@pftq pftq changed the title Multi-GPU Model Loading Issues Multi-GPU Model Loading Issues - Issue & Code Fix Mar 18, 2025
@TianQi-777
Copy link
Collaborator

Thanks for your suggestion. In our test environment, multiple GPUs can load the model in parallel without encountering the problem mentioned in the issue. We suspect that this problem may occur on certain specific GPU or test environments. To be compatible with these situations, we have merged this solution into the main repository and appended it to our community contribution list.

@pftq
Copy link
Contributor Author

pftq commented Mar 21, 2025

Good to hear and glad to help!

@pftq
Copy link
Contributor Author

pftq commented Mar 28, 2025

@TianQi-777 Just wanted to make sure you were aware of my comment here:
#37 (comment)

It should only be a good thing (only kicks in after the 192-frame limit) but just making sure you guys were aware it's there. And maybe good to credit thu-ml as well for the original discovery (although I still did some work to implement it for the repo here).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants