Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using default_dtypes instead of hard-coding dtypes. #666

Merged
merged 17 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import TYPE_CHECKING, Iterable, List

from cubed.array_api import __array_namespace_info__
from cubed.backend_array_api import namespace as nxp
from cubed.core import Plan, gensym
from cubed.core.ops import map_blocks
Expand All @@ -25,6 +26,7 @@ def arange(
num = int(max(math.ceil((stop - start) / step), 0))
if dtype is None:
dtype = nxp.arange(start, stop, step * num if num else step).dtype

chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand Down Expand Up @@ -62,8 +64,8 @@ def asarray(
): # pragma: no cover
return asarray(a.data)
elif not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
a = nxp.asarray(a, dtype=dtype)

if dtype is None:
dtype = a.dtype

Expand All @@ -89,8 +91,9 @@ def empty_like(x, /, *, dtype=None, device=None, chunks=None, spec=None) -> "Arr
def empty_virtual_array(
shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True
) -> "Array":
dtypes = __array_namespace_info__().default_dtypes(device=device)
if dtype is None:
dtype = nxp.float64
dtype = dtypes["real floating"]

chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
name = gensym()
Expand All @@ -105,10 +108,11 @@ def empty_virtual_array(
def eye(
n_rows, n_cols=None, /, *, k=0, dtype=None, device=None, chunks="auto", spec=None
) -> "Array":
dtypes = __array_namespace_info__().default_dtypes(device=device)
if n_cols is None:
n_cols = n_rows
if dtype is None:
dtype = nxp.float64
dtype = dtypes["real floating"]

shape = (n_rows, n_cols)
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
Expand Down Expand Up @@ -136,17 +140,18 @@ def _eye(x, k=None, chunksize=None, block_id=None):
def full(
shape, fill_value, *, dtype=None, device=None, chunks="auto", spec=None
) -> "Array":
dtypes = __array_namespace_info__().default_dtypes(device=device)
shape = normalize_shape(shape)
if dtype is None:
# check bool first since True/False are instances of int and float
if isinstance(fill_value, bool):
dtype = nxp.bool
elif isinstance(fill_value, int):
dtype = nxp.int64
dtype = dtypes["integral"]
elif isinstance(fill_value, float):
dtype = nxp.float64
dtype = dtypes["real floating"]
elif isinstance(fill_value, complex):
dtype = nxp.complex128
dtype = dtypes["complex floating"]
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down Expand Up @@ -187,13 +192,15 @@ def linspace(
chunks="auto",
spec=None,
) -> "Array":
dtypes = __array_namespace_info__().default_dtypes(device=device)

range_ = stop - start
div = (num - 1) if endpoint else num
if div == 0:
div = 1
step = float(range_) / div
if dtype is None:
dtype = nxp.float64
dtype = dtypes["real floating"]
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand All @@ -210,15 +217,23 @@ def linspace(
step=step,
endpoint=endpoint,
linspace_dtype=dtype,
device=device,
)


def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None):
def _linspace(
x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None
):
dtypes = __array_namespace_info__().default_dtypes(device=device)

bs = x.shape[0]
i = block_id[0]
adjusted_bs = bs - 1 if endpoint else bs
blockstart = start + (i * size * step)
blockstop = blockstart + (adjusted_bs * step)

# float_ is a type casting function.
float_ = dtypes["real floating"].type
blockstart = float_(start + (i * size * step))
blockstop = float_(blockstart + float_(adjusted_bs * step))
return nxp.linspace(
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
)
Expand Down Expand Up @@ -256,8 +271,10 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]:


def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
dtypes = __array_namespace_info__().default_dtypes(device=device)

if dtype is None:
dtype = nxp.float64
dtype = dtypes["real floating"]
return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down Expand Up @@ -302,8 +319,10 @@ def _tri_mask(N, M, k, chunks, spec):


def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
dtypes = __array_namespace_info__().default_dtypes(device=device)

