Skip to content

Commit a73bd6c

Browse files
authored
[Fix] Set div_mode to False and fix view_as position (#912)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Set div_mode to False to use the ACLNN kernel, which is crucial when using ACL Graph. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 58b4137 commit a73bd6c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

vllm_ascend/ops/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def vanilla_chunked_prefill(
131131

132132
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
133133
head_dim]).to(output.dtype))
134-
output = output.view_as(attn_output)
135134
output.copy_(attn_output)
136135
return attn_output
137136

@@ -248,6 +247,7 @@ def vanilla_chunked_prefill_mla(
248247

249248
attn_output = (attn_output[q_mask].view([-1, num_heads,
250249
v_head_dim]).to(output.dtype))
250+
output = output.view_as(attn_output)
251251
output.copy_(attn_output)
252252
return attn_output
253253

vllm_ascend/quantization/w8a8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
2525
input_offset: torch.Tensor):
2626
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
27-
torch.qint8, -1, True)
27+
torch.qint8, -1, False)
2828

2929

3030
class AscendW8A8LinearMethod:
@@ -102,12 +102,12 @@ def apply(
102102

103103
def process_weights_after_loading(self, layer):
104104
expanding_factor = layer.weight.data.shape[1]
105-
layer.aclnn_input_scale = torch.nn.Parameter(
105+
layer.aclnn_input_scale = 1 / torch.nn.Parameter(
106106
layer.input_scale.data.repeat(expanding_factor),
107107
requires_grad=False)
108108
layer.aclnn_input_offset = torch.nn.Parameter(
109109
layer.input_offset.data.repeat(expanding_factor),
110-
requires_grad=False)
110+
requires_grad=False).to(layer.aclnn_input_scale.dtype)
111111
if self.transpose_weight:
112112
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
113113
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)

0 commit comments

Comments
 (0)