Skip to content

Commit 4b64f18

Browse files
committed
update_table_schema also for delta table
1 parent 34c6963 commit 4b64f18

8 files changed

+95
-39
lines changed

python/hsfs/core/feature_group_engine.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,7 @@ def append_features(self, feature_group, new_features):
296296
)
297297

298298
# write empty dataframe to update parquet schema
299-
if feature_group.time_travel_format == "DELTA":
300-
engine.get_instance().add_cols_to_delta_table(feature_group)
301-
else:
302-
engine.get_instance().save_empty_dataframe(feature_group)
299+
engine.get_instance().update_table_schema(feature_group)
303300

304301
def update_description(self, feature_group, description):
305302
"""Updates the description of a feature group."""

python/hsfs/core/hudi_engine.py

-19
Original file line numberDiff line numberDiff line change
@@ -234,25 +234,6 @@ def _setup_hudi_read_opts(self, hudi_fg_alias, read_options):
234234

235235
return hudi_options
236236

237-
def reconcile_hudi_schema(
238-
self, save_empty_dataframe_callback, hudi_fg_alias, read_options
239-
):
240-
if sorted(self._spark_session.table(hudi_fg_alias.alias).columns) != sorted(
241-
[feature.name for feature in hudi_fg_alias.feature_group._features] + self.HUDI_SPEC_FEATURE_NAMES
242-
):
243-
full_fg = self._feature_group_api.get(
244-
feature_store_id=hudi_fg_alias.feature_group._feature_store_id,
245-
name=hudi_fg_alias.feature_group.name,
246-
version=hudi_fg_alias.feature_group.version,
247-
)
248-
249-
save_empty_dataframe_callback(full_fg)
250-
251-
self.register_temporary_table(
252-
hudi_fg_alias,
253-
read_options,
254-
)
255-
256237
@staticmethod
257238
def _get_last_commit_metadata(spark_context, base_path):
258239
hopsfs_conf = spark_context._jvm.org.apache.hadoop.fs.FileSystem.get(

python/hsfs/engine/python.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1212,12 +1212,8 @@ def save_stream_dataframe(
12121212
"Stream ingestion is not available on Python environments, because it requires Spark as engine."
12131213
)
12141214

1215-
def save_empty_dataframe(self, feature_group: Union[FeatureGroup, ExternalFeatureGroup]) -> None:
1216-
"""Wrapper around save_dataframe in order to provide no-op."""
1217-
pass
1218-
1219-
def add_cols_to_delta_table(self, feature_group: FeatureGroup) -> None:
1220-
"""Wrapper around add_cols_to_delta_table in order to provide no-op."""
1215+
def update_table_schema(self, feature_group: Union[FeatureGroup, ExternalFeatureGroup]) -> None:
1216+
"""Wrapper around update_table_schema in order to provide no-op."""
12211217
pass
12221218

12231219
def _get_app_options(

python/hsfs/engine/spark.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ def register_hudi_temporary_table(
221221
read_options,
222222
)
223223

224-
hudi_engine_instance.reconcile_hudi_schema(
225-
self.save_empty_dataframe, hudi_fg_alias, read_options
224+
self.reconcile_schema(
225+
hudi_fg_alias, read_options, hudi_engine_instance
226226
)
227227

228228
def register_delta_temporary_table(
@@ -241,6 +241,30 @@ def register_delta_temporary_table(
241241
read_options,
242242
)
243243

244+
self.reconcile_schema(
245+
delta_fg_alias, read_options, delta_engine_instance
246+
)
247+
248+
def reconcile_schema(
249+
self, fg_alias, read_options, engine_instance
250+
):
251+
if sorted(self._spark_session.table(fg_alias.alias).columns) != sorted(
252+
[feature.name for feature in fg_alias.feature_group._features] +
253+
self.HUDI_SPEC_FEATURE_NAMES if fg_alias.feature_group.time_travel_format == "HUDI" else []
254+
):
255+
full_fg = self._feature_group_api.get(
256+
feature_store_id=fg_alias.feature_group._feature_store_id,
257+
name=fg_alias.feature_group.name,
258+
version=fg_alias.feature_group.version,
259+
)
260+
261+
self.update_table_schema(full_fg)
262+
263+
engine_instance.register_temporary_table(
264+
fg_alias,
265+
read_options,
266+
)
267+
244268
def _return_dataframe_type(self, dataframe, dataframe_type):
245269
if dataframe_type.lower() in ["default", "spark"]:
246270
return dataframe
@@ -1324,7 +1348,13 @@ def is_spark_dataframe(self, dataframe):
13241348
return True
13251349
return False
13261350

1327-
def save_empty_dataframe(self, feature_group):
1351+
def update_table_schema(self, feature_group):
1352+
if feature_group.time_travel_format == "DELTA":
1353+
self._add_cols_to_delta_table(feature_group)
1354+
else:
1355+
self._save_empty_dataframe(feature_group)
1356+
1357+
def _save_empty_dataframe(self, feature_group):
13281358
location = feature_group.prepare_spark_location()
13291359

13301360
dataframe = self._spark_session.read.format("hudi").load(location)
@@ -1343,7 +1373,7 @@ def save_empty_dataframe(self, feature_group):
13431373
{},
13441374
)
13451375

1346-
def add_cols_to_delta_table(self, feature_group):
1376+
def _add_cols_to_delta_table(self, feature_group):
13471377
location = feature_group.prepare_spark_location()
13481378

13491379
dataframe = self._spark_session.read.format("delta").load(location)

python/tests/client/test_base_client.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import requests
2121
from hsfs.client.base import Client
2222
from hsfs.client.exceptions import RestAPIError
23-
2423
from tests.util import changes_environ
2524

2625

python/tests/core/test_feature_group_engine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def test_append_features(self, mocker):
709709

710710
# Assert
711711
assert (
712-
mock_engine_get_instance.return_value.save_empty_dataframe.call_count == 1
712+
mock_engine_get_instance.return_value.update_table_schema.call_count == 1
713713
)
714714
assert len(mock_fg_engine_update_features_metadata.call_args[0][1]) == 4
715715

python/tests/engine/test_python.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2565,12 +2565,12 @@ def test_save_stream_dataframe(self):
25652565
== "Stream ingestion is not available on Python environments, because it requires Spark as engine."
25662566
)
25672567

2568-
def test_save_empty_dataframe(self):
2568+
def test_update_table_schema(self):
25692569
# Arrange
25702570
python_engine = python.Engine()
25712571

25722572
# Act
2573-
result = python_engine.save_empty_dataframe(feature_group=None)
2573+
result = python_engine.update_table_schema(feature_group=None)
25742574

25752575
# Assert
25762576
assert result is None

python/tests/engine/test_spark.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def test_register_hudi_temporary_table(self, mocker):
203203
# Arrange
204204
mock_hudi_engine = mocker.patch("hsfs.core.hudi_engine.HudiEngine")
205205
mocker.patch("hsfs.feature_group.FeatureGroup.from_response_json")
206+
mock_reconcile_schema = mocker.patch("hsfs.engine.spark.Engine.reconcile_schema")
206207

207208
spark_engine = spark.Engine()
208209

@@ -220,6 +221,31 @@ def test_register_hudi_temporary_table(self, mocker):
220221

221222
# Assert
222223
assert mock_hudi_engine.return_value.register_temporary_table.call_count == 1
224+
assert mock_reconcile_schema.call_count == 1
225+
226+
def test_register_delta_temporary_table(self, mocker):
227+
# Arrange
228+
mock_delta_engine = mocker.patch("hsfs.core.delta_engine.DeltaEngine")
229+
mocker.patch("hsfs.feature_group.FeatureGroup.from_response_json")
230+
mock_reconcile_schema = mocker.patch("hsfs.engine.spark.Engine.reconcile_schema")
231+
232+
spark_engine = spark.Engine()
233+
234+
hudi_fg_alias = hudi_feature_group_alias.HudiFeatureGroupAlias(
235+
feature_group=None, alias=None
236+
)
237+
238+
# Act
239+
spark_engine.register_delta_temporary_table(
240+
delta_fg_alias=hudi_fg_alias,
241+
feature_store_id=None,
242+
feature_store_name=None,
243+
read_options=None,
244+
)
245+
246+
# Assert
247+
assert mock_delta_engine.return_value.register_temporary_table.call_count == 1
248+
assert mock_reconcile_schema.call_count == 1
223249

224250
def test_return_dataframe_type_default(self, mocker):
225251
# Arrange
@@ -4540,7 +4566,7 @@ def test_is_spark_dataframe_spark_dataframe(self):
45404566
# Assert
45414567
assert result is True
45424568

4543-
def test_save_empty_dataframe(self, mocker):
4569+
def test_update_table_schema_hudi(self, mocker):
45444570
# Arrange
45454571
mock_spark_engine_save_dataframe = mocker.patch(
45464572
"hsfs.engine.spark.Engine.save_dataframe"
@@ -4560,15 +4586,42 @@ def test_save_empty_dataframe(self, mocker):
45604586
partition_key=[],
45614587
id=10,
45624588
featurestore_name="test_featurestore",
4589+
time_travel_format="HUDI",
45634590
)
45644591

45654592
# Act
4566-
spark_engine.save_empty_dataframe(feature_group=fg)
4593+
spark_engine.update_table_schema(feature_group=fg)
45674594

45684595
# Assert
45694596
assert mock_spark_engine_save_dataframe.call_count == 1
45704597
assert mock_spark_read.format.call_count == 1
45714598

4599+
def test_update_table_schema_delta(self, mocker):
4600+
# Arrange
4601+
mock_spark_read = mocker.patch("pyspark.sql.SparkSession.read")
4602+
mock_format = mocker.Mock()
4603+
mock_spark_read.format.return_value = mock_format
4604+
4605+
# Arrange
4606+
spark_engine = spark.Engine()
4607+
4608+
fg = feature_group.FeatureGroup(
4609+
name="test",
4610+
version=1,
4611+
featurestore_id=99,
4612+
primary_key=[],
4613+
partition_key=[],
4614+
id=10,
4615+
featurestore_name="test_featurestore",
4616+
time_travel_format="DELTA",
4617+
)
4618+
4619+
# Act
4620+
spark_engine.update_table_schema(feature_group=fg)
4621+
4622+
# Assert
4623+
assert mock_spark_read.format.call_count == 1
4624+
45724625
def test_apply_transformation_function_single_output_udf_default_mode(self, mocker):
45734626
# Arrange
45744627
mocker.patch("hopsworks_common.client.get_instance")

0 commit comments

Comments
 (0)