Skip to content

Commit 5838d82

Browse files
authored
[FSTORE-1672] Allow multiple on-demand features to be returned from an on-demand transformation function and allow passing of local variables to a transformation function (logicalclocks#452) (logicalclocks#468)
1 parent c575e11 commit 5838d82

17 files changed

+1028
-178
lines changed

python/hsfs/core/feature_group_engine.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import warnings
18-
from typing import List, Union
18+
from typing import Any, Dict, List, Union
1919

2020
from hsfs import engine, feature, util
2121
from hsfs import feature_group as fg
@@ -49,12 +49,18 @@ def _update_feature_group_schema_on_demand_transformations(
4949
transformed_features = []
5050
dropped_features = []
5151
for tf in feature_group.transformation_functions:
52-
transformed_features.append(
53-
feature.Feature(
54-
tf.hopsworks_udf.output_column_names[0],
55-
tf.hopsworks_udf.return_types[0],
56-
on_demand=True,
57-
)
52+
transformed_features.extend(
53+
[
54+
feature.Feature(
55+
output_column_name,
56+
return_type,
57+
on_demand=True,
58+
)
59+
for output_column_name, return_type in zip(
60+
tf.hopsworks_udf.output_column_names,
61+
tf.hopsworks_udf.return_types,
62+
)
63+
]
5864
)
5965
if tf.hopsworks_udf.dropped_features:
6066
dropped_features.extend(tf.hopsworks_udf.dropped_features)
@@ -141,6 +147,8 @@ def insert(
141147
storage,
142148
write_options,
143149
validation_options: dict = None,
150+
transformation_context: Dict[str, Any] = None,
151+
transform: bool = True,
144152
):
145153
dataframe_features = engine.get_instance().parse_schema_feature_group(
146154
feature_dataframe,
@@ -152,16 +160,20 @@ def insert(
152160
if (
153161
not isinstance(feature_group, fg.ExternalFeatureGroup)
154162
and feature_group.transformation_functions
163+
and transform
155164
):
156165
feature_dataframe = engine.get_instance()._apply_transformation_function(
157-
feature_group.transformation_functions, feature_dataframe
166+
feature_group.transformation_functions,
167+
feature_dataframe,
168+
transformation_context=transformation_context,
158169
)
159170

160-
dataframe_features = (
161-
self._update_feature_group_schema_on_demand_transformations(
162-
feature_group=feature_group, features=dataframe_features
171+
dataframe_features = (
172+
self._update_feature_group_schema_on_demand_transformations(
173+
feature_group=feature_group, features=dataframe_features
174+
)
163175
)
164-
)
176+
165177
util.validate_embedding_feature_type(
166178
feature_group.embedding_index, dataframe_features
167179
)
@@ -361,6 +373,8 @@ def insert_stream(
361373
timeout,
362374
checkpoint_dir,
363375
write_options,
376+
transformation_context: Dict[str, Any] = None,
377+
transform: bool = True,
364378
):
365379
if not feature_group.online_enabled and not feature_group.stream:
366380
raise exceptions.FeatureStoreException(
@@ -377,9 +391,11 @@ def insert_stream(
377391
)
378392
)
379393

380-
if feature_group.transformation_functions:
394+
if feature_group.transformation_functions and transform:
381395
dataframe = engine.get_instance()._apply_transformation_function(
382-
feature_group.transformation_functions, dataframe
396+
feature_group.transformation_functions,
397+
dataframe,
398+
transformation_context=transformation_context,
383399
)
384400

385401
util.validate_embedding_feature_type(

python/hsfs/core/feature_view_engine.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def create_training_dataset(
392392
primary_keys=False,
393393
event_time=False,
394394
training_helper_columns=False,
395+
transformation_context: Dict[str, Any] = None,
395396
):
396397
self._set_event_time(feature_view_obj, training_dataset_obj)
397398
updated_instance = self._create_training_data_metadata(
@@ -405,6 +406,7 @@ def create_training_dataset(
405406
primary_keys=primary_keys,
406407
event_time=event_time,
407408
training_helper_columns=training_helper_columns,
409+
transformation_context=transformation_context,
408410
)
409411
return updated_instance, td_job
410412

@@ -420,6 +422,7 @@ def get_training_data(
420422
event_time=False,
421423
training_helper_columns=False,
422424
dataframe_type="default",
425+
transformation_context: Dict[str, Any] = None,
423426
):
424427
# check if provided td version has already existed.
425428
if training_dataset_version:
@@ -497,6 +500,7 @@ def get_training_data(
497500
read_options,
498501
dataframe_type,
499502
training_dataset_version,
503+
transformation_context=transformation_context,
500504
)
501505
self.compute_training_dataset_statistics(
502506
feature_view_obj, td_updated, split_df
@@ -581,6 +585,7 @@ def recreate_training_dataset(
581585
statistics_config,
582586
user_write_options,
583587
spine=None,
588+
transformation_context: Dict[str, Any] = None,
584589
):
585590
training_dataset_obj = self._get_training_dataset_metadata(
586591
feature_view_obj, training_dataset_version
@@ -597,6 +602,7 @@ def recreate_training_dataset(
597602
user_write_options,
598603
training_dataset_obj=training_dataset_obj,
599604
spine=spine,
605+
transformation_context=transformation_context,
600606
)
601607
# Set training dataset schema after training dataset has been generated
602608
training_dataset_obj.schema = self.get_training_dataset_schema(
@@ -757,6 +763,7 @@ def compute_training_dataset(
757763
primary_keys=False,
758764
event_time=False,
759765
training_helper_columns=False,
766+
transformation_context: Dict[str, Any] = None,
760767
):
761768
if training_dataset_obj:
762769
pass
@@ -791,6 +798,7 @@ def compute_training_dataset(
791798
user_write_options,
792799
self._OVERWRITE,
793800
feature_view_obj=feature_view_obj,
801+
transformation_context=transformation_context,
794802
)
795803

796804
# Set training dataset schema after training dataset has been generated
@@ -913,6 +921,7 @@ def get_batch_data(
913921
inference_helper_columns=False,
914922
dataframe_type="default",
915923
transformed=True,
924+
transformation_context: Dict[str, Any] = None,
916925
):
917926
self._check_feature_group_accessibility(feature_view_obj)
918927

@@ -936,7 +945,9 @@ def get_batch_data(
936945
).read(read_options=read_options, dataframe_type=dataframe_type)
937946
if transformation_functions and transformed:
938947
return engine.get_instance()._apply_transformation_function(
939-
transformation_functions, dataset=feature_dataframe
948+
transformation_functions,
949+
dataset=feature_dataframe,
950+
transformation_context=transformation_context,
940951
)
941952
else:
942953
return feature_dataframe

0 commit comments

Comments
 (0)