15
15
#
16
16
from __future__ import annotations
17
17
18
+ import base64
19
+ from datetime import datetime
18
20
import itertools
19
21
import logging
20
22
from typing import Any , Dict , List , Optional , Tuple , Union
@@ -32,6 +34,9 @@ class OnlineStoreRestClientEngine:
32
34
RETURN_TYPE_FEATURE_VALUE_LIST = "feature_value_list"
33
35
RETURN_TYPE_RESPONSE_JSON = "response_json" # as a python dict
34
36
MISSING_STATUS = "MISSING"
37
+ BINARY_TYPE = "binary"
38
+ DATE_TYPE = "date"
39
+ FEATURE_TYPE_TO_DECODE = [BINARY_TYPE , DATE_TYPE ]
35
40
36
41
def __init__ (
37
42
self ,
@@ -78,10 +83,21 @@ def __init__(
78
83
self ._is_inference_helpers_list .append (True )
79
84
elif not feat .training_helper_column :
80
85
self ._is_inference_helpers_list .append (False )
86
+ self ._feature_to_decode = self .get_feature_to_decode (features )
81
87
_logger .debug (
82
88
f"Mapping fg_id to feature names: { self ._feature_names_per_fg_id } ."
83
89
)
84
90
91
+ def get_feature_to_decode (self , features : List [td_feature_mod .TrainingDatasetFeature ]) -> Dict [int , str ]:
92
+ """Get the feature to decode from the RonDB Rest Server Feature Store API.
93
+
94
+ """
95
+ feature_to_decode = {}
96
+ for feat in features :
97
+ if feat .type in self .FEATURE_TYPE_TO_DECODE :
98
+ feature_to_decode [self ._ordered_feature_names .index (feat .name )] = feat .type
99
+ return feature_to_decode
100
+
85
101
def build_base_payload (
86
102
self ,
87
103
metadata_options : Optional [Dict [str , bool ]] = None ,
@@ -127,6 +143,17 @@ def build_base_payload(
127
143
128
144
_logger .debug (f"Base payload: { base_payload } " )
129
145
return base_payload
146
+
147
+ def decode_rdrs_feature_values (self , feature_values : List [Any ]) -> List [Any ]:
148
+ """Decode the response from the RonDB Rest Server Feature Store API.
149
+
150
+ """
151
+ for feature_index , data_type in self ._feature_to_decode .items ():
152
+ if data_type == self .BINARY_TYPE and feature_values [feature_index ] is not None :
153
+ feature_values [feature_index ] = base64 .b64decode (feature_values [feature_index ])
154
+ elif data_type == self .DATE_TYPE and feature_values [feature_index ] is not None :
155
+ feature_values [feature_index ] = datetime .strptime (feature_values [feature_index ], "%Y-%m-%d" ).date ()
156
+ return feature_values
130
157
131
158
def get_single_feature_vector (
132
159
self ,
@@ -185,7 +212,6 @@ def get_single_feature_vector(
185
212
response = self ._online_store_rest_client_api .get_single_raw_feature_vector (
186
213
payload = payload
187
214
)
188
-
189
215
if return_type != self .RETURN_TYPE_RESPONSE_JSON :
190
216
return self .convert_rdrs_response_to_feature_value_row (
191
217
row_feature_values = response ["features" ],
@@ -309,6 +335,7 @@ def convert_rdrs_response_to_feature_value_row(
309
335
A dictionary with the feature names as keys and the feature values as values. Values types are not guaranteed to
310
336
match the feature type in the metadata. Timestamp SQL types are converted to python datetime.
311
337
"""
338
+ row_feature_values = self .decode_rdrs_feature_values (row_feature_values )
312
339
if drop_missing and (
313
340
detailed_status is None and row_feature_values is not None
314
341
):
0 commit comments