Skip to content

Commit

Permalink
HC-418: Support nominal S3 pathing style (#33)
Browse files Browse the repository at this point in the history
* HC-418: Support usage of nominal S3 pathing styles

* Add unit test for utility function that determines the S3 bucket and key

* add newline

* fixed typos, make unit test more generic

* get endpoint from region_info; check for default region

* try getting 169.254.169.254 to get the region

* fix messaging

* Bump version

Co-authored-by: Mike Cayanan <michael.d.cayanan@jpl.nasa.gov>
  • Loading branch information
mcayanan and Mike Cayanan authored May 23, 2022
1 parent fbfb11c commit 981b576
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 24 deletions.
2 changes: 1 addition & 1 deletion osaka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
from __future__ import division
from __future__ import absolute_import

__version__ = "1.1.0"
__version__ = "1.2.0"
__url__ = "https://github.com/hysds/osaka"
__description__ = "Osaka (Object Store Abstraction K Arcitecture)"
84 changes: 61 additions & 23 deletions osaka/storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from builtins import open
from builtins import int
from builtins import str

import requests.exceptions
from future import standard_library

standard_library.install_aliases()
Expand All @@ -28,24 +30,33 @@
# S3 region info
S3_REGION_INFO = None

# entry that holds the default region
DEFAULT_REGION = None

# regexes
NOT_FOUND_RE = re.compile(r"Not Found")

# Endpoint used to get the region if an AWS config file isn't found
AWS_ID_ENDPOINT = "http://169.254.169.254/latest/dynamic/instance-identity/document"


def get_region_info():
"""
Return region info dict.
"""
global S3_REGION_INFO
if S3_REGION_INFO is None:
global DEFAULT_REGION
if S3_REGION_INFO is None and DEFAULT_REGION is None:
S3_REGION_INFO = {}
DEFAULT_REGION = ""
s = botocore.session.get_session()
DEFAULT_REGION = s.get_config_variable('region')
for part in s.get_available_partitions():
for region in s.get_available_regions("s3", part):
s3 = s.create_client("s3", region)
ep = urllib.parse.urlparse(s3.meta.endpoint_url).netloc
S3_REGION_INFO[region] = ep
return S3_REGION_INFO
return S3_REGION_INFO, DEFAULT_REGION


class S3(osaka.base.StorageBase):
Expand All @@ -58,6 +69,7 @@ def __init__(self):
Constructor
"""
self.tmpfiles = []
self.is_nominal_style = False

def connect(self, uri, params={}):
"""
Expand All @@ -72,15 +84,42 @@ def connect(self, uri, params={}):
session_kwargs = {}
kwargs = {}
check_host = parsed.hostname if "location" not in params else params["location"]
for region, ep in get_region_info().items():
region_info, default_region = get_region_info()
found_ep_and_region = False
for region, ep in region_info.items():
if re.search(ep, check_host):
kwargs["endpoint_url"] = ep
session_kwargs["region_name"] = region
found_ep_and_region = True
break
if parsed.hostname is not None:
kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, parsed.hostname)
# Use the default region obtained from the sessions object when the
# region info was being gathered. This check is used to support the cases
# when osaka receives an S3 url in the nominal pathing style.
if not found_ep_and_region:
if not default_region:
# If the default region was not found from the initial boto session, resort
# to trying to get that info from the API
try:
response = requests.get(AWS_ID_ENDPOINT)
data = response.json()
default_region = data.get("region", None)
if not default_region:
osaka.utils.LOGGER.error(f"No 'region' found in json response:\n{data}")
raise Exception
except Exception as e:
raise osaka.utils.OsakaException(
f"Cannot determine region. Verify that an AWS config file exists and that a region "
f"is set within that file.") from e

