Skip to content

Commit

Permalink
Move private function to the bottom of the file.
Browse files Browse the repository at this point in the history
  • Loading branch information
alxmrs committed Jan 15, 2025
1 parent 034024f commit 4364d78
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,33 +113,6 @@ def min(x, /, *, axis=None, keepdims=False, split_every=None):
)


def _validate_and_define_numeric_or_bool_dtype(x, dtype=None, *, fname=None, device=None):
"""Validate the type of the numeric function. If it's None, provide a good default dtype."""
dtypes = __array_namespace_info__().default_dtypes(device=device)

# Validate.
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
errmsg = "Only numeric or boolean dtypes are allowed"
if fname:
errmsg += f" in {fname}"
raise TypeError(errmsg)

# Choose a good default dtype, when None
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = dtypes['integral']
elif x.dtype in _signed_integer_dtypes:
dtype = dtypes['integral']
elif x.dtype in _unsigned_integer_dtypes:
#TODO(#658): I don't think "indexing" --> uint64; is this correct?
dtype = dtypes['indexing']
else:
dtype = x.dtype

return dtype


def prod(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None, device=None):
dtype = _validate_and_define_numeric_or_bool_dtype(x, dtype, fname="prod", device=device)
extra_func_kwargs = dict(dtype=dtype)
Expand Down Expand Up @@ -249,3 +222,30 @@ def _var_combine(a, axis=None, correction=None, **kwargs):

def _var_aggregate(a, correction=None, **kwargs):
return nxp.divide(a["M2"], a["n"] - correction)


def _validate_and_define_numeric_or_bool_dtype(x, dtype=None, *, fname=None, device=None):
"""Validate the type of the numeric function. If it's None, provide a good default dtype."""
dtypes = __array_namespace_info__().default_dtypes(device=device)

# Validate.
# boolean is allowed by numpy
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
errmsg = "Only numeric or boolean dtypes are allowed"
if fname:
errmsg += f" in {fname}"
raise TypeError(errmsg)

# Choose a good default dtype, when None
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = dtypes['integral']
elif x.dtype in _signed_integer_dtypes:
dtype = dtypes['integral']
elif x.dtype in _unsigned_integer_dtypes:
#TODO(#658): I don't think "indexing" --> uint64; is this correct?
dtype = dtypes['indexing']
else:
dtype = x.dtype

return dtype

0 comments on commit 4364d78

Please sign in to comment.