Skip to content

Commit d657d5f

Browse files
eitanturokchenyuxyzgeohot
authored
[Bounty] Vectorize Transcendental (tinygrad#9058)
* init * cast everythig right * more casting * install pillow in test * quick tests * simplify * quick tests * delete test * tests * fix import error * add vec to ldexp3k * vec for bitcast * some helper tests * high level tests * clean tests * change tolerance so cuda passes * ruff passes * remove tests for transcendental helpers * ruff passes * make exponent in power vectorized * fix pow test * add newline * add vec dtype to ilogb2k * comment + clean up * ruff --------- Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
1 parent 8ae215d commit d657d5f

File tree

4 files changed

+54
-24
lines changed

4 files changed

+54
-24
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ jobs:
447447
with:
448448
key: dsp-minimal
449449
deps: testing_minimal
450-
pydeps: "onnx==1.16.0 onnxruntime"
450+
pydeps: "onnx==1.16.0 onnxruntime pillow"
451451
llvm: "true"
452452
- name: Set up Docker Buildx
453453
uses: docker/setup-buildx-action@v3
@@ -466,6 +466,8 @@ jobs:
466466
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
467467
- name: Test LLVM=1 DEVECTORIZE=0
468468
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
469+
- name: Test LLVM=1 DEVECTORIZE=0 for model
470+
run: PYTHONPATH="." LLVM=1 DEVECTORIZE=0 python3 test/models/test_efficientnet.py
469471
- name: Test CPU=1 DEVECTORIZE=0
470472
run: CPU=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
471473

test/test_transcendental.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,33 @@ def test_transcendental_exp2_fusion(self):
128128
c = c.exp2()
129129
check_schedule(c, 1)
130130

131+
class TestTranscendentalVectorized(unittest.TestCase):
132+
def _vectorized_data(self, low, high, vec_size):
133+
np_data = np.linspace(low, high, num=(128 // vec_size) * vec_size, dtype=np.float32).reshape(-1, vec_size)
134+
data = Tensor(np_data, dtype=dtypes.float32.vec(vec_size))
135+
return data, np_data
136+
137+
def _test_vectorized_op(self, fxn, np_fxn, data_range, vec_size, param_range=None):
138+
data, np_data = self._vectorized_data(data_range[0], data_range[1], vec_size)
139+
if param_range:
140+
param, np_param = self._vectorized_data(param_range[0], param_range[1], vec_size)
141+
out, np_out = fxn(data, param), np_fxn(np_data, np_param)
142+
else:
143+
out, np_out = fxn(data), np_fxn(np_data)
144+
np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-4)
145+
146+
def test_exp2_vectorized(self):
147+
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.exp2, np.exp2, (-100, 100), vec_size)
148+
149+
def test_log2_vectorized(self):
150+
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.log2, np.log2, (0.001, 200), vec_size)
151+
152+
def test_sin_vectorized(self):
153+
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.sin, np.sin, (-100, 100), vec_size)
154+
155+
def test_pow_vectorized(self):
156+
# np.pow returns nan for negative values raised to a non-integral power
157+
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.pow, np.pow, (0.001, 200), vec_size, param_range=(-10, 10))
158+
131159
if __name__ == '__main__':
132160
unittest.main()

tinygrad/codegen/devectorizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def no_vectorized_acc(acc:UOp):
183183

184184
devectorize_load_store = PatternMatcher([
185185
# TODO: add vectorized support to transcendental
186-
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
186+
(UPat((Ops.INDEX), name="alu"), no_vectorized_alu),
187187
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
188188
])
189189

tinygrad/codegen/transcendental.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,51 +10,51 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
1010
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
1111

1212
# *** helper functions for bit manipulation ***
13-
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
14-
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
15-
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
13+
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
14+
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
15+
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
1616

1717
# **** utils ****
1818
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
1919
def shl(x:UOp, y:int) -> UOp: return x * (2**y)
2020

2121
def rintk(d:UOp) -> UOp:
2222
"""round d:float to int away from 0"""
23-
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
23+
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
2424
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
2525

2626
def pow2if(q:UOp, float_dtype:DType):
2727
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
28-
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
28+
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype.scalar()].vec(q.dtype.vcount)
2929
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
3030

3131
def ilogb2k(d:UOp) -> UOp:
3232
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
33-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
34-
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
33+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
34+
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
3535
# -1 <= ilog2bk(d) <= 128
3636
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
3737

3838
def ldexp3k(d:UOp, e:UOp) -> UOp:
3939
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
40-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
41-
cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
42-
m1 = d.bitcast(cast_map[d.dtype])
43-
m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype))
40+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
41+
dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
42+
m1 = d.bitcast(dtype)
43+
m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
4444
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
4545

4646
def ldexp2k(d:UOp, e:UOp) -> UOp:
4747
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
48-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
48+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
4949
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
5050

5151
def frexp(v:UOp) -> tuple[UOp, UOp]:
5252
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
53-
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
53+
assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
5454
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
55-
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
56-
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
57-
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
55+
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
56+
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
57+
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
5858
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
5959
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
6060
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
@@ -70,7 +70,7 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
7070
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
7171
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
7272
"""
73-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
73+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
7474
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
7575
# 190 bits of 2/pi for Payne-Hanek style argument reduction
7676
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
@@ -172,7 +172,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
172172
- fast=True assumes x <= switch_over.
173173
- switch_over is the threshold for switching to payne_hanek_reduction.
174174
"""
175-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
175+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
176176
# mask +-inf/nan as zero
177177
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
178178
# x_sign = sign(x)
@@ -194,7 +194,7 @@ def xexp2(d:UOp) -> UOp:
194194
Implements a 1.0 ULP approximation for Ops.EXP2
195195
- Paper: https://arxiv.org/pdf/2001.09258
196196
"""
197-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
197+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
198198
# mask +=inf/nan as zero.
199199
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
200200
q = rintk(x)
@@ -207,7 +207,7 @@ def xexp2(d:UOp) -> UOp:
207207
0.6931471805599452862e+0, 0.1000000000000000000e+1])
208208
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
209209
u = ldexp2k(u, q) # u*2^q
210-
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
210+
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
211211
# Replace x >= upper with +inf
212212
u = (d >= upper).where(d.const_like(math.inf), u)
213213
# Replace x < lower with zero.
@@ -220,7 +220,7 @@ def xlog2(d:UOp) -> UOp:
220220
Implements a 1.0 ULP approximation for Ops.LOG2
221221
Paper: https://arxiv.org/pdf/2001.09258 5.5
222222
"""
223-
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
223+
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
224224
# TODO: float16 denormal need float32 to achieve precision
225225
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
226226
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
@@ -248,7 +248,7 @@ def xlog2(d:UOp) -> UOp:
248248
r = (d<-0.0).where(r.const_like(math.nan), r)
249249
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
250250
# log2_zero = the value of unmasked xlog2(0.0).
251-
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
251+
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
252252
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
253253
# log2(NaN) = NaN
254254
r = d.ne(d).where(r.const_like(math.nan), r)

0 commit comments

Comments
 (0)