Skip to content

Commit 02a28f1

Browse files
authored
add workflow cap for all users (#140)
add workflow cap for all users
1 parent e77e37d commit 02a28f1

File tree

5 files changed

+141
-40
lines changed

5 files changed

+141
-40
lines changed

src/argowrapper/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
GEN3_WORKFLOW_PHASE_LABEL: Final = "phase"
4343
GEN3_SUBMIT_TIMESTAMP_LABEL: Final = "submittedAt"
4444
GEN3_NON_VA_WORKFLOW_MONTHLY_CAP: Final = 20
45+
GEN3_DEFAULT_WORKFLOW_MONTHLY_CAP: Final = 50
4546

4647

4748
class POD_COMPLETION_STRATEGY(Enum):

src/argowrapper/engine/argo_engine.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,11 @@ def get_user_workflows_for_current_month(self, auth_header: str) -> List[Dict]:
405405
)
406406
user_monthly_workflows = []
407407
for workflow in all_user_workflows:
408-
if workflow[GEN3_WORKFLOW_PHASE_LABEL] in {"Running", "Succeeded"}:
408+
if workflow[GEN3_WORKFLOW_PHASE_LABEL] in {
409+
"Running",
410+
"Succeeded",
411+
"Failed",
412+
}:
409413
submitted_time_str = workflow[GEN3_SUBMIT_TIMESTAMP_LABEL]
410414
submitted_time = datetime.strptime(
411415
submitted_time_str, "%Y-%m-%dT%H:%M:%SZ"

src/argowrapper/routes/routes.py

+52-27
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
GEN3_TEAM_PROJECT_METADATA_LABEL,
1818
GEN3_USER_METADATA_LABEL,
1919
GEN3_NON_VA_WORKFLOW_MONTHLY_CAP,
20+
GEN3_DEFAULT_WORKFLOW_MONTHLY_CAP,
2021
)
2122

2223
from argowrapper import logger
@@ -209,7 +210,7 @@ def wrapper(*args, **kwargs):
209210
return wrapper
210211

211212

212-
def check_user_billing_id(request):
213+
def check_user_info_for_billing_id_and_workflow_limit(request):
213214
"""
214215
Check whether user is non-VA user
215216
if user is VA-user, do nothing and proceed
@@ -232,16 +233,26 @@ def check_user_billing_id(request):
232233
raise exception
233234
logger.info("Got user info successfully. Checking for billing id..")
234235

235-
if "tags" in user_info and "billing_id" in user_info["tags"]:
236-
billing_id = user_info["tags"]["billing_id"]
237-
logger.info("billing id found in user tags: " + billing_id)
238-
return billing_id
236+
if "tags" in user_info:
237+
if "billing_id" in user_info["tags"]:
238+
billing_id = user_info["tags"]["billing_id"]
239+
logger.info("billing id found in user tags: " + billing_id)
240+
else:
241+
billing_id = None
242+
243+
if "workflow_limit" in user_info["tags"]:
244+
workflow_limit = user_info["tags"]["workflow_limit"]
245+
logger.info("Workflow limit found in user tags: " + workflow_limit)
246+
else:
247+
workflow_limit = None
248+
249+
return billing_id, workflow_limit
239250
else:
240-
logger.info("billing id not found.")
241-
return None
251+
logger.info("User info does not have tags")
252+
return None, None
242253

243254

244-
def check_user_reached_monthly_workflow_cap(request_token):
255+
def check_user_reached_monthly_workflow_cap(request_token, billing_id, custom_limit):
245256
"""
246257
Query Argo service to see how many successful run user already
247258
have in the current calendar month. If the number is greater than
@@ -252,20 +263,28 @@ def check_user_reached_monthly_workflow_cap(request_token):
252263
current_month_workflows = argo_engine.get_user_workflows_for_current_month(
253264
request_token
254265
)
266+
username = argo_engine_helper.get_username_from_token(request_token)
267+
if custom_limit:
268+
limit = custom_limit
269+
else:
270+
if billing_id:
271+
limit = GEN3_NON_VA_WORKFLOW_MONTHLY_CAP
272+
else:
273+
limit = GEN3_DEFAULT_WORKFLOW_MONTHLY_CAP
255274

