Skip to content

[pull] master from tinygrad:master #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/actions/setup-tinygrad/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ inputs:
description: 'Extra dependency groups (comma separated)'
required: false
default: ''
pydeps:
description: 'Extra Python dependency groups (space separated)'
required: false
default: ''
opencl:
description: "Install OpenCL?"
required: false
Expand Down Expand Up @@ -83,11 +87,11 @@ runs:
- name: Install dependencies (with extra)
if: inputs.deps != ''
shell: bash
run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e ".[${{ inputs.deps }}]" --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e ".[${{ inputs.deps }}]" ${{ inputs.pydeps }} --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
- name: Install dependencies (without extra)
if: inputs.deps == ''
shell: bash
run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e .
run: pip install ${{ (runner.os == 'macOS' && '--user') || (runner.os != 'macOS' && '') }} -e . ${{ inputs.pydeps }}

# **** OpenCL ****

Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ jobs:
uses: ./.github/actions/setup-tinygrad
with:
deps: docs
- name: Install capstone for CLANG disassembly
run: pip install capstone
pydeps: "capstone"
- name: Use as an external package
run: |
mkdir $HOME/test_external_dir
Expand Down Expand Up @@ -403,6 +402,9 @@ jobs:
uses: actions/checkout@v4
- name: Setup Environment
uses: ./.github/actions/setup-tinygrad
with:
key: dsp
pydeps: "onnx==1.16.0 onnxruntime"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build QEMU Docker with cache
Expand All @@ -416,6 +418,8 @@ jobs:
cache-to: type=gha,mode=min
- name: Run test_tiny on DSP
run: DEBUG=2 DSP=1 python test/test_tiny.py
- name: Test quantize onnx
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py

testwebgpu:
name: Linux (WebGPU)
Expand Down
4 changes: 2 additions & 2 deletions test/test_linearizer_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ def test_failure_56(self):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=2, arg=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])

def test_failure_57(self):
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
Expand Down Expand Up @@ -1409,7 +1409,7 @@ def test_failure_57(self):
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"])
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])

if __name__ == '__main__':
unittest.main()
88 changes: 88 additions & 0 deletions test/test_quantize_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
import unittest
from tinygrad import Tensor, Context, Device
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem

N = 1024

def create_gemm_model(model_path:str, in_size=N, out_size=N):
import onnx
from onnx import helper, numpy_helper, TensorProto
# Define input and output
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, in_size])
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, out_size])

# Create random weights and bias
W_data = np.random.randn(in_size, out_size).astype(np.float32)
B_data = np.random.randn(out_size).astype(np.float32)

W_init = numpy_helper.from_array(W_data, name="W")
B_init = numpy_helper.from_array(B_data, name="B")

gemm_node = helper.make_node("Gemm", inputs=["input", "W", "B"], outputs=["output"], alpha=1.0, beta=1.0, transB=0)
graph_def = helper.make_graph([gemm_node], "SingleGemmGraph", [input_tensor], [output_tensor], initializer=[W_init, B_init])

# Create and save the model
model_def = helper.make_model(graph_def, producer_name="single_gemm_example")
onnx.save_model(model_def, model_path)
return model_path

def sexec(out:Tensor, opts:list[Opt]):
si = out.schedule()[-1]
k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer)
#opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128)] #, Opt(op=OptOps.UNROLL, axis=0, arg=4)]
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
for _ in range(3): ei.run(wait=True)

@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
class TestQuantizeOnnx(unittest.TestCase):
def test_quant(self):
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
from examples.benchmark_onnx import load_onnx_model
class FakeDataReader(CalibrationDataReader):
def __init__(self): self.cnt = 0
def get_next(self) -> dict:
self.cnt += 1
if self.cnt == 100: return None
return {"input": np.random.uniform(size=(1, N)).astype(np.float32)}
out_file = "/tmp/test_out.onnx"
quantize_static(create_gemm_model("/tmp/test_in.onnx"), out_file,
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False,
activation_type=QuantType.QInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": True})
run_onnx_jit, _ = load_onnx_model(out_file)
with Context(NOOPT=1):
run_onnx_jit(input=Tensor(np.random.uniform(size=(1, N)).astype(np.float32)))

def test_prequant_conv2d_1x1(self):
X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8))
W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8))
out = X.conv2d(W, acc_dtype=X.dtype)
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)

def test_prequant_gemm(self):
N = 512
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
out = X.matmul(W, acc_dtype=X.dtype)
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)

