Skip to content

get_uri -> prepare_spark_location #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,6 @@ public Dataset<Row> registerOnDemandTemporaryTable(ExternalFeatureGroup onDemand
? onDemandFeatureGroup.getDataFormat().toString() : null, getOnDemandOptions(onDemandFeatureGroup),
onDemandFeatureGroup.getStorageConnector().getPath(onDemandFeatureGroup.getPath()));

if (!Strings.isNullOrEmpty(onDemandFeatureGroup.getLocation())) {
sparkSession.sparkContext().textFile(onDemandFeatureGroup.getLocation(), 0).collect();
}

dataset.createOrReplaceTempView(alias);
return dataset;
}
Expand Down
22 changes: 12 additions & 10 deletions python/hsfs/core/delta_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def save_delta_fg(self, dataset, write_options, validation_id=None):
return self._feature_group_api.commit(self._feature_group, fg_commit)

def register_temporary_table(self, delta_fg_alias, read_options):
location = self._feature_group.prepare_spark_location()

delta_options = self._setup_delta_read_opts(delta_fg_alias, read_options)
self._spark_session.read.format(self.DELTA_SPARK_FORMAT).options(
**delta_options
).load(self._feature_group.get_uri()).createOrReplaceTempView(
).load(location).createOrReplaceTempView(
delta_fg_alias.alias
)

Expand Down Expand Up @@ -85,17 +87,17 @@ def _setup_delta_read_opts(self, delta_fg_alias, read_options):
return delta_options

def delete_record(self, delete_df):
uri = self._feature_group.get_uri()
location = self._feature_group.prepare_spark_location()

if not DeltaTable.isDeltaTable(
self._spark_session, uri
self._spark_session, location
):
raise FeatureStoreException(
f"This is no data available in Feature group {self._feature_group.name}, or it not DELTA enabled "
)
else:
fg_source_table = DeltaTable.forPath(
self._spark_session, uri
self._spark_session, location
)

source_alias = (
Expand All @@ -111,18 +113,18 @@ def delete_record(self, delete_df):
).whenMatchedDelete().execute()

fg_commit = self._get_last_commit_metadata(
self._spark_session, uri
self._spark_session, location
)
return self._feature_group_api.commit(self._feature_group, fg_commit)

def _write_delta_dataset(self, dataset, write_options):
uri = self._feature_group.get_uri()
location = self._feature_group.prepare_spark_location()

if write_options is None:
write_options = {}

if not DeltaTable.isDeltaTable(
self._spark_session, uri
self._spark_session, location
):
(
dataset.write.format(DeltaEngine.DELTA_SPARK_FORMAT)
Expand All @@ -133,11 +135,11 @@ def _write_delta_dataset(self, dataset, write_options):
else []
)
.mode("append")
.save(uri)
.save(location)
)
else:
fg_source_table = DeltaTable.forPath(
self._spark_session, uri
self._spark_session, location
)

source_alias = (
Expand All @@ -153,7 +155,7 @@ def _write_delta_dataset(self, dataset, write_options):
).whenMatchedUpdateAll().whenNotMatchedInsertAll().execute()

return self._get_last_commit_metadata(
self._spark_session, uri
self._spark_session, location
)

def _generate_merge_query(self, source_alias, updates_alias):
Expand Down
10 changes: 6 additions & 4 deletions python/hsfs/core/hudi_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,25 @@ def delete_record(self, delete_df, write_options):
return self._feature_group_api.commit(self._feature_group, fg_commit)

def register_temporary_table(self, hudi_fg_alias, read_options):
location = self._feature_group.prepare_spark_location()

hudi_options = self._setup_hudi_read_opts(hudi_fg_alias, read_options)
self._spark_session.read.format(self.HUDI_SPARK_FORMAT).options(
**hudi_options
).load(self._feature_group.get_uri()).createOrReplaceTempView(
).load(location).createOrReplaceTempView(
hudi_fg_alias.alias
)

def _write_hudi_dataset(self, dataset, save_mode, operation, write_options):
uri = self._feature_group.get_uri()
location = self._feature_group.prepare_spark_location()

hudi_options = self._setup_hudi_write_opts(operation, write_options)
dataset.write.format(HudiEngine.HUDI_SPARK_FORMAT).options(**hudi_options).mode(
save_mode
).save(uri)
).save(location)

