|
| 1 | +import numpy as np |
| 2 | +import ctypes |
| 3 | +from tinygrad import Tensor, GlobalCounters, Context |
| 4 | +from tinygrad.engine.realize import lower_schedule, CompiledRunner |
| 5 | +from tinygrad.device import CPUProgram |
| 6 | +from dataclasses import replace |
| 7 | +from keystone import Ks, KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN |
| 8 | + |
| 9 | +# only the memory access, over 100 GB/s! (sometimes) |
| 10 | +reduce_asm = """ |
| 11 | +movi v0.2d, #0000000000000000 |
| 12 | +mov w9, #0x30 |
| 13 | +mov w10, #0x20 |
| 14 | +mov x8, #-0x10 |
| 15 | +movi v1.2d, #0000000000000000 |
| 16 | +movk w9, #0x300, lsl #16 |
| 17 | +movi v2.2d, #0000000000000000 |
| 18 | +movk w10, #0x200, lsl #16 |
| 19 | +movi v3.2d, #0000000000000000 |
| 20 | +mov w11, #0x1000000 |
| 21 | +mov w12, #0x3ffff0 |
| 22 | +loop: |
| 23 | +ldp q4, q5, [x1] |
| 24 | +add x13, x1, x11 |
| 25 | +add x15, x1, x10 |
| 26 | +add x14, x1, x9 |
| 27 | +add x8, x8, #0x10 |
| 28 | +cmp x8, x12 |
| 29 | +ldp q6, q7, [x1, #0x20] |
| 30 | +add x1, x1, #0x40 |
| 31 | +ldp q4, q5, [x13] |
| 32 | +ldp q6, q7, [x13, #0x20] |
| 33 | +ldp q4, q5, [x15, #-0x20] |
| 34 | +ldp q6, q7, [x15] |
| 35 | +ldp q4, q5, [x14, #-0x30] |
| 36 | +ldp q6, q7, [x14, #-0x10] |
| 37 | +b.lo loop |
| 38 | +fadd v0.4s, v1.4s, v0.4s |
| 39 | +fadd v0.4s, v2.4s, v0.4s |
| 40 | +fadd v0.4s, v3.4s, v0.4s |
| 41 | +dup v1.4s, v0.s[1] |
| 42 | +dup v2.4s, v0.s[2] |
| 43 | +fadd v1.4s, v0.4s, v1.4s |
| 44 | +dup v0.4s, v0.s[3] |
| 45 | +fadd v1.4s, v2.4s, v1.4s |
| 46 | +fadd v0.4s, v0.4s, v1.4s |
| 47 | +str s0, [x0] |
| 48 | +ret |
| 49 | +""" |
| 50 | + |
| 51 | +ks = Ks(KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN) |
| 52 | +arm_bytecode, _ = ks.asm(reduce_asm) |
| 53 | +arm_bytecode = bytes(arm_bytecode) |
| 54 | + |
| 55 | +reduce_src = """ |
| 56 | +// data1 is 16M inputs |
| 57 | +typedef float float4 __attribute__((aligned(32),vector_size(16))); |
| 58 | +void reduce(float* restrict data0, float* restrict data1) { |
| 59 | + float4 acc0 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 60 | + float4 acc1 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 61 | + float4 acc2 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 62 | + float4 acc3 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 63 | + float4 acc4 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 64 | + float4 acc5 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 65 | + float4 acc6 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 66 | + float4 acc7 = {0.0f, 0.0f, 0.0f, 0.0f}; |
| 67 | + float* data1_1 = data1+4194304; |
| 68 | + float* data1_2 = data1+(4194304*2); |
| 69 | + float* data1_3 = data1+(4194304*3); |
| 70 | + for (int ridx0 = 0; ridx0 < 16777216/4; ridx0+=16) { |
| 71 | + float4 val0 = *(float4*)((data1+(ridx0+0))); |
| 72 | + float4 val1 = *(float4*)((data1+(ridx0+4))); |
| 73 | + float4 val2 = *(float4*)((data1+(ridx0+8))); |
| 74 | + float4 val3 = *(float4*)((data1+(ridx0+12))); |
| 75 | + acc0 += val0; |
| 76 | + acc1 += val1; |
| 77 | + acc2 += val2; |
| 78 | + acc3 += val3; |
| 79 | + val0 = *(float4*)((data1_1+(ridx0+0))); |
| 80 | + val1 = *(float4*)((data1_1+(ridx0+4))); |
| 81 | + val2 = *(float4*)((data1_1+(ridx0+8))); |
| 82 | + val3 = *(float4*)((data1_1+(ridx0+12))); |
| 83 | + acc4 += val0; |
| 84 | + acc5 += val1; |
| 85 | + acc6 += val2; |
| 86 | + acc7 += val3; |
| 87 | + val0 = *(float4*)((data1_2+(ridx0+0))); |
| 88 | + val1 = *(float4*)((data1_2+(ridx0+4))); |
| 89 | + val2 = *(float4*)((data1_2+(ridx0+8))); |
| 90 | + val3 = *(float4*)((data1_2+(ridx0+12))); |
| 91 | + acc0 += val0; |
| 92 | + acc1 += val1; |
| 93 | + acc2 += val2; |
| 94 | + acc3 += val3; |
| 95 | + val0 = *(float4*)((data1_3+(ridx0+0))); |
| 96 | + val1 = *(float4*)((data1_3+(ridx0+4))); |
| 97 | + val2 = *(float4*)((data1_3+(ridx0+8))); |
| 98 | + val3 = *(float4*)((data1_3+(ridx0+12))); |
| 99 | + acc4 += val0; |
| 100 | + acc5 += val1; |
| 101 | + acc6 += val2; |
| 102 | + acc7 += val3; |
| 103 | + } |
| 104 | + float4 out = acc0+acc1+acc2+acc3+acc4+acc5+acc6+acc7; |
| 105 | + *(data0+0) = out[0]+out[1]+out[2]+out[3]; |
| 106 | +} |
| 107 | +""" |
| 108 | + |
| 109 | +if __name__ == "__main__": |
| 110 | + a = Tensor(np_array:=(np.random.default_rng().random((4096, 4096), dtype=np.float32)-0.5)).realize() |
| 111 | + with Context(SPLIT_REDUCEOP=0): |
| 112 | + # TODO: make it easy to alter the OptOps for a ScheduleItem |
| 113 | + GlobalCounters.reset() |
| 114 | + out = a.sum() |
| 115 | + sis = out.schedule() |
| 116 | + for i,ei in enumerate(lower_schedule(sis)): |
| 117 | + if i == 0: |
| 118 | + # change the source code |
| 119 | + prg_spec = ei.prg.p |
| 120 | + prg_spec = replace(prg_spec, name="reduce", src=reduce_src) |
| 121 | + prg = CompiledRunner(prg_spec) |
| 122 | + # change the assembly |
| 123 | + #prg._prg = CPUProgram(prg_spec.name, arm_bytecode) |
| 124 | + print("buffer at:",hex(ctypes.addressof(ei.bufs[1]._buf))) |
| 125 | + ei = replace(ei, prg=prg) |
| 126 | + ei.run() |
| 127 | + print(out.item()) |
| 128 | + np.testing.assert_allclose(out.item(), np_array.sum(), atol=1, rtol=1e-4) |
0 commit comments