Skip to content

Commit 391b2da

Browse files
committed
Modify process_local_sources to support mount paths
Signed-off-by: Dar, Efrat <efrat.dar@intel.com>
1 parent 333049f commit 391b2da

File tree

3 files changed

+74
-54
lines changed

3 files changed

+74
-54
lines changed

openfl-workspace/torch/histology_s3/src/dataloader.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,32 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""You may copy this file as the starting point of your own model."""
5-
6-
from collections.abc import Iterable
7-
from logging import getLogger
85
import os
96
import sys
7+
from collections.abc import Iterable
8+
from logging import getLogger
109

11-
12-
from openfl.federated import PyTorchDataLoader
1310
import numpy as np
14-
from openfl.federated.data.sources.torch.verifiable_map_style_image_folder import VerifiableImageFolder
15-
from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser
16-
from openfl.utilities.path_check import is_directory_traversal
1711
import torch
1812
from torch.utils.data import random_split
1913
from torchvision.transforms import ToTensor
2014

15+
from openfl.federated import PyTorchDataLoader
16+
from openfl.federated.data.sources.data_sources_json_parser import DataSourcesJsonParser
17+
from openfl.federated.data.sources.torch.verifiable_map_style_image_folder import VerifiableImageFolder
18+
from openfl.utilities.path_check import is_directory_traversal
2119

2220
logger = getLogger(__name__)
2321

2422

2523
class PyTorchHistologyVerifiableDataLoader(PyTorchDataLoader):
2624
"""PyTorch data loader for Histology dataset."""
2725

28-
def __init__(self, data_path, batch_size, **kwargs):
26+
def __init__(self, data_path=None, batch_size=32, **kwargs):
2927
"""Instantiate the data object.
3028
3129
Args:
32-
data_path: The file path to the data
30+
data_path: The file path to the data. If None, initialize for model creation only.
3331
batch_size: The batch size of the data loader
3432
**kwargs: Additional arguments, passed to super init
3533
and load_mnist_shard
@@ -61,17 +59,19 @@ def __init__(self, data_path, batch_size, **kwargs):
6159
else:
6260
logger.info("The dataset is valid.")
6361

64-
_, num_classes, X_train, y_train, X_valid, y_valid = load_histology_shard(
65-
verifible_dataset_info=verifible_dataset_info, verify_dataset_items=verify_dataset_items, **kwargs)
62+
X_train, y_train, X_valid, y_valid = load_histology_shard(
63+
verifible_dataset_info=verifible_dataset_info,
64+
verify_dataset_items=verify_dataset_items,
65+
feature_shape=self.feature_shape,
66+
num_classes=self.num_classes,
67+
**kwargs
68+
)
6669

6770
self.X_train = X_train
6871
self.y_train = y_train
6972
self.X_valid = X_valid
7073
self.y_valid = y_valid
7174

72-
self.num_classes = num_classes
73-
74-
7575
def get_feature_shape(self):
7676
"""Returns the shape of an example feature array.
7777
@@ -101,7 +101,6 @@ def get_verifiable_dataset_info(self, data_path):
101101
Raises:
102102
SystemExit: If `data_path` is invalid or missing `datasources.json`.
103103
"""
104-
"""Return the verifiable dataset info object for the given data sources."""
105104
if data_path and is_directory_traversal(data_path):
106105
logger.error("Data path is out of the openfl workspace scope.")
107106
if not os.path.isdir(data_path):
@@ -152,7 +151,8 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
152151
n_train = int(train_split_ratio * len(dataset))
153152
n_valid = len(dataset) - n_train
154153
ds_train, ds_val = random_split(
155-
dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0))
154+
dataset, lengths=[n_train, n_valid], generator=torch.manual_seed(0)
155+
)
156156

