@@ -10,51 +10,51 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
10
10
return x .ne (math .inf ).where (x .ne (x ).where (nan , x .ne (- math .inf ).where (ratio , _inf )), inf )
11
11
12
12
# *** helper functions for bit manipulation ***
13
- def mantissa_bits (d :DType ) -> int : return dtypes .finfo (d )[1 ]
14
- def exponent_bias (d :DType ) -> int : return {dtypes .float64 : 1023 , dtypes .float32 : 127 , dtypes .float16 : 15 }[d ]
15
- def exponent_mask (d :DType ) -> int : return {dtypes .float64 : 2047 , dtypes .float32 : 255 , dtypes .float16 : 31 }[d ]
13
+ def mantissa_bits (d :DType ) -> int : return dtypes .finfo (d . scalar () )[1 ]
14
+ def exponent_bias (d :DType ) -> int : return {dtypes .float64 : 1023 , dtypes .float32 : 127 , dtypes .float16 : 15 }[d . scalar () ]
15
+ def exponent_mask (d :DType ) -> int : return {dtypes .float64 : 2047 , dtypes .float32 : 255 , dtypes .float16 : 31 }[d . scalar () ]
16
16
17
17
# **** utils ****
18
18
def shr (x :UOp , y :int ) -> UOp : return x // (2 ** y )
19
19
def shl (x :UOp , y :int ) -> UOp : return x * (2 ** y )
20
20
21
21
def rintk (d :UOp ) -> UOp :
22
22
"""round d:float to int away from 0"""
23
- out_dtype = {dtypes .float64 : dtypes .int64 , dtypes .float32 : dtypes .int32 , dtypes .float16 : dtypes .int16 }[d .dtype ]
23
+ out_dtype = {dtypes .float64 : dtypes .int64 , dtypes .float32 : dtypes .int32 , dtypes .float16 : dtypes .int16 }[d .dtype . scalar ()]. vec ( d . dtype . vcount )
24
24
return (d + (d < 0.0 ).where (d .const_like (- 0.5 ), d .const_like (0.5 ))).cast (out_dtype )
25
25
26
26
def pow2if (q :UOp , float_dtype :DType ):
27
27
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
28
- out_dtype = {dtypes .int64 : dtypes .float64 , dtypes .int32 : dtypes .float32 , dtypes .int16 : float_dtype }[q .dtype ]
28
+ out_dtype = {dtypes .int64 : dtypes .float64 , dtypes .int32 : dtypes .float32 , dtypes .int16 : float_dtype }[q .dtype . scalar ()]. vec ( q . dtype . vcount )
29
29
return shl (q + exponent_bias (out_dtype ), mantissa_bits (out_dtype )).bitcast (out_dtype )
30
30
31
31
def ilogb2k (d :UOp ) -> UOp :
32
32
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
33
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
34
- dint = d .bitcast ({dtypes .float64 : dtypes .int64 , dtypes .float32 : dtypes .int32 , dtypes .float16 : dtypes .int16 }[d .dtype ] )
33
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
34
+ dint = d .bitcast ({dtypes .float64 : dtypes .int64 , dtypes .float32 : dtypes .int32 , dtypes .float16 : dtypes .int16 }[d .dtype . scalar ()]. vec ( d . dtype . vcount ) )
35
35
# -1 <= ilog2bk(d) <= 128
36
36
return (shr (dint , mantissa_bits (d .dtype )) & exponent_mask (d .dtype )) - exponent_bias (d .dtype )
37
37
38
38
def ldexp3k (d :UOp , e :UOp ) -> UOp :
39
39
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
40
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
41
- cast_map = {dtypes .float64 : dtypes .int64 , dtypes .float32 : dtypes .int32 , dtypes .float16 : dtypes .int16 }
42
- m1 = d .bitcast (cast_map [ d . dtype ] )
43
- m2 = shl (e .cast (cast_map [ d . dtype ] ), mantissa_bits (d .dtype ))
40
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES and e .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
41
+ dtype = {dtypes .float64 : dtypes .int64 , dtypes .float32 : dtypes .int32 , dtypes .float16 : dtypes .int16 }[ d . dtype . scalar ()]. vec ( d . dtype . count )
42
+ m1 = d .bitcast (dtype )
43
+ m2 = shl (e .cast (dtype ), mantissa_bits (d .dtype ))
44
44
return (m1 + m2 ).bitcast (d .dtype ).cast (d .dtype )
45
45
46
46
def ldexp2k (d :UOp , e :UOp ) -> UOp :
47
47
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
48
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e .dtype in (dtypes .int16 , dtypes .int32 , dtypes .int64 )
48
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES and e .dtype . scalar () in (dtypes .int16 , dtypes .int32 , dtypes .int64 )
49
49
return (d * pow2if (shr (e , 1 ), d .dtype )) * pow2if (e - shr (e , 1 ), d .dtype )
50
50
51
51
def frexp (v :UOp ) -> tuple [UOp , UOp ]:
52
52
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
53
- assert v .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
53
+ assert v .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
54
54
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
55
- m1 = {dtypes .float64 : 0x000FFFFFFFFFFFFF , dtypes .float32 : 0x807FFFFF , dtypes .float16 : 0x83FF }[v .dtype ]
56
- m2 = {dtypes .float64 : 0x3FE0000000000000 , dtypes .float32 : 0x3F000000 , dtypes .float16 : 0x3800 }[v .dtype ]
57
- bits = v .bitcast ({dtypes .float64 : dtypes .uint64 , dtypes .float32 : dtypes .uint32 , dtypes .float16 : dtypes .uint16 }[v .dtype ] )
55
+ m1 = {dtypes .float64 : 0x000FFFFFFFFFFFFF , dtypes .float32 : 0x807FFFFF , dtypes .float16 : 0x83FF }[v .dtype . scalar () ]
56
+ m2 = {dtypes .float64 : 0x3FE0000000000000 , dtypes .float32 : 0x3F000000 , dtypes .float16 : 0x3800 }[v .dtype . scalar () ]
57
+ bits = v .bitcast ({dtypes .float64 : dtypes .uint64 , dtypes .float32 : dtypes .uint32 , dtypes .float16 : dtypes .uint16 }[v .dtype . scalar ()]. vec ( v . dtype . count ) )
58
58
exponent = shr (bits , mantissa_bits (v .dtype )) & exponent_mask (v .dtype )
59
59
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
60
60
mantissa = ((bits & m1 ) | m2 ).bitcast (v .dtype )
@@ -70,7 +70,7 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
70
70
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
71
71
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
72
72
"""
73
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
73
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
74
74
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
75
75
# 190 bits of 2/pi for Payne-Hanek style argument reduction
76
76
two_over_pi_f = [0x00000000 , 0x28be60db , 0x9391054a , 0x7f09d5f4 , 0x7d4d3770 , 0x36d8a566 , 0x4f10e410 ]
@@ -172,7 +172,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
172
172
- fast=True assumes x <= switch_over.
173
173
- switch_over is the threshold for switching to payne_hanek_reduction.
174
174
"""
175
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
175
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
176
176
# mask +-inf/nan as zero
177
177
x = _lazy_map_numbers (d , d .const_like (0.0 ), d .const_like (0.0 ), d .const_like (0.0 ), d )
178
178
# x_sign = sign(x)
@@ -194,7 +194,7 @@ def xexp2(d:UOp) -> UOp:
194
194
Implements a 1.0 ULP approximation for Ops.EXP2
195
195
- Paper: https://arxiv.org/pdf/2001.09258
196
196
"""
197
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
197
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
198
198
# mask +=inf/nan as zero.
199
199
x = _lazy_map_numbers (d , d .const_like (0.0 ), d .const_like (0.0 ), d .const_like (0.0 ), d )
200
200
q = rintk (x )
@@ -207,7 +207,7 @@ def xexp2(d:UOp) -> UOp:
207
207
0.6931471805599452862e+0 , 0.1000000000000000000e+1 ])
208
208
else : u = polyN (s , [0.1535920892e-3 , 0.1339262701e-2 , 0.9618384764e-2 , 0.5550347269e-1 , 0.2402264476e+0 , 0.6931471825e+0 , 1.0 ])
209
209
u = ldexp2k (u , q ) # u*2^q
210
- upper , lower = {dtypes .float64 : (1024 , - 2000 ), dtypes .float32 : (128 , - 150 ), dtypes .float16 : (23 , - 22 )}[d .dtype ]
210
+ upper , lower = {dtypes .float64 : (1024 , - 2000 ), dtypes .float32 : (128 , - 150 ), dtypes .float16 : (23 , - 22 )}[d .dtype . scalar () ]
211
211
# Replace x >= upper with +inf
212
212
u = (d >= upper ).where (d .const_like (math .inf ), u )
213
213
# Replace x < lower with zero.
@@ -220,7 +220,7 @@ def xlog2(d:UOp) -> UOp:
220
220
Implements a 1.0 ULP approximation for Ops.LOG2
221
221
Paper: https://arxiv.org/pdf/2001.09258 5.5
222
222
"""
223
- assert d .dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
223
+ assert d .dtype . scalar () in TRANSCENDENTAL_SUPPORTED_DTYPES
224
224
# TODO: float16 denormal need float32 to achieve precision
225
225
if d .dtype == dtypes .float16 : return xlog2 (d .cast (dtypes .float32 )).cast (dtypes .float16 )
226
226
FLT_MIN = d .const_like (1e-6 if d .dtype == dtypes .float16 else 1e-4 )
@@ -248,7 +248,7 @@ def xlog2(d:UOp) -> UOp:
248
248
r = (d < - 0.0 ).where (r .const_like (math .nan ), r )
249
249
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
250
250
# log2_zero = the value of unmasked xlog2(0.0).
251
- log2_zero = {dtypes .float64 : - 1087 , dtypes .float32 : - 191 , dtypes .float16 : - 79 }[d .dtype ]
251
+ log2_zero = {dtypes .float64 : - 1087 , dtypes .float32 : - 191 , dtypes .float16 : - 79 }[d .dtype . scalar () ]
252
252
r = r .ne (log2_zero ).where (r , r .const_like (- math .inf ))
253
253
# log2(NaN) = NaN
254
254
r = d .ne (d ).where (r .const_like (math .nan ), r )
0 commit comments