Skip to content

Commit 48344bb

Browse files
authored
Merge pull request #15 from AllenInstitute/feature/copy-changes-from-ocs
Copy over changes from OCS
2 parents 6c4e3f6 + e01f86b commit 48344bb

File tree

5 files changed

+278
-30
lines changed

5 files changed

+278
-30
lines changed

src/aibs_informatics_aws_utils/data_sync/operations.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
MAX_LOCK_WAIT_TIME_IN_SECS = 60 * 60 * 6 # 6 hours
3636

37+
LOCK_ROOT_ENV_VAR = "DATA_SYNC_LOCK_ROOT"
3738

3839
LocalPath = Union[Path, EFSPath]
3940

@@ -100,7 +101,7 @@ def sync_s3_to_local(self, source_path: S3URI, destination_path: LocalPath):
100101
@retry(CannotAcquirePathLockError, tries=tries, delay=delay, backoff=1)
101102
@functools.wraps(sync_paths)
102103
def sync_paths_with_lock(*args, **kwargs):
103-
with PathLock(destination_path) as lock:
104+
with PathLock(destination_path, lock_root=os.getenv(LOCK_ROOT_ENV_VAR)) as lock:
104105
response = sync_paths(*args, **kwargs)
105106
return response
106107

src/aibs_informatics_aws_utils/dynamodb/conditions.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from collections import defaultdict
33
from dataclasses import dataclass
4-
from typing import Any, Dict, Iterable, List, Mapping, Match, Union, cast
4+
from typing import Any, Dict, Iterable, List, Mapping, Match, Optional, Union, cast
55

66
from aibs_informatics_core.collections import ValidatedStr
77
from aibs_informatics_core.models.aws.dynamodb import (
@@ -24,6 +24,39 @@
2424
logger = get_logger(__name__)
2525

2626

27+
def condition_to_str(condition: Optional[ConditionBase]) -> Optional[str]:
28+
"""Converts a ConditionBase Boto3 object to its str representation
29+
30+
NOTE: Function should be removed if this PR is merged: https://github.com/boto/boto3/pull/3254
31+
32+
Examples:
33+
34+
>>> condition_to_str(Key("name").eq("new_name") & Attr("description").begins_with("new"))
35+
(name = new_name AND begins_with(description, new))
36+
37+
>>> condition_to_str(Attr("description").contains("cool"))
38+
contains(description, cool)
39+
40+
>>> condition_to_str(None)
41+
None
42+
"""
43+
if condition is None:
44+
return None
45+
46+
builder = ConditionExpressionBuilder()
47+
expression = builder.build_expression(condition)
48+
49+
condition_expression = expression.condition_expression
50+
51+
for name_placeholder, actual_name in expression.attribute_name_placeholders.items():
52+
condition_expression = condition_expression.replace(name_placeholder, str(actual_name))
53+
54+
for value_placeholder, actual_value in expression.attribute_value_placeholders.items():
55+
condition_expression = condition_expression.replace(value_placeholder, str(actual_value))
56+
57+
return condition_expression
58+
59+
2760
@dataclass
2861
class ExpressionComponentsBase:
2962
expression: str

src/aibs_informatics_aws_utils/dynamodb/table.py

+50-27
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Dict,
66
Generic,
77
List,
8+
Literal,
89
Mapping,
910
MutableMapping,
1011
Optional,
@@ -14,6 +15,7 @@
1415
TypeVar,
1516
Union,
1617
cast,
18+
overload,
1719
)
1820

1921
from aibs_informatics_core.env import EnvBase
@@ -42,6 +44,7 @@
4244
from botocore.exceptions import ClientError
4345

