Skip to content

Commit c078bcc

Browse files
committed
update to handle explicit features
1 parent 90b12a0 commit c078bcc

File tree

2 files changed

+64
-35
lines changed

2 files changed

+64
-35
lines changed

python/hsfs/core/schema_validation.py

+27-25
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def validate_schema(self, feature_group, df, df_features):
5454
)
5555
# Execute data type specific validation
5656
errors, column_lengths, is_pk_null, is_string_length_exceeded = (
57-
self._validate_df_specifics(feature_group, df, bool(feature_group.id))
57+
self._validate_df_specifics(feature_group, df)
5858
)
5959

6060
# Handle errors
@@ -68,7 +68,7 @@ def validate_schema(self, feature_group, df, df_features):
6868

6969
return df_features
7070

71-
def _validate_df_specifics(self, feature_group, df, is_fg_created):
71+
def _validate_df_specifics(self, feature_group, df):
7272
"""To be implemented by subclasses"""
7373
raise NotImplementedError("Subclasses must implement this method")
7474

@@ -77,7 +77,8 @@ def get_feature_from_list(feature_name, features):
7777
for i_feature in features:
7878
if i_feature.name == feature_name:
7979
return i_feature
80-
raise ValueError(f"Feature {feature_name} not found in feature list")
80+
81+
return None
8182

8283
@staticmethod
8384
def extract_numbers(input_string):
@@ -87,13 +88,14 @@ def extract_numbers(input_string):
8788
return re.findall(pattern, input_string)
8889

8990
def get_online_varchar_length(self, feature):
90-
# returns the column length of varchar columns
91-
if not feature.type == "string":
92-
raise ValueError("Feature not a string type")
93-
if not feature.online_type:
94-
raise ValueError("Feature is not online enabled")
95-
96-
return int(self.extract_numbers(feature.online_type)[0])
91+
# check of online_type is not null and starts with varchar
92+
if (
93+
feature
94+
and feature.online_type
95+
and feature.online_type.startswith("varchar")
96+
):
97+
return int(self.extract_numbers(feature.online_type)[0])
98+
return None
9799

98100
@staticmethod
99101
def increase_string_columns(column_lengths: dict, dataframe_features):
@@ -109,7 +111,7 @@ def increase_string_columns(column_lengths: dict, dataframe_features):
109111

110112
class PandasValidator(DataFrameValidator):
111113
# Pandas df specific validator
112-
def _validate_df_specifics(self, feature_group, df, is_fg_created):
114+
def _validate_df_specifics(self, feature_group, df):
113115
errors = {}
114116
column_lengths = {}
115117
is_pk_null = False
@@ -118,7 +120,7 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
118120
# Check for null values in primary key columns
119121
for pk in feature_group.primary_key:
120122
if df[pk].isnull().any():
121-
errors[pk] = f"Primary key column {pk} contains null values"
123+
errors[pk] = f"Primary key column {pk} contains null values."
122124
is_pk_null = True
123125

124126
# Check string lengths
@@ -128,13 +130,13 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
128130
self.get_online_varchar_length(
129131
self.get_feature_from_list(col, feature_group.features)
130132
)
131-
if is_fg_created
133+
if feature_group.features
132134
else 100
133135
)
134136

135-
if currentmax > col_max_len:
137+
if col_max_len is not None and currentmax > col_max_len:
136138
errors[col] = (
137-
f"Column {col} has string values longer than {col_max_len} characters"
139+
f"String length exceeded. Column {col} has string values longer than maximum colum limit of {col_max_len} characters."
138140
)
139141
column_lengths[col] = currentmax
140142
is_string_length_exceeded = True
@@ -144,7 +146,7 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
144146

145147
class PolarsValidator(DataFrameValidator):
146148
# Polars df specific validator
147-
def _validate_df_specifics(self, feature_group, df, is_fg_created):
149+
def _validate_df_specifics(self, feature_group, df):
148150
import polars as pl
149151

150152
errors = {}
@@ -155,7 +157,7 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
155157
# Check for null values in primary key columns
156158
for pk in feature_group.primary_key:
157159
if df[pk].is_null().any():
158-
errors[pk] = f"Primary key column {pk} contains null values"
160+
errors[pk] = f"Primary key column {pk} contains null values."
159161
is_pk_null = True
160162

161163
# Check string lengths
@@ -165,13 +167,13 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
165167
self.get_online_varchar_length(
166168
self.get_feature_from_list(col, feature_group.features)
167169
)
168-
if is_fg_created
170+
if feature_group.features
169171
else 100
170172
)
171173

172-
if currentmax > col_max_len:
174+
if col_max_len is not None and currentmax > col_max_len:
173175
errors[col] = (
174-
f"Column {col} has string values longer than {col_max_len} characters"
176+
f"String length exceeded. Column {col} has string values longer than maximum colum limit of {col_max_len} characters."
175177
)
176178
column_lengths[col] = currentmax
177179
is_string_length_exceeded = True
@@ -181,7 +183,7 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
181183