if dtype is None:
dtype = nxp.float64
dtype = dtypes["real floating"]
return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down
30 changes: 30 additions & 0 deletions cubed/array_api/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copied from numpy.array_api
from cubed.array_api.inspection import __array_namespace_info__
from cubed.backend_array_api import namespace as nxp

int8 = nxp.int8
Expand Down Expand Up @@ -86,3 +87,32 @@
"complex floating-point": _complex_floating_dtypes,
"floating-point": _floating_dtypes,
}


# A Cubed-specific utility.
def _upcast_integral_dtypes(x, dtype=None, *, allowed_dtypes=("numeric",), fname=None, device=None):
"""Ensure the input dtype is allowed. If it's None, provide a good default dtype."""
dtypes = __array_namespace_info__().default_dtypes(device=device)

# Validate.
is_invalid = all(x.dtype not in _dtype_categories[a] for a in allowed_dtypes)
if is_invalid:
errmsg = f"Only {' or '.join(allowed_dtypes)} dtypes are allowed"
if fname:
errmsg += f" in {fname}"
raise TypeError(errmsg)

# Choose a good default dtype, when None
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = dtypes["integral"]
elif x.dtype in _signed_integer_dtypes:
dtype = dtypes["integral"]
elif x.dtype in _unsigned_integer_dtypes:
# Type arithmetic to produce an unsigned integer dtype at the same default precision.
default_bits = nxp.iinfo(dtypes["integral"]).bits
dtype = nxp.dtype(f"u{default_bits // 8}")
else:
dtype = x.dtype

return dtype
39 changes: 7 additions & 32 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import math

from cubed.array_api.dtypes import (
_boolean_dtypes,
_numeric_dtypes,
_real_floating_dtypes,
_real_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
int64,
uint64,
_upcast_integral_dtypes,
)
from cubed.array_api.elementwise_functions import sqrt
from cubed.backend_array_api import namespace as nxp
Expand All @@ -35,6 +30,7 @@ def mean(x, /, *, axis=None, keepdims=False, split_every=None):
# pair of fields needed to keep per-chunk counts and totals for computing
# the mean.
dtype = x.dtype
#TODO(#658): Should these be default dtypes?
intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype)
return reduction(
Expand Down Expand Up @@ -113,19 +109,8 @@ def min(x, /, *, axis=None, keepdims=False, split_every=None):
)


def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
else:
dtype = x.dtype
def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="prod", device=device)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand All @@ -150,19 +135,8 @@ def std(x, /, *, axis=None, correction=0.0, keepdims=False, split_every=None):
)


def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
else:
dtype = x.dtype
def sum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric", "boolean",), fname="sum", device=device)
extra_func_kwargs = dict(dtype=dtype)
return reduction(
x,
Expand All @@ -189,6 +163,7 @@ def var(
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
dtype = x.dtype
#TODO(#658): Should these be default dtypes?
intermediate_dtype = [("n", nxp.int64), ("mu", nxp.float64), ("M2", nxp.float64)]
extra_func_kwargs = dict(dtype=intermediate_dtype, correction=correction)
return reduction(
Expand Down
28 changes: 3 additions & 25 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import numpy as np

from cubed.array_api.dtypes import (
_numeric_dtypes,
_signed_integer_dtypes,
_unsigned_integer_dtypes,
complex64,
complex128,
float32,
float64,
int64,
uint64,
)
from cubed.array_api.dtypes import _upcast_integral_dtypes
from cubed.backend_array_api import namespace as nxp
from cubed.core import reduction

Expand Down Expand Up @@ -60,21 +50,9 @@ def _nannumel(x, **kwargs):
return nxp.sum(~(nxp.isnan(x)), **kwargs)


def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None):
def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
"""Return the sum of array elements over a given axis treating NaNs as zero."""
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in nansum")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
dtype = int64
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = float64
elif x.dtype == complex64:
dtype = complex128
else:
dtype = x.dtype
dtype = _upcast_integral_dtypes(x, dtype, allowed_dtypes=("numeric",), fname="nansum", device=device)
return reduction(
x,
nxp.nansum,
Expand Down
Loading