Skip to content

[Bug]: Inference failed using enable_prefix_caching=True in 0.73rc2 #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Potabk opened this issue Mar 31, 2025 · 7 comments
Closed

[Bug]: Inference failed using enable_prefix_caching=True in 0.73rc2 #447

Potabk opened this issue Mar 31, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@Potabk
Copy link
Contributor

Potabk commented Mar 31, 2025

Your current environment

The output of `python collect_env.py`
Your output of above commands here

🐛 Describe the bug

I am using the vllm-ascend v0.7.3rc2 image quay.io/ascend/vllm-ascend:v0.7.3rc2 to test feature of Automatic Prefix Caching, and my test script runs as follow:

import time
from vllm import LLM, SamplingParams


# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
| ID  | Name          | Age | Occupation    | Country       | Email                  | Phone Number   | Address                       |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 1   | John Doe      | 29  | Engineer      | USA           | john.doe@example.com   | 555-1234       | 123 Elm St, Springfield, IL  |
| 2   | Jane Smith    | 34  | Doctor        | Canada        | jane.smith@example.com | 555-5678       | 456 Oak St, Toronto, ON      |
| 3   | Alice Johnson | 27  | Teacher       | UK            | alice.j@example.com    | 555-8765       | 789 Pine St, London, UK      |
| 4   | Bob Brown     | 45  | Artist        | Australia     | bob.b@example.com      | 555-4321       | 321 Maple St, Sydney, NSW    |
| 5   | Carol White   | 31  | Scientist     | New Zealand   | carol.w@example.com    | 555-6789       | 654 Birch St, Wellington, NZ |
| 6   | Dave Green    | 28  | Lawyer        | Ireland       | dave.g@example.com     | 555-3456       | 987 Cedar St, Dublin, IE     |
| 7   | Emma Black    | 40  | Musician      | USA           | emma.b@example.com     | 555-1111       | 246 Ash St, New York, NY     |
| 8   | Frank Blue    | 37  | Chef          | Canada        | frank.b@example.com    | 555-2222       | 135 Spruce St, Vancouver, BC |
| 9   | Grace Yellow  | 50  | Engineer      | UK            | grace.y@example.com    | 555-3333       | 864 Fir St, Manchester, UK   |
| 10  | Henry Violet  | 32  | Artist        | Australia     | henry.v@example.com    | 555-4444       | 753 Willow St, Melbourne, VIC|
| 11  | Irene Orange  | 26  | Scientist     | New Zealand   | irene.o@example.com    | 555-5555       | 912 Poplar St, Auckland, NZ  |
| 12  | Jack Indigo   | 38  | Teacher       | Ireland       | jack.i@example.com     | 555-6666       | 159 Elm St, Cork, IE         |
| 13  | Karen Red     | 41  | Lawyer        | USA           | karen.r@example.com    | 555-7777       | 357 Cedar St, Boston, MA     |
| 14  | Leo Brown     | 30  | Chef          | Canada        | leo.b@example.com      | 555-8888       | 246 Oak St, Calgary, AB      |
| 15  | Mia Green     | 33  | Musician      | UK            | mia.g@example.com      | 555-9999       | 975 Pine St, Edinburgh, UK   |
| 16  | Noah Yellow   | 29  | Doctor        | Australia     | noah.y@example.com     | 555-0000       | 864 Birch St, Brisbane, QLD  |
| 17  | Olivia Blue   | 35  | Engineer      | New Zealand   | olivia.b@example.com   | 555-1212       | 753 Maple St, Hamilton, NZ   |
| 18  | Peter Black   | 42  | Artist        | Ireland       | peter.b@example.com    | 555-3434       | 912 Fir St, Limerick, IE     |
| 19  | Quinn White   | 28  | Scientist     | USA           | quinn.w@example.com    | 555-5656       | 159 Willow St, Seattle, WA   |
| 20  | Rachel Red    | 31  | Teacher       | Canada        | rachel.r@example.com   | 555-7878       | 357 Poplar St, Ottawa, ON    |
| 21  | Steve Green   | 44  | Lawyer        | UK            | steve.g@example.com    | 555-9090       | 753 Elm St, Birmingham, UK   |
| 22  | Tina Blue     | 36  | Musician      | Australia     | tina.b@example.com     | 555-1213       | 864 Cedar St, Perth, WA      |
| 23  | Umar Black    | 39  | Chef          | New Zealand   | umar.b@example.com     | 555-3435       | 975 Spruce St, Christchurch, NZ|
| 24  | Victor Yellow | 43  | Engineer      | Ireland       | victor.y@example.com   | 555-5657       | 246 Willow St, Galway, IE    |
| 25  | Wendy Orange  | 27  | Artist        | USA           | wendy.o@example.com    | 555-7879       | 135 Elm St, Denver, CO       |
| 26  | Xavier Green  | 34  | Scientist     | Canada        | xavier.g@example.com   | 555-9091       | 357 Oak St, Montreal, QC     |
| 27  | Yara Red      | 41  | Teacher       | UK            | yara.r@example.com     | 555-1214       | 975 Pine St, Leeds, UK       |
| 28  | Zack Blue     | 30  | Lawyer        | Australia     | zack.b@example.com     | 555-3436       | 135 Birch St, Adelaide, SA   |
| 29  | Amy White     | 33  | Musician      | New Zealand   | amy.w@example.com      | 555-5658       | 159 Maple St, Wellington, NZ |
| 30  | Ben Black     | 38  | Chef          | Ireland       | ben.b@example.com      | 555-7870       | 246 Fir St, Waterford, IE    |
"""


def get_generation_time(llm, sampling_params, prompts):
    # time the generation
    start_time = time.time()
    output = llm.generate(prompts, sampling_params=sampling_params)
    end_time = time.time()
    # print the output and generation time
    print(f"Output: {output[0].outputs[0].text}")
    print(f"Generation time: {end_time - start_time} seconds.")


# set enable_prefix_caching=True to enable APC
llm = LLM(
    model='lmsys/longchat-13b-16k',
    enable_prefix_caching=True
)

sampling_params = SamplingParams(temperature=0, max_tokens=100)

# Querying the age of John Doe
get_generation_time(
    llm,
    sampling_params,
    LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
)

# Querying the age of Zack Blue
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
get_generation_time(
    llm,
    sampling_params,
    LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
)

and the error trace as follow:

/usr/local/python3.10/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
WARNING 03-31 11:49:53 utils.py:2262] Methods add_lora,add_prompt_adapter,cache_config,compilation_config,current_platform,list_loras,list_prompt_adapters,load_config,pin_lora,pin_prompt_adapter,remove_lora,remove_prompt_adapter not implemented in <vllm_ascend.worker.worker.NPUWorker object at 0xfffd3499d6f0>
INFO 03-31 11:50:01 weight_utils.py:254] Using model weights format ['*.bin']
INFO 03-31 11:50:01 weight_utils.py:270] Time spent downloading weights for lmsys/longchat-13b-16k: 0.582296 seconds
Loading pt checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading pt checkpoint shards:  33% Completed | 1/3 [00:08<00:16,  8.32s/it]
Loading pt checkpoint shards:  67% Completed | 2/3 [00:21<00:11, 11.35s/it]
Loading pt checkpoint shards: 100% Completed | 3/3 [00:35<00:00, 12.35s/it]
Loading pt checkpoint shards: 100% Completed | 3/3 [00:35<00:00, 11.78s/it]

INFO 03-31 11:50:42 executor_base.py:111] # npu blocks: 284, # CPU blocks: 40
INFO 03-31 11:50:42 executor_base.py:116] Maximum concurrency for 16384 tokens per request: 2.22x
INFO 03-31 11:50:43 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 5.45 seconds
Processed prompts: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.30it/s, est. speed input: 4275.92 toks/s, output: 9.21 toks/s]
Output: 29.
Generation time: 0.45357656478881836 seconds.
Processed prompts:   0%|                                                           | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]mki_log mkdir /root/atb/
mki_log mkdir /root/atb/log
[rank0]:[E331 11:50:43.800583278 compiler_depend.ts:422] setup failed!
Exception raised from OperationSetup at build/third_party/op-plugin/op_plugin/CMakeFiles/op_plugin_atb.dir/compiler_depend.ts:131 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0xb8 (0xffff9798c908 in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x70 (0xffff9793b4e0 in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: atb::OperationSetup(atb::VariantPack, atb::Operation*, atb::Context*) + 0x8c (0xfffe8870426c in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libop_plugin_atb.so)
frame #3: <unknown function> + 0x34934 (0xfffe88704934 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libop_plugin_atb.so)
frame #4: <unknown function> + 0x15c95c4 (0xfffdefb195c4 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #5: <unknown function> + 0x758c84 (0xfffdeeca8c84 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #6: <unknown function> + 0x759498 (0xfffdeeca9498 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #7: <unknown function> + 0x75617c (0xfffdeeca617c in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #8: <unknown function> + 0x4c9e4c (0xffff979c9e4c in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #9: <unknown function> + 0x7d5b8 (0xffffa232d5b8 in /lib/aarch64-linux-gnu/libc.so.6)
frame #10: <unknown function> + 0xe5edc (0xffffa2395edc in /lib/aarch64-linux-gnu/libc.so.6)

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/demo.py", line 69, in <module>
[rank0]:     get_generation_time(
[rank0]:   File "/workspace/demo.py", line 45, in get_generation_time
[rank0]:     output = llm.generate(prompts, sampling_params=sampling_params)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/utils.py", line 1057, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 469, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 1397, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1391, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 139, in execute_model
[rank0]:     output = self.collective_rpc("execute_model",
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
[rank0]:     answer = run_method(self.driver_worker, method, args, kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/utils.py", line 2196, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 420, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm_ascend/worker/model_runner.py", line 1146, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 547, in forward
[rank0]:     model_output = self.model(input_ids, positions, kv_caches,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 172, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 368, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 282, in forward
[rank0]:     hidden_states = self.self_attn(positions=positions,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 206, in forward
[rank0]:     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/attention/layer.py", line 198, in forward
[rank0]:     return self.impl.forward(self, query, key, value,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm_ascend/attention/attention.py", line 734, in forward
[rank0]:     torch_npu._npu_reshape_and_cache(key=key,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]: RuntimeError: The Inner error is reported as above. The process exits for this inner error, and the current working operator name is SelfAttentionOperation.
[rank0]: Since the operator is called asynchronously, the stacktrace may be inaccurate. If you want to get the accurate stacktrace, pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.
[rank0]: Note: ASCEND_LAUNCH_BLOCKING=1 will force ops to run in synchronous mode, resulting in performance degradation. Please unset ASCEND_LAUNCH_BLOCKING in time after debugging.
[rank0]: [ERROR] 2025-03-31-11:50:43 (PID:516, Device:0, RankID:-1) ERR00100 PTA call acl api failed.

the device debug log:

[ERROR] APP(516,python):2025-03-31-11:50:43.628.481 [log_inner.cpp:76]aclAppLog call vsnprintf_s failed
[ERROR] APP(516,python):2025-03-31-11:50:43.628.507 [log_inner.cpp:76]881 build/CMakeFiles/torch_npu.dir/compiler_depend.ts:ExecFuncOpApi:428: "[PTA]:"Custom hand fail! name=SelfAttentionOperation, ret=0x0x186a0""
[ERROR] APP(516,python):2025-03-31-11:50:43.630.606 [log_inner.cpp:76]881 build/CMakeFiles/torch_npu.dir/compiler_depend.ts:ReadQueue:381: "[PTA]:"---Thread---281468302389536: device = 0, write_idx = 3509, read_idx = 3466, status = 1, ret = 100000""
@Potabk Potabk added the bug Something isn't working label Mar 31, 2025
@Potabk
Copy link
Contributor Author

Potabk commented Mar 31, 2025

And When I reduce the prefix input length,

LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
| ID  | Name          | Age | Occupation    | Country       | Email                  | Phone Number   | Address                       |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 2   | Jane Smith    | 34  | Doctor        | Canada        | jane.smith@example.com | 555-5678       | 456 Oak St, Toronto, ON      |
"""

