Skip to content

Commit 891322f

Browse files
authored
split into grouper.py (tinygrad#9768)
* split into grouper.py * update tests * reorder
1 parent 219b8c9 commit 891322f

File tree

5 files changed

+454
-445
lines changed

5 files changed

+454
-445
lines changed

test/test_rewrite_tracked_childen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from tinygrad import Tensor
33
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp, merge_views
4-
from tinygrad.engine.schedule import sym
4+
from tinygrad.engine.grouper import sym
55

66
class TestRewriteTrackedChildren(unittest.TestCase):
77
def test_children_in_context(self):

test/test_schedule.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from tinygrad.device import is_dtype_supported
1212
from tinygrad.dtype import DType, ImageDType
1313
from tinygrad.shape.shapetracker import ShapeTracker
14-
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
14+
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp, view_left
1515
from tinygrad.codegen.symbolic import symbolic_simple
1616
from tinygrad.spec import type_verify, shape_spec
1717
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
18-
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym
18+
from tinygrad.engine.grouper import view_right, sym
19+
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
1920
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
2021
from extra.models.llama import precompute_freqs_cis
2122
remove_movement_ops = merge_views

test/test_uops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu # noqa F401
1111
from tinygrad.spec import spec
1212
from tinygrad.renderer import ProgramSpec
13-
from tinygrad.engine.schedule import fix_kernel_ops
13+
from tinygrad.engine.grouper import fix_kernel_ops
1414
from tinygrad.engine.realize import CompiledRunner, get_kernel
1515
from tinygrad.codegen.linearize import linearize_uop
1616
from tinygrad.codegen.devectorizer import full_graph_rewrite
@@ -481,7 +481,7 @@ def test_simple_order_with_special(self):
481481
class TestUPatHelpers(unittest.TestCase):
482482
def test_location(self):
483483
self.assertEqual(sym.patterns[-1][0].location[0].replace("\\", "/").split("/")[-1], "symbolic.py")
484-
self.assertEqual(fix_kernel_ops.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
484+
self.assertEqual(fix_kernel_ops.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "grouper.py")
485485
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
486486
test_upat = UPat(Ops.CONST, dtypes.bool)
487487
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])

0 commit comments

Comments
 (0)