Skip to content
This repository was archived by the owner on Apr 19, 2023. It is now read-only.

Commit 175af35

Browse files
Add doc strings and tests for volume loader
1 parent 0cbaf1d commit 175af35

File tree

2 files changed

+109
-0
lines changed

2 files changed

+109
-0
lines changed

inferno/io/volumetric/volume.py

+52
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,32 @@
1111

1212

1313
class VolumeLoader(SyncableDataset):
14+
""" Loader for in-memory volumetric data.
15+
16+
Parameters
17+
----------
18+
volume: np.ndarray
19+
the volumetric data
20+
window_size: list or tuple
21+
size of the (3d) sliding window used for iteration
22+
stride: list or tuple
23+
stride of the (3d) sliding window used for iteration
24+
downsampling_ratio: list or tuple (default: None)
25+
factor by which the data is downsampled (no downsapling by default)
26+
padding: list (default: None)
27+
padding for data, follows np.pad syntax
28+
padding_mode: str (default: 'reflect')
29+
padding mode as in np.pad
30+
transforms: callable (default: None)
31+
transforms applied on each batch loaded from volume
32+
return_index_spec: bool (default: False)
33+
whether to return the index spec for each batch
34+
name: str (default: None)
35+
name of this volume
36+
is_multichannel: bool (default: False)
37+
is this a multichannel volume? sliding window is NOT applied to channel dimension
38+
"""
39+
1440
def __init__(self, volume, window_size, stride, downsampling_ratio=None, padding=None,
1541
padding_mode='reflect', transforms=None, return_index_spec=False, name=None,
1642
is_multichannel=False):
@@ -125,6 +151,32 @@ def __repr__(self):
125151

126152

127153
class HDF5VolumeLoader(VolumeLoader):
154+
""" Loader for volumes stored in hdf5, zarr or n5.
155+
156+
Zarr and n5 are file formats very similar to hdf5, but use
157+
the regular filesystem to store data instead of a filesystem
158+
in a file as hdf5.
159+
The file type will be infered from the extension:
160+
.hdf5, .h5 and .hdf map to hdf5
161+
.n5 maps to n5
162+
.zr and .zarr map to zarr
163+
It will fail for other extensions.
164+
165+
Parameters
166+
----------
167+
path: str
168+
path to file
169+
path_in_h5_dataset: str (default: None)
170+
path in file
171+
data_slice: slice (default: None)
172+
slice loaded from dataset
173+
transforms: callable (default: None)
174+
transforms applied on each batch loaded from volume
175+
name: str (default: None)
176+
name of this volume
177+
slicing_config: kwargs
178+
keyword arguments for base class `VolumeLoader`
179+
"""
128180

129181
@staticmethod
130182
def is_h5(file_path):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import unittest
2+
import os
3+
from shutil import rmtree
4+
5+
import numpy as np
6+
import h5py
7+
8+
9+
class TestVolumeLoader(unittest.TestCase):
10+
shape = (100, 100, 100)
11+
def setUp(self):
12+
self.data = np.random.rand(*self.shape)
13+
14+
def test_loader(self):
15+
from inferno.io.volumetric import VolumeLoader
16+
loader = VolumeLoader(self.data,
17+
window_size=(10, 10, 10),
18+
stride=(10, 10, 10), return_index_spec=True)
19+
for batch, idx in loader:
20+
slice_ = loader.base_sequence[int(idx)]
21+
expected = self.data[slice_]
22+
self.assertEqual(batch.shape, expected.shape)
23+
self.assertTrue(np.allclose(batch, expected))
24+
25+
26+
class TestHDF5VolumeLoader(unittest.TestCase):
27+
shape = (100, 100, 100)
28+
def setUp(self):
29+
try:
30+
os.mkdir('./tmp')
31+
except OSError:
32+
pass
33+
self.data = np.random.rand(*self.shape)
34+
with h5py.File('./tmp/data.h5') as f:
35+
f.create_dataset('data', data=self.data)
36+
37+
def tearDown(self):
38+
try:
39+
rmtree('./tmp')
40+
except OSError:
41+
pass
42+
43+
def test_hdf5_loader(self):
44+
from inferno.io.volumetric import HDF5VolumeLoader
45+
loader = HDF5VolumeLoader('./tmp/data.h5', 'data',
46+
window_size=(10, 10, 10),
47+
stride=(10, 10, 10), return_index_spec=True)
48+
for batch, idx in loader:
49+
slice_ = loader.base_sequence[int(idx)]
50+
expected = self.data[slice_]
51+
self.assertEqual(batch.shape, expected.shape)
52+
self.assertTrue(np.allclose(batch, expected))
53+
54+
55+
56+
if __name__ == '__main__':
57+
unittest.main()

0 commit comments

Comments
 (0)