Skip to content

Commit ca5064a

Browse files
authored
remove Kernel.float4_axis [pr] (tinygrad#9448)
1 parent 0e591ba commit ca5064a

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tinygrad/codegen/kernel.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,6 @@ def copy(self):
102102
@property
103103
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
104104

105-
# TODO: these need more tests or it might silently be no-op
106-
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
107-
108105
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
109106
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
110107
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
@@ -461,7 +458,8 @@ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve
461458

462459
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
463460
# are we grouping? (requires local shape support)
464-
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
461+
if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \
462+
self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
465463
# TODO: use 1024 if it's allowed in a smarter way
466464
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
467465
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):

0 commit comments

Comments
 (0)