|
61 | 61 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
62 | 62 | from vllm_ascend.platform import NPUPlatform
|
63 | 63 | from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
| 64 | +from vllm_ascend.utils import ProfileExecuteDuration |
64 | 65 |
|
65 | 66 | if TYPE_CHECKING:
|
66 | 67 | import xgrammar as xgr # type: ignore[import-untyped]
|
@@ -628,36 +629,38 @@ def _process_reqs(
|
628 | 629 | with set_forward_context(attn_metadata,
|
629 | 630 | self.vllm_config,
|
630 | 631 | num_tokens=num_input_tokens):
|
631 |
| - model_kwargs = {} |
632 |
| - if self.enable_torchair_graph_mode: |
633 |
| - model_kwargs["kv_caches"] = self.kv_caches |
634 |
| - model_kwargs["attn_metadata"] = attn_metadata |
635 |
| - if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
636 |
| - torch._dynamo.mark_static(input_ids) |
637 |
| - torch._dynamo.mark_static(positions) |
638 |
| - torch._dynamo.mark_static(attn_metadata.decode.block_table) |
639 |
| - torch._dynamo.mark_static(attn_metadata.decode.input_positions) |
640 |
| - torch._dynamo.mark_static(attn_metadata.slot_mapping) |
641 |
| - for kv in self.kv_caches: |
642 |
| - if isinstance(kv, tuple): |
643 |
| - torch._dynamo.mark_static(kv[0]) |
644 |
| - torch._dynamo.mark_static(kv[1]) |
645 |
| - hidden_states = self.compile_model( |
646 |
| - input_ids=input_ids, |
647 |
| - positions=positions, |
648 |
| - intermediate_tensors=intermediate_tensors, |
649 |
| - inputs_embeds=None, |
650 |
| - **model_kwargs, |
651 |
| - ) |
652 |
| - else: |
653 |
| - assert self.model is not None |
654 |
| - hidden_states = self.model( |
655 |
| - input_ids=input_ids, |
656 |
| - positions=positions, |
657 |
| - intermediate_tensors=intermediate_tensors, |
658 |
| - inputs_embeds=None, |
659 |
| - **model_kwargs, |
660 |
| - ) |
| 632 | + with ProfileExecuteDuration().capture_async("forward"): |
| 633 | + model_kwargs = {} |
| 634 | + if self.enable_torchair_graph_mode: |
| 635 | + model_kwargs["kv_caches"] = self.kv_caches |
| 636 | + model_kwargs["attn_metadata"] = attn_metadata |
| 637 | + if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
| 638 | + torch._dynamo.mark_static(input_ids) |
| 639 | + torch._dynamo.mark_static(positions) |
| 640 | + torch._dynamo.mark_static(attn_metadata.decode.block_table) |
| 641 | + torch._dynamo.mark_static( |
| 642 | + attn_metadata.decode.input_positions) |
| 643 | + torch._dynamo.mark_static(attn_metadata.slot_mapping) |
| 644 | + for kv in self.kv_caches: |
| 645 | + if isinstance(kv, tuple): |
| 646 | + torch._dynamo.mark_static(kv[0]) |
| 647 | + torch._dynamo.mark_static(kv[1]) |
| 648 | + hidden_states = self.compile_model( |
| 649 | + input_ids=input_ids, |
| 650 | + positions=positions, |
| 651 | + intermediate_tensors=intermediate_tensors, |
| 652 | + inputs_embeds=None, |
| 653 | + **model_kwargs, |
| 654 | + ) |
| 655 | + else: |
| 656 | + assert self.model is not None |
| 657 | + hidden_states = self.model( |
| 658 | + input_ids=input_ids, |
| 659 | + positions=positions, |
| 660 | + intermediate_tensors=intermediate_tensors, |
| 661 | + inputs_embeds=None, |
| 662 | + **model_kwargs, |
| 663 | + ) |
661 | 664 |
|
662 | 665 | use_spec_decode = len(
|
663 | 666 | scheduler_output.scheduled_spec_decode_tokens) > 0
|
@@ -844,103 +847,113 @@ def execute_model(
|
844 | 847 | scheduler_output: "SchedulerOutput",
|
845 | 848 | intermediate_tensors: Optional[IntermediateTensors] = None,
|
846 | 849 | ) -> Union[ModelRunnerOutput, torch.Tensor]:
|
847 |
| - self._update_states(scheduler_output) |
848 |
| - if not scheduler_output.total_num_scheduled_tokens: |
849 |
| - # Return empty ModelRunnerOuptut if there's no work to do. |
850 |
| - return EMPTY_MODEL_RUNNER_OUTPUT |
851 |
| - (attn_metadata, hidden_states, spec_decode_metadata, positions, |
852 |
| - num_scheduled_tokens, |
853 |
| - sample_indices) = (self._process_reqs(scheduler_output, |
854 |
| - intermediate_tensors)) |
855 |
| - logits = self.model.compute_logits(hidden_states[sample_indices], None) |
856 |
| - |
857 |
| - # Apply structured output bitmasks if present |
858 |
| - if scheduler_output.grammar_bitmask is not None: |
859 |
| - logits = self.apply_grammar_bitmask(scheduler_output, logits) |
860 |
| - |
861 |
| - # Sample the next token and get logprobs if needed. |
862 |
| - sampling_metadata = self.input_batch.sampling_metadata |
863 |
| - if spec_decode_metadata is None: |
864 |
| - sampler_output = self.sampler( |
865 |
| - logits=logits, |
866 |
| - sampling_metadata=sampling_metadata, |
867 |
| - ) |
868 |
| - else: |
869 |
| - # When indexing with a tensor (bonus_logits_indices), PyTorch |
870 |
| - # creates a new tensor with separate storage from the original |
871 |
| - # logits tensor. This means any in-place operations on bonus_logits |
872 |
| - # won't affect the original logits tensor. |
873 |
| - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] |
874 |
| - sampler_output = self.sampler( |
875 |
| - logits=bonus_logits, |
876 |
| - sampling_metadata=sampling_metadata, |
877 |
| - ) |
878 |
| - bonus_token_ids = sampler_output.sampled_token_ids |
| 850 | + with ProfileExecuteDuration().capture_async( |
| 851 | + "prepare input and forward"): |
| 852 | + self._update_states(scheduler_output) |
| 853 | + if not scheduler_output.total_num_scheduled_tokens: |
| 854 | + # Return empty ModelRunnerOuptut if there's no work to do. |
| 855 | + return EMPTY_MODEL_RUNNER_OUTPUT |
| 856 | + (attn_metadata, hidden_states, spec_decode_metadata, positions, |
| 857 | + num_scheduled_tokens, |
| 858 | + sample_indices) = (self._process_reqs(scheduler_output, |
| 859 | + intermediate_tensors)) |
| 860 | + |
| 861 | + with ProfileExecuteDuration().capture_async("post process"): |
| 862 | + logits = self.model.compute_logits(hidden_states[sample_indices], |
| 863 | + None) |
| 864 | + |
| 865 | + # Apply structured output bitmasks if present |
| 866 | + if scheduler_output.grammar_bitmask is not None: |
| 867 | + logits = self.apply_grammar_bitmask(scheduler_output, logits) |
| 868 | + |
| 869 | + # Sample the next token and get logprobs if needed. |
| 870 | + sampling_metadata = self.input_batch.sampling_metadata |
| 871 | + if spec_decode_metadata is None: |
| 872 | + sampler_output = self.sampler( |
| 873 | + logits=logits, |
| 874 | + sampling_metadata=sampling_metadata, |
| 875 | + ) |
| 876 | + else: |
| 877 | + # When indexing with a tensor (bonus_logits_indices), PyTorch |
| 878 | + # creates a new tensor with separate storage from the original |
| 879 | + # logits tensor. This means any in-place operations on bonus_logits |
| 880 | + # won't affect the original logits tensor. |
| 881 | + bonus_logits = logits[ |
| 882 | + spec_decode_metadata.bonus_logits_indices] |
| 883 | + sampler_output = self.sampler( |
| 884 | + logits=bonus_logits, |
| 885 | + sampling_metadata=sampling_metadata, |
| 886 | + ) |
| 887 | + bonus_token_ids = sampler_output.sampled_token_ids |
| 888 | + |
| 889 | + # Just like `bonus_logits`, `target_logits` is a new tensor with |
| 890 | + # separate storage from the original `logits` tensor. Therefore, |
| 891 | + # it is safe to update `target_logits` in place. |
| 892 | + target_logits = logits[ |
| 893 | + spec_decode_metadata.target_logits_indices] |
| 894 | + output_token_ids = self.rejection_sampler( |
| 895 | + spec_decode_metadata, |
| 896 | + None, # draft_probs |
| 897 | + target_logits, |
| 898 | + bonus_token_ids, |
| 899 | + sampling_metadata, |
| 900 | + ) |
| 901 | + sampler_output.sampled_token_ids = output_token_ids |
| 902 | + |
| 903 | + # TODO(woosuk): The following loop can be slow since it iterates over |
| 904 | + # the requests one by one. Optimize. |
| 905 | + for i, req_id in enumerate(self.input_batch.req_ids): |
| 906 | + req_state = self.requests[req_id] |
| 907 | + seq_len = (req_state.num_computed_tokens + |
| 908 | + scheduler_output.num_scheduled_tokens[req_id]) |
| 909 | + if seq_len < req_state.num_tokens: |
| 910 | + # Ignore the sampled token. |
| 911 | + # Rewind the generator state as if the token was not sampled. |
| 912 | + generator = self.input_batch.generators.get(i) |
| 913 | + if generator is not None: |
| 914 | + generator.set_offset(generator.get_offset() - 4) |
| 915 | + |
| 916 | + # NOTE: NPU -> CPU Sync happens here. |
| 917 | + # Move as many CPU operations as possible before this sync point. |
| 918 | + logprobs_tensors = sampler_output.logprobs_tensors |
| 919 | + logprobs_lists = logprobs_tensors.tolists() \ |
| 920 | + if logprobs_tensors is not None else None |
| 921 | + |
| 922 | + # Get the valid generated tokens. |
| 923 | + sampled_token_ids = sampler_output.sampled_token_ids |
| 924 | + max_gen_len = sampled_token_ids.shape[-1] |
| 925 | + if max_gen_len == 1: |
| 926 | + # No spec decode tokens. |
| 927 | + valid_sampled_token_ids = sampled_token_ids.tolist() |
| 928 | + else: |
| 929 | + # Includes spec decode tokens. |
| 930 | + valid_sampled_token_ids = self.rejection_sampler.parse_output( |
| 931 | + sampled_token_ids, |
| 932 | + self.input_batch.vocab_size, |
| 933 | + ) |
879 | 934 |
|
880 |
| - # Just like `bonus_logits`, `target_logits` is a new tensor with |
881 |
| - # separate storage from the original `logits` tensor. Therefore, |
882 |
| - # it is safe to update `target_logits` in place. |
883 |
| - target_logits = logits[spec_decode_metadata.target_logits_indices] |
884 |
| - output_token_ids = self.rejection_sampler( |
885 |
| - spec_decode_metadata, |
886 |
| - None, # draft_probs |
887 |
| - target_logits, |
888 |
| - bonus_token_ids, |
| 935 | + spec_token_ids = self._get_spec_token_ids( |
| 936 | + valid_sampled_token_ids, |
889 | 937 | sampling_metadata,
|
| 938 | + scheduler_output, |
| 939 | + spec_decode_metadata, |
| 940 | + positions, |
| 941 | + num_scheduled_tokens, |
| 942 | + hidden_states, |
| 943 | + attn_metadata, |
890 | 944 | )
|
891 |
| - sampler_output.sampled_token_ids = output_token_ids |
892 | 945 |
|
893 |
| - # TODO(woosuk): The following loop can be slow since it iterates over |
894 |
| - # the requests one by one. Optimize. |
895 |
| - for i, req_id in enumerate(self.input_batch.req_ids): |
896 |
| - req_state = self.requests[req_id] |
897 |
| - seq_len = (req_state.num_computed_tokens + |
898 |
| - scheduler_output.num_scheduled_tokens[req_id]) |
899 |
| - if seq_len < req_state.num_tokens: |
900 |
| - # Ignore the sampled token. |
901 |
| - # Rewind the generator state as if the token was not sampled. |
902 |
| - generator = self.input_batch.generators.get(i) |
903 |
| - if generator is not None: |
904 |
| - generator.set_offset(generator.get_offset() - 4) |
905 |
| - |
906 |
| - # NOTE: NPU -> CPU Sync happens here. |
907 |
| - # Move as many CPU operations as possible before this sync point. |
908 |
| - logprobs_tensors = sampler_output.logprobs_tensors |
909 |
| - logprobs_lists = logprobs_tensors.tolists() \ |
910 |
| - if logprobs_tensors is not None else None |
911 |
| - |
912 |
| - # Get the valid generated tokens. |
913 |
| - sampled_token_ids = sampler_output.sampled_token_ids |
914 |
| - max_gen_len = sampled_token_ids.shape[-1] |
915 |
| - if max_gen_len == 1: |
916 |
| - # No spec decode tokens. |
917 |
| - valid_sampled_token_ids = sampled_token_ids.tolist() |
918 |
| - else: |
919 |
| - # Includes spec decode tokens. |
920 |
| - valid_sampled_token_ids = self.rejection_sampler.parse_output( |
921 |
| - sampled_token_ids, |
922 |
| - self.input_batch.vocab_size, |
| 946 | + model_runner_output = ModelRunnerOutput( |
| 947 | + req_ids=self.input_batch.req_ids, |
| 948 | + req_id_to_index=self.input_batch.req_id_to_index, |
| 949 | + sampled_token_ids=valid_sampled_token_ids, |
| 950 | + spec_token_ids=spec_token_ids, |
| 951 | + logprobs=logprobs_lists, |
| 952 | + prompt_logprobs_dict={}, |
923 | 953 | )
|
924 | 954 |
|
925 |
| - spec_token_ids = self._get_spec_token_ids( |
926 |
| - valid_sampled_token_ids, |
927 |
| - sampling_metadata, |
928 |
| - scheduler_output, |
929 |
| - spec_decode_metadata, |
930 |
| - positions, |
931 |
| - num_scheduled_tokens, |
932 |
| - hidden_states, |
933 |
| - attn_metadata, |
934 |
| - ) |
935 |
| - |
936 |
| - model_runner_output = ModelRunnerOutput( |
937 |
| - req_ids=self.input_batch.req_ids, |
938 |
| - req_id_to_index=self.input_batch.req_id_to_index, |
939 |
| - sampled_token_ids=valid_sampled_token_ids, |
940 |
| - spec_token_ids=spec_token_ids, |
941 |
| - logprobs=logprobs_lists, |
942 |
| - prompt_logprobs_dict={}, |
943 |
| - ) |
| 955 | + capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill" |
| 956 | + ProfileExecuteDuration().pop_captured_sync(capture_name) |
944 | 957 | return model_runner_output
|
945 | 958 |
|
946 | 959 | def _profile_multimodal(self) -> None:
|
|
0 commit comments