Skip to content

Commit fea59c7

Browse files
authored
[Bugfix][Kernel] Use int64_t for indices in fp8 quant kernels (#6649)
1 parent 739b61a commit fea59c7

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

csrc/quantization/fp8/common.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
103103
vec4_t<scalar_t> const* vectorized_in =
104104
reinterpret_cast<vec4_t<scalar_t> const*>(input);
105105

106-
int const num_vec_elems = num_elems >> 2;
106+
int64_t const num_vec_elems = num_elems >> 2;
107107
float absmax_val = 0.0f;
108108

109109
#pragma unroll 4
110-
for (int i = tid; i < num_vec_elems; i += step) {
110+
for (int64_t i = tid; i < num_vec_elems; i += step) {
111111
vec4_t<scalar_t> in_vec = vectorized_in[i];
112112
absmax_val = max(absmax_val, fabs(in_vec.x));
113113
absmax_val = max(absmax_val, fabs(in_vec.y));
@@ -116,7 +116,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
116116
}
117117

118118
// Handle the remaining elements if num_elems is not divisible by 4
119-
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
119+
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
120120
absmax_val = max(absmax_val, fabs(input[i]));
121121
}
122122

@@ -134,10 +134,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
134134
reinterpret_cast<vec4_t<scalar_t> const*>(input);
135135
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
136136

137-
int const num_vec_elems = num_elems >> 2;
137+
int64_t const num_vec_elems = num_elems >> 2;
138138

139139
#pragma unroll 4
140-
for (int i = tid; i < num_vec_elems; i += step) {
140+
for (int64_t i = tid; i < num_vec_elems; i += step) {
141141
vec4_t<scalar_t> in_vec = vectorized_in[i];
142142
float8x4_t out_vec;
143143

@@ -153,7 +153,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
153153
}
154154

155155
// Handle the remaining elements if num_elems is not divisible by 4
156-
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
156+
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
157157
out[i] = scaled_fp8_conversion<is_scale_inverted>(
158158
static_cast<float>(input[i]), scale);
159159
}

tests/kernels/test_fp8_quant.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,28 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
6060
assert torch.allclose(ref_scale, ops_scale)
6161
assert torch.allclose(ref_out.to(dtype=torch.float32),
6262
ops_out.to(dtype=torch.float32))
63+
64+
65+
# Regression test for a case with large activations where an int32 index cannot
66+
# represent the number of elements.
67+
@torch.inference_mode()
68+
@pytest.mark.parametrize("seed", SEEDS)
69+
def test_fp8_quant_large(seed: int) -> None:
70+
torch.random.manual_seed(seed)
71+
torch.cuda.manual_seed(seed)
72+
73+
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
74+
hidden_size = 1152 # Smallest hidden_size to reproduce the error
75+
dtype = torch.bfloat16
76+
77+
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
78+
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
79+
ops_out, _ = ops.scaled_fp8_quant(x, scale)
80+
81+
# Minimize memory footprint in this test by freeing x and upconverting
82+
# the outputs in place. (torch.allclose does not support fp8)
83+
del x
84+
ref_out = ref_out.to(dtype=dtype)
85+
ops_out = ops_out.to(dtype=dtype)
86+
87+
assert torch.allclose(ref_out, ops_out)

0 commit comments

Comments
 (0)