Skip to content

Commit e964b89

Browse files
committed
add unit tests
1 parent caa3ca5 commit e964b89

File tree

1 file changed

+105
-66
lines changed

1 file changed

+105
-66
lines changed

test/test_routes.py

+105-66
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"control_cohort_definition_id": -1,
3636
"workflow_name": "wf_name",
3737
TEAM_PROJECT_FIELD_NAME: "dummy-team-project",
38+
"user_tags": None, # For testing purpose
3839
}
3940

4041
cohort_definition_data = {
@@ -67,26 +68,37 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]:
6768
yield client
6869

6970

70-
def test_submit_workflow(client):
71-
def mocked_requests_get(*args, **kwargs):
72-
class MockResponse:
73-
def __init__(self, json_data, status_code):
74-
self.json_data = json_data
75-
self.status_code = status_code
71+
def mocked_requests_get(*args, **kwargs):
72+
class MockResponse:
73+
def __init__(self, json_data, status_code):
74+
self.json_data = json_data
75+
self.status_code = status_code
76+
77+
def json(self):
78+
return self.json_data
79+
80+
def raise_for_status(self):
81+
if self.status_code == 500:
82+
raise Exception("fence is down")
83+
if self.status_code != 200:
84+
raise Exception()
7685

77-
def json(self):
78-
return self.json_data
86+
if (
87+
kwargs["url"]
88+
== "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/4/by-team-project?team-project=dummy-team-project"
89+
):
90+
return MockResponse(cohort_definition_data, 200)
7991

80-
def raise_for_status(self):
81-
if self.status_code != 200:
82-
raise Exception()
92+
if kwargs["url"] == "http://fence-service/user":
93+
if data["user_tags"] != 500:
94+
return MockResponse(data["user_tags"], 200)
95+
else:
96+
return MockResponse({}, 500)
97+
98+
return None
8399

84-
if (
85-
kwargs["url"]
86-
== "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/4/by-team-project?team-project=dummy-team-project"
87-
):
88-
return MockResponse(cohort_definition_data, 200)
89100

101+
def test_submit_workflow(client):
90102
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
91103
"argowrapper.routes.routes.argo_engine.workflow_submission"
92104
) as mock_engine, patch(
@@ -500,59 +512,44 @@ def test_submit_workflow_with_user_billing_id(client):
500512
"argowrapper.routes.routes.argo_engine.workflow_submission"
501513
) as mock_engine, patch(
502514
"argowrapper.routes.routes.log_auth_check_type"
503-
) as mock_log:
515+
) as mock_log, patch(
516+
"requests.get"
517+
) as mock_requests:
504518
mock_auth.return_value = True
505519
mock_engine.return_value = "workflow_123"
506-
with patch("requests.get") as mock_request:
507-
mock_resp = mock.Mock()
508-
mock_resp.status_code = 200
509-
mock_resp.raise_for_status = mock.Mock()
510-
mock_resp.json = mock.Mock(return_value={"tags": {}})
511-
mock_request.return_value = mock_resp
520+
mock_requests.side_effect = mocked_requests_get
512521

513-
response = client.post(
514-
"/submit",
515-
data=json.dumps(data),
516-
headers={
517-
"Content-Type": "application/json",
518-
"Authorization": EXAMPLE_AUTH_HEADER,
519-
},
520-
)
521-
assert response.status_code == 200
522-
assert mock_engine.call_args.args[2] == None
522+
data["user_tags"] = {"tags": {}}
523+
response = client.post(
524+
"/submit",
525+
data=json.dumps(data),
526+
headers={
527+
"Content-Type": "application/json",
528+
"Authorization": EXAMPLE_AUTH_HEADER,
529+
},
530+
)
531+
assert response.status_code == 200
532+
assert mock_engine.call_args.args[2] == None
523533

524-
mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}})
534+
data["user_tags"] = {"tags": {"othertag1": "tag1"}}
525535

