|
57 | 57 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
58 | 58 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
59 | 59 |
|
| 60 | +from vllm_ascend.utils import ProfileExecuteDuration |
60 | 61 | from vllm_ascend.attention.attention import AttentionMaskBuilder
|
61 | 62 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
62 | 63 | from vllm_ascend.platform import NPUPlatform
|
@@ -640,36 +641,37 @@ def _process_reqs(
|
640 | 641 | with set_forward_context(attn_metadata,
|
641 | 642 | self.vllm_config,
|
642 | 643 | num_tokens=num_input_tokens):
|
643 |
| - model_kwargs = {} |
644 |
| - if self.enable_torchair_graph_mode: |
645 |
| - model_kwargs["kv_caches"] = self.kv_caches |
646 |
| - model_kwargs["attn_metadata"] = attn_metadata |
647 |
| - if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
648 |
| - torch._dynamo.mark_static(input_ids) |
649 |
| - torch._dynamo.mark_static(positions) |
650 |
| - torch._dynamo.mark_static(attn_metadata.decode.block_table) |
651 |
| - torch._dynamo.mark_static(attn_metadata.decode.input_positions) |
652 |
| - torch._dynamo.mark_static(attn_metadata.slot_mapping) |
653 |
| - for kv in self.kv_caches: |
654 |
| - if isinstance(kv, tuple): |
655 |
| - torch._dynamo.mark_static(kv[0]) |
656 |
| - torch._dynamo.mark_static(kv[1]) |
657 |
| - hidden_states = self.compile_model( |
658 |
| - input_ids=input_ids, |
659 |
| - positions=positions, |
660 |
| - intermediate_tensors=intermediate_tensors, |
661 |
| - inputs_embeds=None, |
662 |
| - **model_kwargs, |
663 |
| - ) |
664 |
| - else: |
665 |
| - assert self.model is not None |
666 |
| - hidden_states = self.model( |
667 |
| - input_ids=input_ids, |
668 |
| - positions=positions, |
669 |
| - intermediate_tensors=intermediate_tensors, |
670 |
| - inputs_embeds=None, |
671 |
| - **model_kwargs, |
672 |
| - ) |
| 644 | + with ProfileExecuteDuration().capture_async("forward"): |
| 645 | + model_kwargs = {} |
| 646 | + if self.enable_torchair_graph_mode: |
| 647 | + model_kwargs["kv_caches"] = self.kv_caches |
| 648 | + model_kwargs["attn_metadata"] = attn_metadata |
| 649 | + if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
| 650 | + torch._dynamo.mark_static(input_ids) |
| 651 | + torch._dynamo.mark_static(positions) |
| 652 | + torch._dynamo.mark_static(attn_metadata.decode.block_table) |
| 653 | + torch._dynamo.mark_static(attn_metadata.decode.input_positions) |
| 654 | + torch._dynamo.mark_static(attn_metadata.slot_mapping) |
| 655 | + for kv in self.kv_caches: |
| 656 | + if isinstance(kv, tuple): |
| 657 | + torch._dynamo.mark_static(kv[0]) |
| 658 | + torch._dynamo.mark_static(kv[1]) |
| 659 | + hidden_states = self.compile_model( |
| 660 | + input_ids=input_ids, |
| 661 | + positions=positions, |
| 662 | + intermediate_tensors=intermediate_tensors, |
| 663 | + inputs_embeds=None, |
| 664 | + **model_kwargs, |
| 665 | + ) |
| 666 | + else: |
| 667 | + assert self.model is not None |
| 668 | + hidden_states = self.model( |
| 669 | + input_ids=input_ids, |
| 670 | + positions=positions, |
| 671 | + intermediate_tensors=intermediate_tensors, |
| 672 | + inputs_embeds=None, |
| 673 | + **model_kwargs, |
| 674 | + ) |
673 | 675 |
|
674 | 676 | use_spec_decode = len(
|
675 | 677 | scheduler_output.scheduled_spec_decode_tokens) > 0
|
@@ -856,103 +858,109 @@ def execute_model(
|
856 | 858 | scheduler_output: "SchedulerOutput",
|
857 | 859 | intermediate_tensors: Optional[IntermediateTensors] = None,
|
858 | 860 | ) -> Union[ModelRunnerOutput, torch.Tensor]:
|
859 |
| - self._update_states(scheduler_output) |
860 |
| - if not scheduler_output.total_num_scheduled_tokens: |
861 |
| - # Return empty ModelRunnerOuptut if there's no work to do. |
862 |
| - return EMPTY_MODEL_RUNNER_OUTPUT |
863 |
| - (attn_metadata, hidden_states, spec_decode_metadata, positions, |
864 |
| - num_scheduled_tokens, |
865 |
| - sample_indices) = (self._process_reqs(scheduler_output, |
866 |
| - intermediate_tensors)) |
867 |
| - logits = self.model.compute_logits(hidden_states[sample_indices], None) |
868 |
| - |
869 |
| - # Apply structured output bitmasks if present |
870 |
| - if scheduler_output.grammar_bitmask is not None: |
871 |
| - logits = self.apply_grammar_bitmask(scheduler_output, logits) |
872 |
| - |
873 |
| - # Sample the next token and get logprobs if needed. |
874 |
| - sampling_metadata = self.input_batch.sampling_metadata |
875 |
| - if spec_decode_metadata is None: |
876 |
| - sampler_output = self.sampler( |
877 |
| - logits=logits, |
878 |
| - sampling_metadata=sampling_metadata, |
879 |
| - ) |
880 |
| - else: |
881 |
| - # When indexing with a tensor (bonus_logits_indices), PyTorch |
882 |
| - # creates a new tensor with separate storage from the original |
883 |
| - # logits tensor. This means any in-place operations on bonus_logits |
884 |
| - # won't affect the original logits tensor. |
885 |
| - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] |
886 |
| - sampler_output = self.sampler( |
887 |
| - logits=bonus_logits, |
888 |
| - sampling_metadata=sampling_metadata, |
889 |
| - ) |
890 |
| - bonus_token_ids = sampler_output.sampled_token_ids |
| 861 | + with ProfileExecuteDuration().capture_async("prepare input and forward"): |
| 862 | + self._update_states(scheduler_output) |
| 863 | + if not scheduler_output.total_num_scheduled_tokens: |
| 864 | + # Return empty ModelRunnerOuptut if there's no work to do. |
| 865 | + return EMPTY_MODEL_RUNNER_OUTPUT |
| 866 | + (attn_metadata, hidden_states, spec_decode_metadata, positions, |
| 867 | + num_scheduled_tokens, |
| 868 | + sample_indices) = (self._process_reqs(scheduler_output, |
| 869 | + intermediate_tensors)) |
| 870 | + |
| 871 | + with ProfileExecuteDuration().capture_async("post process"): |
| 872 | + logits = self.model.compute_logits(hidden_states[sample_indices], None) |
| 873 | + |
| 874 | + # Apply structured output bitmasks if present |
| 875 | + if scheduler_output.grammar_bitmask is not None: |
| 876 | + logits = self.apply_grammar_bitmask(scheduler_output, logits) |
| 877 | + |
| 878 | + # Sample the next token and get logprobs if needed. |
| 879 | + sampling_metadata = self.input_batch.sampling_metadata |
| 880 | + if spec_decode_metadata is None: |
| 881 | + sampler_output = self.sampler( |
| 882 | + logits=logits, |
| 883 | + sampling_metadata=sampling_metadata, |
| 884 | + ) |
| 885 | + else: |
| 886 | + # When indexing with a tensor (bonus_logits_indices), PyTorch |
| 887 | + # creates a new tensor with separate storage from the original |
| 888 | + # logits tensor. This means any in-place operations on bonus_logits |
| 889 | + # won't affect the original logits tensor. |
| 890 | + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] |
| 891 | + sampler_output = self.sampler( |
| 892 | + logits=bonus_logits, |
| 893 | + sampling_metadata=sampling_metadata, |
| 894 | + ) |
| 895 | + bonus_token_ids = sampler_output.sampled_token_ids |
| 896 | + |
| 897 | + # Just like `bonus_logits`, `target_logits` is a new tensor with |
| 898 | + # separate storage from the original `logits` tensor. Therefore, |
| 899 | + # it is safe to update `target_logits` in place. |
| 900 | + target_logits = logits[spec_decode_metadata.target_logits_indices] |
| 901 | + output_token_ids = self.rejection_sampler( |
| 902 | + spec_decode_metadata, |
| 903 | + None, # draft_probs |
| 904 | + target_logits, |
| 905 | + bonus_token_ids, |
| 906 | + sampling_metadata, |
| 907 | + ) |
| 908 | + sampler_output.sampled_token_ids = output_token_ids |
| 909 | + |
| 910 | + # TODO(woosuk): The following loop can be slow since it iterates over |
| 911 | + # the requests one by one. Optimize. |
| 912 | + for i, req_id in enumerate(self.input_batch.req_ids): |
| 913 | + req_state = self.requests[req_id] |
| 914 | + seq_len = (req_state.num_computed_tokens + |
| 915 | + scheduler_output.num_scheduled_tokens[req_id]) |
| 916 | + if seq_len < req_state.num_tokens: |
| 917 | + # Ignore the sampled token. |
| 918 | + # Rewind the generator state as if the token was not sampled. |
| 919 | + generator = self.input_batch.generators.get(i) |
| 920 | + if generator is not None: |
| 921 | + generator.set_offset(generator.get_offset() - 4) |
| 922 | + |
| 923 | + # NOTE: NPU -> CPU Sync happens here. |
| 924 | + # Move as many CPU operations as possible before this sync point. |
| 925 | + logprobs_tensors = sampler_output.logprobs_tensors |
| 926 | + logprobs_lists = logprobs_tensors.tolists() \ |
| 927 | + if logprobs_tensors is not None else None |
| 928 | + |
| 929 | + # Get the valid generated tokens. |
| 930 | + sampled_token_ids = sampler_output.sampled_token_ids |
| 931 | + max_gen_len = sampled_token_ids.shape[-1] |
| 932 | + if max_gen_len == 1: |
| 933 | + # No spec decode tokens. |
| 934 | + valid_sampled_token_ids = sampled_token_ids.tolist() |
| 935 | + else: |
| 936 | + # Includes spec decode tokens. |
| 937 | + valid_sampled_token_ids = self.rejection_sampler.parse_output( |
| 938 | + sampled_token_ids, |
| 939 | + self.input_batch.vocab_size, |
| 940 | + ) |
891 | 941 |
|
892 |
| - # Just like `bonus_logits`, `target_logits` is a new tensor with |
893 |
| - # separate storage from the original `logits` tensor. Therefore, |
894 |
| - # it is safe to update `target_logits` in place. |
895 |
| - target_logits = logits[spec_decode_metadata.target_logits_indices] |
896 |
| - output_token_ids = self.rejection_sampler( |
897 |
| - spec_decode_metadata, |
898 |
| - None, # draft_probs |
899 |
| - target_logits, |
900 |
| - bonus_token_ids, |
| 942 | + spec_token_ids = self._get_spec_token_ids( |
| 943 | + valid_sampled_token_ids, |
901 | 944 | sampling_metadata,
|
| 945 | + scheduler_output, |
| 946 | + spec_decode_metadata, |
| 947 | + positions, |
| 948 | + num_scheduled_tokens, |
| 949 | + hidden_states, |
| 950 | + attn_metadata, |
902 | 951 | )
|
903 |
| - sampler_output.sampled_token_ids = output_token_ids |
904 | 952 |
|
905 |
| - # TODO(woosuk): The following loop can be slow since it iterates over |
906 |
| - # the requests one by one. Optimize. |
907 |
| - for i, req_id in enumerate(self.input_batch.req_ids): |
908 |
| - req_state = self.requests[req_id] |
909 |
| - seq_len = (req_state.num_computed_tokens + |
910 |
| - scheduler_output.num_scheduled_tokens[req_id]) |
911 |
| - if seq_len < req_state.num_tokens: |
912 |
| - # Ignore the sampled token. |
913 |
| - # Rewind the generator state as if the token was not sampled. |
914 |
| - generator = self.input_batch.generators.get(i) |
915 |
| - if generator is not None: |
916 |
| - generator.set_offset(generator.get_offset() - 4) |
917 |
| - |
918 |
| - # NOTE: NPU -> CPU Sync happens here. |
919 |
| - # Move as many CPU operations as possible before this sync point. |
920 |
| - logprobs_tensors = sampler_output.logprobs_tensors |
921 |
| - logprobs_lists = logprobs_tensors.tolists() \ |
922 |
| - if logprobs_tensors is not None else None |
923 |
| - |
924 |
| - # Get the valid generated tokens. |
925 |
| - sampled_token_ids = sampler_output.sampled_token_ids |
926 |
| - max_gen_len = sampled_token_ids.shape[-1] |
927 |
| - if max_gen_len == 1: |
928 |
| - # No spec decode tokens. |
929 |
| - valid_sampled_token_ids = sampled_token_ids.tolist() |
930 |
| - else: |
931 |
| - # Includes spec decode tokens. |
932 |
| - valid_sampled_token_ids = self.rejection_sampler.parse_output( |
933 |
| - sampled_token_ids, |
934 |
| - self.input_batch.vocab_size, |
| 953 | + model_runner_output = ModelRunnerOutput( |
| 954 | + req_ids=self.input_batch.req_ids, |
| 955 | + req_id_to_index=self.input_batch.req_id_to_index, |
| 956 | + sampled_token_ids=valid_sampled_token_ids, |
| 957 | + spec_token_ids=spec_token_ids, |
| 958 | + logprobs=logprobs_lists, |
| 959 | + prompt_logprobs_dict={}, |
935 | 960 | )
|
936 | 961 |
|
937 |
| - spec_token_ids = self._get_spec_token_ids( |
938 |
| - valid_sampled_token_ids, |
939 |
| - sampling_metadata, |
940 |
| - scheduler_output, |
941 |
| - spec_decode_metadata, |
942 |
| - positions, |
943 |
| - num_scheduled_tokens, |
944 |
| - hidden_states, |
945 |
| - attn_metadata, |
946 |
| - ) |
947 |
| - |
948 |
| - model_runner_output = ModelRunnerOutput( |
949 |
| - req_ids=self.input_batch.req_ids, |
950 |
| - req_id_to_index=self.input_batch.req_id_to_index, |
951 |
| - sampled_token_ids=valid_sampled_token_ids, |
952 |
| - spec_token_ids=spec_token_ids, |
953 |
| - logprobs=logprobs_lists, |
954 |
| - prompt_logprobs_dict={}, |
955 |
| - ) |
| 962 | + capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" |
| 963 | + ProfileExecuteDuration().pop_captured_sync(capture_name) |
956 | 964 | return model_runner_output
|
957 | 965 |
|
958 | 966 | def _profile_multimodal(self) -> None:
|
|
0 commit comments