4446
from aibs_informatics_aws_utils.core import get_client_error_code
47+
from aibs_informatics_aws_utils.dynamodb.conditions import condition_to_str
4548
from aibs_informatics_aws_utils.dynamodb.functions import (
4649
convert_floats_to_decimals,
4750
execute_partiql_statement,
@@ -72,8 +75,7 @@ def check_db_query_unique(
7275
key_condition_expression: Optional[ConditionBase] = None,
7376
filter_expression: Optional[ConditionBase] = None,
7477
):
75-
# TODO: this should be len(query_result) > 1
76-
if len(query_result) != 1:
78+
if len(query_result) > 1:
7779
readable_key_expression: Optional[BuiltConditionExpression] = None
7880
if key_condition_expression:
7981
expression_builder = ConditionExpressionBuilder()
@@ -245,8 +247,6 @@ def build_optimized_condition_expression_set(
245247

246248
@dataclass
247249
class DynamoDBTable(LoggingMixin, Generic[DB_MODEL, DB_INDEX]):
248-
# env_base: EnvBase = field(default_factory=EnvBase.from_env)
249-
250250
def __post_init__(self):
251251
check_table_name_and_index_match(self.table_name, self.get_db_index_cls())
252252

@@ -432,10 +432,9 @@ def query(
432432
expect_non_empty (bool, optional): Whether the resulting query should return at least
433433
one result. An error will be raised if expect_non_empty=True and 0 results were
434434
returned by the query.
435-
expect_unique (bool, option): Whether the result of the query is expected to be
436-
unique (i.e. returns 1 and ONLY 1 result). An error will be raised if 0 or
437-
more than 1 results were returned for the query.
438-
435+
expect_unique (bool, option): Whether the result of the query is expected to
436+
return AT MOST one result. An error will be raised if expect_unique=True and MORE
437+
than 1 result was returned for the query.
439438
Returns:
440439
Sequence[Dict[str, Any]]: A sequence of dictionaries representing database rows
441440
where partition_key/sort_key and filter conditions are satisfied.
@@ -455,7 +454,8 @@ def query(
455454

456455
self.log.info(
457456
f"Calling query on {self.table_name} table (index: {index_name}, "
458-
f"key condition: {key_condition_expression}, filters: {filter_expression})"
457+
f"key condition: {condition_to_str(key_condition_expression)}, "
458+
f"filters: {condition_to_str(filter_expression)})"
459459
)
460460

461461
items = table_query(
@@ -507,9 +507,9 @@ def scan(
507507
expect_non_empty (bool, optional): Whether the resulting query should return at least
508508
one result. An error will be raised if expect_non_empty=True and 0 results were
509509
returned by the query.
510-
expect_unique (bool, option): Whether the result of the query is expected to be
511-
unique (i.e. returns 1 and ONLY 1 result). An error will be raised if 0 or
512-
more than 1 results were returned for the query.
510+
expect_unique (bool, option): Whether the result of the query is expected to
511+
return AT MOST one result. An error will be raised if expect_unique=True and MORE
512+
than 1 result was returned for the query.
513513
514514
Returns:
515515
Sequence[Dict[str, Any]]: A sequence of dictionaries representing database rows
@@ -524,7 +524,7 @@ def scan(
524524

525525
self.log.info(
526526
f"Calling scan on {self.table_name} table (index: {index_name},"
527-
f" filters: {filter_expression})"
527+
f" filters: {condition_to_str(filter_expression)})"
528528
)
529529

530530
items = table_scan(
@@ -611,17 +611,9 @@ def put(
611611
condition_expression: Optional[ConditionBase] = None,
612612
**table_put_item_kwargs,
613613
) -> DB_MODEL:
614-
# Convert our ConditionBase class to a readable string (for logging/debugging purposes)
615-
if condition_expression:
616-
expression = ConditionExpressionBuilder().build_expression(condition_expression)
617-
expression_str = (
618-
f"({expression.condition_expression}, "
619-
f"{expression.attribute_name_placeholders}, "
620-
f"{expression.attribute_value_placeholders})"
621-
)
622-
else:
623-
expression_str = "None"
624-
put_summary = f"(entry: {entry}, condition_expression: {expression_str})"
614+
put_summary = (
615+
f"(entry: {entry}, condition_expression: {condition_to_str(condition_expression)})"
616+
)
625617
self.log.debug(f"{self.table_name} - Putting new entry: {put_summary}")
626618

627619
e_msg_intro = f"{self.table_name} - Error putting entry: {put_summary}."
@@ -659,22 +651,30 @@ def update(
659651

660652
for k in key:
661653
new_attributes.pop(k, None)
654+
# Add k:v pair from new_attributes if new != old value for a given key
655+
new_clean_attrs: Dict[str, Any] = {}
656+
if old_entry:
657+
for k, new_v in new_attributes.items():
658+
if getattr(old_entry, k) != new_v:
659+
new_clean_attrs[k] = new_v
660+
else:
661+
new_clean_attrs = new_attributes
662662

663-
if not new_attributes:
663+
if not new_clean_attrs:
664664
self.log.debug(
665665
f"{self.table_name} - No attr_updates to do! Skipping _update_entry call."
666666
)
667667
if not old_entry:
668668
old_entry = self.get(key)
669669
return old_entry
670670

671-
update_summary = f"(old_entry: {old_entry}, new_attributes: {new_attributes})"
671+
update_summary = f"(old_entry: {old_entry}, new_attributes: {new_clean_attrs})"
672672
self.log.debug(f"{self.table_name} - Updating entry: {update_summary}")
673673
try:
674674
updated_item = table_update_item(
675675
table_name=self.table_name,
676676
key=key,
677-
attributes=new_attributes,
677+
attributes=new_clean_attrs,
678678
return_values="ALL_NEW",
679679
**table_update_item_kwargs,
680680
)
@@ -694,6 +694,29 @@ def update(
694694
self.log.debug(f"{self.table_name} - Successfully updated entry: {updated_entry}")
695695
return updated_entry
696696

697+
@overload
698+
def delete(
699+
self,
700+
key: Union[DynamoDBKey, DB_MODEL],
701+
error_on_nonexistent: Literal[True],
702+
) -> DB_MODEL:
703+
...
704+
705+
@overload
706+
def delete(
707+
self,
708+
key: Union[DynamoDBKey, DB_MODEL],
709+
error_on_nonexistent: Literal[False],
710+
) -> Optional[DB_MODEL]:
711+
...
712+
713+
@overload
714+
def delete(
715+
self,
716+
key: Union[DynamoDBKey, DB_MODEL],
717+
) -> Optional[DB_MODEL]:
718+
...
719+
697720
def delete(
698721
self,
699722
key: Union[DynamoDBKey, DB_MODEL],

src/aibs_informatics_aws_utils/sqs.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import logging
3+
from typing import Optional, Type
34

5+
from aibs_informatics_core.utils.json import DecimalEncoder
46
from botocore.exceptions import ClientError
57

68
from aibs_informatics_aws_utils.core import AWSService, get_region
@@ -13,7 +15,7 @@
1315
get_sqs_resource = AWSService.SQS.get_resource
1416

1517

16-
def delete_from_queue(queue_name: str, receipt_handle: str, region: str = None):
18+
def delete_from_queue(queue_name: str, receipt_handle: str, region: Optional[str] = None):
1719
sqs = get_sqs_client()
1820
queue_url_response = sqs.get_queue_url(QueueName=queue_name)
1921
queue_url = queue_url_response["QueueUrl"]
@@ -40,3 +42,63 @@ def send_to_dispatch_queue(payload: dict, env_base: str):
4042
response = sqs.send_message(QueueUrl=queue_url, MessageBody=json.dumps(payload))
4143

4244
return response["MD5OfMessageBody"]
45+
46+
47+
def send_sqs_message(
48+
queue_name: str,
49+
payload: dict,
50+
message_deduplication_id: Optional[str] = None,
51+
message_group_id: Optional[str] = None,
52+
payload_json_encoder: Type[json.JSONEncoder] = DecimalEncoder,
53+
) -> str:
54+
"""Send a message to an SQS queue by providing a queue name
55+
56+
Args:
57+
queue_name (str): The name of the queue that you want to send a message to.
58+
(e.g. 'aps-sync-request-queue.fifo')
59+
payload (dict): A dictionary representing the message payload you would like to send.
60+
message_deduplication_id (Optional[str], optional): An ID that can be used by SQS
61+
to remove messages that have the same deduplication_id. Do not set if your
62+
SQS queue already uses content based deduplication. Defaults to None.
63+
message_group_id (Optional[str], optional): Required for FIFO queues.
64+
Messages sent with the same message_group_id will obey FIFO rules. Messages with
65+
different message_group_ids may be interleaved. Defaults to None.
66+
payload_json_encoder (Type[json.JSONEncoder], optional): The JSONEncoder
67+
class that should be used to covert the input `payload` dictionary into
68+
a json string. By default uses a DecimalEncoder which can handle decimal.Decimal types.
69+
70+
Raises:
71+
AWSError: If the provided queue_name cannot be resolved to an SQS url.
72+
HINT: Does the code calling this function have the correct SQS permissions?
73+
RuntimeError: If the destination queue is a FIFO queue, then `message_group_id` MUST
74+
be provided.
75+
76+
Returns:
77+
str: Returns an MD5 digest of the send message body.
78+
"""
79+
sqs = get_sqs_client(region=get_region())
80+
try:
81+
queue_url_response = sqs.get_queue_url(QueueName=queue_name)
82+
except ClientError as e:
83+
raise AWSError(
84+
f"Could not find SQS queue with name: {queue_name}. "
85+
"Does the code calling send_sqs_message() have sqs:GetQueueUrl permissions?"
86+
)
87+
88+
send_sqs_message_args = {
89+
"QueueUrl": queue_url_response["QueueUrl"],
90+
"MessageBody": json.dumps(payload, cls=payload_json_encoder),
91+
}
92+
93+
if message_group_id is not None:
94+
send_sqs_message_args["MessageGroupId"] = message_group_id
95+
else:
96+
if queue_name.endswith(".fifo"):
97+
raise RuntimeError("SQS messages for a FIFO queue *must* include a message_group_id!")
98+
99+
if message_deduplication_id is not None:
100+
send_sqs_message_args["MessageDeduplicationId"] = message_deduplication_id
101+
102+
response = sqs.send_message(**send_sqs_message_args) # type: ignore # complains about valid kwargs
103+
104+
return response["MD5OfMessageBody"]

0 commit comments

Comments
 (0)