Skip to content

Commit d426136

Browse files
Fix issue with fp8 ops on some models. (#8045)
_scaled_mm errors when an input is non contiguous.
1 parent 1b3bf0a commit d426136

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

comfy/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,10 @@ def fp8_linear(self, input):
308308
if scale_input is None:
309309
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
310310
input = torch.clamp(input, min=-448, max=448, out=input)
311-
input = input.reshape(-1, input_shape[2]).to(dtype)
311+
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
312312
else:
313313
scale_input = scale_input.to(input.device)
314-
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
314+
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
315315

316316
if bias is not None:
317317
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)

0 commit comments

Comments
 (0)