Skip to content

Commit 5fe3692

Browse files
committed
decode rdrs features
1 parent cca6eb0 commit 5fe3692

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

python/hsfs/core/online_store_rest_client_engine.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#
1616
from __future__ import annotations
1717

18+
import base64
19+
from datetime import datetime
1820
import itertools
1921
import logging
2022
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -32,6 +34,9 @@ class OnlineStoreRestClientEngine:
3234
RETURN_TYPE_FEATURE_VALUE_LIST = "feature_value_list"
3335
RETURN_TYPE_RESPONSE_JSON = "response_json" # as a python dict
3436
MISSING_STATUS = "MISSING"
37+
BINARY_TYPE = "binary"
38+
DATE_TYPE = "date"
39+
FEATURE_TYPE_TO_DECODE = [BINARY_TYPE, DATE_TYPE]
3540

3641
def __init__(
3742
self,
@@ -78,10 +83,21 @@ def __init__(
7883
self._is_inference_helpers_list.append(True)
7984
elif not feat.training_helper_column:
8085
self._is_inference_helpers_list.append(False)
86+
self._feature_to_decode = self.get_feature_to_decode(features)
8187
_logger.debug(
8288
f"Mapping fg_id to feature names: {self._feature_names_per_fg_id}."
8389
)
8490

91+
def get_feature_to_decode(self, features: List[td_feature_mod.TrainingDatasetFeature]) -> Dict[int, str]:
92+
"""Get the feature to decode from the RonDB Rest Server Feature Store API.
93+
94+
"""
95+
feature_to_decode = {}
96+
for feat in features:
97+
if feat.type in self.FEATURE_TYPE_TO_DECODE:
98+
feature_to_decode[self._ordered_feature_names.index(feat.name)] = feat.type
99+
return feature_to_decode
100+
85101
def build_base_payload(
86102
self,
87103
metadata_options: Optional[Dict[str, bool]] = None,
@@ -127,6 +143,17 @@ def build_base_payload(
127143

128144
_logger.debug(f"Base payload: {base_payload}")
129145
return base_payload
146+
147+
def decode_rdrs_feature_values(self, feature_values: List[Any]) -> List[Any]:
148+
"""Decode the response from the RonDB Rest Server Feature Store API.
149+
150+
"""
151+
for feature_index, data_type in self._feature_to_decode.items():
152+
if data_type == self.BINARY_TYPE and feature_values[feature_index] is not None:
153+
feature_values[feature_index] = base64.b64decode(feature_values[feature_index])
154+
elif data_type == self.DATE_TYPE and feature_values[feature_index] is not None:
155+
feature_values[feature_index] = datetime.strptime(feature_values[feature_index], "%Y-%m-%d").date()
156+
return feature_values
130157

131158
def get_single_feature_vector(
132159
self,
@@ -185,7 +212,6 @@ def get_single_feature_vector(
185212
response = self._online_store_rest_client_api.get_single_raw_feature_vector(
186213
payload=payload
187214
)
188-
189215
if return_type != self.RETURN_TYPE_RESPONSE_JSON:
190216
return self.convert_rdrs_response_to_feature_value_row(
191217
row_feature_values=response["features"],
@@ -309,6 +335,7 @@ def convert_rdrs_response_to_feature_value_row(
309335
A dictionary with the feature names as keys and the feature values as values. Values types are not guaranteed to
310336
match the feature type in the metadata. Timestamp SQL types are converted to python datetime.
311337
"""
338+
row_feature_values = self.decode_rdrs_feature_values(row_feature_values)
312339
if drop_missing and (
313340
detailed_status is None and row_feature_values is not None
314341
):

0 commit comments

Comments
 (0)