Skip to content

Commit

Permalink
Additional error handling and typing adjustments.
Browse files Browse the repository at this point in the history
  • Loading branch information
TJohnsonAZ committed Aug 29, 2024
1 parent d3b73de commit 68ce04f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
10 changes: 5 additions & 5 deletions doc/devlog/2024-08-28.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -66,7 +66,7 @@
"np.savez('./scratch/npz_test.npz', arr=array, arr2=array2)\n",
"\n",
"# this adrio usees a slice to exclude the first element\n",
"adrio = NPZ(Path('./scratch/npz_test.npz'), 'arr', [slice(1, 3)])\n",
"adrio = NPZ(Path('./scratch/npz_test.npz'), 'arr', np.s_[1:3,])\n",
"adrio2 = NPZ(Path('./scratch/npz_test.npz'), 'arr2')\n",
"\n",
"\n",
Expand All @@ -76,7 +76,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -86,7 +86,7 @@
" ['of', 'strings']], dtype='<U7')"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -99,7 +99,7 @@
"np.savez('./scratch/npz_test2.npz', array)\n",
"\n",
"# slice each axis of the array individually\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', [slice(0, 2), slice(1, 3)])\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', (slice(0, 2), slice(1, 3)))\n",
"\n",
"adrio.evaluate()"
]
Expand Down
37 changes: 19 additions & 18 deletions epymorph/adrio/numpy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os import PathLike
from typing import Any, Sequence, cast
from pathlib import Path
from typing import Any, cast

import numpy as np
from numpy.typing import NDArray
Expand All @@ -14,15 +15,19 @@ class NPY(Adrio[Any]):

file_path: PathLike
"""The path to the .npy file containing data."""
array_slice: Sequence[slice] | None
array_slice: tuple[slice, ...] | None
"""Optional slice(s) of the array to load."""

def __init__(self, file_path: PathLike, array_slice: Sequence[slice] | None = None) -> None:
def __init__(self, file_path: PathLike, array_slice: tuple[slice, ...] | None = None) -> None:
self.file_path = file_path
self.array_slice = array_slice

@override
def evaluate(self) -> NDArray:
if Path(self.file_path).suffix != '.npy':
msg = 'Incorrect file type. Only .npy files can be loaded through NPY ADRIOs.'
raise DataResourceException(msg)

try:
data = cast(NDArray, np.load(self.file_path))
except OSError as e:
Expand All @@ -32,16 +37,12 @@ def evaluate(self) -> NDArray:
msg = 'Object arrays cannot be loaded.'
raise DataResourceException(msg) from e

if self.array_slice is not None:
if self.array_slice is not None and len(self.array_slice) != 0:
if len(self.array_slice) != data.ndim:
msg = 'One slice is required for each array axis.'
raise DataResourceException(msg)

for axis, curr_slice in enumerate(self.array_slice):
data = data.take(
indices=range(curr_slice.start, curr_slice.stop),
axis=axis
)
data = data[self.array_slice]

return data

Expand All @@ -53,33 +54,33 @@ class NPZ(Adrio[Any]):
"""The path to the .npz file containing data."""
array_name: str
"""The name of the array in the .npz file to load."""
array_slice: Sequence[slice] | None
array_slice: tuple[slice, ...] | None
"""Optional slice(s) of the array to load."""

def __init__(self, file_path: PathLike, array_name: str, array_slice: Sequence[slice] | None = None) -> None:
def __init__(self, file_path: PathLike, array_name: str, array_slice: tuple[slice, ...] | None = None) -> None:
self.file_path = file_path
self.array_name = array_name
self.array_slice = array_slice

def evaluate(self) -> NDArray:
if Path(self.file_path).suffix != '.npz':
msg = 'Incorrect file type. Only .npz files can be loaded through NPZ ADRIOs.'
raise DataResourceException(msg)

try:
data = cast(NDArray, np.load(self.file_path))
data = cast(NDArray, np.load(self.file_path)[self.array_name])
except OSError as e:
msg = 'File not found.'
raise DataResourceException(msg) from e
except ValueError as e:
msg = 'Object arrays cannot be loaded.'
raise DataResourceException(msg) from e

if self.array_slice is not None:
if self.array_slice is not None and len(self.array_slice) != 0:
if len(self.array_slice) != data.ndim:
msg = 'One slice is required for each array axis.'
raise DataResourceException(msg)

for axis, curr_slice in enumerate(self.array_slice):
data = data.take(
indices=range(curr_slice.start, curr_slice.stop),
axis=axis
)
data = data[self.array_slice]

return data

0 comments on commit 68ce04f

Please sign in to comment.