Skip to content

Commit 8ddb135

Browse files
authored
fix UPat.location after pickle (tinygrad#9763)
* fix UPat.location after pickle [pr] * named upat test
1 parent 4cd27aa commit 8ddb135

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

test/test_pickle.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def test_pickle_pattern_matcher(self):
2121

2222
def test_pickle_main_pattern_matcher(self):
2323
from tinygrad.codegen.devectorizer import sym
24-
pickle.dumps(sym)
24+
ssym = pickle.dumps(sym)
25+
dsym = pickle.loads(ssym)
26+
self.assertEqual(dsym.patterns[0][0].location, sym.patterns[0][0].location)
2527

2628
def test_pickle_realized_tensor(self):
2729
print("** init")

test/test_uops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,9 @@ def test_location(self):
485485
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "spec.py")
486486
test_upat = UPat(Ops.CONST, dtypes.bool)
487487
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
488+
test_upat_named = test_upat.named("test_name")
489+
self.assertEqual(test_upat.location[0], test_upat_named.location[0])
490+
self.assertNotEqual(test_upat.location[1], test_upat_named.location[1])
488491

489492
class TestUopsObject(unittest.TestCase):
490493
# LOL, running this test breaks all instances of "4"

tinygrad/ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ class UPat(MathTrait):
708708
__slots__ = ("op", "dtype", "arg", "name", "src")
709709
def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None,
710710
src:Optional[Union[tuple[UPat, ...], list[UPat], UPat]]=None, arg:Any=None,
711-
name:Optional[str]=None, allow_any_len:bool=False, location=None, custom_early_reject:Optional[set[Ops]]=None):
711+
name:Optional[str]=None, allow_any_len:bool=False, custom_early_reject:Optional[set[Ops]]=None, location=None):
712712
assert op is None or isinstance(op, (Ops, tuple, set)), "op must be Ops or tuple of Ops"
713713
self.op: Optional[tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else (tuple(op) if isinstance(op, set) else op)
714714
self.dtype: Optional[tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
@@ -733,7 +733,8 @@ def __init__(self, op:Optional[Union[Ops, tuple[Ops, ...], set[Ops]]]=None, dtyp
733733
upat_match = [src] if isinstance(src, UPat) else ([] if src is None else self.src[0])
734734
self.early_reject = {pp.op[0] for pp in upat_match if pp.op is not None and len(pp.op) == 1}
735735

736-
def __reduce__(self): return UPat,(self.op, self.dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject)
736+
def __reduce__(self):
737+
return UPat, (self.op, self.dtype, self._in_src, self.arg, self.name, not self.strict_length, self.custom_early_reject, self.location)
737738
def named(self, name:str): return UPat(self.op, self.dtype, self._in_src, self.arg, name, not self.strict_length, self.custom_early_reject)
738739

739740
@staticmethod

0 commit comments

Comments
 (0)