Skip to content

Commit d9e5c8f

Browse files
authored
Merge pull request #3 from AllenInstitute/hotfix/invalid-DescribeAccessPoints-call-pattern
Fix DescribeAccessPoints call pattern
2 parents 235693f + 0ae9c18 commit d9e5c8f

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

src/aibs_informatics_aws_utils/efs/core.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import logging
1212
from pathlib import Path
13-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
13+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1414

1515
from aibs_informatics_core.utils.decorators import retry
1616
from aibs_informatics_core.utils.tools.dicttools import remove_null_values
@@ -144,35 +144,36 @@ def list_efs_access_points(
144144
"""
145145
efs = get_efs_client()
146146

147-
file_system_ids: List[Optional[str]] = []
147+
file_system_ids: List[str] = []
148148
if file_system_id:
149149
file_system_ids.append(file_system_id)
150150
elif file_system_name or file_system_tags:
151151
file_systems = list_efs_file_systems(
152152
file_system_id=file_system_id, name=file_system_name, tags=file_system_tags
153153
)
154154
file_system_ids.extend(map(lambda _: _["FileSystemId"], file_systems))
155-
else:
156-
file_system_ids.append(None)
157155

158156
access_points: List[AccessPointDescriptionTypeDef] = []
159157

160-
for fs_id in file_system_ids:
158+
if access_point_id or not file_system_ids:
161159
response = efs.describe_access_points(
162-
**remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id)) # type: ignore
160+
**remove_null_values(dict(AccessPointId=access_point_id)) # type: ignore
163161
)
164-
access_points.extend(response["AccessPoints"])
165-
while response.get("NextToken"):
166-
response = efs.describe_access_points(
167-
**remove_null_values( # type: ignore
168-
dict(
169-
AccessPointId=access_point_id,
170-
FileSystemId=fs_id,
171-
NextToken=response["NextToken"],
172-
)
173-
)
174-
)
162+
# If file_system_ids is empty, we want to include all access points. Otherwise,
163+
# we only want to include access points that belong to the file systems
164+
# in file_system_ids.
165+
for access_point in response["AccessPoints"]:
166+
if not file_system_ids or access_point.get("FileSystemId") in file_system_ids:
167+
access_points.append(access_point)
168+
else:
169+
for fs_id in file_system_ids:
170+
response = efs.describe_access_points(FileSystemId=fs_id)
175171
access_points.extend(response["AccessPoints"])
172+
while response.get("NextToken"):
173+
response = efs.describe_access_points(
174+
FileSystemId=fs_id, NextToken=response["NextToken"]
175+
)
176+
access_points.extend(response["AccessPoints"])
176177

177178
filtered_access_points: List[AccessPointDescriptionTypeDef] = []
178179

@@ -227,13 +228,13 @@ def get_efs_access_point(
227228
if len(access_points) > 1:
228229
raise ValueError(
229230
f"Found more than one access points ({len(access_points)}) "
230-
f"based on access point filters (id={access_point_id}, name={access_point_id}, tags={access_point_tags}) "
231+
f"based on access point filters (id={access_point_id}, name={access_point_name}, tags={access_point_tags}) "
231232
f"and on file system filters (id={file_system_id}, name={file_system_name}, tags={file_system_tags}) "
232233
)
233234
elif len(access_points) == 0:
234235
raise ValueError(
235236
f"Found no access points "
236-
f"based on access point filters (id={access_point_id}, name={access_point_id}, tags={access_point_tags}) "
237+
f"based on access point filters (id={access_point_id}, name={access_point_name}, tags={access_point_tags}) "
237238
f"and on file system filters (id={file_system_id}, name={file_system_name}, tags={file_system_tags}) "
238239
)
239240
return access_points[0]

test/aibs_informatics_aws_utils/efs/test_core.py

+18
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,24 @@ def test__list_efs_access_points__filters_based_on_ap_tag(self):
6060
self.assertEqual(len(access_points), 1)
6161
self.assertEqual(access_points[0].get("AccessPointId"), access_point_id1)
6262

63+
def test__list_efs_access_points__all_filters_provided(self):
64+
file_system_id = self.create_file_system("fs1", env="dev")
65+
access_point_id1 = self.create_access_point(
66+
file_system_id=file_system_id, access_point_name="ap1", env="dev"
67+
)
68+
access_point_id2 = self.create_access_point(
69+
file_system_id=file_system_id, access_point_name="ap2", env="prod"
70+
)
71+
access_points = list_efs_access_points(
72+
file_system_id=file_system_id,
73+
file_system_name="fs1",
74+
access_point_id=access_point_id1,
75+
access_point_name="ap1",
76+
access_point_tags=dict(env="dev"),
77+
)
78+
self.assertEqual(len(access_points), 1)
79+
self.assertEqual(access_points[0].get("AccessPointId"), access_point_id1)
80+
6381
def test__list_efs_access_points__filters_based_on_name(self):
6482
file_system_id = self.create_file_system("fs1", env="dev")
6583
access_point_id1 = self.create_access_point(

0 commit comments

Comments
 (0)