2
2
# SPDX-License-Identifier: Apache-2.0
3
3
4
4
"""You may copy this file as the starting point of your own model."""
5
-
6
- from collections .abc import Iterable
7
- from logging import getLogger
8
5
import os
9
6
import sys
7
+ from collections .abc import Iterable
8
+ from logging import getLogger
10
9
11
-
12
- from openfl .federated import PyTorchDataLoader
13
10
import numpy as np
14
- from openfl .federated .data .sources .torch .verifiable_map_style_image_folder import VerifiableImageFolder
15
- from openfl .federated .data .sources .data_sources_json_parser import DataSourcesJsonParser
16
- from openfl .utilities .path_check import is_directory_traversal
17
11
import torch
18
12
from torch .utils .data import random_split
19
13
from torchvision .transforms import ToTensor
20
14
15
+ from openfl .federated import PyTorchDataLoader
16
+ from openfl .federated .data .sources .data_sources_json_parser import DataSourcesJsonParser
17
+ from openfl .federated .data .sources .torch .verifiable_map_style_image_folder import VerifiableImageFolder
18
+ from openfl .utilities .path_check import is_directory_traversal
21
19
22
20
logger = getLogger (__name__ )
23
21
24
22
25
23
class PyTorchHistologyVerifiableDataLoader (PyTorchDataLoader ):
26
24
"""PyTorch data loader for Histology dataset."""
27
25
28
- def __init__ (self , data_path , batch_size , ** kwargs ):
26
+ def __init__ (self , data_path = None , batch_size = 32 , ** kwargs ):
29
27
"""Instantiate the data object.
30
28
31
29
Args:
32
- data_path: The file path to the data
30
+ data_path: The file path to the data. If None, initialize for model creation only.
33
31
batch_size: The batch size of the data loader
34
32
**kwargs: Additional arguments, passed to super init
35
33
and load_mnist_shard
@@ -61,17 +59,19 @@ def __init__(self, data_path, batch_size, **kwargs):
61
59
else :
62
60
logger .info ("The dataset is valid." )
63
61
64
- _ , num_classes , X_train , y_train , X_valid , y_valid = load_histology_shard (
65
- verifible_dataset_info = verifible_dataset_info , verify_dataset_items = verify_dataset_items , ** kwargs )
62
+ X_train , y_train , X_valid , y_valid = load_histology_shard (
63
+ verifible_dataset_info = verifible_dataset_info ,
64
+ verify_dataset_items = verify_dataset_items ,
65
+ feature_shape = self .feature_shape ,
66
+ num_classes = self .num_classes ,
67
+ ** kwargs
68
+ )
66
69
67
70
self .X_train = X_train
68
71
self .y_train = y_train
69
72
self .X_valid = X_valid
70
73
self .y_valid = y_valid
71
74
72
- self .num_classes = num_classes
73
-
74
-
75
75
def get_feature_shape (self ):
76
76
"""Returns the shape of an example feature array.
77
77
@@ -101,7 +101,6 @@ def get_verifiable_dataset_info(self, data_path):
101
101
Raises:
102
102
SystemExit: If `data_path` is invalid or missing `datasources.json`.
103
103
"""
104
- """Return the verifiable dataset info object for the given data sources."""
105
104
if data_path and is_directory_traversal (data_path ):
106
105
logger .error ("Data path is out of the openfl workspace scope." )
107
106
if not os .path .isdir (data_path ):
@@ -152,7 +151,8 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
152
151
n_train = int (train_split_ratio * len (dataset ))
153
152
n_valid = len (dataset ) - n_train
154
153
ds_train , ds_val = random_split (
155
- dataset , lengths = [n_train , n_valid ], generator = torch .manual_seed (0 ))
154
+ dataset , lengths = [n_train , n_valid ], generator = torch .manual_seed (0 )
155
+ )
156
156
157
157
# create the shards
158
158
X_train , y_train = list (zip (* ds_train ))
@@ -164,41 +164,40 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
164
164
return (X_train , y_train ), (X_valid , y_valid )
165
165
166
166
167
-
168
- def load_histology_shard (verifible_dataset_info , verify_dataset_items ,
167
+ def load_histology_shard (verifible_dataset_info , verify_dataset_items , feature_shape = None , num_classes = None ,
169
168
categorical = False , channels_last = False , ** kwargs ):
170
169
"""
171
170
Load the Histology dataset.
172
171
173
172
Args:
174
- data_path (str): path to data directory
173
+ verifible_dataset_info (VerifiableDatasetInfo): The verifiable dataset info object.
174
+ verify_dataset_items (bool): True = verify the dataset items while loading data
175
+ feature_shape (list, optional): The shape of input features.
176
+ num_classes (int, optional): Number of classes.
175
177
categorical (bool): True = convert the labels to one-hot encoded
176
178
vectors (Default = True)
177
179
channels_last (bool): True = The input images have the channels
178
180
last (Default = True)
179
181
**kwargs: Additional parameters to pass to the function
180
182
181
183
Returns:
182
- list: The input shape
183
- int: The number of classes
184
184
numpy.ndarray: The training data
185
185
numpy.ndarray: The training labels
186
186
numpy.ndarray: The validation data
187
187
numpy.ndarray: The validation labels
188
188
"""
189
- img_rows , img_cols = 150 , 150
190
- num_classes = 8
189
+ img_rows , img_cols = feature_shape [1 ], feature_shape [2 ]
191
190
192
- (X_train , y_train ), (X_valid , y_valid ) = _load_raw_data (verifible_dataset_info , verify_dataset_items , ** kwargs )
191
+ (X_train , y_train ), (X_valid , y_valid ) = _load_raw_data (
192
+ verifible_dataset_info , verify_dataset_items , ** kwargs
193
+ )
193
194
194
195
if channels_last :
195
196
X_train = X_train .reshape (X_train .shape [0 ], img_rows , img_cols , 3 )
196
197
X_valid = X_valid .reshape (X_valid .shape [0 ], img_rows , img_cols , 3 )
197
- input_shape = (img_rows , img_cols , 3 )
198
198
else :
199
199
X_train = X_train .reshape (X_train .shape [0 ], 3 , img_rows , img_cols )
200
200
X_valid = X_valid .reshape (X_valid .shape [0 ], 3 , img_rows , img_cols )
201
- input_shape = (3 , img_rows , img_cols )
202
201
203
202
logger .info (f'Histology > X_train Shape : { X_train .shape } ' )
204
203
logger .info (f'Histology > y_train Shape : { y_train .shape } ' )
@@ -210,4 +209,4 @@ def load_histology_shard(verifible_dataset_info, verify_dataset_items,
210
209
y_train = np .eye (num_classes )[y_train ]
211
210
y_valid = np .eye (num_classes )[y_valid ]
212
211
213
- return input_shape , num_classes , X_train , y_train , X_valid , y_valid
212
+ return X_train , y_train , X_valid , y_valid
0 commit comments