17
17
from __future__ import annotations
18
18
19
19
import base64
20
- import json
21
20
import logging
22
21
import os
23
22
24
- import boto3
25
23
import requests
26
24
from hopsworks_common .client import auth , base , exceptions
27
25
from hopsworks_common .client .exceptions import FeatureStoreException
37
35
38
36
39
37
class Client (base .Client ):
40
- DEFAULT_REGION = "default"
41
38
SECRETS_MANAGER = "secretsmanager"
42
39
PARAMETER_STORE = "parameterstore"
43
40
LOCAL_STORE = "local"
@@ -48,8 +45,6 @@ def __init__(
48
45
port ,
49
46
project ,
50
47
engine ,
51
- region_name ,
52
- secrets_store ,
53
48
hostname_verification ,
54
49
trust_store_path ,
55
50
cert_folder ,
@@ -65,17 +60,14 @@ def __init__(
65
60
self ._port = port
66
61
self ._base_url = "https://" + self ._host + ":" + str (self ._port )
67
62
_logger .info ("Base URL: %s" , self ._base_url )
68
- self ._region_name = region_name or self .DEFAULT_REGION
69
- _logger .debug ("Region name: %s" , self ._region_name )
70
63
71
64
if api_key_value is not None :
72
65
_logger .debug ("Using provided API key value" )
73
66
api_key = api_key_value
74
67
else :
75
- _logger .debug ("Querying secrets store for API key" )
76
- if secrets_store is None :
77
- secrets_store = self .LOCAL_STORE
78
- api_key = self ._get_secret (secrets_store , "api-key" , api_key_file )
68
+ _logger .debug (f"Reading api key from { api_key_file } " )
69
+ with open (api_key_file ) as f :
70
+ api_key = f .readline ().strip ()
79
71
80
72
_logger .debug ("Using api key to setup header authentification" )
81
73
self ._auth = auth .ApiKeyAuth (api_key )
@@ -84,17 +76,16 @@ def __init__(
84
76
self ._session = requests .session ()
85
77
self ._connected = True
86
78
87
- self ._verify = self ._get_verify (self . _host , trust_store_path )
79
+ self ._verify = self ._get_verify (hostname_verification , trust_store_path )
88
80
_logger .debug ("Verify: %s" , self ._verify )
89
81
90
82
self ._cert_key = None
91
83
self ._cert_folder_base = cert_folder
92
84
self ._cert_folder = None
93
85
94
- self ._hsfs_post_init (project , engine , region_name )
86
+ self ._hsfs_post_init (project , engine )
95
87
96
- def _hsfs_post_init (self , project , engine , region_name ):
97
- self ._region_name = region_name or self ._region_name or self .DEFAULT_REGION
88
+ def _hsfs_post_init (self , project , engine ):
98
89
self ._project_name = project
99
90
if project is not None :
100
91
project_info = self ._get_project_info (project )
@@ -295,7 +286,7 @@ def _get_client_key_path(self, project_name=None) -> str:
295
286
_logger .debug (f"Getting client key path { path } " )
296
287
return path
297
288
298
- def _get_secret (self , secrets_store , secret_key = None , api_key_file = None ):
289
+ def _get_secret (self , secret_key = None , api_key_file = None ):
299
290
"""Returns secret value from the AWS Secrets Manager or Parameter Store.
300
291
301
292
:param secrets_store: the underlying secrets storage to be used, e.g. `secretsmanager` or `parameterstore`
@@ -309,67 +300,9 @@ def _get_secret(self, secrets_store, secret_key=None, api_key_file=None):
309
300
:return: secret
310
301
:rtype: str
311
302
"""
312
- _logger .debug (f"Querying secrets store { secrets_store } for secret { secret_key } " )
313
- if secrets_store == self .SECRETS_MANAGER :
314
- return self ._query_secrets_manager (secret_key )
315
- elif secrets_store == self .PARAMETER_STORE :
316
- return self ._query_parameter_store (secret_key )
317
- elif secrets_store == self .LOCAL_STORE :
318
- if not api_key_file :
319
- raise exceptions .ExternalClientError (
320
- "api_key_file needs to be set for local mode"
321
- )
322
- _logger .debug (f"Reading api key from { api_key_file } " )
323
- with open (api_key_file ) as f :
324
- return f .readline ().strip ()
325
- else :
326
- raise exceptions .UnknownSecretStorageError (
327
- "Secrets storage " + secrets_store + " is not supported."
328
- )
329
-
330
- def _query_secrets_manager (self , secret_key ):
331
- _logger .debug ("Querying secrets manager for secret key: %s" , secret_key )
332
- secret_name = "hopsworks/role/" + self ._assumed_role ()
333
- args = {"service_name" : "secretsmanager" }
334
- region_name = self ._get_region ()
335
- if region_name :
336
- args ["region_name" ] = region_name
337
- client = boto3 .client (** args )
338
- get_secret_value_response = client .get_secret_value (SecretId = secret_name )
339
- return json .loads (get_secret_value_response ["SecretString" ])[secret_key ]
340
-
341
- def _assumed_role (self ):
342
- _logger .debug ("Getting assumed role" )
343
- client = boto3 .client ("sts" )
344
- response = client .get_caller_identity ()
345
- # arns for assumed roles in SageMaker follow the following schema
346
- # arn:aws:sts::123456789012:assumed-role/my-role-name/my-role-session-name
347
- local_identifier = response ["Arn" ].split (":" )[- 1 ].split ("/" )
348
- if len (local_identifier ) != 3 or local_identifier [0 ] != "assumed-role" :
349
- raise Exception (
350
- "Failed to extract assumed role from arn: " + response ["Arn" ]
351
- )
352
- return local_identifier [1 ]
353
-
354
- def _get_region (self ):
355
- if self ._region_name != self .DEFAULT_REGION :
356
- _logger .debug (f"Region name is not default, returning { self ._region_name } " )
357
- return self ._region_name
358
- else :
359
- _logger .debug ("Region name is default, returning None" )
360
- return None
361
-
362
- def _query_parameter_store (self , secret_key ):
363
- _logger .debug ("Querying parameter store for secret key: %s" , secret_key )
364
- args = {"service_name" : "ssm" }
365
- region_name = self ._get_region ()
366
- if region_name :
367
- args ["region_name" ] = region_name
368
- client = boto3 .client (** args )
369
- name = "/hopsworks/role/" + self ._assumed_role () + "/type/" + secret_key
370
- return client .get_parameter (Name = name , WithDecryption = True )["Parameter" ][
371
- "Value"
372
- ]
303
+ _logger .debug (f"Reading api key from { api_key_file } " )
304
+ with open (api_key_file ) as f :
305
+ return f .readline ().strip ()
373
306
374
307
def _get_project_info (self , project_name ):
375
308
"""Makes a REST call to hopsworks to get all metadata of a project for the provided project.
0 commit comments