def test_prequant_gemv(self):
N = 2048
# ugh, it's so broken with those casts. need DONT_REALIZE_EXPAND=1 python3 test/test_quantize_onnx.py TestQuantizeOnnx.test_prequant
X = Tensor(np.random.uniform(0, 255, size=(1,N)).astype(np.uint8))
W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
#out = X.cast(dtypes.int) @ W.cast(dtypes.int)
#out = X @ W
out = X.matmul(W, acc_dtype=X.dtype)
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
sexec(out, opts)

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
elif opt.op is OptOps.UPCAST: # yellow
check(axis < self.first_reduce, "upcast is for non-reduce")
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
check(amt <= 16, "don't upcast more than 16")
check((self.opts is not None and self.opts.device == "DSP") or amt <= 16, "don't upcast more than 16")
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op is OptOps.NOLOCALS:
Expand Down
11 changes: 9 additions & 2 deletions tinygrad/codegen/rewriter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Any, Callable
import functools, itertools, operator
from collections import defaultdict
from tinygrad.device import Device
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
Expand All @@ -11,11 +12,18 @@
# ***** float4/image store handling *****

def fold_expanded(ex, buf):
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
new_srcs = dedup(list(ex.src))
old_new_srcs = new_srcs[:]
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)

# TODO: get the device from the buffer somehow
if Device.DEFAULT == "DSP":
if buf.dtype.base == dtypes.bool: return None
lengths = [128,4]
else:
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))

# first, extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict] = defaultdict(dict)
for i,s in enumerate(new_srcs):
Expand All @@ -30,7 +38,6 @@ def fold_expanded(ex, buf):
offsets_rootsrc[root_src][arg] = i

# then rewrite everything we can
lengths = [4] if is_image else ([8,4,2] if buf.dtype.base == dtypes.half and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
used: set[tuple[UOp, UOp]] = set()
for rootsrc, offsets in offsets_rootsrc.items():
for o in offsets:
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,8 +1170,8 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
])

symbolic = symbolic_simple+PatternMatcher([
# ** COMMUTATIVE flipping **
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# ** COMMUTATIVE flipping (only for ints) **
(UPat(GroupOp.Commutative, dtype=dtypes.int, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms **
Expand Down
8 changes: 5 additions & 3 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CLANG', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.dtype.size}];"),
Expand Down Expand Up @@ -49,7 +49,8 @@
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \
f".{'xyzwabcd'[x.arg[0]]}")),
])

extra_pm = PatternMatcher([
Expand Down Expand Up @@ -104,7 +105,8 @@ def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType):
return (self.smem_prefix if dt.local and self.smem_prefix_for_cast else self.buffer_prefix) + self.render_dtype(dt.base) + "*"
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
if dt.count > 1: return self.type_map.get(scalar:=dt.scalar(), scalar.name).replace(" ", "_") + str(dt.count)
return self.type_map.get(scalar:=dt.scalar(), scalar.name)

def __getitem__(self, key): return self.r[key] # hacky helper
def render(self, name:str, uops:list[UOp]) -> str:
Expand Down
39 changes: 31 additions & 8 deletions tinygrad/runtime/ops_dsp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, time, struct
import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys, subprocess, struct
assert sys.platform != 'win32'
from tinygrad.device import BufferSpec, Compiled, Allocator, Compiler, MallocAllocator
from tinygrad.dtype import dtypes, DType, PtrDType
Expand All @@ -9,9 +9,30 @@
from tinygrad.runtime.autogen import libc, qcom_dsp
if getenv("IOCTL"): import extra.dsp.run # noqa: F401 # pylint: disable=unused-import

from tinygrad.helpers import all_same
from tinygrad.ops import PatternMatcher, UPat, GroupOp

def revectorize(v:UOp):
if not all_same([x.op for x in v.src]) or any(dtypes.is_bool(x.dtype) for x in v.src[0].src): return None
new_srcs = [UOp(Ops.VECTORIZE, v.src[0].src[i].dtype.vec(v.dtype.count), tuple(x.src[i] for x in v.src)) for i in range(len(v.src[0].src))]
return UOp(v.src[0].op, v.dtype, tuple(new_srcs), v.src[0].arg)

revectorize_pm = PatternMatcher([
(UPat(Ops.VECTORIZE, src=UPat((*GroupOp.ALU, Ops.ASSIGN, Ops.CAST)), name="v"), revectorize),
# vectorize DEFINE_ACC (similar to expander)
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC), name="v"),
lambda v: UOp(Ops.DEFINE_ACC, v.dtype,
(UOp.broadcast(UOp.const(v.dtype.scalar(), v.src[0].src[0].arg), v.dtype.count),)+v.src[0].src[1:], v.src[0].arg)),
# vectorize increasing GEPs = nothing (wrong if dtypes don't match!)
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP), name="v"),
lambda v: v.src[0].src[0] if all_same([x.src for x in v.src]) and \
[x.arg[0] if len(x.arg) == 1 else None for x in v.src] == list(range(v.dtype.count)) else None),
])

