@@ -89,6 +89,7 @@ def sqrt(self): return self.alu(Ops.SQRT)
89
89
def sin (self ): return self .alu (Ops .SIN )
90
90
def log2 (self ): return self .alu (Ops .LOG2 )
91
91
def exp2 (self ): return self .alu (Ops .EXP2 )
92
+ def pow (self , x ): return self .alu (Ops .POW , x )
92
93
93
94
# the order of these Ops controls the order of the toposort
94
95
class Ops (FastEnum ):
@@ -133,7 +134,7 @@ class Ops(FastEnum):
133
134
134
135
# BinaryOps
135
136
ADD = auto (); MUL = auto (); IDIV = auto (); MAX = auto (); MOD = auto (); CMPLT = auto (); CMPNE = auto (); XOR = auto () # noqa: E702
136
- SHL = auto (); SHR = auto (); OR = auto (); AND = auto (); THREEFRY = auto (); SUB = auto (); FDIV = auto () # noqa: E702
137
+ SHL = auto (); SHR = auto (); OR = auto (); AND = auto (); THREEFRY = auto (); SUB = auto (); FDIV = auto (); POW = auto () # noqa: E702
137
138
138
139
# TernaryOps
139
140
WHERE = auto (); MULACC = auto () # noqa: E702
@@ -155,7 +156,7 @@ class Ops(FastEnum):
155
156
class GroupOp :
156
157
Unary = {Ops .EXP2 , Ops .LOG2 , Ops .SIN , Ops .SQRT , Ops .RECIP , Ops .NEG }
157
158
Binary = {Ops .ADD , Ops .MUL , Ops .IDIV , Ops .MAX , Ops .MOD , Ops .CMPLT , Ops .CMPNE , Ops .XOR , Ops .SHL , Ops .SHR , Ops .OR , Ops .AND , Ops .THREEFRY ,
158
- Ops .SUB , Ops .FDIV }
159
+ Ops .SUB , Ops .FDIV , Ops . POW }
159
160
Ternary = {Ops .WHERE , Ops .MULACC }
160
161
ALU = set .union (Unary , Binary , Ternary )
161
162
@@ -175,7 +176,7 @@ class GroupOp:
175
176
Idempotent = {Ops .OR , Ops .AND , Ops .MAX }
176
177
177
178
# do not preserve f(0) = 0
178
- UnsafePad = {Ops .RECIP , Ops .LOG2 , Ops .EXP2 , Ops .IDIV }
179
+ UnsafePad = {Ops .RECIP , Ops .LOG2 , Ops .EXP2 , Ops .IDIV , Ops . POW }
179
180
180
181
All = set (Ops )
181
182
@@ -675,10 +676,15 @@ def safe_exp2(x):
675
676
try : return 2 ** x
676
677
except OverflowError : return math .inf
677
678
679
+ def safe_pow (x , y ):
680
+ try : return math .nan if isinstance (p := pow (x , y ), complex ) else p
681
+ except ZeroDivisionError : return math .inf
682
+ except ValueError : return math .inf if x > 0 else - math .inf
683
+
678
684
python_alu : dict [Ops , Callable ] = {
679
685
Ops .LOG2 : lambda x : math .log2 (x ) if x > 0 else - math .inf if x == 0 else math .nan , Ops .EXP2 : safe_exp2 ,
680
686
Ops .SQRT : lambda x : math .sqrt (x ) if x >= 0 else math .nan , Ops .RECIP : lambda x : 1 / x if x != 0 else math .copysign (math .inf , x ),
681
- Ops .SIN : lambda x : math .sin (x ) if not math .isinf (x ) else math .nan ,
687
+ Ops .SIN : lambda x : math .sin (x ) if not math .isinf (x ) else math .nan , Ops . POW : safe_pow ,
682
688
Ops .NEG : operator .neg , Ops .ADD : operator .add , Ops .SUB : operator .sub , Ops .MUL : operator .mul , Ops .CMPNE : operator .ne , Ops .CMPLT : operator .lt ,
683
689
Ops .XOR : operator .xor , Ops .OR : operator .or_ , Ops .AND : operator .and_ , Ops .SHR : operator .rshift , Ops .SHL : operator .lshift , Ops .MAX : max ,
684
690
Ops .MOD : lambda x ,y : abs (int (x ))% abs (int (y ))* (1 ,- 1 )[x < 0 ], Ops .IDIV : lambda x ,y : abs (x )// abs (y )* (1 ,- 1 )[x * y < 0 ] if y != 0 else 0 ,
0 commit comments