157157
# create the shards
158158
X_train, y_train = list(zip(*ds_train))
@@ -164,41 +164,40 @@ def _load_raw_data(verifiable_dataset_info, verify_dataset_items=False, train_sp
164164
return (X_train, y_train), (X_valid, y_valid)
165165

166166

167-
168-
def load_histology_shard(verifible_dataset_info, verify_dataset_items,
167+
def load_histology_shard(verifible_dataset_info, verify_dataset_items, feature_shape=None, num_classes=None,
169168
categorical=False, channels_last=False, **kwargs):
170169
"""
171170
Load the Histology dataset.
172171
173172
Args:
174-
data_path (str): path to data directory
173+
verifible_dataset_info (VerifiableDatasetInfo): The verifiable dataset info object.
174+
verify_dataset_items (bool): True = verify the dataset items while loading data
175+
feature_shape (list, optional): The shape of input features.
176+
num_classes (int, optional): Number of classes.
175177
categorical (bool): True = convert the labels to one-hot encoded
176178
vectors (Default = True)
177179
channels_last (bool): True = The input images have the channels
178180
last (Default = True)
179181
**kwargs: Additional parameters to pass to the function
180182
181183
Returns:
182-
list: The input shape
183-
int: The number of classes
184184
numpy.ndarray: The training data
185185
numpy.ndarray: The training labels
186186
numpy.ndarray: The validation data
187187
numpy.ndarray: The validation labels
188188
"""
189-
img_rows, img_cols = 150, 150
190-
num_classes = 8
189+
img_rows, img_cols = feature_shape[1], feature_shape[2]
191190

192-
(X_train, y_train), (X_valid, y_valid) = _load_raw_data(verifible_dataset_info, verify_dataset_items, **kwargs)
191+
(X_train, y_train), (X_valid, y_valid) = _load_raw_data(
192+
verifible_dataset_info, verify_dataset_items, **kwargs
193+
)
193194

194195
if channels_last:
195196
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 3)
196197
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 3)
197-
input_shape = (img_rows, img_cols, 3)
198198
else:
199199
X_train = X_train.reshape(X_train.shape[0], 3, img_rows, img_cols)
200200
X_valid = X_valid.reshape(X_valid.shape[0], 3, img_rows, img_cols)
201-
input_shape = (3, img_rows, img_cols)
202201

