Skip to content

Commit 9307572

Browse files
authored
Ops.POW and transcendental (tinygrad#8911)
1 parent bff7c70 commit 9307572

File tree

5 files changed

+26
-11
lines changed

5 files changed

+26
-11
lines changed

tinygrad/codegen/rewriter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
77
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
88
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
9-
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
9+
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
1010
from tinygrad.renderer import Renderer
1111

1212
# ***** float4/image store handling *****
@@ -124,6 +124,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
124124
def get_late_rewrite_patterns(ops, force_transcendental=False):
125125
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
126126
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
127+
pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src)))
127128
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
128129
if Ops.AND in ops:
129130
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]

tinygrad/codegen/transcendental.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,13 @@ def xlog2(d:UOp) -> UOp:
254254
r = d.ne(d).where(r.const_like(math.nan), r)
255255
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
256256
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
257+
258+
def xpow(base:UOp, exponent:UOp) -> UOp:
259+
# start with b ** e = exp2(e * log2(b))
260+
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
261+
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
262+
adj = (base < 0).where((exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)).where(
263+
ret.const_like(math.nan),
264+
(exponent.cast(dtypes.int32).cast(dtypes.uint32)%2).eq(1).where(ret.const_like(-1), ret.const_like(1))), ret.const_like(1))
265+
# fix 0 ** 0 = 1
266+
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * adj)

tinygrad/gradient.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def reduce_gradient(ctx:UOp, ret:UOp):
2222
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
2323
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
2424
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
25+
(UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))),
2526
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
2627
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
2728
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),

tinygrad/ops.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def sqrt(self): return self.alu(Ops.SQRT)
8989
def sin(self): return self.alu(Ops.SIN)
9090
def log2(self): return self.alu(Ops.LOG2)
9191
def exp2(self): return self.alu(Ops.EXP2)
92+
def pow(self, x): return self.alu(Ops.POW, x)
9293

9394
# the order of these Ops controls the order of the toposort
9495
class Ops(FastEnum):
@@ -133,7 +134,7 @@ class Ops(FastEnum):
133134

134135
# BinaryOps
135136
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
136-
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
137+
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
137138

138139
# TernaryOps
139140
WHERE = auto(); MULACC = auto() # noqa: E702
@@ -155,7 +156,7 @@ class Ops(FastEnum):
155156
class GroupOp:
156157
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
157158
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
158-
Ops.SUB, Ops.FDIV}
159+
Ops.SUB, Ops.FDIV, Ops.POW}
159160
Ternary = {Ops.WHERE, Ops.MULACC}
160161
ALU = set.union(Unary, Binary, Ternary)
161162

@@ -175,7 +176,7 @@ class GroupOp:
175176
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
176177

177178
# do not preserve f(0) = 0
178-
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
179+
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
179180

180181
All = set(Ops)
181182

@@ -675,10 +676,15 @@ def safe_exp2(x):
675676
try: return 2 ** x
676677
except OverflowError: return math.inf
677678

679+
def safe_pow(x, y):
680+
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
681+
except ZeroDivisionError: return math.inf
682+
except ValueError: return math.inf if x > 0 else -math.inf
683+
678684
python_alu: dict[Ops, Callable] = {
679685
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
680686
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
681-
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
687+
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
682688
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
683689
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
684690
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,

tinygrad/tensor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3313,12 +3313,9 @@ def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
33133313
base, exponent = self._broadcasted(x, reverse=reverse)
33143314
# TODO: int pow
33153315
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
3316-
# start with b ** e = exp(e * log(b))
3317-
ret = base.abs().log().mul(exponent).exp()
3318-
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
3319-
adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
3320-
# fix 0 ** 0 = 1
3321-
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
3316+
3317+
# NOTE: pow(int, float) -> int
3318+
ret = base._apply_uop(UOp.pow, exponent)
33223319
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
33233320

33243321
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:

0 commit comments

Comments
 (0)