182184
class PySparkValidator(DataFrameValidator):
183185
# PySpark-specific validator
184-
def _validate_df_specifics(self, feature_group, df, is_fg_created):
186+
def _validate_df_specifics(self, feature_group, df):
185187
# Import PySpark SQL functions and types
186188
import pyspark.sql.functions as sf
187189
from pyspark.sql.types import StringType
@@ -194,7 +196,7 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
194196
# Check for null values in primary key columns
195197
for pk in feature_group.primary_key:
196198
if df.filter(df[pk].isNull()).count() > 0:
197-
errors[pk] = f"Primary key column {pk} contains null values"
199+
errors[pk] = f"Primary key column {pk} contains null values."
198200
is_pk_null = True
199201

200202
# Check string lengths for string columns
@@ -209,13 +211,13 @@ def _validate_df_specifics(self, feature_group, df, is_fg_created):
209211
self.get_online_varchar_length(
210212
self.get_feature_from_list(col, feature_group.features)
211213
)
212-
if is_fg_created
214+
if feature_group.features
213215
else 100
214216
)
215217

216-
if currentmax > col_max_len:
218+
if col_max_len is not None and currentmax > col_max_len:
217219
errors[col] = (
218-
f"Column {col} has string values longer than {col_max_len} characters"
220+
f"String length exceeded. Column {col} has string values longer than maximum colum limit of {col_max_len} characters."
219221
)
220222
column_lengths[col] = currentmax
221223
is_string_length_exceeded = True

python/tests/test_schema_validator.py

+37-10
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def pandas_df():
2828
for _ in range(random.randint(1, 100))
2929
)
3030
)
31-
for i in range(3)
31+
for _ in range(3)
3232
],
3333
}
3434
)
@@ -47,7 +47,7 @@ def polars_df():
4747
for _ in range(random.randint(1, 100))
4848
)
4949
)
50-
for i in range(3)
50+
for _ in range(3)
5151
],
5252
}
5353
)
@@ -138,9 +138,7 @@ def test_validate_schema_string_length_exceeded(
138138
self, pandas_df, feature_group_created, mocker
139139
):
140140
pandas_df.loc[0, "string_col"] = "a" * 101
141-
with pytest.raises(
142-
ValueError, match="Column string_col has string values longer than 100"
143-
):
141+
with pytest.raises(ValueError, match="String length exceeded"):
144142
DataFrameValidator().validate_schema(
145143
feature_group_created, pandas_df, feature_group_created.features
146144
)
@@ -149,22 +147,38 @@ def test_validate_schema_feature_group_created(
149147
self, pandas_df, feature_group_created, mocker
150148
):
151149
pandas_df.loc[0, "string_col"] = "a" * 101
152-
with pytest.raises(
153-
ValueError, match="Column string_col has string values longer"
154-
):
150+
with pytest.raises(ValueError, match="String length exceeded"):
155151
DataFrameValidator().validate_schema(
156152
feature_group_created, pandas_df, feature_group_created.features
157153
)
158154

159155
def test_validate_schema_feature_group_not_created(
160156
self, pandas_df, feature_group_data
161157
):
158+
# test with non existing feature group with no explicit features
159+
# arrange
162160
pandas_df.loc[0, "string_col"] = "a" * 101
161+
initial_features = [
162+
Feature("primary_key", "int"),
163+
Feature("event_time", "string"),
164+
Feature("string_col", "string"),
165+
]
166+
feature_group_data.features = []
163167
df_features = DataFrameValidator().validate_schema(
164-
feature_group_data, pandas_df, feature_group_data.features
168+
feature_group_data, pandas_df, initial_features
165169
)
166170
assert df_features[2].online_type == "varchar(101)"
167171

172+
def test_validate_schema_feature_group_with_features_not_created(
173+
self, pandas_df, feature_group_data
174+
):
175+
# test with feature group with explicit features
176+
df_features = DataFrameValidator().validate_schema(
177+
feature_group_data, pandas_df, feature_group_data.features
178+
)
179+
# assert that the online type of the string_col feature is same as explcitly set in the feature group
180+
assert df_features[2].online_type == "varchar(200)"
181+
168182
def test_pk_null_string_length_exceeded(self, pandas_df, feature_group_data):
169183
pandas_df.loc[0, "primary_key"] = None
170184
pandas_df.loc[0, "string_col"] = "a" * 101
@@ -180,5 +194,18 @@ def test_offline_fg(self, pandas_df, feature_group_data, caplog):
180194
df_features = PandasValidator().validate_schema(
181195
feature_group_data, pandas_df, feature_group_data.features
182196
)
197+
# assert no changes were made
198+
assert df_features == feature_group_data.features
199+
200+
def test_should_not_update_nonvarchar(self, pandas_df, feature_group_data):
201+
# Test that the validator does not update the online type of a non-varchar column
202+
# arrange
203+
# set string_col feature online type to text
204+
feature_group_data.features[2].online_type = "text"
205+
pandas_df.loc[0, "string_col"] = "b" * 1001
206+
# act
207+
df_features = PandasValidator().validate_schema(
208+
feature_group_data, pandas_df, feature_group_data.features
209+
)
210+
183211
assert df_features == feature_group_data.features
184-
assert "Feature group is not online enabled. Skipping validation" in caplog.text

0 commit comments

Comments
 (0)