Skip to content

[Bugifx] Remove TritonPlaceholder from sys.modules #17317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 2, 2025

Conversation

Isotr0py
Copy link
Collaborator

@Isotr0py Isotr0py commented Apr 28, 2025

FIX #17309 (link existing issues this PR will resolve)

Tested on x86 CPU environment by uninstalling triton intentionally:

$ python examples/offline_inference/basic/basic.py 
INFO 04-29 01:23:25 [importing.py:19] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 04-29 01:23:27 [importing.py:60] Ignore import error when loading torch._inductor.runtime.triton_helpers: No module named 'triton'
WARNING 04-29 01:23:27 [importing.py:30] Triton is not installed. Using dummy decorators. Install it via `pip install triton` to enable kernelcompilation.
INFO 04-29 01:23:27 [importing.py:74] Triton module has been replaced with a placeholder.
INFO 04-29 01:23:30 [__init__.py:239] Automatically detected platform cpu.
INFO 04-29 01:23:41 [config.py:716] This model supports multiple tasks: {'embed', 'generate', 'reward', 'score', 'classify'}. Defaulting to 'generate'.
WARNING 04-29 01:23:41 [_logger.py:72] device type=cpu is not supported by the V1 Engine. Falling back to V0. 
INFO 04-29 01:23:41 [config.py:1772] Disabled the custom all-reduce kernel because it is not supported on current platform.
WARNING 04-29 01:23:41 [_logger.py:72] Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) for CPU backend is not set, using 4 by default.
WARNING 04-29 01:23:41 [_logger.py:72] uni is not supported on CPU, fallback to mp distributed executor backend.
INFO 04-29 01:23:41 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5.dev277+gb90093e42) with config: model='/data/LLM-model/opt-125m', speculative_config=None, tokenizer='/data/LLM-model/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=/data/LLM-model/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False, 
INFO 04-29 01:23:41 [cpu.py:45] Using Torch SDPA backend.
INFO 04-29 01:23:41 [parallel_state.py:1004] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.47it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.47it/s]

INFO 04-29 01:23:42 [loader.py:458] Loading weights took 0.30 seconds
INFO 04-29 01:23:42 [executor_base.py:112] # cpu blocks: 910, # CPU blocks: 0
INFO 04-29 01:23:42 [executor_base.py:117] Maximum concurrency for 2048 tokens per request: 56.88x
INFO 04-29 01:23:42 [llm_engine.py:437] init engine (profile, create kv cache, warmup model) took 0.45 seconds
Processed prompts:   0%|                                                                                                            | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]WARNING 04-29 01:23:42 [_logger.py:72] Pin memory is not supported on CPU.
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.45it/s, est. speed input: 22.42 toks/s, output: 55.19 toks/s]

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@Isotr0py Isotr0py requested a review from youkaichao April 28, 2025 17:18
Signed-off-by: Isotr0py <2037008807@qq.com>
"Ignore import error when loading " \
"%s: %s", module_info.name, e)
continue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem seems more like TritonPlaceholder is a non-standard way of doing things, it's generally not OK to monkey patch a module like that. I don't know why we needed TritonPlaceholder -- is there something else we can do there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why we needed TritonPlaceholder

TritonPlaceholder is introduced in #15099 (comment) to avoid unintentional triton import error on non-GPU backend when import some module contains unused Triton ops (triton import has broken them several times).

The root issue is that TritonPlaceholder is added to sys.modules before torch.compile calling, which cause the torch triton check _is_triton_available return negative True.

In torch2.7, I think we can simply hack _is_triton_available together with TritonPlaceholder, if torch has migrated all triton checks to use this method. However, unfortunately, torch2.6 still has some triton checks not using this method, which is always a conflict with TritonPlaceholder: https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/_inductor/runtime/triton_heuristics.py#L52-L58

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution is further implement the TritonPlaceholder to be compatible with torch's triton check. Let me try to refactor the TritonPlaceholder with this way....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution is further implement the TritonPlaceholder to be compatible with torch's triton check. Let me try to refactor the TritonPlaceholder with this way....

@Isotr0py sorry for this bug, I noticed this bug, too. I think we can simulate the behavior of torch._inductor when it can't find triton in TritonPlaceholder, welcome any discuss or assign to solve this.
https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/_inductor/runtime/triton_heuristics.py#L81-L94

Copy link
Collaborator

@zou3519 zou3519 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't vLLM have try: import triton except: blocks whenever it imports triton? Putting a placeholder that is not actually triton into sys.modules breaks all third-party libraries that are doing standard things with triton (e.g. try: import triton except: , or even just accessing fields that are not in your placeholder)

Copy link
Collaborator Author

@Isotr0py Isotr0py Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't vLLM have try: import triton except: blocks whenever it imports triton?

I found that even if it is a script that only contains triton functions, it will be referenced by more files, scattering in all the project. Modifying these files will introduce greater changes.

Hmmm, refer to #15099 (comment), using try: import triton except: blocks will introduce great changes across the whole project, and it's also possible to make non-GPU backend broken again once developer forgot to add this block when they implement a triton function.

I prefer to simulate the behavior of torch._inductor.runtime when it can't find triton in TritonPlaceholder, because it looks like that the triton check of torch._inductor is also putting some placeholders when it can't find triton.

Copy link
Collaborator

@zou3519 zou3519 Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly prefer not monkey-patching torch._inductor.runtime.

I prefer to simulate the behavior of torch._inductor.runtime when it can't find triton in TritonPlaceholder, because it looks like that the triton check of torch._inductor is also putting some placeholders when it can't find triton.

