Skip to content

Commit 138baf6

Browse files
authored
Merge branch 'master' into add_num_uploaders
2 parents 29113f5 + 71f4477 commit 138baf6

File tree

3 files changed

+43
-8
lines changed

3 files changed

+43
-8
lines changed

src/lightning/data/streaming/client.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
if _BOTO3_AVAILABLE:
88
import boto3
99
import botocore
10+
from botocore.credentials import InstanceMetadataProvider
11+
from botocore.utils import InstanceMetadataFetcher
1012

1113

1214
class S3Client:
@@ -18,20 +20,34 @@ def __init__(self, refetch_interval: int = 3300) -> None:
1820
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
1921
self._client: Optional[Any] = None
2022

23+
def _create_client(self) -> None:
24+
has_shared_credentials_file = os.getenv("AWS_SHARED_CREDENTIALS_FILE") == os.getenv("AWS_CONFIG_FILE") == "/.credentials/.aws_credentials" # noqa: E501
25+
26+
if has_shared_credentials_file:
27+
self._client = boto3.client(
28+
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
29+
)
30+
else:
31+
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
32+
credentials = provider.load()
33+
self._client = boto3.client(
34+
"s3",
35+
aws_access_key_id=credentials.access_key,
36+
aws_secret_access_key=credentials.secret_key,
37+
aws_session_token=credentials.token,
38+
config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
39+
)
40+
2141
@property
2242
def client(self) -> Any:
2343
if not self._has_cloud_space_id:
2444
if self._client is None:
25-
self._client = boto3.client(
26-
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
27-
)
45+
self._create_client()
2846
return self._client
2947

3048
# Re-generate credentials for EC2
3149
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
32-
self._client = boto3.client(
33-
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
34-
)
50+
self._create_client()
3551
self._last_time = time()
3652

3753
return self._client

src/lightning/data/streaming/downloader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
5757
return
5858

5959
try:
60-
with FileLock(local_filepath + ".lock", timeout=1 if obj.path.endswith(_INDEX_FILENAME) else 0):
60+
with FileLock(local_filepath + ".lock", timeout=3 if obj.path.endswith(_INDEX_FILENAME) else 0):
6161
if self._s5cmd_available:
6262
proc = subprocess.Popen(
6363
f"s5cmd cp {remote_filepath} {local_filepath}",

tests/tests_data/streaming/test_client.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
1313
botocore = mock.MagicMock()
1414
monkeypatch.setattr(client, "botocore", botocore)
1515

16+
instance_metadata_provider = mock.MagicMock()
17+
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
18+
19+
instance_metadata_fetcher = mock.MagicMock()
20+
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
21+
1622
s3 = client.S3Client(1)
1723
assert s3.client
1824
assert s3.client
@@ -24,7 +30,8 @@ def test_s3_client_without_cloud_space_id(monkeypatch):
2430

2531

2632
@pytest.mark.skipif(sys.platform == "win32", reason="not supported on windows")
27-
def test_s3_client_with_cloud_space_id(monkeypatch):
33+
@pytest.mark.parametrize("use_shared_credentials", [False, True])
34+
def test_s3_client_with_cloud_space_id(use_shared_credentials, monkeypatch):
2835
boto3 = mock.MagicMock()
2936
monkeypatch.setattr(client, "boto3", boto3)
3037

@@ -33,6 +40,16 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
3340

3441
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
3542

43+
if use_shared_credentials:
44+
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", "/.credentials/.aws_credentials")
45+
monkeypatch.setenv("AWS_CONFIG_FILE", "/.credentials/.aws_credentials")
46+
47+
instance_metadata_provider = mock.MagicMock()
48+
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
49+
50+
instance_metadata_fetcher = mock.MagicMock()
51+
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
52+
3653
s3 = client.S3Client(1)
3754
assert s3.client
3855
assert s3.client
@@ -45,3 +62,5 @@ def test_s3_client_with_cloud_space_id(monkeypatch):
4562
assert s3.client
4663
assert s3.client
4764
assert len(boto3.client._mock_mock_calls) == 9
65+
66+
assert instance_metadata_provider._mock_call_count == 0 if use_shared_credentials else 3

0 commit comments

Comments
 (0)