An error occurred:

/usr/local/python3.10/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
WARNING 03-31 12:23:18 utils.py:2262] Methods add_lora,add_prompt_adapter,cache_config,compilation_config,current_platform,list_loras,list_prompt_adapters,load_config,pin_lora,pin_prompt_adapter,remove_lora,remove_prompt_adapter not implemented in <vllm_ascend.worker.worker.NPUWorker object at 0xfffd17a7d630>
INFO 03-31 12:23:26 weight_utils.py:254] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading pt checkpoint shards:  33% Completed | 1/3 [00:05<00:11,  5.81s/it]
Loading pt checkpoint shards:  67% Completed | 2/3 [00:15<00:08,  8.28s/it]
Loading pt checkpoint shards: 100% Completed | 3/3 [00:24<00:00,  8.54s/it]
Loading pt checkpoint shards: 100% Completed | 3/3 [00:24<00:00,  8.22s/it]

INFO 03-31 12:23:56 executor_base.py:111] # npu blocks: 284, # CPU blocks: 40
INFO 03-31 12:23:56 executor_base.py:116] Maximum concurrency for 16384 tokens per request: 2.22x
INFO 03-31 12:23:57 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 5.39 seconds
Processed prompts: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.97it/s, est. speed input: 787.86 toks/s, output: 19.94 toks/s]
Output: 34.
Generation time: 0.21849370002746582 seconds.
Processed prompts:   0%|                                                           | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]mki_log log dir:/root/atb/log exist
[rank0]:[E331 12:23:57.516555563 compiler_depend.ts:422] setup failed!
Exception raised from OperationSetup at build/third_party/op-plugin/op_plugin/CMakeFiles/op_plugin_atb.dir/compiler_depend.ts:131 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0xb8 (0xffff7aa6c908 in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x70 (0xffff7aa1b4e0 in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: atb::OperationSetup(atb::VariantPack, atb::Operation*, atb::Context*) + 0x8c (0xfffe6b7e426c in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libop_plugin_atb.so)
frame #3: <unknown function> + 0x34934 (0xfffe6b7e4934 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libop_plugin_atb.so)
frame #4: <unknown function> + 0x15c95c4 (0xfffdd2bf95c4 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #5: <unknown function> + 0x758c84 (0xfffdd1d88c84 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #6: <unknown function> + 0x759498 (0xfffdd1d89498 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #7: <unknown function> + 0x75617c (0xfffdd1d8617c in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #8: <unknown function> + 0x4c9e4c (0xffff7aaa9e4c in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #9: <unknown function> + 0x7d5b8 (0xffff8540d5b8 in /lib/aarch64-linux-gnu/libc.so.6)
frame #10: <unknown function> + 0xe5edc (0xffff85475edc in /lib/aarch64-linux-gnu/libc.so.6)

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/demo.py", line 40, in <module>
[rank0]:     get_generation_time(
[rank0]:   File "/workspace/demo.py", line 16, in get_generation_time
[rank0]:     output = llm.generate(prompts, sampling_params=sampling_params)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/utils.py", line 1057, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 469, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 1397, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1391, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 139, in execute_model
[rank0]:     output = self.collective_rpc("execute_model",
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
[rank0]:     answer = run_method(self.driver_worker, method, args, kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/utils.py", line 2196, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 420, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm_ascend/worker/model_runner.py", line 1146, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 547, in forward
[rank0]:     model_output = self.model(input_ids, positions, kv_caches,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 172, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 368, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 282, in forward
[rank0]:     hidden_states = self.self_attn(positions=positions,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 205, in forward
[rank0]:     q, k = self.rotary_emb(positions, q, k)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/custom_op.py", line 25, in forward
[rank0]:     return self._forward_method(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm_ascend/ops/rotary_embedding.py", line 43, in rope_forward_oot
[rank0]:     key = key.contiguous()
[rank0]: RuntimeError: The Inner error is reported as above. The process exits for this inner error, and the current working operator name is SelfAttentionOperation.
[rank0]: Since the operator is called asynchronously, the stacktrace may be inaccurate. If you want to get the accurate stacktrace, pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.
[rank0]: Note: ASCEND_LAUNCH_BLOCKING=1 will force ops to run in synchronous mode, resulting in performance degradation. Please unset ASCEND_LAUNCH_BLOCKING in time after debugging.
[rank0]: [ERROR] 2025-03-31-12:23:57 (PID:3396, Device:0, RankID:-1) ERR00100 PTA call acl api failed.

@Potabk
Copy link
Contributor Author

Potabk commented Mar 31, 2025

And When I reduce the prefix input length,

LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
| ID  | Name          | Age | Occupation    | Country       | Email                  | Phone Number   | Address                       |
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
| 2   | Jane Smith    | 34  | Doctor        | Canada        | jane.smith@example.com | 555-5678       | 456 Oak St, Toronto, ON      |
"""

An error occurred:

/usr/local/python3.10/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:247: RuntimeWarning: torch.jit.script and torch.jit.script_method will be disabled by transfer_to_npu, which currently does not support them, if you need to enable them, please do not use transfer_to_npu.
  warnings.warn(msg, RuntimeWarning)
WARNING 03-31 12:23:18 utils.py:2262] Methods add_lora,add_prompt_adapter,cache_config,compilation_config,current_platform,list_loras,list_prompt_adapters,load_config,pin_lora,pin_prompt_adapter,remove_lora,remove_prompt_adapter not implemented in <vllm_ascend.worker.worker.NPUWorker object at 0xfffd17a7d630>
INFO 03-31 12:23:26 weight_utils.py:254] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Loading pt checkpoint shards:  33% Completed | 1/3 [00:05<00:11,  5.81s/it]
Loading pt checkpoint shards:  67% Completed | 2/3 [00:15<00:08,  8.28s/it]
Loading pt checkpoint shards: 100% Completed | 3/3 [00:24<00:00,  8.54s/it]
Loading pt checkpoint shards: 100% Completed | 3/3 [00:24<00:00,  8.22s/it]

INFO 03-31 12:23:56 executor_base.py:111] # npu blocks: 284, # CPU blocks: 40
INFO 03-31 12:23:56 executor_base.py:116] Maximum concurrency for 16384 tokens per request: 2.22x
INFO 03-31 12:23:57 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 5.39 seconds
Processed prompts: 100%|████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.97it/s, est. speed input: 787.86 toks/s, output: 19.94 toks/s]
Output: 34.
Generation time: 0.21849370002746582 seconds.
Processed prompts:   0%|                                                           | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]mki_log log dir:/root/atb/log exist
[rank0]:[E331 12:23:57.516555563 compiler_depend.ts:422] setup failed!
Exception raised from OperationSetup at build/third_party/op-plugin/op_plugin/CMakeFiles/op_plugin_atb.dir/compiler_depend.ts:131 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0xb8 (0xffff7aa6c908 in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x70 (0xffff7aa1b4e0 in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #2: atb::OperationSetup(atb::VariantPack, atb::Operation*, atb::Context*) + 0x8c (0xfffe6b7e426c in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libop_plugin_atb.so)
frame #3: <unknown function> + 0x34934 (0xfffe6b7e4934 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libop_plugin_atb.so)
frame #4: <unknown function> + 0x15c95c4 (0xfffdd2bf95c4 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #5: <unknown function> + 0x758c84 (0xfffdd1d88c84 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #6: <unknown function> + 0x759498 (0xfffdd1d89498 in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #7: <unknown function> + 0x75617c (0xfffdd1d8617c in /usr/local/python3.10/lib/python3.10/site-packages/torch_npu/lib/libtorch_npu.so)
frame #8: <unknown function> + 0x4c9e4c (0xffff7aaa9e4c in /usr/local/python3.10/lib/python3.10/site-packages/torch/lib/libc10.so)
frame #9: <unknown function> + 0x7d5b8 (0xffff8540d5b8 in /lib/aarch64-linux-gnu/libc.so.6)
frame #10: <unknown function> + 0xe5edc (0xffff85475edc in /lib/aarch64-linux-gnu/libc.so.6)

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/demo.py", line 40, in <module>
[rank0]:     get_generation_time(
[rank0]:   File "/workspace/demo.py", line 16, in get_generation_time
[rank0]:     output = llm.generate(prompts, sampling_params=sampling_params)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/utils.py", line 1057, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 469, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 1397, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1391, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 139, in execute_model
[rank0]:     output = self.collective_rpc("execute_model",
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
[rank0]:     answer = run_method(self.driver_worker, method, args, kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/utils.py", line 2196, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 420, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm_ascend/worker/model_runner.py", line 1146, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 547, in forward
[rank0]:     model_output = self.model(input_ids, positions, kv_caches,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 172, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 368, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 282, in forward
[rank0]:     hidden_states = self.self_attn(positions=positions,
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/models/llama.py", line 205, in forward
[rank0]:     q, k = self.rotary_emb(positions, q, k)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm/model_executor/custom_op.py", line 25, in forward
[rank0]:     return self._forward_method(*args, **kwargs)
[rank0]:   File "/usr/local/python3.10/lib/python3.10/site-packages/vllm_ascend/ops/rotary_embedding.py", line 43, in rope_forward_oot
[rank0]:     key = key.contiguous()
[rank0]: RuntimeError: The Inner error is reported as above. The process exits for this inner error, and the current working operator name is SelfAttentionOperation.
[rank0]: Since the operator is called asynchronously, the stacktrace may be inaccurate. If you want to get the accurate stacktrace, pleace set the environment variable ASCEND_LAUNCH_BLOCKING=1.
[rank0]: Note: ASCEND_LAUNCH_BLOCKING=1 will force ops to run in synchronous mode, resulting in performance degradation. Please unset ASCEND_LAUNCH_BLOCKING in time after debugging.
[rank0]: [ERROR] 2025-03-31-12:23:57 (PID:3396, Device:0, RankID:-1) ERR00100 PTA call acl api failed.

update: When I further reduce the length of the prefix input, the script looks like:

import time
from vllm import LLM, SamplingParams


# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
LONG_PROMPT = "You are a helpful assistant, and my name is joe, i'm 18 years old, please answer me a question"


def get_generation_time(llm, sampling_params, prompts):
    # time the generation
    start_time = time.time()
    output = llm.generate(prompts, sampling_params=sampling_params)
    end_time = time.time()
    # print the output and generation time
    print(f"Output: {output[0].outputs[0].text}")
    print(f"Generation time: {end_time - start_time} seconds.")


# set enable_prefix_caching=True to enable APC
llm = LLM(
    model='lmsys/longchat-13b-16k',
    enable_prefix_caching=True
)

sampling_params = SamplingParams(temperature=0, max_tokens=100)

# Querying the age of John Doe
get_generation_time(
    llm,
    sampling_params,
    LONG_PROMPT + "Question: How old am I ? ",
)

# Querying the age of Zack Blue
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
get_generation_time(
    llm,
    sampling_params,
    LONG_PROMPT + "Question: what is my name ? ",
)

it works, but the acceleration effect of APC does not seem to work:

Answer: You are 18 years old.
Generation time: 0.5930778980255127 seconds.
Processed prompts: 100%|█████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.09it/s, est. speed input: 39.35 toks/s, output: 24.05 toks/s]
Output: 

Answer: Hello! My name is Joe, and I'm 18 years old.
Generation time: 0.9169011116027832 seconds.

@rjg-lyh
Copy link
Contributor

rjg-lyh commented Mar 31, 2025

@Potabk As I tried your first long-sequence script using the Qwen2.5-7B model, it was able to generate results normally with an earlier commit. We will look into identifying the root cause of the issue.

Image

Image

@Yikun Yikun reopened this Apr 1, 2025
@wangxiyuan
Copy link
Collaborator

It needs new version of NNAL. we'll address it once it's released.

@matthewygf
Copy link

Same here, seems to come from nnal/atb Mki::LogSinkFile::DeleteOldestFile() , any fixes ?

@Potabk
Copy link
Contributor Author

Potabk commented Apr 7, 2025

Same here, seems to come from nnal/atb Mki::LogSinkFile::DeleteOldestFile() , any fixes ?

Will fix until releases of new version of NNAL

@Potabk
Copy link
Contributor Author

Potabk commented May 6, 2025

Since nnal 8.1rc1 has released, and this bug is verified repaired(see #644 ), we close this issue

@Potabk Potabk closed this as completed May 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants