Skip to content

[FSTORE-1667] Feature store client doesn't work on Databricks with BYOK setup #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 12, 2025
18 changes: 8 additions & 10 deletions python/hsfs/engine/spark.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,6 @@
import json
import os
import re
import shutil
import uuid
import warnings
from datetime import date, datetime, timezone
@@ -1121,22 +1120,18 @@ def add_file(self, file):

# for external clients, download the file
if client._is_external():
tmp_file = os.path.join(SparkFiles.getRootDirectory(), file_name)
tmp_file = f"/tmp/{file_name}"
print("Reading key file from storage connector.")
response = self._dataset_api.read_content(file, util.get_dataset_type(file))

with open(tmp_file, "wb") as f:
f.write(response.content)
else:
self._spark_context.addFile(file)

# The file is not added to the driver current working directory
# We should add it manually by copying from the download location
# The file will be added to the executors current working directory
# before the next task is executed
shutil.copy(SparkFiles.get(file_name), file_name)
file = f"file://{tmp_file}"

self._spark_context.addFile(file)

return file_name
return SparkFiles.get(file_name)

def profile(
self,
@@ -1681,6 +1676,9 @@ def read_feature_log(query, time_col):
df = query.read()
return df.drop("log_id", time_col)

def get_spark_version(self):
return self._spark_session.version


class SchemaError(Exception):
"""Thrown when schemas don't match"""
53 changes: 49 additions & 4 deletions python/hsfs/storage_connector.py
Original file line number Diff line number Diff line change
@@ -1304,16 +1304,61 @@ def confluent_options(self) -> Dict[str, Any]:

return config

def _read_pem(self, file_name):
with open(file_name, "r") as file:
return file.read()

def spark_options(self) -> Dict[str, Any]:
"""Return prepared options to be passed to Spark, based on the additional arguments.
This is done by just adding 'kafka.' prefix to kafka_options.
https://spark.apache.org/docs/latest/structured-streaming-kafka-integration.html#kafka-specific-configurations
"""
config = {}
for key, value in self.kafka_options().items():
config[f"{KafkaConnector.SPARK_FORMAT}.{key}"] = value
from packaging import version

return config
spark_config = {}

kafka_options = self.kafka_options()

for key, value in kafka_options.items():
if key in [
"ssl.truststore.location",
"ssl.truststore.password",
"ssl.keystore.location",
"ssl.keystore.password",
"ssl.key.password",
] and version.parse(
engine.get_instance().get_spark_version()
) >= version.parse("3.2.0"):
# We can only use this in the newer version of Spark which depend on Kafka > 2.7.0
# Kafka 2.7.0 adds support for providing the SSL credentials as PEM objects.
if not self._pem_files_created:
(
ca_chain_path,
client_cert_path,
client_key_path,
) = client.get_instance()._write_pem(
kafka_options["ssl.keystore.location"],
kafka_options["ssl.keystore.password"],
kafka_options["ssl.truststore.location"],
kafka_options["ssl.truststore.password"],
f"kafka_sc_{client.get_instance()._project_id}_{self._id}",
)
self._pem_files_created = True
spark_config["kafka.ssl.truststore.certificates"] = self._read_pem(
ca_chain_path
)
spark_config["kafka.ssl.keystore.certificate.chain"] = (
self._read_pem(client_cert_path)
)
spark_config["kafka.ssl.keystore.key"] = self._read_pem(
client_key_path
)
spark_config["kafka.ssl.truststore.type"] = "PEM"
spark_config["kafka.ssl.keystore.type"] = "PEM"
else:
spark_config[f"{KafkaConnector.SPARK_FORMAT}.{key}"] = value

return spark_config

def read(
self,
3 changes: 3 additions & 0 deletions python/tests/core/test_kafka_engine.py
Original file line number Diff line number Diff line change
@@ -414,6 +414,7 @@ def test_spark_get_kafka_config(self, mocker, backend_fixtures):
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]
sc = storage_connector.StorageConnector.from_response_json(json)
mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

mocker.patch("hopsworks_common.client._is_external", return_value=False)
# Act
@@ -456,6 +457,7 @@ def test_spark_get_kafka_config_external_client(self, mocker, backend_fixtures):
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]
sc = storage_connector.StorageConnector.from_response_json(json)
mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

