@@ -103,11 +103,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
103
103
vec4_t <scalar_t > const * vectorized_in =
104
104
reinterpret_cast <vec4_t <scalar_t > const *>(input);
105
105
106
- int const num_vec_elems = num_elems >> 2 ;
106
+ int64_t const num_vec_elems = num_elems >> 2 ;
107
107
float absmax_val = 0 .0f ;
108
108
109
109
#pragma unroll 4
110
- for (int i = tid; i < num_vec_elems; i += step) {
110
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
111
111
vec4_t <scalar_t > in_vec = vectorized_in[i];
112
112
absmax_val = max (absmax_val, fabs (in_vec.x ));
113
113
absmax_val = max (absmax_val, fabs (in_vec.y ));
@@ -116,7 +116,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
116
116
}
117
117
118
118
// Handle the remaining elements if num_elems is not divisible by 4
119
- for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
119
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
120
120
absmax_val = max (absmax_val, fabs (input[i]));
121
121
}
122
122
@@ -134,10 +134,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
134
134
reinterpret_cast <vec4_t <scalar_t > const *>(input);
135
135
float8x4_t * vectorized_out = reinterpret_cast <float8x4_t *>(out);
136
136
137
- int const num_vec_elems = num_elems >> 2 ;
137
+ int64_t const num_vec_elems = num_elems >> 2 ;
138
138
139
139
#pragma unroll 4
140
- for (int i = tid; i < num_vec_elems; i += step) {
140
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
141
141
vec4_t <scalar_t > in_vec = vectorized_in[i];
142
142
float8x4_t out_vec;
143
143
@@ -153,7 +153,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
153
153
}
154
154
155
155
// Handle the remaining elements if num_elems is not divisible by 4
156
- for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
156
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
157
157
out[i] = scaled_fp8_conversion<is_scale_inverted>(
158
158
static_cast <float >(input[i]), scale);
159
159
}
0 commit comments