526-
response = client.post(
527-
"/submit",
528-
data=json.dumps(data),
529-
headers={
530-
"Content-Type": "application/json",
531-
"Authorization": EXAMPLE_AUTH_HEADER,
532-
},
533-
)
534-
assert response.status_code == 200
535-
assert mock_engine.call_args.args[2] == None
536+
response = client.post(
537+
"/submit",
538+
data=json.dumps(data),
539+
headers={
540+
"Content-Type": "application/json",
541+
"Authorization": EXAMPLE_AUTH_HEADER,
542+
},
543+
)
544+
assert response.status_code == 200
545+
assert mock_engine.call_args.args[2] == None
536546

537-
mock_resp.json = mock.Mock(
538-
return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}}
539-
)
540-
with patch(
541-
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
542-
) as mock_check_monthly_cap:
543-
mock_check_monthly_cap.return_value = False
544-
response = client.post(
545-
"/submit",
546-
data=json.dumps(data),
547-
headers={
548-
"Content-Type": "application/json",
549-
"Authorization": EXAMPLE_AUTH_HEADER,
550-
},
551-
)
552-
assert mock_engine.call_args.args[2] == "1234"
553-
554-
mock_resp.status_code == 500
555-
mock_resp.raise_for_status.side_effect = Exception("fence is down")
547+
data["user_tags"] = {"tags": {"othertag1": "tag1", "billing_id": "1234"}}
548+
549+
with patch(
550+
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
551+
) as mock_check_monthly_cap:
552+
mock_check_monthly_cap.return_value = False
556553
response = client.post(
557554
"/submit",
558555
data=json.dumps(data),
@@ -561,8 +558,19 @@ def test_submit_workflow_with_user_billing_id(client):
561558
"Authorization": EXAMPLE_AUTH_HEADER,
562559
},
563560
)
564-
assert response.status_code == 500
565-
assert "fence is down" in str(response.content)
561+
assert mock_engine.call_args.args[2] == "1234"
562+
563+
data["user_tags"] == 500
564+
565+
response = client.post(
566+
"/submit",
567+
data=json.dumps(data),
568+
headers={
569+
"Content-Type": "application/json",
570+
"Authorization": EXAMPLE_AUTH_HEADER,
571+
},
572+
)
573+
assert response.status_code == 500
566574

567575

568576
def test_check_user_reached_monthly_workflow_cap():
@@ -600,11 +608,14 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
600608
"argowrapper.routes.routes.check_user_billing_id"
601609
) as mock_check_billing_id, patch(
602610
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
603-
) as mock_check_monthly_cap:
611+
) as mock_check_monthly_cap, patch(
612+
"requests.get"
613+
) as mock_requests:
604614
mock_auth.return_value = True
605615
mock_engine.return_value = "workflow_123"
606616
mock_check_billing_id.return_value = "1234"
607617
mock_check_monthly_cap.return_value = True
618+
mock_requests.side_effect = mocked_requests_get
608619

609620
response = client.post(
610621
"/submit",
@@ -615,3 +626,31 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
615626
},
616627
)
617628
assert response.status_code == 403
629+
630+
631+
def test_submit_workflow_with_non_team_project_cohort(client):
632+
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
633+
"argowrapper.routes.routes.argo_engine.workflow_submission"
634+
) as mock_engine, patch(
635+
"argowrapper.routes.routes.log_auth_check_type"
636+
) as mock_log, patch(
637+
"argowrapper.routes.routes.check_user_billing_id"
638+
) as mock_check_billing_id, patch(
639+
"requests.get"
640+
) as mock_requests:
641+
mock_auth.return_value = True
642+
mock_engine.return_value = "workflow_123"
643+
mock_check_billing_id.return_value = None
644+
mock_requests.side_effect = mocked_requests_get
645+
646+
data["outcome"]["cohort_ids"] = [4]
647+
648+
response = client.post(
649+
"/submit",
650+
data=json.dumps(data),
651+
headers={
652+
"Content-Type": "application/json",
653+
"Authorization": EXAMPLE_AUTH_HEADER,
654+
},
655+
)
656+
assert response.status_code == 401

0 commit comments

Comments
 (0)