Skip to content

Commit

Permalink
Support any valid slice, including use of np.s_ and np.index_exp.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tyler Coles committed Aug 30, 2024
1 parent 68ce04f commit 6c5eb68
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 33 deletions.
144 changes: 141 additions & 3 deletions doc/devlog/2024-08-28.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -22,7 +22,7 @@
"array([1, 2, 3])"
]
},
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -34,6 +34,7 @@
"import numpy as np\n",
"\n",
"from epymorph.adrio.numpy import NPY\n",
"from epymorph.error import DataResourceException\n",
"\n",
"array = np.array([1, 2, 3])\n",
"np.save('./scratch/npy_test.npy', arr=array)\n",
Expand All @@ -44,7 +45,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -103,6 +104,143 @@
"\n",
"adrio.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([['This', 'is', 'an'],\n",
" ['for', 'testing', '.']], dtype='<U7')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# single slice\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', slice(0, None, 2))\n",
"adrio.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([['This', 'an'],\n",
" ['array', 'strings']], dtype='<U7')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# np.s_ support\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', np.s_[0:2, ::2])\n",
"adrio.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([['This', 'is', 'an'],\n",
" ['array', 'of', 'strings'],\n",
" ['for', 'testing', '.']], dtype='<U7')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# ellipsis support: note -- fewer indices than axes is actually allowable\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', np.s_[...])\n",
"adrio.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([['This', 'is', 'an'],\n",
" ['array', 'of', 'strings'],\n",
" ['for', 'testing', '.']], dtype='<U7')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# np.index_exp support\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', np.index_exp[...])\n",
"adrio.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataResourceException('Specified array slice is invalid for the shape of this data.')\n"
]
}
],
"source": [
"# ERROR: too many indices\n",
"try:\n",
" adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', np.index_exp[1:3, 1:3, 1:3])\n",
" adrio.evaluate()\n",
"except DataResourceException as e:\n",
" print(repr(e))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DataResourceException('Incorrect file type. Only .npy files can be loaded through NPY ADRIOs.')\n"
]
}
],
"source": [
"# ERROR: wrong file type\n",
"try:\n",
" adrio = NPY(Path('./scratch/npz_test2.npz'))\n",
"except DataResourceException as e:\n",
" print(repr(e))"
]
}
],
"metadata": {
Expand Down
56 changes: 26 additions & 30 deletions epymorph/adrio/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,41 @@
from epymorph.adrio.adrio import Adrio
from epymorph.error import DataResourceException

_SliceLike = slice | type(Ellipsis)
_ArraySlice = _SliceLike | tuple[_SliceLike, ...]


class NPY(Adrio[Any]):
"""Retrieves an array of data from a user-provided .npy file."""

file_path: PathLike
"""The path to the .npy file containing data."""
array_slice: tuple[slice, ...] | None
array_slice: _ArraySlice | None
"""Optional slice(s) of the array to load."""

def __init__(self, file_path: PathLike, array_slice: tuple[slice, ...] | None = None) -> None:
def __init__(self, file_path: PathLike, array_slice: _ArraySlice | None = None) -> None:
if Path(file_path).suffix != '.npy':
msg = 'Incorrect file type. Only .npy files can be loaded through NPY ADRIOs.'
raise DataResourceException(msg)
self.file_path = file_path
self.array_slice = array_slice

@override
def evaluate(self) -> NDArray:
if Path(self.file_path).suffix != '.npy':
msg = 'Incorrect file type. Only .npy files can be loaded through NPY ADRIOs.'
raise DataResourceException(msg)

try:
data = cast(NDArray, np.load(self.file_path))
if self.array_slice is not None:
data = data[self.array_slice]
return data
except OSError as e:
msg = 'File not found.'
raise DataResourceException(msg) from e
except ValueError as e:
msg = 'Object arrays cannot be loaded.'
raise DataResourceException(msg) from e

if self.array_slice is not None and len(self.array_slice) != 0:
if len(self.array_slice) != data.ndim:
msg = 'One slice is required for each array axis.'
raise DataResourceException(msg)

data = data[self.array_slice]

return data
except IndexError as e:
msg = 'Specified array slice is invalid for the shape of this data.'
raise DataResourceException(msg) from e


class NPZ(Adrio[Any]):
Expand All @@ -54,33 +53,30 @@ class NPZ(Adrio[Any]):
"""The path to the .npz file containing data."""
array_name: str
"""The name of the array in the .npz file to load."""
array_slice: tuple[slice, ...] | None
array_slice: _ArraySlice | None
"""Optional slice(s) of the array to load."""

def __init__(self, file_path: PathLike, array_name: str, array_slice: tuple[slice, ...] | None = None) -> None:
def __init__(self, file_path: PathLike, array_name: str, array_slice: _ArraySlice | None = None) -> None:
if Path(file_path).suffix != '.npz':
msg = 'Incorrect file type. Only .npz files can be loaded through NPZ ADRIOs.'
raise DataResourceException(msg)
self.file_path = file_path
self.array_name = array_name
self.array_slice = array_slice

@override
def evaluate(self) -> NDArray:
if Path(self.file_path).suffix != '.npz':
msg = 'Incorrect file type. Only .npz files can be loaded through NPZ ADRIOs.'
raise DataResourceException(msg)

try:
data = cast(NDArray, np.load(self.file_path)[self.array_name])
if self.array_slice is not None:
data = data[self.array_slice]
return data
except OSError as e:
msg = 'File not found.'
raise DataResourceException(msg) from e
except ValueError as e:
msg = 'Object arrays cannot be loaded.'
raise DataResourceException(msg) from e

if self.array_slice is not None and len(self.array_slice) != 0:
if len(self.array_slice) != data.ndim:
msg = 'One slice is required for each array axis.'
raise DataResourceException(msg)

data = data[self.array_slice]

return data
except IndexError as e:
msg = 'Specified array slice is invalid for the shape of this data.'
raise DataResourceException(msg) from e

0 comments on commit 6c5eb68

Please sign in to comment.