Skip to content

Commit

Permalink
Rename variables for consistency.
Browse files Browse the repository at this point in the history
  • Loading branch information
TJohnsonAZ committed Aug 28, 2024
1 parent 232ccb2 commit daf9695
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
6 changes: 3 additions & 3 deletions doc/devlog/2024-08-28.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -84,8 +84,8 @@
],
"source": [
"# save and load a multidimensional array\n",
"array3 = np.array([[1, 2, 3], [4, 5, 6]])\n",
"np.savez('./scratch/npz_test2.npz', array3)\n",
"array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n",
"np.savez('./scratch/npz_test2.npz', array)\n",
"\n",
"# slice each axis of the array individually\n",
"adrio = NPZ(Path('./scratch/npz_test2.npz'), 'arr_0', [slice(0, 2), slice(1, 3)])\n",
Expand Down
38 changes: 22 additions & 16 deletions epymorph/adrio/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,29 @@ class NPY(Adrio[Any]):

file_path: PathLike
"""The path to the .npy file containing data."""
arr_slice: list[slice] | None
array_slice: list[slice] | None
"""Optional slice(s) of the array to load."""

def __init__(self, file_path: PathLike, slice: list[slice] | None = None) -> None:
def __init__(self, file_path: PathLike, array_slice: list[slice] | None = None) -> None:
self.file_path = file_path
self.slice = slice
self.array_slice = array_slice

def evaluate(self) -> NDArray:
data = np.load(self.file_path)

data = np.array(data)

if self.arr_slice is not None:
if len(self.arr_slice) != data.ndim:
if self.array_slice is not None:
if len(self.array_slice) != data.ndim:
msg = 'One slice is required for each array axis.'
raise DataResourceException(msg)
axis = 0
for x in self.arr_slice:
if x is not None and isinstance(x, slice):
data = data.take(indices=range(x.start, x.stop), axis=axis)
for curr_slice in self.array_slice:
if curr_slice is not None and isinstance(curr_slice, slice):
data = data.take(
indices=range(curr_slice.start, curr_slice.stop),
axis=axis
)
axis += 1

return data
Expand All @@ -45,26 +48,29 @@ class NPZ(Adrio[Any]):
"""The path to the .npz file containing data."""
array_name: str
"""The name of the array in the .npz file to load."""
arr_slice: list[slice] | None
array_slice: list[slice] | None
"""Optional slice(s) of the array to load."""

def __init__(self, file_path: PathLike, array_name: str, arr_slice: list[slice] | None = None) -> None:
def __init__(self, file_path: PathLike, array_name: str, array_slice: list[slice] | None = None) -> None:
self.file_path = file_path
self.array_name = array_name
self.arr_slice = arr_slice
self.array_slice = array_slice

def evaluate(self) -> NDArray:
data = np.load(self.file_path)
data = np.array(data[self.array_name])

if self.arr_slice is not None:
if len(self.arr_slice) != data.ndim:
if self.array_slice is not None:
if len(self.array_slice) != data.ndim:
msg = 'One slice is required for each array axis.'
raise DataResourceException(msg)
axis = 0
for x in self.arr_slice:
if x is not None and isinstance(x, slice):
data = data.take(indices=range(x.start, x.stop), axis=axis)
for curr_slice in self.array_slice:
if curr_slice is not None and isinstance(curr_slice, slice):
data = data.take(
indices=range(curr_slice.start, curr_slice.stop),
axis=axis
)
axis += 1

return data

0 comments on commit daf9695

Please sign in to comment.