feature_group_commit = self._get_last_commit_metadata(
self._spark_context, uri
self._spark_context, location
)

return feature_group_commit
Expand Down
16 changes: 7 additions & 9 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,10 @@ def register_external_temporary_table(self, external_fg, alias):
external_fg.query,
external_fg.data_format,
external_fg.options,
external_fg.get_uri(),
external_fg.prepare_spark_location(),
)
else:
external_dataset = external_fg.dataframe
if external_fg.location:
self._spark_session.sparkContext.textFile(external_fg.location).collect()

external_dataset.createOrReplaceTempView(alias)
return external_dataset
Expand Down Expand Up @@ -1250,9 +1248,9 @@ def is_spark_dataframe(self, dataframe):
return False

def save_empty_dataframe(self, feature_group, new_features=None):
dataframe = self._spark_session.read.format("hudi").load(
feature_group.get_uri()
)
location = feature_group.prepare_spark_location()

dataframe = self._spark_session.read.format("hudi").load(location)

if (new_features is not None):
if isinstance(new_features, list):
Expand All @@ -1273,9 +1271,9 @@ def save_empty_dataframe(self, feature_group, new_features=None):
)

def add_cols_to_delta_table(self, feature_group, new_features):
uri = self._feature_group.get_uri()
location = feature_group.prepare_spark_location()

dataframe = self._spark_session.read.format("delta").load(uri)
dataframe = self._spark_session.read.format("delta").load(location)

if (new_features is not None):
if isinstance(new_features, list):
Expand All @@ -1288,7 +1286,7 @@ def add_cols_to_delta_table(self, feature_group, new_features):
"append"
).option("mergeSchema", "true").option(
"spark.databricks.delta.schema.autoMerge.enabled", "true"
).save(uri)
).save(location)

def _apply_transformation_function(
self,
Expand Down
14 changes: 5 additions & 9 deletions python/hsfs/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,15 +2067,11 @@ def path(self) -> Optional[str]:
def storage_connector(self) -> "sc.StorageConnector":
return self._storage_connector

def get_uri(self) -> str:
"""Location of data."""
if (self.storage_connector is None):
return self.location
else:
path = self.storage_connector._get_path(self.path)
if engine.get_type().startswith("spark"):
path = self.storage_connector.prepare_spark(path)
return path
def prepare_spark_location(self) -> str:
location = self.location
if (self.storage_connector is not None):
location = self.storage_connector.prepare_spark(location)
return location

@property
def topic_name(self) -> Optional[str]:
Expand Down
40 changes: 0 additions & 40 deletions python/tests/engine/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,43 +174,6 @@ def test_register_external_temporary_table(self, mocker):
# Arrange
mocker.patch("hopsworks_common.client.get_instance")
mock_sc_read = mocker.patch("hsfs.storage_connector.JdbcConnector.read")
mock_pyspark_getOrCreate = mocker.patch(
"pyspark.sql.session.SparkSession.builder.getOrCreate"
)

spark_engine = spark.Engine()

jdbc_connector = storage_connector.JdbcConnector(
id=1,
name="test_connector",
featurestore_id=1,
connection_string="",
arguments="",
)

external_fg = feature_group.ExternalFeatureGroup(
storage_connector=jdbc_connector, id=10
)

# Act
spark_engine.register_external_temporary_table(
external_fg=external_fg,
alias=None,
)

# Assert
assert (
mock_pyspark_getOrCreate.return_value.sparkContext.textFile.call_count == 0
)
assert mock_sc_read.return_value.createOrReplaceTempView.call_count == 1

def test_register_external_temporary_table_external_fg_location(self, mocker):
# Arrange
mocker.patch("hopsworks_common.client.get_instance")
mock_sc_read = mocker.patch("hsfs.storage_connector.JdbcConnector.read")
mock_pyspark_getOrCreate = mocker.patch(
"pyspark.sql.session.SparkSession.builder.getOrCreate"
)

spark_engine = spark.Engine()

Expand All @@ -233,9 +196,6 @@ def test_register_external_temporary_table_external_fg_location(self, mocker):
)

# Assert
assert (
mock_pyspark_getOrCreate.return_value.sparkContext.textFile.call_count == 1
)
assert mock_sc_read.return_value.createOrReplaceTempView.call_count == 1

def test_register_hudi_temporary_table(self, mocker):
Expand Down
Loading