Skip to content

Commit 9d98ff6

Browse files
authored
Resolve unsafe thread race condition for workflow submission (#145)
* add lock to workflow submission logic * update lock to be per user * Add sleep timer to submission to ensure workflow endpoint register new workflow
1 parent f0c1533 commit 9d98ff6

File tree

1 file changed

+70
-50
lines changed

1 file changed

+70
-50
lines changed

src/argowrapper/engine/argo_engine.py

+70-50
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from argowrapper.engine.helpers.workflow_factory import WorkflowFactory
4141
from argowrapper.workflows.argo_workflows.gwas import GWAS
4242
import requests
43+
import time
44+
from threading import Lock
4345

4446

4547
class ArgoEngine:
@@ -57,6 +59,7 @@ def __repr__(self) -> str:
5759
return f"dry_run={self.dry_run}"
5860

5961
def __init__(self, dry_run: bool = False):
62+
self.user_locks = {}
6063
self.dry_run = dry_run
6164
# workflow "given names" by uid cache:
6265
self.workflow_given_names_cache = {}
@@ -196,6 +199,11 @@ def _get_log_errors(self, uid: str, status_nodes_dict: Dict) -> List[Dict]:
196199
pass
197200
return errors
198201

202+
def _get_lock_for_user(self, username: str) -> Lock:
203+
if username not in self.user_locks:
204+
self.user_locks[username] = Lock()
205+
return self.user_locks[username]
206+
199207
def get_workflow_details(
200208
self, workflow_name: str, uid: str = None
201209
) -> Dict[str, any]:
@@ -582,59 +590,71 @@ def get_workflow_logs(self, workflow_name: str, uid: str) -> List[Dict]:
582590
)
583591

584592
def workflow_submission(self, request_body: Dict, auth_header: str):
593+
# Lock function so only one can run at a time per user
594+
username = argo_engine_helper.get_username_from_token(auth_header)
595+
user_lock = self._get_lock_for_user(username)
596+
user_lock.acquire()
585597

586-
workflow = WorkflowFactory._get_workflow(
587-
ARGO_NAMESPACE, request_body, auth_header, WORKFLOW.GWAS
588-
)
589-
workflow_yaml = workflow._to_dict()
590-
591-
reached_monthly_cap = False
592-
593-
# check if user has a billing id tag:
594-
(
595-
billing_id,
596-
workflow_limit,
597-
) = self.check_user_info_for_billing_id_and_workflow_limit(auth_header)
598-
599-
# If billing_id exists for user, add it to workflow label and pod metadata
600-
# remove gen3-username from pod metadata
601-
if billing_id:
602-
workflow_yaml["metadata"]["labels"]["billing_id"] = billing_id
603-
pod_labels = workflow_yaml["spec"]["podMetadata"]["labels"]
604-
pod_labels["billing_id"] = billing_id
605-
pod_labels["gen3username"] = ""
606-
607-
# if user has billing_id (non-VA user), check if they already reached the monthly cap
608-
workflow_run, workflow_limit = self.check_user_monthly_workflow_cap(
609-
auth_header, billing_id, workflow_limit
610-
)
611-
612-
reached_monthly_cap = workflow_run >= workflow_limit
598+
try:
599+
if "workflow_name" in request_body.keys():
600+
logger.info(f"lock acquired for {request_body['workflow_name']}")
601+
workflow = WorkflowFactory._get_workflow(
602+
ARGO_NAMESPACE, request_body, auth_header, WORKFLOW.GWAS
603+
)
604+
workflow_yaml = workflow._to_dict()
605+
606+
reached_monthly_cap = True
607+
608+
# check if user has a billing id tag:
609+
(
610+
billing_id,
611+
workflow_limit,
612+
) = self.check_user_info_for_billing_id_and_workflow_limit(auth_header)
613+
614+
# If billing_id exists for user, add it to workflow label and pod metadata
615+
# remove gen3-username from pod metadata
616+
if billing_id:
617+
workflow_yaml["metadata"]["labels"]["billing_id"] = billing_id
618+
pod_labels = workflow_yaml["spec"]["podMetadata"]["labels"]
619+
pod_labels["billing_id"] = billing_id
620+
pod_labels["gen3username"] = ""
621+
622+
# if user has billing_id (non-VA user), check if they already reached the monthly cap
623+
workflow_run, workflow_limit = self.check_user_monthly_workflow_cap(
624+
auth_header, billing_id, workflow_limit
625+
)
613626

614-
# submit workflow:
615-
if not reached_monthly_cap:
616-
try:
617-
response = self.api_instance.create_workflow(
618-
namespace=ARGO_NAMESPACE,
619-
body=IoArgoprojWorkflowV1alpha1WorkflowCreateRequest(
620-
workflow=workflow_yaml,
627+
reached_monthly_cap = workflow_run >= workflow_limit
628+
629+
# submit workflow:
630+
if not reached_monthly_cap:
631+
try:
632+
response = self.api_instance.create_workflow(
633+
namespace=ARGO_NAMESPACE,
634+
body=IoArgoprojWorkflowV1alpha1WorkflowCreateRequest(
635+
workflow=workflow_yaml,
636+
_check_return_type=False,
637+
_check_type=False,
638+
),
621639
_check_return_type=False,
622-
_check_type=False,
623-
),
624-
_check_return_type=False,
625-
)
626-
logger.debug(response)
627-
except Exception as exception:
628-
logger.error(traceback.format_exc())
629-
logger.error(
630-
f"could not submit workflow, failed with error {exception}"
631-
)
632-
raise exception
633-
else:
634-
logger.warning(EXCEED_WORKFLOW_LIMIT_ERROR)
635-
raise Exception(EXCEED_WORKFLOW_LIMIT_ERROR)
636-
637-
return workflow.wf_name
640+
async_req=False,
641+
)
642+
logger.debug(response)
643+
except Exception as exception:
644+
logger.error(traceback.format_exc())
645+
logger.error(
646+
f"could not submit workflow, failed with error {exception}"
647+
)
648+
raise exception
649+
else:
650+
logger.warning(EXCEED_WORKFLOW_LIMIT_ERROR)
651+
raise Exception(EXCEED_WORKFLOW_LIMIT_ERROR)
652+
653+
return workflow.wf_name
654+
finally:
655+
# Make sure submission registers in Argo
656+
time.sleep(5)
657+
user_lock.release()
638658

639659
def check_user_info_for_billing_id_and_workflow_limit(self, request_token):
640660
"""

0 commit comments

Comments
 (0)