diff --git a/osaka/__init__.py b/osaka/__init__.py index 47429b5..e377743 100644 --- a/osaka/__init__.py +++ b/osaka/__init__.py @@ -3,6 +3,6 @@ from __future__ import division from __future__ import absolute_import -__version__ = "1.1.0" +__version__ = "1.2.0" __url__ = "https://github.com/hysds/osaka" __description__ = "Osaka (Object Store Abstraction K Arcitecture)" diff --git a/osaka/storage/s3.py b/osaka/storage/s3.py index 23ddb88..a59d876 100644 --- a/osaka/storage/s3.py +++ b/osaka/storage/s3.py @@ -5,6 +5,8 @@ from builtins import open from builtins import int from builtins import str + +import requests.exceptions from future import standard_library standard_library.install_aliases() @@ -28,24 +30,33 @@ # S3 region info S3_REGION_INFO = None +# entry that holds the default region +DEFAULT_REGION = None + # regexes NOT_FOUND_RE = re.compile(r"Not Found") +# Endpoint used to get the region if an AWS config file isn't found +AWS_ID_ENDPOINT = "http://169.254.169.254/latest/dynamic/instance-identity/document" + def get_region_info(): """ Return region info dict. """ global S3_REGION_INFO - if S3_REGION_INFO is None: + global DEFAULT_REGION + if S3_REGION_INFO is None and DEFAULT_REGION is None: S3_REGION_INFO = {} + DEFAULT_REGION = "" s = botocore.session.get_session() + DEFAULT_REGION = s.get_config_variable('region') for part in s.get_available_partitions(): for region in s.get_available_regions("s3", part): s3 = s.create_client("s3", region) ep = urllib.parse.urlparse(s3.meta.endpoint_url).netloc S3_REGION_INFO[region] = ep - return S3_REGION_INFO + return S3_REGION_INFO, DEFAULT_REGION class S3(osaka.base.StorageBase): @@ -58,6 +69,7 @@ def __init__(self): Constructor """ self.tmpfiles = [] + self.is_nominal_style = False def connect(self, uri, params={}): """ @@ -72,15 +84,42 @@ def connect(self, uri, params={}): session_kwargs = {} kwargs = {} check_host = parsed.hostname if "location" not in params else params["location"] - for region, ep in get_region_info().items(): + region_info, default_region = get_region_info() + found_ep_and_region = False + for region, ep in region_info.items(): if re.search(ep, check_host): kwargs["endpoint_url"] = ep session_kwargs["region_name"] = region + found_ep_and_region = True break - if parsed.hostname is not None: - kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, parsed.hostname) + # Use the default region obtained from the sessions object when the + # region info was being gathered. This check is used to support the cases + # when osaka receives an S3 url in the nominal pathing style. + if not found_ep_and_region: + if not default_region: + # If the default region was not found from the initial boto session, resort + # to trying to get that info from the API + try: + response = requests.get(AWS_ID_ENDPOINT) + data = response.json() + default_region = data.get("region", None) + if not default_region: + osaka.utils.LOGGER.error(f"No 'region' found in json response:\n{data}") + raise Exception + except Exception as e: + raise osaka.utils.OsakaException( + f"Cannot determine region. Verify that an AWS config file exists and that a region " + f"is set within that file.") from e + + ep = region_info.get(default_region) + kwargs["endpoint_url"] = f"{parsed.scheme}://{ep}" + session_kwargs["region_name"] = default_region + self.is_nominal_style = True else: - kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, kwargs["endpoint_url"]) + if parsed.hostname is not None: + kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, parsed.hostname) + else: + kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, kwargs["endpoint_url"]) if parsed.port is not None and parsed.port != 80 and parsed.port != 443: kwargs["endpoint_url"] = "%s:%s" % (kwargs["endpoint_url"], parsed.port) if parsed.username is not None: @@ -138,9 +177,7 @@ def get(self, uri): @param uri: uri to get """ osaka.utils.LOGGER.debug("Getting stream from URI: {0}".format(uri)) - container, key = osaka.utils.get_container_and_path( - urllib.parse.urlparse(uri).path - ) + container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style) bucket = self.bucket(container, create=False) obj = bucket.Object(key) fname = "/tmp/osaka-s3-" + str(datetime.datetime.now()) @@ -167,9 +204,7 @@ def put(self, stream, uri): @param uri: uri to put """ osaka.utils.LOGGER.debug("Putting stream to URI: {0}".format(uri)) - container, key = osaka.utils.get_container_and_path( - urllib.parse.urlparse(uri).path - ) + container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style) bucket = self.bucket(container) obj = bucket.Object(key) extra = {} @@ -200,15 +235,20 @@ def listAllChildren(self, uri): if "__top__" in self.cache and uri == self.cache["__top__"]: return [k for k in list(self.cache.keys()) if k != "__top__"] parsed = urllib.parse.urlparse(uri) - container, key = osaka.utils.get_container_and_path(parsed.path) + container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style) bucket = self.bucket(container, create=False) collection = bucket.objects.filter(Prefix=key) - uriBase = ( - parsed.scheme - + "://" - + parsed.hostname - + (":" + str(parsed.port) if parsed.port is not None else "") - ) + if self.is_nominal_style: + # Needs only 1 slash because the 2nd one gets added when the "full" + # value gets put together + uriBase = f"{parsed.scheme}:/" + else: + uriBase = ( + parsed.scheme + + "://" + + parsed.hostname + + (":" + str(parsed.port) if parsed.port is not None else "") + ) # Setup cache, and fill it with listings self.cache["__top__"] = uri for item in collection: @@ -236,7 +276,7 @@ def list(self, uri): depth = len(uri.rstrip("/").split("/")) return [ item - for item in self.listAllChildren() + for item in self.listAllChildren(uri) if len(item.rstrip("/").split("/")) == (depth + 1) ] @@ -297,9 +337,7 @@ def rm(self, uri): Remove this uri from backend @param uri: uri to remove """ - container, key = osaka.utils.get_container_and_path( - urllib.parse.urlparse(uri).path - ) + container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style) bucket = self.bucket(container, create=False) obj = bucket.Object(key) obj.delete() diff --git a/osaka/tests/test_utils.py b/osaka/tests/test_utils.py new file mode 100644 index 0000000..25e8eb7 --- /dev/null +++ b/osaka/tests/test_utils.py @@ -0,0 +1,32 @@ +import unittest + +from osaka import utils + + +class UtilsTest(unittest.TestCase): + def setUp(self): + self.legacy_s3_url = "s3://s3-us-west-2.amazonaws.com:80/my_bucket/foo/bar/key" + self.nominal_s3_url = "s3://my_bucket/foo/bar/key" + self.expected_container = "my_bucket" + self.expected_key = "foo/bar/key" + + def test_get_s3_container_and_path_legacy_style(self): + container, key = utils.get_s3_container_and_path(self.legacy_s3_url) + self.assertEquals(container, + self.expected_container, + f"Did not get expected container value: " + f"container={container}, expected={self.expected_container}") + self.assertEquals(key, self.expected_key, f"Did not get expected key value: key={key}, " + f"expected={self.expected_key}") + + def test_get_s3_container_and_path_nominal_style(self): + container, key = utils.get_s3_container_and_path(self.nominal_s3_url, is_nominal_style=True) + self.assertEquals(container, self.expected_container, + f"Did not get expected container value: " + f"container={container}, expected={self.expected_container}") + self.assertEquals(key, self.expected_key, f"Did not get expected key value: key={key}, " + f"expected={self.expected_key}") + + +if __name__ == "__main__": + unittest.main() diff --git a/osaka/utils.py b/osaka/utils.py index 10dd824..ade5d21 100644 --- a/osaka/utils.py +++ b/osaka/utils.py @@ -59,6 +59,24 @@ def get_container_and_path(urlpath): return (split[0], "" if not len(split) > 1 else split[1]) +def get_s3_container_and_path(s3_url, is_nominal_style=False): + """ + Gets the s3 container (a.k.a. bucket) and path (a.k.a key) from the + given url. + + :param s3_url: The S3 url to parse. + :param is_nominal_style: Set to true if the given S3 url is in a + nominal style format (i.e. s3://bucket/key). Otherwise, set to false. + :return: + """ + parsed = urllib.parse.urlparse(s3_url) + if is_nominal_style is True: + # If we are dealing with the nominal S3 style path + return parsed.hostname, parsed.path.split("/", 1)[1] + else: + return get_container_and_path(parsed.path) + + # def walk(func, directory,destdir, *params): # ''' # Walk the directory and call the function for each file