There's a difference between putting a placeholder locally (that is what torch._inductor does) vs globally (that is what vLLM does, by modifying sys.modules["triton"]). The difference is that local placeholders don't break other libraries, while the global placeholder (sys.modules["triton"]) does. We should fix this if we can. One idea is that we refactor all of vLLM's triton imports to import triton from a single file in vLLM that does the try-catch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would strongly prefer not monkey-patching torch._inductor.runtime.

I prefer to simulate the behavior of torch._inductor.runtime when it can't find triton in TritonPlaceholder, because it looks like that the triton check of torch._inductor is also putting some placeholders when it can't find triton.

There's a difference between putting a placeholder locally (that is what torch._inductor does) vs globally (that is what vLLM does, by modifying sys.modules["triton"]). The difference is that local placeholders don't break other libraries, while the global placeholder (sys.modules["triton"]) does. We should fix this if we can. One idea is that we refactor all of vLLM's triton imports to import triton from a single file in vLLM that does the try-catch.

This looks reseanable, and using try-catch would be safer. However importing triton from single file in vLLM will break the code static check.

And it indeed need much more changes in vLLM, and we should make dummy decorators for triton.jit, etc. You can see #17446 as a reference, I happened to keep the previous code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm thinking if we can introduce some import hooks instead of putting placeholders to sys.modules. So that we can limit the placeholder only be used under vllm project.

Signed-off-by: Isotr0py <2037008807@qq.com>
Comment on lines 74 to 83
# Replace triton submodules with dummy objects to keep compatibility with
# torch triton check
triton_modules_with_objects = {
"triton.compiler": ["CompiledKernel"],
"triton.runtime.autotuner": ["OutOfResources"],
"triton.runtime.jit": ["KernelInterface"],
}
for module_name, dummy_objects in triton_modules_with_objects.items():
sys.modules[module_name] = TritonModulePlaceholder(
module_name, dummy_objects=dummy_objects)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MengqingCao We can simulate the behavior of torch._inductor.runtim.triton_heuristics, but it's a little bit inflexible, is there any better method to achieve it?

Copy link
Contributor

@MengqingCao MengqingCao Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use a recursive method to define the attrs of the placeholder. However it's difficult to define its func __call__, because the triton related placeholders in torch._inductor.runtim.triton_heuristics are different (some of them are assigned None, and the others are Object). And maybe there exsists other type of placeholders in torch inductor.

    class TritonModulePlaceholder(types.ModuleType):

        def __init__(
            self,
            name: str,
            dummy_objects: Optional[list[str]] = None,
        ):
            super().__init__(name)

            if dummy_objects is not None:
                for obj_name in dummy_objects:
                    setattr(self, obj_name, object)
        def __getattr__(self, name):
            return TritonModulePlaceholder(f"{self.__name__}.{name}")

Do you have any idea based on this? @Isotr0py

@davidxia
Copy link
Contributor

davidxia commented May 1, 2025

This PR (commit 8eccfbb) fixes vllm serve for me on Apple silicon. Specifically vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 and requests to /metrics, /docs, and /v1/chat/completions. 🙌

Signed-off-by: Isotr0py <2037008807@qq.com>
@Isotr0py Isotr0py changed the title [Bugifx] Fix conflicts between Triton placeholder and torch Triton checks [Bugifx] Remove TritonPlaceholder from sys.modules May 1, 2025
@Isotr0py
Copy link
Collaborator Author

Isotr0py commented May 1, 2025

@zou3519 @MengqingCao Since lots of non-GPU backend is broken due to this issue currently. I decided to remove sys.modules in this PR (partially revert #15099) for a quick fix firstly.

I keep the TritonPlaceholder since it can still be used as local placeholder like this:

from vllm.triton_utils.importing import HAS_TRITON, TritonPlaceholder, TritonLanguagePlaceholder

if HAS_TRITON:
    import triton
    import triton.language as tl
else:
    triton = TritonPlaceholder()
    tl = TritonLanguagePlaceholder()

Anyway, we can leave this modification to be done in #17446.

Copy link
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partial fix looks good to me, thank you

@zou3519 zou3519 requested review from houseroad and WoosukKwon May 2, 2025 01:01
@MengqingCao
Copy link
Contributor

@zou3519 @MengqingCao Since lots of non-GPU backend is broken due to this issue currently. I decided to remove sys.modules in this PR (partially revert #15099) for a quick fix firstly.

I keep the TritonPlaceholder since it can still be used as local placeholder like this:

from vllm.triton_utils.importing import HAS_TRITON, TritonPlaceholder, TritonLanguagePlaceholder

if HAS_TRITON:
    import triton
    import triton.language as tl
else:
    triton = TritonPlaceholder()
    tl = TritonLanguagePlaceholder()

Anyway, we can leave this modification to be done in #17446.

Okay, I will make this change in #17446 after merging this PR. LGTM now, thanks!

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the test plan with this partial fix?

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label May 2, 2025
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix the non-GPU breakage first.

@houseroad houseroad merged commit 9e2de9b into vllm-project:main May 2, 2025
61 checks passed
@Isotr0py Isotr0py deleted the fix-triton-placeholder branch May 2, 2025 07:57
radeksm pushed a commit to radeksm/vllm that referenced this pull request May 2, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Shafi-Hussain pushed a commit to Shafi-Hussain/vllm that referenced this pull request May 30, 2025
Shafi-Hussain pushed a commit to odh-on-pz/vllm-downstream that referenced this pull request Jun 2, 2025
ckhordiasma added a commit to red-hat-data-services/vllm that referenced this pull request Jun 2, 2025
[Bugifx] Remove TritonPlaceholder from sys.modules (vllm-project#17317)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: triton placeholder is conflicting with pytorch's triton checks
5 participants