ep = region_info.get(default_region)
kwargs["endpoint_url"] = f"{parsed.scheme}://{ep}"
session_kwargs["region_name"] = default_region
self.is_nominal_style = True
else:
kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, kwargs["endpoint_url"])
if parsed.hostname is not None:
kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, parsed.hostname)
else:
kwargs["endpoint_url"] = "%s://%s" % (parsed.scheme, kwargs["endpoint_url"])
if parsed.port is not None and parsed.port != 80 and parsed.port != 443:
kwargs["endpoint_url"] = "%s:%s" % (kwargs["endpoint_url"], parsed.port)
if parsed.username is not None:
Expand Down Expand Up @@ -138,9 +177,7 @@ def get(self, uri):
@param uri: uri to get
"""
osaka.utils.LOGGER.debug("Getting stream from URI: {0}".format(uri))
container, key = osaka.utils.get_container_and_path(
urllib.parse.urlparse(uri).path
)
container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style)
bucket = self.bucket(container, create=False)
obj = bucket.Object(key)
fname = "/tmp/osaka-s3-" + str(datetime.datetime.now())
Expand All @@ -167,9 +204,7 @@ def put(self, stream, uri):
@param uri: uri to put
"""
osaka.utils.LOGGER.debug("Putting stream to URI: {0}".format(uri))
container, key = osaka.utils.get_container_and_path(
urllib.parse.urlparse(uri).path
)
container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style)
bucket = self.bucket(container)
obj = bucket.Object(key)
extra = {}
Expand Down Expand Up @@ -200,15 +235,20 @@ def listAllChildren(self, uri):
if "__top__" in self.cache and uri == self.cache["__top__"]:
return [k for k in list(self.cache.keys()) if k != "__top__"]
parsed = urllib.parse.urlparse(uri)
container, key = osaka.utils.get_container_and_path(parsed.path)
container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style)
bucket = self.bucket(container, create=False)
collection = bucket.objects.filter(Prefix=key)
uriBase = (
parsed.scheme
+ "://"
+ parsed.hostname
+ (":" + str(parsed.port) if parsed.port is not None else "")
)
if self.is_nominal_style:
# Needs only 1 slash because the 2nd one gets added when the "full"
# value gets put together
uriBase = f"{parsed.scheme}:/"
else:
uriBase = (
parsed.scheme
+ "://"
+ parsed.hostname
+ (":" + str(parsed.port) if parsed.port is not None else "")
)
# Setup cache, and fill it with listings
self.cache["__top__"] = uri
for item in collection:
Expand Down Expand Up @@ -236,7 +276,7 @@ def list(self, uri):
depth = len(uri.rstrip("/").split("/"))
return [
item
for item in self.listAllChildren()
for item in self.listAllChildren(uri)
if len(item.rstrip("/").split("/")) == (depth + 1)
]

Expand Down Expand Up @@ -297,9 +337,7 @@ def rm(self, uri):
Remove this uri from backend
@param uri: uri to remove
"""
container, key = osaka.utils.get_container_and_path(
urllib.parse.urlparse(uri).path
)
container, key = osaka.utils.get_s3_container_and_path(uri, is_nominal_style=self.is_nominal_style)
bucket = self.bucket(container, create=False)
obj = bucket.Object(key)
obj.delete()
Expand Down
32 changes: 32 additions & 0 deletions osaka/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from osaka import utils


class UtilsTest(unittest.TestCase):
def setUp(self):
self.legacy_s3_url = "s3://s3-us-west-2.amazonaws.com:80/my_bucket/foo/bar/key"
self.nominal_s3_url = "s3://my_bucket/foo/bar/key"
self.expected_container = "my_bucket"
self.expected_key = "foo/bar/key"

def test_get_s3_container_and_path_legacy_style(self):
container, key = utils.get_s3_container_and_path(self.legacy_s3_url)
self.assertEquals(container,
self.expected_container,
f"Did not get expected container value: "
f"container={container}, expected={self.expected_container}")
self.assertEquals(key, self.expected_key, f"Did not get expected key value: key={key}, "
f"expected={self.expected_key}")

def test_get_s3_container_and_path_nominal_style(self):
container, key = utils.get_s3_container_and_path(self.nominal_s3_url, is_nominal_style=True)
self.assertEquals(container, self.expected_container,
f"Did not get expected container value: "
f"container={container}, expected={self.expected_container}")
self.assertEquals(key, self.expected_key, f"Did not get expected key value: key={key}, "
f"expected={self.expected_key}")


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions osaka/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def get_container_and_path(urlpath):
return (split[0], "" if not len(split) > 1 else split[1])


def get_s3_container_and_path(s3_url, is_nominal_style=False):
"""
Gets the s3 container (a.k.a. bucket) and path (a.k.a key) from the
given url.
:param s3_url: The S3 url to parse.
:param is_nominal_style: Set to true if the given S3 url is in a
nominal style format (i.e. s3://bucket/key). Otherwise, set to false.
:return:
"""
parsed = urllib.parse.urlparse(s3_url)
if is_nominal_style is True:
# If we are dealing with the nominal S3 style path
return parsed.hostname, parsed.path.split("/", 1)[1]
else:
return get_container_and_path(parsed.path)


# def walk(func, directory,destdir, *params):
# '''
# Walk the directory and call the function for each file
Expand Down

0 comments on commit 981b576

Please sign in to comment.