# Act
results = kafka_engine.get_kafka_config(
@@ -497,6 +499,7 @@ def test_spark_get_kafka_config_internal_kafka(self, mocker, backend_fixtures):
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]
sc = storage_connector.StorageConnector.from_response_json(json)
mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

# Act
results = kafka_engine.get_kafka_config(
9 changes: 8 additions & 1 deletion python/tests/engine/test_spark.py
Original file line number Diff line number Diff line change
@@ -866,6 +866,7 @@ def test_save_stream_dataframe(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
@@ -993,6 +994,7 @@ def test_save_stream_dataframe_query_name(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
@@ -1124,6 +1126,7 @@ def test_save_stream_dataframe_checkpoint_dir(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
@@ -1251,6 +1254,7 @@ def test_save_stream_dataframe_await_termination(self, mocker, backend_fixtures)
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
@@ -1515,6 +1519,7 @@ def test_save_online_dataframe(self, mocker, backend_fixtures):
)

mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
@@ -3221,8 +3226,10 @@ def test_read_location_format_tsv(self, mocker):

def test_read_stream(self, mocker):
# Arrange
mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mocker.patch("hopsworks_common.client.get_instance")
mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

mock_pyspark_getOrCreate = mocker.patch(
"pyspark.sql.session.SparkSession.builder.getOrCreate"
)
52 changes: 51 additions & 1 deletion python/tests/test_storage_connector.py
Original file line number Diff line number Diff line change
@@ -618,10 +618,12 @@ def test_kafka_options_external(self, mocker, backend_fixtures):

def test_spark_options(self, mocker, backend_fixtures):
# Arrange
mocker.patch("hsfs.engine.get_instance")
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance")
json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"]

mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"

mock_client_get_instance.return_value._get_jks_trust_store_path.return_value = (
"result_from_get_jks_trust_store_path"
)
@@ -653,6 +655,7 @@ def test_spark_options_external(self, mocker, backend_fixtures):
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"]

mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0"
mock_engine_get_instance.return_value.add_file.return_value = (
"result_from_add_file"
)
@@ -675,6 +678,53 @@ def test_spark_options_external(self, mocker, backend_fixtures):
"kafka.ssl.key.password": "test_ssl_key_password",
}

def test_spark_options_spark_35(self, mocker, backend_fixtures):
# Arrange
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance")
json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"]

mock_engine_get_instance.return_value.get_spark_version.return_value = "3.5.0"

mock_client_get_instance.return_value._get_jks_trust_store_path.return_value = (
"result_from_get_jks_trust_store_path"
)
mock_client_get_instance.return_value._get_jks_key_store_path.return_value = (
"result_from_get_jks_key_store_path"
)
mock_client_get_instance.return_value._cert_key = "result_from_cert_key"
mock_client_get_instance.return_value._write_pem.return_value = (
None,
None,
None,
)

sc = storage_connector.StorageConnector.from_response_json(json)

# Mock the read pem method in the storage connector itself
sc._read_pem = mocker.Mock()
sc._read_pem.side_effect = [
"test_ssl_ca",
"test_ssl_certificate",
"test_ssl_key",
]

# Act
config = sc.spark_options()

# Assert
assert config == {
"kafka.test_option_name": "test_option_value",
"kafka.bootstrap.servers": "test_bootstrap_servers",
"kafka.security.protocol": "test_security_protocol",
"kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm",
"kafka.ssl.truststore.type": "PEM",
"kafka.ssl.keystore.type": "PEM",
"kafka.ssl.truststore.certificates": "test_ssl_ca",
"kafka.ssl.keystore.certificate.chain": "test_ssl_certificate",
"kafka.ssl.keystore.key": "test_ssl_key",
}

def test_confluent_options(self, mocker, backend_fixtures):
# Arrange
mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance")
Loading
Oops, something went wrong.