From 68ce04f211c17206eba625b95580b41aafa6f13c Mon Sep 17 00:00:00 2001 From: TJohnsonAZ Date: Thu, 29 Aug 2024 14:54:39 -0700 Subject: [PATCH] Additional error handling and typing adjustments. --- doc/devlog/2024-08-28.ipynb | 10 +++++----- epymorph/adrio/numpy.py | 37 +++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/doc/devlog/2024-08-28.ipynb b/doc/devlog/2024-08-28.ipynb index 25e83112..ee9462f5 100644 --- a/doc/devlog/2024-08-28.ipynb +++ b/doc/devlog/2024-08-28.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -66,7 +66,7 @@ "np.savez('./scratch/npz_test.npz', arr=array, arr2=array2)\n", "\n", "# this adrio usees a slice to exclude the first element\n", - "adrio = NPZ(Path('./scratch/npz_test.npz'), 'arr', [slice(1, 3)])\n", + "adrio = NPZ(Path('./scratch/npz_test.npz'), 'arr', np.s_[1:3,])\n", "adrio2 = NPZ(Path('./scratch/npz_test.npz'), 'arr2')\n", "\n", "\n", @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -86,7 +86,7 @@ " ['of', 'strings']], dtype=' None: + def __init__(self, file_path: PathLike, array_slice: tuple[slice, ...] | None = None) -> None: self.file_path = file_path self.array_slice = array_slice @override def evaluate(self) -> NDArray: + if Path(self.file_path).suffix != '.npy': + msg = 'Incorrect file type. Only .npy files can be loaded through NPY ADRIOs.' + raise DataResourceException(msg) + try: data = cast(NDArray, np.load(self.file_path)) except OSError as e: @@ -32,16 +37,12 @@ def evaluate(self) -> NDArray: msg = 'Object arrays cannot be loaded.' raise DataResourceException(msg) from e - if self.array_slice is not None: + if self.array_slice is not None and len(self.array_slice) != 0: if len(self.array_slice) != data.ndim: msg = 'One slice is required for each array axis.' raise DataResourceException(msg) - for axis, curr_slice in enumerate(self.array_slice): - data = data.take( - indices=range(curr_slice.start, curr_slice.stop), - axis=axis - ) + data = data[self.array_slice] return data @@ -53,17 +54,21 @@ class NPZ(Adrio[Any]): """The path to the .npz file containing data.""" array_name: str """The name of the array in the .npz file to load.""" - array_slice: Sequence[slice] | None + array_slice: tuple[slice, ...] | None """Optional slice(s) of the array to load.""" - def __init__(self, file_path: PathLike, array_name: str, array_slice: Sequence[slice] | None = None) -> None: + def __init__(self, file_path: PathLike, array_name: str, array_slice: tuple[slice, ...] | None = None) -> None: self.file_path = file_path self.array_name = array_name self.array_slice = array_slice def evaluate(self) -> NDArray: + if Path(self.file_path).suffix != '.npz': + msg = 'Incorrect file type. Only .npz files can be loaded through NPZ ADRIOs.' + raise DataResourceException(msg) + try: - data = cast(NDArray, np.load(self.file_path)) + data = cast(NDArray, np.load(self.file_path)[self.array_name]) except OSError as e: msg = 'File not found.' raise DataResourceException(msg) from e @@ -71,15 +76,11 @@ def evaluate(self) -> NDArray: msg = 'Object arrays cannot be loaded.' raise DataResourceException(msg) from e - if self.array_slice is not None: + if self.array_slice is not None and len(self.array_slice) != 0: if len(self.array_slice) != data.ndim: msg = 'One slice is required for each array axis.' raise DataResourceException(msg) - for axis, curr_slice in enumerate(self.array_slice): - data = data.take( - indices=range(curr_slice.start, curr_slice.stop), - axis=axis - ) + data = data[self.array_slice] return data