14
14
# limitations under the License.
15
15
#
16
16
17
+ from __future__ import annotations
18
+
19
+ import base64
17
20
import os
18
- from abc import ABC , abstractmethod
21
+ import textwrap
22
+ import time
23
+ from pathlib import Path
19
24
20
25
import furl
21
26
import requests
24
29
from hopsworks .decorators import connected
25
30
26
31
32
+ try :
33
+ import jks
34
+ except ImportError :
35
+ pass
36
+
37
+
27
38
urllib3 .disable_warnings (urllib3 .exceptions .SecurityWarning )
28
39
urllib3 .disable_warnings (urllib3 .exceptions .InsecureRequestWarning )
29
40
30
41
31
- class Client ( ABC ) :
42
+ class Client :
32
43
TOKEN_FILE = "token.jwt"
44
+ TOKEN_EXPIRED_RETRY_INTERVAL = 0.6
45
+ TOKEN_EXPIRED_MAX_RETRIES = 10
46
+
33
47
APIKEY_FILE = "api.key"
34
48
REST_ENDPOINT = "REST_ENDPOINT"
49
+ DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV = "DEFAULT_DATABRICKS_ROOT_VIRTUALENV_ENV"
35
50
HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"
36
51
37
- @abstractmethod
38
- def __init__ (self ):
39
- """To be implemented by clients."""
40
- pass
41
-
42
52
def _get_verify (self , verify , trust_store_path ):
43
53
"""Get verification method for sending HTTP requests to Hopsworks.
44
54
@@ -163,11 +173,9 @@ def _send_request(
163
173
164
174
if response .status_code == 401 and self .REST_ENDPOINT in os .environ :
165
175
# refresh token and retry request - only on hopsworks
166
- self ._auth = auth .BearerAuth (self ._read_jwt ())
167
- # Update request with the new token
168
- request .auth = self ._auth
169
- prepped = self ._session .prepare_request (request )
170
- response = self ._session .send (prepped , verify = self ._verify , stream = stream )
176
+ response = self ._retry_token_expired (
177
+ request , stream , self .TOKEN_EXPIRED_RETRY_INTERVAL , 1
178
+ )
171
179
172
180
if response .status_code // 100 != 2 :
173
181
raise exceptions .RestAPIError (url , response )
@@ -180,6 +188,106 @@ def _send_request(
180
188
return None
181
189
return response .json ()
182
190
191
+ def _retry_token_expired (self , request , stream , wait , retries ):
192
+ """Refresh the JWT token and retry the request. Only on Hopsworks.
193
+ As the token might take a while to get refreshed. Keep trying
194
+ """
195
+ # Sleep the waited time before re-issuing the request
196
+ time .sleep (wait )
197
+
198
+ self ._auth = auth .BearerAuth (self ._read_jwt ())
199
+ # Update request with the new token
200
+ request .auth = self ._auth
201
+ prepped = self ._session .prepare_request (request )
202
+ response = self ._session .send (prepped , verify = self ._verify , stream = stream )
203
+
204
+ if response .status_code == 401 and retries < self .TOKEN_EXPIRED_MAX_RETRIES :
205
+ # Try again.
206
+ return self ._retry_token_expired (request , stream , wait * 2 , retries + 1 )
207
+ else :
208
+ # If the number of retries have expired, the _send_request method
209
+ # will throw an exception to the user as part of the status_code validation.
210
+ return response
211
+
183
212
def _close (self ):
184
213
"""Closes a client. Can be implemented for clean up purposes, not mandatory."""
185
214
self ._connected = False
215
+
216
+ def _write_pem (
217
+ self , keystore_path , keystore_pw , truststore_path , truststore_pw , prefix
218
+ ):
219
+ ks = jks .KeyStore .load (Path (keystore_path ), keystore_pw , try_decrypt_keys = True )
220
+ ts = jks .KeyStore .load (
221
+ Path (truststore_path ), truststore_pw , try_decrypt_keys = True
222
+ )
223
+
224
+ ca_chain_path = os .path .join ("/tmp" , f"{ prefix } _ca_chain.pem" )
225
+ self ._write_ca_chain (ks , ts , ca_chain_path )
226
+
227
+ client_cert_path = os .path .join ("/tmp" , f"{ prefix } _client_cert.pem" )
228
+ self ._write_client_cert (ks , client_cert_path )
229
+
230
+ client_key_path = os .path .join ("/tmp" , f"{ prefix } _client_key.pem" )
231
+ self ._write_client_key (ks , client_key_path )
232
+
233
+ return ca_chain_path , client_cert_path , client_key_path
234
+
235
+ def _write_ca_chain (self , ks , ts , ca_chain_path ):
236
+ """
237
+ Converts JKS keystore and truststore file into ca chain PEM to be compatible with Python libraries
238
+ """
239
+ ca_chain = ""
240
+ for store in [ks , ts ]:
241
+ for _ , c in store .certs .items ():
242
+ ca_chain = ca_chain + self ._bytes_to_pem_str (c .cert , "CERTIFICATE" )
243
+
244
+ with Path (ca_chain_path ).open ("w" ) as f :
245
+ f .write (ca_chain )
246
+
247
+ def _write_client_cert (self , ks , client_cert_path ):
248
+ """
249
+ Converts JKS keystore file into client cert PEM to be compatible with Python libraries
250
+ """
251
+ client_cert = ""
252
+ for _ , pk in ks .private_keys .items ():
253
+ for c in pk .cert_chain :
254
+ client_cert = client_cert + self ._bytes_to_pem_str (c [1 ], "CERTIFICATE" )
255
+
256
+ with Path (client_cert_path ).open ("w" ) as f :
257
+ f .write (client_cert )
258
+
259
+ def _write_client_key (self , ks , client_key_path ):
260
+ """
261
+ Converts JKS keystore file into client key PEM to be compatible with Python libraries
262
+ """
263
+ client_key = ""
264
+ for _ , pk in ks .private_keys .items ():
265
+ client_key = client_key + self ._bytes_to_pem_str (
266
+ pk .pkey_pkcs8 , "PRIVATE KEY"
267
+ )
268
+
269
+ with Path (client_key_path ).open ("w" ) as f :
270
+ f .write (client_key )
271
+
272
+ def _bytes_to_pem_str (self , der_bytes , pem_type ):
273
+ """
274
+ Utility function for creating PEM files
275
+
276
+ Args:
277
+ der_bytes: DER encoded bytes
278
+ pem_type: type of PEM, e.g Certificate, Private key, or RSA private key
279
+
280
+ Returns:
281
+ PEM String for a DER-encoded certificate or private key
282
+ """
283
+ pem_str = ""
284
+ pem_str = pem_str + "-----BEGIN {}-----" .format (pem_type ) + "\n "
285
+ pem_str = (
286
+ pem_str
287
+ + "\r \n " .join (
288
+ textwrap .wrap (base64 .b64encode (der_bytes ).decode ("ascii" ), 64 )
289
+ )
290
+ + "\n "
291
+ )
292
+ pem_str = pem_str + "-----END {}-----" .format (pem_type ) + "\n "
293
+ return pem_str
0 commit comments