class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = False
supports_float4 = True
extra_matcher = revectorize_pm+ClangRenderer.extra_matcher
buffer_suffix = " restrict __attribute__((align_value(128)))"
kernel_prefix = "__attribute__((noinline)) "
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
Expand Down Expand Up @@ -231,12 +252,14 @@ class MockDSPRenderer(DSPRenderer):
def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str,tuple[DType,bool]]], uops:list[UOp], prefix=None) -> str:
ret = ClangRenderer.render_kernel(self, function_name, kernel, bufs, uops, prefix)
# https://gpages.juszkiewicz.com.pl/syscalls-table/syscalls.html
# control register 21 is HEX_REG_QEMU_INSN_CNT, 0x6a15c000 loads it
msrc = ['''static long syscall(long r0, long r1, long r2, long r3, long r4, long r5, long r6) {
long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = #%7; trap0(#1); %0 = r0" : "=r" (retval)
: "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "i" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; }
long retval; __asm__ volatile("r0 = %1; r1 = %2; r2 = %3; r3 = %4; r4 = %5; r5 = %6; r6 = %7; trap0(#1); %0 = r0" : "=r" (retval)
: "r" (r0), "r" (r1), "r" (r2), "r" (r3), "r" (r4), "r" (r5), "r" (r6) : "r0", "r1", "r2", "r3", "r4", "r5", "r6"); return retval; }
static int read(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 63); }}
static int write(int fd, void* buf, int len) {{ return syscall(fd, (long)buf, len, 0, 0, 0, 64); }}
static int exit(int ret) {{ return syscall(ret, 0, 0, 0, 0, 0, 93); }}
static unsigned int inscount(void) {{ unsigned int ret; __asm__ volatile(".word 0x6a15c000; %0 = R0" : "=r" (ret) : : "r0"); return ret; }}
static void *mmap2(void *addr, unsigned int length, int prot, int flags, int fd, unsigned long offset) {{
return (void*)syscall((long)addr, length, prot, flags, fd, offset, 222); }}''', 'void _start(void) {']
for i,b in enumerate(bufs):
Expand All @@ -245,7 +268,9 @@ def render_kernel(self, function_name:str, kernel:list[str], bufs:list[tuple[str
msrc.append(f"void *buf{i} = mmap2(0, {sz}, 3, 0x21, -1, 0); read(0, buf{i}, {sz});")
else:
msrc.append(f"unsigned int val{i}; read(0, &val{i}, 4);")
msrc.append("unsigned int st = inscount();")
msrc.append(f"{function_name}({', '.join([(f'(void*)buf{i}' if isinstance(b[1][0], PtrDType) else f'val{i}') for i,b in enumerate(bufs)])});")
msrc.append("unsigned int et = inscount() - st; write(1, &et, sizeof(et));")
for i,b in enumerate(bufs):
if isinstance(b[1][0], PtrDType): msrc.append(f"write(1, buf{i}, {b[1][0].size*b[1][0].itemsize});")
msrc.append('exit(0); }')
Expand All @@ -259,13 +284,11 @@ def __call__(self, *bufs, vals:tuple[int, ...]=(), wait=False):
dsp_lib.flush()
os.chmod(dsp_lib.name, 0o0777)
# NOTE: this timing includes a docker launch
start = time.perf_counter()
proc = subprocess.run(["docker", "run", "--rm", "-i", "-v", f"{os.path.abspath(os.path.dirname(dsp_lib.name))}:/work", "-w", "/work",
"qemu-hexagon", "-c", f"qemu-hexagon {'-strace' if DEBUG >= 3 else ''} /work/"+os.path.basename(dsp_lib.name)],
input=b''.join([bytes(x) for x in bufs] + [struct.pack("I", x) for x in vals]), stdout=subprocess.PIPE, check=True)
elapsed = time.perf_counter() - start
offset = 0
offset = 4
for x in bufs:
x[:] = proc.stdout[offset:offset+len(x)]
offset += len(x)
return elapsed
return struct.unpack("I", proc.stdout[0:4])[0] / 1e9 # pretend it's 1 Ghz, but this is an inscount, not a time
4 changes: 1 addition & 3 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3307,9 +3307,7 @@ def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
# NOTE: the mid-point is for backward, revisit after new gradient API
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
return (self<x).detach().where(x, self)
return self._apply_broadcasted_uop(UOp.maximum, x)

def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
Expand Down
Loading