From 25ec663b7f94e44715e83529613fe226bbd8da69 Mon Sep 17 00:00:00 2001 From: Fabio Buso Date: Fri, 24 Jan 2025 11:46:38 +0100 Subject: [PATCH 1/8] [FSTORE-1667] Feature store client doesn't work on Databricks with BYOK setup (#1415) * Add tests * Fix gh workflow * Fix mocks in pyspark tests * Fix mock in /test_spark.py::TestSpark::test_read_stream --- .github/workflows/python.yml | 19 +++++++++ python/hsfs/engine/spark.py | 6 +++ python/hsfs/storage_connector.py | 53 ++++++++++++++++++++++++-- python/tests/engine/test_spark.py | 7 ++++ python/tests/test_storage_connector.py | 50 ++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index b9566d4a1..d89806da0 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -68,8 +68,15 @@ jobs: java-version: "8" distribution: "adopt" + - name: Set up JDK 8 + uses: actions/setup-java@v3 + with: + java-version: "8" + distribution: "adopt" + - name: Set Timezone run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone + run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone - uses: actions/checkout@v4 - name: Copy README @@ -131,6 +138,12 @@ jobs: java-version: "8" distribution: "adopt" + - name: Set up JDK 8 + uses: actions/setup-java@v3 + with: + java-version: "8" + distribution: "adopt" + - name: Set Timezone run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone @@ -164,6 +177,12 @@ jobs: java-version: "8" distribution: "adopt" + - name: Set up JDK 8 + uses: actions/setup-java@v3 + with: + java-version: "8" + distribution: "adopt" + - name: Set Timezone run: sudo timedatectl set-timezone Europe/Amsterdam && echo Europe/Amsterdam | sudo tee /etc/timezone diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index a126ab8c5..faf4d86fa 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -1681,6 +1681,12 @@ def read_feature_log(query, time_col): df = query.read() return df.drop("log_id", time_col) + def get_spark_version(self): + return self._spark_session.version + + def get_spark_version(self): + return self._spark_session.version + class SchemaError(Exception): """Thrown when schemas don't match""" diff --git a/python/hsfs/storage_connector.py b/python/hsfs/storage_connector.py index 790e5d2c1..146ab1c13 100644 --- a/python/hsfs/storage_connector.py +++ b/python/hsfs/storage_connector.py @@ -1298,16 +1298,61 @@ def confluent_options(self) -> Dict[str, Any]: return config + def _read_pem(self, file_name): + with open(file_name, "r") as file: + return file.read() + def spark_options(self) -> Dict[str, Any]: """Return prepared options to be passed to Spark, based on the additional arguments. This is done by just adding 'kafka.' prefix to kafka_options. https://spark.apache.org/docs/latest/structured-streaming-kafka-integration.html#kafka-specific-configurations """ - config = {} - for key, value in self.kafka_options().items(): - config[f"{KafkaConnector.SPARK_FORMAT}.{key}"] = value + from packaging import version - return config + spark_config = {} + + kafka_options = self.kafka_options() + + for key, value in kafka_options.items(): + if key in [ + "ssl.truststore.location", + "ssl.truststore.password", + "ssl.keystore.location", + "ssl.keystore.password", + "ssl.key.password", + ] and version.parse( + engine.get_instance().get_spark_version() + ) >= version.parse("3.2.0"): + # We can only use this in the newer version of Spark which depend on Kafka > 2.7.0 + # Kafka 2.7.0 adds support for providing the SSL credentials as PEM objects. + if not self._pem_files_created: + ( + ca_chain_path, + client_cert_path, + client_key_path, + ) = client.get_instance()._write_pem( + kafka_options["ssl.keystore.location"], + kafka_options["ssl.keystore.password"], + kafka_options["ssl.truststore.location"], + kafka_options["ssl.truststore.password"], + f"kafka_sc_{client.get_instance()._project_id}_{self._id}", + ) + self._pem_files_created = True + spark_config["kafka.ssl.truststore.certificates"] = self._read_pem( + ca_chain_path + ) + spark_config["kafka.ssl.keystore.certificate.chain"] = ( + self._read_pem(client_cert_path) + ) + spark_config["kafka.ssl.keystore.key"] = self._read_pem( + client_key_path + ) + spark_config["kafka.ssl.truststore.type"] = "PEM" + spark_config["kafka.ssl.keystore.type"] = "PEM" + else: + spark_config[f"{KafkaConnector.SPARK_FORMAT}.{key}"] = value + + return spark_config def read( self, diff --git a/python/tests/engine/test_spark.py b/python/tests/engine/test_spark.py index 44eff9209..24a98c0f7 100644 --- a/python/tests/engine/test_spark.py +++ b/python/tests/engine/test_spark.py @@ -866,6 +866,7 @@ def test_save_stream_dataframe(self, mocker, backend_fixtures): ) mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mock_engine_get_instance.return_value.add_file.return_value = ( "result_from_add_file" ) @@ -993,6 +994,7 @@ def test_save_stream_dataframe_query_name(self, mocker, backend_fixtures): ) mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mock_engine_get_instance.return_value.add_file.return_value = ( "result_from_add_file" ) @@ -1124,6 +1126,7 @@ def test_save_stream_dataframe_checkpoint_dir(self, mocker, backend_fixtures): ) mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mock_engine_get_instance.return_value.add_file.return_value = ( "result_from_add_file" ) @@ -1251,6 +1254,7 @@ def test_save_stream_dataframe_await_termination(self, mocker, backend_fixtures) ) mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mock_engine_get_instance.return_value.add_file.return_value = ( "result_from_add_file" ) @@ -1515,6 +1519,7 @@ def test_save_online_dataframe(self, mocker, backend_fixtures): ) mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mock_engine_get_instance.return_value.add_file.return_value = ( "result_from_add_file" ) @@ -3223,6 +3228,8 @@ def test_read_stream(self, mocker): # Arrange mocker.patch("hsfs.engine.get_instance") mocker.patch("hopsworks_common.client.get_instance") + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" + mock_pyspark_getOrCreate = mocker.patch( "pyspark.sql.session.SparkSession.builder.getOrCreate" ) diff --git a/python/tests/test_storage_connector.py b/python/tests/test_storage_connector.py index a1f7e77c3..9bb502671 100644 --- a/python/tests/test_storage_connector.py +++ b/python/tests/test_storage_connector.py @@ -618,6 +618,8 @@ def test_spark_options(self, mocker, backend_fixtures): mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance") json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"] + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" + mock_client_get_instance.return_value._get_jks_trust_store_path.return_value = ( "result_from_get_jks_trust_store_path" ) @@ -649,6 +651,7 @@ def test_spark_options_external(self, mocker, backend_fixtures): mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mock_engine_get_instance.return_value.add_file.return_value = ( "result_from_add_file" ) @@ -671,6 +674,53 @@ def test_spark_options_external(self, mocker, backend_fixtures): "kafka.ssl.key.password": "test_ssl_key_password", } + def test_spark_options_spark_35(self, mocker, backend_fixtures): + # Arrange + mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") + mock_client_get_instance = mocker.patch("hsfs.client.get_instance") + json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"] + + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.5.0" + + mock_client_get_instance.return_value._get_jks_trust_store_path.return_value = ( + "result_from_get_jks_trust_store_path" + ) + mock_client_get_instance.return_value._get_jks_key_store_path.return_value = ( + "result_from_get_jks_key_store_path" + ) + mock_client_get_instance.return_value._cert_key = "result_from_cert_key" + mock_client_get_instance.return_value._write_pem.return_value = ( + None, + None, + None, + ) + + sc = storage_connector.StorageConnector.from_response_json(json) + + # Mock the read pem method in the storage connector itself + sc._read_pem = mocker.Mock() + sc._read_pem.side_effect = [ + "test_ssl_ca", + "test_ssl_certificate", + "test_ssl_key", + ] + + # Act + config = sc.spark_options() + + # Assert + assert config == { + "kafka.test_option_name": "test_option_value", + "kafka.bootstrap.servers": "test_bootstrap_servers", + "kafka.security.protocol": "test_security_protocol", + "kafka.ssl.endpoint.identification.algorithm": "test_ssl_endpoint_identification_algorithm", + "kafka.ssl.truststore.type": "PEM", + "kafka.ssl.keystore.type": "PEM", + "kafka.ssl.truststore.certificates": "test_ssl_ca", + "kafka.ssl.keystore.certificate.chain": "test_ssl_certificate", + "kafka.ssl.keystore.key": "test_ssl_key", + } + def test_confluent_options(self, mocker, backend_fixtures): # Arrange mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") From 2a2c67eb57a9e066710608c7ee5268c793fedc01 Mon Sep 17 00:00:00 2001 From: bubriks Date: Wed, 12 Feb 2025 11:54:53 +0200 Subject: [PATCH 2/8] merge fix --- .github/workflows/python.yml | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index d89806da0..ed8f4c430 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -68,12 +68,6 @@ jobs: java-version: "8" distribution: "adopt" - - name: Set up JDK 8 - uses: actions/setup-java@v3 - with: - java-version: "8" - distribution: "adopt" - - name: Set Timezone run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone @@ -138,12 +132,6 @@ jobs: java-version: "8" distribution: "adopt" - - name: Set up JDK 8 - uses: actions/setup-java@v3 - with: - java-version: "8" - distribution: "adopt" - - name: Set Timezone run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone @@ -177,12 +165,6 @@ jobs: java-version: "8" distribution: "adopt" - - name: Set up JDK 8 - uses: actions/setup-java@v3 - with: - java-version: "8" - distribution: "adopt" - - name: Set Timezone run: sudo timedatectl set-timezone Europe/Amsterdam && echo Europe/Amsterdam | sudo tee /etc/timezone From c072cda2c63e48085a09b9aab95f8e01832f71db Mon Sep 17 00:00:00 2001 From: bubriks Date: Wed, 12 Feb 2025 11:55:52 +0200 Subject: [PATCH 3/8] more fixing --- .github/workflows/python.yml | 1 - python/hsfs/engine/spark.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index ed8f4c430..b9566d4a1 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -70,7 +70,6 @@ jobs: - name: Set Timezone run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone - run: sudo timedatectl set-timezone UTC && echo UTC | sudo tee /etc/timezone - uses: actions/checkout@v4 - name: Copy README diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index faf4d86fa..10c06a1a5 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -1684,9 +1684,6 @@ def read_feature_log(query, time_col): def get_spark_version(self): return self._spark_session.version - def get_spark_version(self): - return self._spark_session.version - class SchemaError(Exception): """Thrown when schemas don't match""" From 38bf266ac681d46d05191821dd1bd5ad74d564a0 Mon Sep 17 00:00:00 2001 From: bubriks Date: Wed, 12 Feb 2025 12:02:19 +0200 Subject: [PATCH 4/8] some tests fixed --- python/tests/engine/test_spark.py | 2 +- python/tests/test_storage_connector.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tests/engine/test_spark.py b/python/tests/engine/test_spark.py index 24a98c0f7..d31ebd6ea 100644 --- a/python/tests/engine/test_spark.py +++ b/python/tests/engine/test_spark.py @@ -3226,7 +3226,7 @@ def test_read_location_format_tsv(self, mocker): def test_read_stream(self, mocker): # Arrange - mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") mocker.patch("hopsworks_common.client.get_instance") mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" diff --git a/python/tests/test_storage_connector.py b/python/tests/test_storage_connector.py index 9bb502671..f96b27aa0 100644 --- a/python/tests/test_storage_connector.py +++ b/python/tests/test_storage_connector.py @@ -614,7 +614,7 @@ def test_kafka_options_external(self, mocker, backend_fixtures): def test_spark_options(self, mocker, backend_fixtures): # Arrange - mocker.patch("hsfs.engine.get_instance") + mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance") json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"] @@ -677,7 +677,7 @@ def test_spark_options_external(self, mocker, backend_fixtures): def test_spark_options_spark_35(self, mocker, backend_fixtures): # Arrange mock_engine_get_instance = mocker.patch("hsfs.engine.get_instance") - mock_client_get_instance = mocker.patch("hsfs.client.get_instance") + mock_client_get_instance = mocker.patch("hopsworks_common.client.get_instance") json = backend_fixtures["storage_connector"]["get_kafka_internal"]["response"] mock_engine_get_instance.return_value.get_spark_version.return_value = "3.5.0" From bda863f219d0157fe9be3b78f5f2953d00c3aff3 Mon Sep 17 00:00:00 2001 From: bubriks Date: Wed, 12 Feb 2025 14:29:07 +0200 Subject: [PATCH 5/8] test fix --- python/tests/core/test_kafka_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tests/core/test_kafka_engine.py b/python/tests/core/test_kafka_engine.py index 8ccbe46c1..9c0d85d65 100644 --- a/python/tests/core/test_kafka_engine.py +++ b/python/tests/core/test_kafka_engine.py @@ -414,6 +414,7 @@ def test_spark_get_kafka_config(self, mocker, backend_fixtures): json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] sc = storage_connector.StorageConnector.from_response_json(json) mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" mocker.patch("hopsworks_common.client._is_external", return_value=False) # Act @@ -456,6 +457,7 @@ def test_spark_get_kafka_config_external_client(self, mocker, backend_fixtures): json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] sc = storage_connector.StorageConnector.from_response_json(json) mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" # Act results = kafka_engine.get_kafka_config( @@ -497,6 +499,7 @@ def test_spark_get_kafka_config_internal_kafka(self, mocker, backend_fixtures): json = backend_fixtures["storage_connector"]["get_kafka_external"]["response"] sc = storage_connector.StorageConnector.from_response_json(json) mock_storage_connector_api.return_value.get_kafka_connector.return_value = sc + mock_engine_get_instance.return_value.get_spark_version.return_value = "3.1.0" # Act results = kafka_engine.get_kafka_config( From 858dded27d359b67faac611ae8df937028b2a290 Mon Sep 17 00:00:00 2001 From: Fabio Buso Date: Tue, 28 Jan 2025 16:34:03 +0100 Subject: [PATCH 6/8] Don't copy files in the home directory --- python/hsfs/engine/spark.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index 10c06a1a5..01d8cdf22 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -19,7 +19,6 @@ import json import os import re -import shutil import uuid import warnings from datetime import date, datetime, timezone @@ -1121,22 +1120,16 @@ def add_file(self, file): # for external clients, download the file if client._is_external(): - tmp_file = os.path.join(SparkFiles.getRootDirectory(), file_name) + tmp_file = "f/tmp/{file_name)" print("Reading key file from storage connector.") response = self._dataset_api.read_content(file, util.get_dataset_type(file)) with open(tmp_file, "wb") as f: f.write(response.content) - else: - self._spark_context.addFile(file) - # The file is not added to the driver current working directory - # We should add it manually by copying from the download location - # The file will be added to the executors current working directory - # before the next task is executed - shutil.copy(SparkFiles.get(file_name), file_name) + self._spark_context.addFile(file) - return file_name + return SparkFiles.get(file_name) def profile( self, From e2f163f8165e097d7383648d055d67bde312e723 Mon Sep 17 00:00:00 2001 From: Fabio Buso Date: Tue, 28 Jan 2025 17:17:49 +0100 Subject: [PATCH 7/8] Fix typo --- python/hsfs/engine/spark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index 01d8cdf22..37b43b366 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -1120,7 +1120,7 @@ def add_file(self, file): # for external clients, download the file if client._is_external(): - tmp_file = "f/tmp/{file_name)" + tmp_file = f"/tmp/{file_name}" print("Reading key file from storage connector.") response = self._dataset_api.read_content(file, util.get_dataset_type(file)) From 667631d905a2cdfeaadec88f7d7bb12a07f69f25 Mon Sep 17 00:00:00 2001 From: bubriks Date: Thu, 13 Feb 2025 11:01:24 +0200 Subject: [PATCH 8/8] Fix add_file --- python/hsfs/engine/spark.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hsfs/engine/spark.py b/python/hsfs/engine/spark.py index 37b43b366..d9918d23e 100644 --- a/python/hsfs/engine/spark.py +++ b/python/hsfs/engine/spark.py @@ -1127,6 +1127,8 @@ def add_file(self, file): with open(tmp_file, "wb") as f: f.write(response.content) + file = f"file://{tmp_file}" + self._spark_context.addFile(file) return SparkFiles.get(file_name)