203202
logger.info(f'Histology > X_train Shape : {X_train.shape}')
204203
logger.info(f'Histology > y_train Shape : {y_train.shape}')
@@ -210,4 +209,4 @@ def load_histology_shard(verifible_dataset_info, verify_dataset_items,
210209
y_train = np.eye(num_classes)[y_train]
211210
y_valid = np.eye(num_classes)[y_valid]
212211

213-
return input_shape, num_classes, X_train, y_train, X_valid, y_valid
212+
return X_train, y_train, X_valid, y_valid

openfl/federated/data/sources/data_sources_json_parser.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
class DataSourcesJsonParser:
1717
@staticmethod
18-
def parse(json_string: str) -> VerifiableDatasetInfo:
18+
def parse(
19+
json_string: str, label="", metadata="", check_dir_traversal=False
20+
) -> VerifiableDatasetInfo:
1921
"""
2022
Parse a JSON string into a dictionary.
2123
@@ -31,48 +33,65 @@ def parse(json_string: str) -> VerifiableDatasetInfo:
3133
except json.JSONDecodeError as e:
3234
raise ValueError(f"Invalid JSON format: {e}")
3335

34-
datasources = DataSourcesJsonParser.process_data_sources(data)
36+
datasources = DataSourcesJsonParser.process_data_sources(data, check_dir_traversal)
3537
if not datasources:
3638
raise ValueError("No data sources were found.")
3739
return VerifiableDatasetInfo(
3840
data_sources=datasources,
39-
label="",
41+
label=label,
42+
metadata=metadata,
4043
)
4144

4245
@staticmethod
43-
def process_data_sources(data):
46+
def process_data_sources(data, check_dir_traversal=False):
4447
"""Process and validate data sources."""
45-
cwd = os.getcwd()
48+
os.getcwd()
4649
datasources = []
50+
local_datasources = {}
4751
for source_name, source_info in data.items():
4852
source_type = source_info.get("type", None)
4953
if source_type is None:
5054
raise ValueError(f"Missing 'type' key in data source configuration: {source_info}")
5155
params = source_info.get("params", {})
52-
if source_type == "local":
53-
datasources.append(
54-
DataSourcesJsonParser.process_local_source(source_name, params, cwd)
55-
)
56+
if source_type == "fs":
57+
local_datasources[source_name] = params
5658
elif source_type == "s3":
5759
datasources.append(DataSourcesJsonParser.process_s3_source(source_name, params))
58-
elif source_type == "azure_blob":
60+
elif source_type == "ab":
5961
datasources.append(
6062
DataSourcesJsonParser.process_azure_blob_source(source_name, params)
6163
)
64+
if local_datasources:
65+
DataSourcesJsonParser.process_local_sources(
66+
local_datasources, datasources, check_dir_traversal
67+
)
6268
return [ds for ds in datasources if ds]
6369

64-
@staticmethod
65-
def process_local_source(source_name, params, cwd):
66-
"""Process a local data source."""
67-
path = params.get("path", None)
68-
if not path:
69-
raise ValueError(f"Missing 'path' parameter for local data source '{source_name}'")
70-
abs_path = os.path.abspath(path)
71-
rel_path = os.path.relpath(abs_path, cwd)
72-
if rel_path and not is_directory_traversal(rel_path):
73-
return LocalDataSource(source_name, rel_path, base_path=Path("."))
74-
else:
75-
raise ValueError(f"Invalid path for local data source '{source_name}': {path}.")
70+
def process_local_sources(local_datasources, datasources, check_dir_traversal=False):
71+
"""Process and validate local data sources."""
72+
# The reason we use common base_dir and source_path relative to that base
73+
# is to simplify path management in containerized environments, such as Docker.
74+
# By using a common base_dir, we can ensure that paths remain consistent
75+
# when mounting volumes, as only the base_dir needs to be adjusted to point
76+
# to the mount path inside the container.
77+
# This way, we only need to adjust the base_dir to point to the mount path.
78+
absolute_paths = {
79+
source_name: os.path.realpath(params.get("path", None))
80+
for source_name, params in local_datasources.items()
81+
}
82+
base_dir = os.path.commonpath(absolute_paths.values())
83+
for source_name, data_path in absolute_paths.items():
84+
relative_path = os.path.relpath(data_path, base_dir)
85+
if check_dir_traversal and is_directory_traversal(data_path):
86+
raise ValueError(
87+
f"Invalid path for local data source '{source_name}': {data_path}."
88+
f" Data path is out of the openfl workspace scope."
89+
)
90+
datasources.append(
91+
LocalDataSource(
92+
name=source_name, source_path=Path(relative_path), base_path=base_dir
93+
)
94+
)
7695

7796
@staticmethod
7897
def process_s3_source(source_name, params):

openfl/interface/collaborator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,11 @@ def register_data_path(collaborator_name, data_path=None, silent=False):
231231
type=ClickPath(exists=True),
232232
help=(
233233
"Path to directory containing sources.json file defining the data sources of the dataset. "
234-
"This file should contain a JSON object with the data sources to be registered. For 'local'"
235-
" type, 'params' must include: 'path'. For 's3' type, 'params' must include: 'uri', "
236-
"'access_key_env_name', 'secret_key_env_name', 'secret_name', and optionally 'endpoint'."
234+
"This file should contain a JSON object with the data sources to be registered. For local "
235+
"data source, 'type' is 'fs', and 'params' must include: 'path'. For 's3' type, 'params' "
236+
"must include: 'uri', 'access_key_env_name', 'secret_key_env_name', 'secret_name', and "
237+
"optionally 'endpoint'. For azure_blob, 'type' is 'ab', and 'params' must include: "
238+
"'connection_string', 'container_name', and optionally 'folder_prefix'."
237239
),
238240
)
239241
def calchash(data_path):
@@ -258,7 +260,7 @@ def calchash(data_path):
258260
sys.exit(1)
259261
with open(datasources_json_path, "r", encoding="utf-8") as file:
260262
data = file.read()
261-
vds = DataSourcesJsonParser.parse(data)
263+
vds = DataSourcesJsonParser.parse(data, check_dir_traversal=True)
262264
root_hash = vds.create_dataset_hash()
263265
hash_file_path = os.path.join(data_path, "hash.txt")
264266
with open(hash_file_path, "w", encoding="utf-8") as hash_file:

0 commit comments

Comments
 (0)