Skip to content

Commit 227a4e1

Browse files
committed
Merge client, connection, version and variable_api
1 parent d6b578a commit 227a4e1

23 files changed

+895
-1815
lines changed

python/hopsworks/client/__init__.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
#
1616

17+
from typing import Literal, Optional, Union
18+
1719
from hopsworks.client import external, hopsworks
1820

1921

@@ -22,16 +24,19 @@
2224

2325

2426
def init(
25-
client_type,
26-
host=None,
27-
port=None,
28-
project=None,
29-
hostname_verification=None,
30-
trust_store_path=None,
31-
cert_folder=None,
32-
api_key_file=None,
33-
api_key_value=None,
34-
):
27+
client_type: Union[Literal["hopsworks"], Literal["external"]],
28+
host: Optional[str] = None,
29+
port: Optional[int] = None,
30+
project: Optional[str] = None,
31+
engine: Optional[str] = None,
32+
region_name: Optional[str] = None,
33+
secrets_store=None,
34+
hostname_verification: Optional[bool] = None,
35+
trust_store_path: Optional[str] = None,
36+
cert_folder: Optional[str] = None,
37+
api_key_file: Optional[str] = None,
38+
api_key_value: Optional[str] = None,
39+
) -> None:
3540
global _client
3641
if not _client:
3742
if client_type == "hopsworks":
@@ -41,6 +46,9 @@ def init(
4146
host,
4247
port,
4348
project,
49+
engine,
50+
region_name,
51+
secrets_store,
4452
hostname_verification,
4553
trust_store_path,
4654
cert_folder,
@@ -49,7 +57,7 @@ def init(
4957
)
5058

5159

52-
def get_instance():
60+
def get_instance() -> Union[hopsworks.Client, external.Client]:
5361
global _client
5462
if _client:
5563
return _client

python/hopsworks/client/auth.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,39 @@
1414
# limitations under the License.
1515
#
1616

17+
from __future__ import annotations
18+
1719
import requests
1820

1921

2022
class BearerAuth(requests.auth.AuthBase):
2123
"""Class to encapsulate a Bearer token."""
2224

23-
def __init__(self, token):
25+
def __init__(self, token: str) -> requests.Request:
2426
self._token = token
2527

26-
def __call__(self, r):
28+
def __call__(self, r: requests.Request) -> requests.Request:
2729
r.headers["Authorization"] = "Bearer " + self._token.strip()
2830
return r
2931

3032

3133
class ApiKeyAuth(requests.auth.AuthBase):
3234
"""Class to encapsulate an API key."""
3335

34-
def __init__(self, token):
36+
def __init__(self, token: str) -> None:
3537
self._token = token
3638

37-
def __call__(self, r):
39+
def __call__(self, r: requests.Request) -> requests.Request:
3840
r.headers["Authorization"] = "ApiKey " + self._token.strip()
3941
return r
42+
43+
44+
class OnlineStoreKeyAuth(requests.auth.AuthBase):
45+
"""Class to encapsulate an API key."""
46+
47+
def __init__(self, token: str) -> None:
48+
self._token = token.strip()
49+
50+
def __call__(self, r: requests.Request) -> requests.Request:
51+
r.headers["X-API-KEY"] = self._token
52+
return r

python/hopsworks/client/base.py

+120-12
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
# limitations under the License.
1515
#
1616

17+
from __future__ import annotations
18+
19+
import base64
1720
import os
18-
from abc import ABC, abstractmethod
21+
import textwrap
22+
import time
23+
from pathlib import Path
1924

2025
import furl
2126
import requests
@@ -24,21 +29,26 @@
2429
from hopsworks.decorators import connected
2530

2631

32+
try:
33+
import jks
34+
except ImportError:
35+
pass
36+
37+
2738
urllib3.disable_warnings(urllib3.exceptions.SecurityWarning)
2839
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
2940

3041

31-
class Client(ABC):
42+
class Client:
3243
TOKEN_FILE = "token.jwt"
44+
TOKEN_EXPIRED_RETRY_INTERVAL = 0.6
45+
TOKEN_EXPIRED_MAX_RETRIES = 10
46+
3347
APIKEY_FILE = "api.key"
3448
REST_ENDPOINT = "REST_ENDPOINT"
49+
DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV = "DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV"
3550
HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"
3651

37-
@abstractmethod
38-
def __init__(self):
39-
"""To be implemented by clients."""
40-
pass
41-
4252
def _get_verify(self, verify, trust_store_path):
4353
"""Get verification method for sending HTTP requests to Hopsworks.
4454
@@ -163,11 +173,9 @@ def _send_request(
163173

164174
if response.status_code == 401 and self.REST_ENDPOINT in os.environ:
165175
# refresh token and retry request - only on hopsworks
166-
self._auth = auth.BearerAuth(self._read_jwt())
167-
# Update request with the new token
168-
request.auth = self._auth
169-
prepped = self._session.prepare_request(request)
170-
response = self._session.send(prepped, verify=self._verify, stream=stream)
176+
response = self._retry_token_expired(
177+
request, stream, self.TOKEN_EXPIRED_RETRY_INTERVAL, 1
178+
)
171179

172180
if response.status_code // 100 != 2:
173181
raise exceptions.RestAPIError(url, response)
@@ -180,6 +188,106 @@ def _send_request(
180188
return None
181189
return response.json()
182190

191+
def _retry_token_expired(self, request, stream, wait, retries):
192+
"""Refresh the JWT token and retry the request. Only on Hopsworks.
193+
As the token might take a while to get refreshed. Keep trying
194+
"""
195+
# Sleep the waited time before re-issuing the request
196+
time.sleep(wait)
197+
198+
self._auth = auth.BearerAuth(self._read_jwt())
199+
# Update request with the new token
200+
request.auth = self._auth
201+
prepped = self._session.prepare_request(request)
202+
response = self._session.send(prepped, verify=self._verify, stream=stream)
203+
204+
if response.status_code == 401 and retries < self.TOKEN_EXPIRED_MAX_RETRIES:
205+
# Try again.
206+
return self._retry_token_expired(request, stream, wait * 2, retries + 1)
207+
else:
208+
# If the number of retries have expired, the _send_request method
209+
# will throw an exception to the user as part of the status_code validation.
210+
return response
211+
183212
def _close(self):
184213
"""Closes a client. Can be implemented for clean up purposes, not mandatory."""
185214
self._connected = False
215+
216+
def _write_pem(
217+
self, keystore_path, keystore_pw, truststore_path, truststore_pw, prefix
218+
):
219+
ks = jks.KeyStore.load(Path(keystore_path), keystore_pw, try_decrypt_keys=True)
220+
ts = jks.KeyStore.load(
221+
Path(truststore_path), truststore_pw, try_decrypt_keys=True
222+
)
223+
224+
ca_chain_path = os.path.join("/tmp", f"{prefix}_ca_chain.pem")
225+
self._write_ca_chain(ks, ts, ca_chain_path)
226+
227+
client_cert_path = os.path.join("/tmp", f"{prefix}_client_cert.pem")
228+
self._write_client_cert(ks, client_cert_path)
229+
230+
client_key_path = os.path.join("/tmp", f"{prefix}_client_key.pem")
231+
self._write_client_key(ks, client_key_path)
232+
233+
return ca_chain_path, client_cert_path, client_key_path
234+
235+
def _write_ca_chain(self, ks, ts, ca_chain_path):
236+
"""
237+
Converts JKS keystore and truststore file into ca chain PEM to be compatible with Python libraries
238+
"""
239+
ca_chain = ""
240+
for store in [ks, ts]:
241+
for _, c in store.certs.items():
242+
ca_chain = ca_chain + self._bytes_to_pem_str(c.cert, "CERTIFICATE")
243+
244+
with Path(ca_chain_path).open("w") as f:
245+
f.write(ca_chain)
246+
247+
def _write_client_cert(self, ks, client_cert_path):
248+
"""
249+
Converts JKS keystore file into client cert PEM to be compatible with Python libraries
250+
"""
251+
client_cert = ""
252+
for _, pk in ks.private_keys.items():
253+
for c in pk.cert_chain:
254+
client_cert = client_cert + self._bytes_to_pem_str(c[1], "CERTIFICATE")
255+
256+
with Path(client_cert_path).open("w") as f:
257+
f.write(client_cert)
258+
259+
def _write_client_key(self, ks, client_key_path):
260+
"""
261+
Converts JKS keystore file into client key PEM to be compatible with Python libraries
262+
"""
263+
client_key = ""
264+
for _, pk in ks.private_keys.items():
265+
client_key = client_key + self._bytes_to_pem_str(
266+
pk.pkey_pkcs8, "PRIVATE KEY"
267+
)
268+
269+
with Path(client_key_path).open("w") as f:
270+
f.write(client_key)
271+
272+
def _bytes_to_pem_str(self, der_bytes, pem_type):
273+
"""
274+
Utility function for creating PEM files
275+
276+
Args:
277+
der_bytes: DER encoded bytes
278+
pem_type: type of PEM, e.g Certificate, Private key, or RSA private key
279+
280+
Returns:
281+
PEM String for a DER-encoded certificate or private key
282+
"""
283+
pem_str = ""
284+
pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n"
285+
pem_str = (
286+
pem_str
287+
+ "\r\n".join(
288+
textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64)
289+
)
290+
+ "\n"
291+
)
292+
pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n"
293+
return pem_str

python/hopsworks/client/exceptions.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,29 @@
1414
# limitations under the License.
1515
#
1616

17+
from __future__ import annotations
18+
19+
from enum import Enum
20+
from typing import Any, Union
21+
22+
import requests
23+
1724

1825
class RestAPIError(Exception):
1926
"""REST Exception encapsulating the response object and url."""
2027

21-
def __init__(self, url, response):
28+
class FeatureStoreErrorCode(int, Enum):
29+
FEATURE_GROUP_COMMIT_NOT_FOUND = 270227
30+
STATISTICS_NOT_FOUND = 270228
31+
32+
def __eq__(self, other: Union[int, Any]) -> bool:
33+
if isinstance(other, int):
34+
return self.value == other
35+
if isinstance(other, self.__class__):
36+
return self is other
37+
return False
38+
39+
def __init__(self, url: str, response: requests.Response) -> None:
2240
try:
2341
error_object = response.json()
2442
except Exception:
@@ -77,8 +95,47 @@ class JobExecutionException(Exception):
7795
"""Generic job executions exception"""
7896

7997

98+
class FeatureStoreException(Exception):
99+
"""Generic feature store exception"""
100+
101+
80102
class ExternalClientError(TypeError):
81103
"""Raised when external client cannot be initialized due to missing arguments."""
82104

83-
def __init__(self, message):
105+
def __init__(self, missing_argument: str) -> None:
106+
message = (
107+
"{0} cannot be of type NoneType, {0} is a non-optional "
108+
"argument to connect to hopsworks from an external environment."
109+
).format(missing_argument)
110+
super().__init__(message)
111+
112+
113+
class VectorDatabaseException(Exception):
114+
# reason
115+
REQUESTED_K_TOO_LARGE = "REQUESTED_K_TOO_LARGE"
116+
REQUESTED_NUM_RESULT_TOO_LARGE = "REQUESTED_NUM_RESULT_TOO_LARGE"
117+
OTHERS = "OTHERS"
118+
119+
# info
120+
REQUESTED_K_TOO_LARGE_INFO_K = "k"
121+
REQUESTED_NUM_RESULT_TOO_LARGE_INFO_N = "n"
122+
123+
def __init__(self, reason: str, message: str, info: str) -> None:
124+
super().__init__(message)
125+
self._info = info
126+
self._reason = reason
127+
128+
@property
129+
def reason(self) -> str:
130+
return self._reason
131+
132+
@property
133+
def info(self) -> str:
134+
return self._info
135+
136+
137+
class DataValidationException(FeatureStoreException):
138+
"""Raised when data validation fails only when using "STRICT" validation ingestion policy."""
139+
140+
def __init__(self, message: str) -> None:
84141
super().__init__(message)

0 commit comments

Comments
 (0)