Skip to content

Commit ca07c44

Browse files
committed
fixing typing pylance errors
1 parent 9c25071 commit ca07c44

File tree

5 files changed

+34
-27
lines changed

5 files changed

+34
-27
lines changed

src/aibs_informatics_aws_utils/dynamodb/conditions.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ class ExpressionComponentsBase:
3333
@cached_property
3434
def expression_attribute_values__serialized(self) -> Dict[str, Dict[str, Any]]:
3535
serializer = TypeSerializer()
36-
return {k: serializer.serialize(v) for k, v in self.expression_attribute_values.items()}
36+
return {
37+
k: cast(Dict[str, Any], serializer.serialize(v))
38+
for k, v in self.expression_attribute_values.items()
39+
}
3740

3841

3942
@dataclass

src/aibs_informatics_aws_utils/dynamodb/functions.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ def table_put_item(
6161

6262

6363
def table_get_item(
64-
table_name: str, key: Mapping[str, Any], attrs: str = None
64+
table_name: str, key: Mapping[str, Any], attrs: Optional[str] = None
6565
) -> Optional[Dict[str, Any]]:
6666
table = table_as_resource(table_name)
67-
props: GetItemInputRequestTypeDef = {"Key": key, "ReturnConsumedCapacity": "NONE"}
67+
props: GetItemInputRequestTypeDef = {"Key": key, "ReturnConsumedCapacity": "NONE"} # type: ignore # we modify use of this type (no table name is needed here)
6868

6969
if attrs is not None:
7070
props["ProjectionExpression"] = attrs
7171

72-
response = table.get_item(**props)
72+
response = table.get_item(**props) # type: ignore # pylance complains about extra fields
7373

7474
logger.info("Response from table.get_item: %s", response)
7575

@@ -79,8 +79,8 @@ def table_get_item(
7979
def table_get_items(
8080
table_name: str,
8181
keys: List[Mapping[str, Any]],
82-
attrs: str = None,
83-
region: str = None,
82+
attrs: Optional[str] = None,
83+
region: Optional[str] = None,
8484
) -> List[Dict[str, Any]]:
8585
db = get_dynamodb_client(region=region)
8686
serializer = TypeSerializer()
@@ -183,7 +183,7 @@ def table_query(
183183
key_condition_expression: ConditionBase,
184184
index_name: Optional[str] = None,
185185
filter_expression: Optional[ConditionBase] = None,
186-
region: str = None,
186+
region: Optional[str] = None,
187187
consistent_read: bool = False,
188188
) -> List[Dict[str, Any]]:
189189
"""Query a table
@@ -252,7 +252,7 @@ def table_query(
252252
items: List[Dict[str, Any]] = []
253253
paginator = db.get_paginator("query")
254254
logger.info(f"Performing DB 'query' on {table.name} with following parameters: {db_request}")
255-
for i, response in enumerate(paginator.paginate(**db_request)):
255+
for i, response in enumerate(paginator.paginate(**db_request)): # type: ignore # pylance complains about extra fields
256256
new_items = response.get("Items", [])
257257
items.extend(new_items)
258258
logger.debug(f"Iter #{i+1}: item count from table. Query: {len(new_items)}")
@@ -266,7 +266,7 @@ def table_scan(
266266
table_name: str,
267267
index_name: Optional[str] = None,
268268
filter_expression: Optional[ConditionBase] = None,
269-
region: str = None,
269+
region: Optional[str] = None,
270270
consistent_read: bool = False,
271271
) -> List[Dict[str, Any]]:
272272
"""Scan a table
@@ -319,7 +319,7 @@ def table_scan(
319319
items: List[Dict[str, Any]] = []
320320
paginator = db.get_paginator("scan")
321321
logger.info(f"Performing DB 'scan' on {table.name} with following parameters: {db_request}")
322-
for i, response in enumerate(paginator.paginate(**db_request)):
322+
for i, response in enumerate(paginator.paginate(**db_request)): # type: ignore # pylance complains about extra fields
323323
new_items = response.get("Items", [])
324324
items.extend(new_items)
325325
logger.debug(f"Iter #{i+1}: item count from table. Scan: {len(new_items)}")
@@ -339,7 +339,9 @@ def table_get_key_schema(table_name: str) -> Dict[str, str]:
339339
return {k["KeyType"]: k["AttributeName"] for k in table.key_schema}
340340

341341

342-
def execute_partiql_statement(statement: str, region: str = None) -> List[Dict[str, Any]]:
342+
def execute_partiql_statement(
343+
statement: str, region: Optional[str] = None
344+
) -> List[Dict[str, Any]]:
343345
db = get_dynamodb_client(region=region)
344346

345347
response = db.execute_statement(Statement=statement)
@@ -351,7 +353,7 @@ def execute_partiql_statement(statement: str, region: str = None) -> List[Dict[s
351353
return results
352354

353355

354-
def table_as_resource(table: str, region: str = None):
356+
def table_as_resource(table: str, region: Optional[str] = None):
355357
"""Helper method to get the table as a resource for given env_label
356358
if provided.
357359
"""

src/aibs_informatics_aws_utils/dynamodb/table.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,20 @@ def build_optimized_condition_expression_set(
202202
new_condition = Key(k).eq(v)
203203
if (
204204
k in candidate_conditions
205-
and candidate_conditions[k]._values[1:] != new_condition._values[1:]
205+
and candidate_conditions[k]._values[1:] != new_condition._values[1:] # type: ignore[union-attr]
206206
):
207207
raise DBQueryException(f"Multiple values provided for attribute {k}!")
208208
candidate_conditions[k] = Key(k).eq(v)
209-
elif len(_._values) and isinstance(_._values[0], (Key, Attr)):
210-
attr_name = cast(str, _._values[0].name)
209+
elif len(_._values) and isinstance(_._values[0], (Key, Attr)): # type: ignore[union-attr]
210+
attr_name = cast(str, _._values[0].name) # type: ignore[union-attr]
211211
if attr_name not in index_all_key_names or not isinstance(
212212
_, SupportedKeyComparisonTypes
213213
):
214214
non_candidate_conditions.append(_)
215215
continue
216216
if (
217217
attr_name in candidate_conditions
218-
and candidate_conditions[attr_name]._values[1:] != _._values[1:]
218+
and candidate_conditions[attr_name]._values[1:] != _._values[1:] # type: ignore[union-attr]
219219
):
220220
raise DBQueryException(f"Multiple values provided for attribute {attr_name}!")
221221
candidate_conditions[attr_name] = _
@@ -228,12 +228,12 @@ def build_optimized_condition_expression_set(
228228
):
229229
target_index = index
230230
partition_key = candidate_conditions.pop(index.key_name)
231-
partition_key._values = (Key(index.key_name), *partition_key._values[1:])
231+
partition_key._values = (Key(index.key_name), *partition_key._values[1:]) # type: ignore[union-attr]
232232
if index.sort_key_name is not None and index.sort_key_name in candidate_conditions:
233233
sort_key_condition_expression = candidate_conditions.pop(index.sort_key_name)
234-
sort_key_condition_expression._values = (
234+
sort_key_condition_expression._values = ( # type: ignore[union-attr]
235235
Key(index.sort_key_name),
236-
*sort_key_condition_expression._values[1:],
236+
*sort_key_condition_expression._values[1:], # type: ignore[union-attr]
237237
)
238238
break
239239

@@ -315,7 +315,9 @@ def build_key(
315315
return key
316316
index = cls.index_or_default(index)
317317
return (
318-
index.get_primary_key(*key) if isinstance(key, tuple) else index.get_primary_key(key)
318+
index.get_primary_key(key[0], key[1])
319+
if isinstance(key, tuple)
320+
else index.get_primary_key(key)
319321
)
320322

321323
# --------------------------------------------------------------------------
@@ -386,8 +388,8 @@ def batch_get(
386388
items = table_get_items(table_name=self.table_name, keys=item_keys)
387389
if len(items) != len(item_keys) and not ignore_missing:
388390
missing_keys = set(
389-
[(_[index.key_name], _.get(index.sort_key_name)) for _ in item_keys]
390-
).difference((_[index.key_name], _.get(index.sort_key_name)) for _ in items)
391+
[(_[index.key_name], _.get(index.sort_key_name or "")) for _ in item_keys]
392+
).difference((_[index.key_name], _.get(index.sort_key_name or "")) for _ in items)
391393

392394
raise DBReadException(f"Could not find items for {missing_keys}")
393395
entries = [self.build_entry(_, partial=partial) for _ in items]
@@ -704,7 +706,7 @@ def delete(
704706
e_msg = f"{self.table_name} - Delete failed for the following primary key: {key}"
705707
try:
706708
deleted_attributes = table_delete_item(
707-
table_name=self.table_name, key=key, return_values="ALL_OLD"
709+
table_name=self.table_name, key=key, return_values="ALL_OLD" # type: ignore[arg-type] # expected type more general than specified here
708710
)
709711

710712
if not deleted_attributes:

src/aibs_informatics_aws_utils/efs/core.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def list_efs_file_systems(
6969

7070
file_systems: List[FileSystemDescriptionTypeDef] = []
7171
paginator_kwargs = remove_null_values(dict(FileSystemId=file_system_id))
72-
for results in paginator.paginate(**paginator_kwargs):
72+
for results in paginator.paginate(**paginator_kwargs): # type: ignore
7373
for fs in results["FileSystems"]:
7474
if name and fs.get("Name") != name:
7575
continue
@@ -159,12 +159,12 @@ def list_efs_access_points(
159159

160160
for fs_id in file_system_ids:
161161
response = efs.describe_access_points(
162-
**remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id))
162+
**remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id)) # type: ignore
163163
)
164164
access_points.extend(response["AccessPoints"])
165165
while response.get("NextToken"):
166166
response = efs.describe_access_points(
167-
**remove_null_values(
167+
**remove_null_values( # type: ignore
168168
dict(
169169
AccessPointId=access_point_id,
170170
FileSystemId=fs_id,

src/aibs_informatics_aws_utils/efs/mount_point.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def is_mounted_path(self, path: StrPath) -> bool:
225225
def as_env_vars(self, name: Optional[str] = None) -> Dict[str, str]:
226226
"""Converts the mount point configuration to environment variables."""
227227
if self.access_point and self.access_point.get("AccessPointId"):
228-
mount_point_id = self.access_point.get("AccessPointId")
228+
mount_point_id = self.access_point["AccessPointId"] # type: ignore pylance complains even though we checked
229229
else:
230230
mount_point_id = self.file_system["FileSystemId"]
231231

0 commit comments

Comments
 (0)