@@ -392,6 +392,7 @@ def create_training_dataset(
392
392
primary_keys = False ,
393
393
event_time = False ,
394
394
training_helper_columns = False ,
395
+ transformation_context : Dict [str , Any ] = None ,
395
396
):
396
397
self ._set_event_time (feature_view_obj , training_dataset_obj )
397
398
updated_instance = self ._create_training_data_metadata (
@@ -405,6 +406,7 @@ def create_training_dataset(
405
406
primary_keys = primary_keys ,
406
407
event_time = event_time ,
407
408
training_helper_columns = training_helper_columns ,
409
+ transformation_context = transformation_context ,
408
410
)
409
411
return updated_instance , td_job
410
412
@@ -420,6 +422,7 @@ def get_training_data(
420
422
event_time = False ,
421
423
training_helper_columns = False ,
422
424
dataframe_type = "default" ,
425
+ transformation_context : Dict [str , Any ] = None ,
423
426
):
424
427
# check if provided td version has already existed.
425
428
if training_dataset_version :
@@ -497,6 +500,7 @@ def get_training_data(
497
500
read_options ,
498
501
dataframe_type ,
499
502
training_dataset_version ,
503
+ transformation_context = transformation_context ,
500
504
)
501
505
self .compute_training_dataset_statistics (
502
506
feature_view_obj , td_updated , split_df
@@ -581,6 +585,7 @@ def recreate_training_dataset(
581
585
statistics_config ,
582
586
user_write_options ,
583
587
spine = None ,
588
+ transformation_context : Dict [str , Any ] = None ,
584
589
):
585
590
training_dataset_obj = self ._get_training_dataset_metadata (
586
591
feature_view_obj , training_dataset_version
@@ -597,6 +602,7 @@ def recreate_training_dataset(
597
602
user_write_options ,
598
603
training_dataset_obj = training_dataset_obj ,
599
604
spine = spine ,
605
+ transformation_context = transformation_context ,
600
606
)
601
607
# Set training dataset schema after training dataset has been generated
602
608
training_dataset_obj .schema = self .get_training_dataset_schema (
@@ -757,6 +763,7 @@ def compute_training_dataset(
757
763
primary_keys = False ,
758
764
event_time = False ,
759
765
training_helper_columns = False ,
766
+ transformation_context : Dict [str , Any ] = None ,
760
767
):
761
768
if training_dataset_obj :
762
769
pass
@@ -791,6 +798,7 @@ def compute_training_dataset(
791
798
user_write_options ,
792
799
self ._OVERWRITE ,
793
800
feature_view_obj = feature_view_obj ,
801
+ transformation_context = transformation_context ,
794
802
)
795
803
796
804
# Set training dataset schema after training dataset has been generated
@@ -913,6 +921,7 @@ def get_batch_data(
913
921
inference_helper_columns = False ,
914
922
dataframe_type = "default" ,
915
923
transformed = True ,
924
+ transformation_context : Dict [str , Any ] = None ,
916
925
):
917
926
self ._check_feature_group_accessibility (feature_view_obj )
918
927
@@ -936,7 +945,9 @@ def get_batch_data(
936
945
).read (read_options = read_options , dataframe_type = dataframe_type )
937
946
if transformation_functions and transformed :
938
947
return engine .get_instance ()._apply_transformation_function (
939
- transformation_functions , dataset = feature_dataframe
948
+ transformation_functions ,
949
+ dataset = feature_dataframe ,
950
+ transformation_context = transformation_context ,
940
951
)
941
952
else :
942
953
return feature_dataframe
0 commit comments