|
1 | 1 | #
|
2 |
| -# Copyright 2022 Logical Clocks AB |
| 2 | +# Copyright 2024 Hopsworks AB |
3 | 3 | #
|
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | # you may not use this file except in compliance with the License.
|
|
14 | 14 | # limitations under the License.
|
15 | 15 | #
|
16 | 16 |
|
17 |
| -import os |
18 |
| -from abc import ABC, abstractmethod |
| 17 | +from hopsworks_common.client.base import ( |
| 18 | + Client, |
| 19 | +) |
19 | 20 |
|
20 |
| -import furl |
21 |
| -import requests |
22 |
| -import urllib3 |
23 |
| -from hopsworks.client import auth, exceptions |
24 |
| -from hopsworks.decorators import connected |
25 | 21 |
|
26 |
| - |
27 |
| -urllib3.disable_warnings(urllib3.exceptions.SecurityWarning) |
28 |
| -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
29 |
| - |
30 |
| - |
31 |
| -class Client(ABC): |
32 |
| - TOKEN_FILE = "token.jwt" |
33 |
| - APIKEY_FILE = "api.key" |
34 |
| - REST_ENDPOINT = "REST_ENDPOINT" |
35 |
| - HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST" |
36 |
| - |
37 |
| - @abstractmethod |
38 |
| - def __init__(self): |
39 |
| - """To be implemented by clients.""" |
40 |
| - pass |
41 |
| - |
42 |
| - def _get_verify(self, verify, trust_store_path): |
43 |
| - """Get verification method for sending HTTP requests to Hopsworks. |
44 |
| -
|
45 |
| - Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c |
46 |
| -
|
47 |
| - :param verify: perform hostname verification, 'true' or 'false' |
48 |
| - :type verify: str |
49 |
| - :param trust_store_path: path of the truststore locally if it was uploaded manually to |
50 |
| - the external environment such as AWS Sagemaker |
51 |
| - :type trust_store_path: str |
52 |
| - :return: if verify is true and the truststore is provided, then return the trust store location |
53 |
| - if verify is true but the truststore wasn't provided, then return true |
54 |
| - if verify is false, then return false |
55 |
| - :rtype: str or boolean |
56 |
| - """ |
57 |
| - if verify == "true": |
58 |
| - if trust_store_path is not None: |
59 |
| - return trust_store_path |
60 |
| - else: |
61 |
| - return True |
62 |
| - |
63 |
| - return False |
64 |
| - |
65 |
| - def _get_host_port_pair(self): |
66 |
| - """ |
67 |
| - Removes "http or https" from the rest endpoint and returns a list |
68 |
| - [endpoint, port], where endpoint is on the format /path.. without http:// |
69 |
| -
|
70 |
| - :return: a list [endpoint, port] |
71 |
| - :rtype: list |
72 |
| - """ |
73 |
| - endpoint = self._base_url |
74 |
| - if "http" in endpoint: |
75 |
| - last_index = endpoint.rfind("/") |
76 |
| - endpoint = endpoint[last_index + 1 :] |
77 |
| - host, port = endpoint.split(":") |
78 |
| - return host, port |
79 |
| - |
80 |
| - def _read_jwt(self): |
81 |
| - """Retrieve jwt from local container.""" |
82 |
| - return self._read_file(self.TOKEN_FILE) |
83 |
| - |
84 |
| - def _read_apikey(self): |
85 |
| - """Retrieve apikey from local container.""" |
86 |
| - return self._read_file(self.APIKEY_FILE) |
87 |
| - |
88 |
| - def _read_file(self, secret_file): |
89 |
| - """Retrieve secret from local container.""" |
90 |
| - with open(os.path.join(self._secrets_dir, secret_file), "r") as secret: |
91 |
| - return secret.read() |
92 |
| - |
93 |
| - def _get_credentials(self, project_id): |
94 |
| - """Makes a REST call to hopsworks for getting the project user certificates needed to connect to services such as Hive |
95 |
| -
|
96 |
| - :param project_id: id of the project |
97 |
| - :type project_id: int |
98 |
| - :return: JSON response with credentials |
99 |
| - :rtype: dict |
100 |
| - """ |
101 |
| - return self._send_request("GET", ["project", project_id, "credentials"]) |
102 |
| - |
103 |
| - def _write_pem_file(self, content: str, path: str) -> None: |
104 |
| - with open(path, "w") as f: |
105 |
| - f.write(content) |
106 |
| - |
107 |
| - @connected |
108 |
| - def _send_request( |
109 |
| - self, |
110 |
| - method, |
111 |
| - path_params, |
112 |
| - query_params=None, |
113 |
| - headers=None, |
114 |
| - data=None, |
115 |
| - stream=False, |
116 |
| - files=None, |
117 |
| - with_base_path_params=True, |
118 |
| - ): |
119 |
| - """Send REST request to Hopsworks. |
120 |
| -
|
121 |
| - Uses the client it is executed from. Path parameters are url encoded automatically. |
122 |
| -
|
123 |
| - :param method: 'GET', 'PUT' or 'POST' |
124 |
| - :type method: str |
125 |
| - :param path_params: a list of path params to build the query url from starting after |
126 |
| - the api resource, for example `["project", 119, "featurestores", 67]`. |
127 |
| - :type path_params: list |
128 |
| - :param query_params: A dictionary of key/value pairs to be added as query parameters, |
129 |
| - defaults to None |
130 |
| - :type query_params: dict, optional |
131 |
| - :param headers: Additional header information, defaults to None |
132 |
| - :type headers: dict, optional |
133 |
| - :param data: The payload as a python dictionary to be sent as json, defaults to None |
134 |
| - :type data: dict, optional |
135 |
| - :param stream: Set if response should be a stream, defaults to False |
136 |
| - :type stream: boolean, optional |
137 |
| - :param files: dictionary for multipart encoding upload |
138 |
| - :type files: dict, optional |
139 |
| - :raises RestAPIError: Raised when request wasn't correctly received, understood or accepted |
140 |
| - :return: Response json |
141 |
| - :rtype: dict |
142 |
| - """ |
143 |
| - f_url = furl.furl(self._base_url) |
144 |
| - if with_base_path_params: |
145 |
| - base_path_params = ["hopsworks-api", "api"] |
146 |
| - f_url.path.segments = base_path_params + path_params |
147 |
| - else: |
148 |
| - f_url.path.segments = path_params |
149 |
| - url = str(f_url) |
150 |
| - |
151 |
| - request = requests.Request( |
152 |
| - method, |
153 |
| - url=url, |
154 |
| - headers=headers, |
155 |
| - data=data, |
156 |
| - params=query_params, |
157 |
| - auth=self._auth, |
158 |
| - files=files, |
159 |
| - ) |
160 |
| - |
161 |
| - prepped = self._session.prepare_request(request) |
162 |
| - response = self._session.send(prepped, verify=self._verify, stream=stream) |
163 |
| - |
164 |
| - if response.status_code == 401 and self.REST_ENDPOINT in os.environ: |
165 |
| - # refresh token and retry request - only on hopsworks |
166 |
| - self._auth = auth.BearerAuth(self._read_jwt()) |
167 |
| - # Update request with the new token |
168 |
| - request.auth = self._auth |
169 |
| - prepped = self._session.prepare_request(request) |
170 |
| - response = self._session.send(prepped, verify=self._verify, stream=stream) |
171 |
| - |
172 |
| - if response.status_code // 100 != 2: |
173 |
| - raise exceptions.RestAPIError(url, response) |
174 |
| - |
175 |
| - if stream: |
176 |
| - return response |
177 |
| - else: |
178 |
| - # handle different success response codes |
179 |
| - if len(response.content) == 0: |
180 |
| - return None |
181 |
| - return response.json() |
182 |
| - |
183 |
| - def _close(self): |
184 |
| - """Closes a client. Can be implemented for clean up purposes, not mandatory.""" |
185 |
| - self._connected = False |
| 22 | +__all__ = [ |
| 23 | + Client, |
| 24 | +] |
0 commit comments