File tree 1 file changed +2
-4
lines changed
1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -102,9 +102,6 @@ def copy(self):
102
102
@property
103
103
def membufs (self ) -> list [UOp ]: return dedup ([x .src [0 ] for x in self .bufs if x .op in {Ops .LOAD , Ops .STORE }])
104
104
105
- # TODO: these need more tests or it might silently be no-op
106
- def float4_axis (self , i :int ): return [x - self .first_upcast for x in self .sts [i ].unit_stride_axes () if x >= self .first_upcast and self .sts [i ].shape [x ]% 4 == 0 ] # noqa: E501
107
-
108
105
def upcasted_axis (self , i :int ) -> list [tuple [int , Optional [sint ], bool ]]:
109
106
upcasted_shape , upcasted_stride = self .sts [i ].shape [self .first_upcast :], self .sts [i ].real_strides ()[self .first_upcast :]
110
107
assert all_int (upcasted_shape ), f"cannot upcast a symbolic amount { upcasted_shape = } "
@@ -461,7 +458,8 @@ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve
461
458
462
459
if self .opts .has_local and self .opts .has_shared and all_int (self .sts [0 ].shape [:self .first_reduce ]):
463
460
# are we grouping? (requires local shape support)
464
- if not self .float4_axis (0 ) and self .first_reduce <= 2 and self .first_reduce + 1 <= self .shape_len and prod (self .sts [0 ].shape [:self .first_reduce ]) <= 2048 : # noqa: E501
461
+ if not [x for x in self .sts [0 ].unit_stride_axes () if x >= self .first_upcast and self .sts [0 ].shape [x ]% 4 == 0 ] and \
462
+ self .first_reduce <= 2 and self .first_reduce < self .shape_len and prod (self .sts [0 ].shape [:self .first_reduce ]) <= 2048 :
465
463
# TODO: use 1024 if it's allowed in a smarter way
466
464
for sz in ([256 , 16 ] if prod (self .sts [0 ].shape [:self .first_reduce ]) <= 32 else [16 ]):
467
465
if all (st .shape [self .first_reduce ] % sz == 0 or st .shape [self .first_reduce ] == 1 for st in self .sts ):
You can’t perform that action at this time.
0 commit comments