256-
if len(current_month_workflows) >= GEN3_NON_VA_WORKFLOW_MONTHLY_CAP:
257-
logger.info(
258-
"User already executed {} workflows this month and cannot create new ones anymore.".format(
259-
len(current_month_workflows)
275+
if len(current_month_workflows) >= limit:
276+
logger.warn(
277+
"This user {} already executed {} workflows this month and cannot create new ones anymore. The currently monthly cap for this user is {}.".format(
278+
username, len(current_month_workflows), limit
260279
)
261280
)
281+
return True
282+
else:
262283
logger.info(
263-
"The currently monthly cap is {}.".format(
264-
GEN3_NON_VA_WORKFLOW_MONTHLY_CAP
284+
"This user {} executed {} workflows this month. The currently monthly cap for this user is {}.".format(
285+
username, len(current_month_workflows), limit
265286
)
266287
)
267-
return True
268-
269288
return False
270289
except Exception as e:
271290
logger.error(e)
@@ -292,13 +311,14 @@ def submit_workflow(
292311
reached_monthly_cap = False
293312

294313
# check if user has a billing id tag:
295-
billing_id = check_user_billing_id(request)
314+
billing_id, workflow_limit = check_user_info_for_billing_id_and_workflow_limit(
315+
request
316+
)
296317

297318
# if user has billing_id (non-VA user), check if they already reached the monthly cap
298-
if billing_id:
299-
reached_monthly_cap = check_user_reached_monthly_workflow_cap(
300-
request.headers.get("Authorization")
301-
)
319+
reached_monthly_cap = check_user_reached_monthly_workflow_cap(
320+
request.headers.get("Authorization"), billing_id, workflow_limit
321+
)
302322

303323
# submit workflow:
304324
if not reached_monthly_cap:
@@ -311,8 +331,9 @@ def submit_workflow(
311331
status_code=HTTP_403_FORBIDDEN,
312332
)
313333
except Exception as exception:
334+
logger.error(str(exception))
314335
return HTMLResponse(
315-
content=str(exception),
336+
content="Unexpected Error Occurred",
316337
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
317338
)
318339

@@ -331,8 +352,9 @@ def get_workflow_details(
331352
return argo_engine.get_workflow_details(workflow_name, uid)
332353

333354
except Exception as exception:
355+
logger.error(str(exception))
334356
return HTMLResponse(
335-
content=str(exception),
357+
content="Unexpected Error Occurred",
336358
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
337359
)
338360

@@ -354,7 +376,7 @@ def retry_workflow(
354376
logger.error(traceback.format_exc())
355377
logger.error(f"could not retry {workflow_name}, failed with error {exception}")
356378
return HTMLResponse(
357-
content=str(exception),
379+
content="Could not retry workflow, error occurred",
358380
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
359381
)
360382

@@ -372,8 +394,9 @@ def cancel_workflow(
372394
return argo_engine.cancel_workflow(workflow_name)
373395

374396
except Exception as exception:
397+
logger.error(str(exception))
375398
return HTMLResponse(
376-
content=str(exception),
399+
content="Unexpected Error Occurred",
377400
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
378401
)
379402

@@ -400,8 +423,9 @@ def get_workflows(
400423
)
401424

402425
except Exception as exception:
426+
logger.error(str(exception))
403427
return HTMLResponse(
404-
content=exception,
428+
content="Unexpected Error Occurred",
405429
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
406430
)
407431

@@ -419,7 +443,8 @@ def get_workflow_logs(
419443
return argo_engine.get_workflow_logs(workflow_name, uid)
420444

421445
except Exception as exception:
446+
logger.error(str(exception))
422447
return HTMLResponse(
423-
content=exception,
448+
content="Unexpected Error Occurred",
424449
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
425450
)

test/test_argo_engine.py

+5
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,11 @@ def test_get_user_workflows_for_current_month(monkeypatch):
825825
"phase": "Succeeded",
826826
"submittedAt": "2023-11-15T17:52:52Z",
827827
},
828+
{
829+
"uid": "uid_3",
830+
"phase": "Failed",
831+
"submittedAt": "2023-11-02T00:00:00Z",
832+
},
828833
]
829834
engine.get_workflows_for_label_selector = mock.MagicMock(
830835
return_value=workflows_mock_response

test/test_routes.py

+78-12
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,17 @@ def test_submit_workflow(client):
106106
) as mock_engine, patch(
107107
"argowrapper.routes.routes.log_auth_check_type"
108108
) as mock_log, patch(
109-
"argowrapper.routes.routes.check_user_billing_id"
109+
"argowrapper.routes.routes.check_user_info_for_billing_id_and_workflow_limit"
110110
) as mock_check_billing_id, patch(
111111
"requests.get"
112-
) as mock_requests:
112+
) as mock_requests, patch(
113+
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
114+
) as mock_check_monthly_cap:
113115
mock_auth.return_value = True
114116
mock_engine.return_value = "workflow_123"
115-
mock_check_billing_id.return_value = None
117+
mock_check_billing_id.return_value = None, None
116118
mock_requests.side_effect = mocked_requests_get
119+
mock_check_monthly_cap.return_value = False
117120

118121
response = client.post(
119122
"/submit",
@@ -516,10 +519,16 @@ def test_submit_workflow_with_user_billing_id(client):
516519
"argowrapper.routes.routes.log_auth_check_type"
517520
) as mock_log, patch(
518521
"requests.get"
519-
) as mock_requests:
522+
) as mock_requests, patch(
523+
"argowrapper.engine.argo_engine.ArgoEngine.get_user_workflows_for_current_month"
524+
) as mock_get_workflow:
520525
mock_auth.return_value = True
521526
mock_engine.return_value = "workflow_123"
522527
mock_requests.side_effect = mocked_requests_get
528+
mock_get_workflow.return_value = [
529+
{"wf_name": "workflow1"},
530+
{"wf_name": "workflow2"},
531+
]
523532

524533
data["user_tags"] = {"tags": {}}
525534
response = client.post(
@@ -562,7 +571,7 @@ def test_submit_workflow_with_user_billing_id(client):
562571
)
563572
assert mock_engine.call_args.args[2] == "1234"
564573

565-
data["user_tags"] == 500
574+
data["user_tags"] = 500
566575

567576
response = client.post(
568577
"/submit",
@@ -588,17 +597,45 @@ def test_check_user_reached_monthly_workflow_cap():
588597
{"wf_name": "workflow1"},
589598
{"wf_name": "workflow2"},
590599
]
600+
601+
# Test Under Default Limit
602+
assert (
603+
check_user_reached_monthly_workflow_cap(
604+
headers["Authorization"], None, None
605+
)
606+
== False
607+
)
608+
609+
# Test Custom Limit
591610
assert (
592-
check_user_reached_monthly_workflow_cap(headers["Authorization"]) == False
611+
check_user_reached_monthly_workflow_cap(headers["Authorization"], None, 2)
612+
== True
593613
)
594614

615+
# Test Billing Id User Exceeding Limit
595616
workflows = []
596617
for index in range(GEN3_NON_VA_WORKFLOW_MONTHLY_CAP + 1):
597618
workflows.append({"wf_name": "workflow" + str(index)})
598-
599619
mock_get_workflow.return_value = workflows
600620

601-
assert check_user_reached_monthly_workflow_cap(headers["Authorization"]) == True
621+
assert (
622+
check_user_reached_monthly_workflow_cap(
623+
headers["Authorization"], "1234", None
624+
)
625+
== True
626+
)
627+
628+
# Test VA User Exceeding Limit
629+
workflows = []
630+
for index in range(GEN3_DEFAULT_WORKFLOW_MONTHLY_CAP + 1):
631+
workflows.append({"wf_name": "workflow" + str(index)})
632+
mock_get_workflow.return_value = workflows
633+
assert (
634+
check_user_reached_monthly_workflow_cap(
635+
headers["Authorization"], None, None
636+
)
637+
== True
638+
)
602639

603640

604641
def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
@@ -607,15 +644,44 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
607644
) as mock_engine, patch(
608645
"argowrapper.routes.routes.log_auth_check_type"
609646
) as mock_log, patch(
610-
"argowrapper.routes.routes.check_user_billing_id"
647+
"argowrapper.routes.routes.check_user_info_for_billing_id_and_workflow_limit"
648+
) as mock_check_billing_id, patch(
649+
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
650+
) as mock_check_monthly_cap, patch(
651+
"requests.get"
652+
) as mock_requests:
653+
mock_auth.return_value = True
654+
mock_engine.return_value = "workflow_123"
655+
mock_check_billing_id.return_value = "1234", None
656+
mock_check_monthly_cap.return_value = True
657+
mock_requests.side_effect = mocked_requests_get
658+
659+
response = client.post(
660+
"/submit",
661+
data=json.dumps(data),
662+
headers={
663+
"Content-Type": "application/json",
664+
"Authorization": EXAMPLE_AUTH_HEADER,
665+
},
666+
)
667+
assert response.status_code == 403
668+
669+
670+
def test_submit_workflow_over_monthly_cap(client):
671+
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
672+
"argowrapper.routes.routes.argo_engine.workflow_submission"
673+
) as mock_engine, patch(
674+
"argowrapper.routes.routes.log_auth_check_type"
675+
) as mock_log, patch(
676+
"argowrapper.routes.routes.check_user_info_for_billing_id_and_workflow_limit"
611677
) as mock_check_billing_id, patch(
612678
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
613679
) as mock_check_monthly_cap, patch(
614680
"requests.get"
615681
) as mock_requests:
616682
mock_auth.return_value = True
617683
mock_engine.return_value = "workflow_123"
618-
mock_check_billing_id.return_value = "1234"
684+
mock_check_billing_id.return_value = None, None
619685
mock_check_monthly_cap.return_value = True
620686
mock_requests.side_effect = mocked_requests_get
621687

@@ -636,13 +702,13 @@ def test_submit_workflow_with_non_team_project_cohort(client):
636702
) as mock_engine, patch(
637703
"argowrapper.routes.routes.log_auth_check_type"
638704
) as mock_log, patch(
639-
"argowrapper.routes.routes.check_user_billing_id"
705+
"argowrapper.routes.routes.check_user_info_for_billing_id_and_workflow_limit"
640706
) as mock_check_billing_id, patch(
641707
"requests.get"
642708
) as mock_requests:
643709
mock_auth.return_value = True
644710
mock_engine.return_value = "workflow_123"
645-
mock_check_billing_id.return_value = None
711+
mock_check_billing_id.return_value = None, None
646712
mock_requests.side_effect = mocked_requests_get
647713

648714
data["outcome"]["cohort_ids"] = [400]

0 commit comments

Comments
 (0)