Skip to content

Commit b7b9248

Browse files
tlrmchlsmthAkshat-Tripathi
authored andcommitted
[Bugfix][CI][V1] Work around V1 + CUDA Graph + torch._scaled_mm fallback issue (vllm-project#13425)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
1 parent 0bbf7db commit b7b9248

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
1010
CompressedTensorsScheme)
1111
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
12-
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
13-
requantize_with_max_scale)
12+
apply_fp8_linear, cutlass_fp8_supported, maybe_create_device_identity,
13+
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
1414
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
1515
ModelWeightParameter,
1616
PerTensorScaleParameter)
@@ -93,6 +93,8 @@ def create_weights(self, layer: torch.nn.Module,
9393
input_size_per_partition: int,
9494
params_dtype: torch.dtype, weight_loader: Callable,
9595
**kwargs):
96+
maybe_create_device_identity()
97+
9698
output_size_per_partition = sum(output_partition_sizes)
9799
layer.logical_widths = output_partition_sizes
98100

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1818
is_layer_skipped)
1919
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20-
apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz)
20+
apply_fp8_linear, maybe_create_device_identity,
21+
normalize_e4m3fn_to_e4m3fnuz)
2122
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
2223
ModelWeightParameter)
2324
from vllm.platforms import current_platform
@@ -84,6 +85,7 @@ def create_weights(
8485
params_dtype: torch.dtype,
8586
**extra_weight_attrs,
8687
):
88+
maybe_create_device_identity()
8789
weight_loader = extra_weight_attrs.get("weight_loader")
8890
del input_size, output_size
8991
output_size_per_partition = sum(output_partition_sizes)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
2525
all_close_1d, apply_fp8_linear, convert_to_channelwise,
2626
cutlass_block_fp8_supported, cutlass_fp8_supported,
27-
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
28-
requantize_with_max_scale)
27+
maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz,
28+
per_tensor_dequantize, requantize_with_max_scale)
2929
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
3030
ModelWeightParameter,
3131
PerTensorScaleParameter)
@@ -162,6 +162,8 @@ def create_weights(
162162
params_dtype: torch.dtype,
163163
**extra_weight_attrs,
164164
):
165+
maybe_create_device_identity()
166+
165167
output_size_per_partition = sum(output_partition_sizes)
166168
weight_loader = extra_weight_attrs.get("weight_loader")
167169

vllm/model_executor/layers/quantization/utils/w8a8_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# Input scaling factors are no longer optional in _scaled_mm starting
1111
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
12-
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
12+
TORCH_DEVICE_IDENTITY = None
1313

1414
# The condition to determine if it is on a platform that supports
1515
# torch._scaled_mm rowwise feature.
@@ -113,6 +113,13 @@ def requantize_with_max_scale(
113113
return max_w_scale, weight
114114

115115

116+
def maybe_create_device_identity():
117+
# Allocate dummy ones tensor for torch._scaled_mm
118+
global TORCH_DEVICE_IDENTITY
119+
if TORCH_DEVICE_IDENTITY is None:
120+
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
121+
122+
116123
def apply_fp8_linear(
117124
input: torch.Tensor,
118125
weight: torch.Tensor,
@@ -215,11 +222,6 @@ def apply_fp8_linear(
215222
# For the scaled_mm fallback case, we break this down, since it
216223
# does not support s_w being a vector.
217224

218-
# Making sure the dummy tensor is on the same device as the weight
219-
global TORCH_DEVICE_IDENTITY
220-
if TORCH_DEVICE_IDENTITY.device != weight.device:
221-
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
222-
223225
# GEMM
224226
# This computes C = (X * W).
225227
# Output in fp32 to allow subsequent ops to happen in-place

0 commit comments

Comments
 (0)