Skip to content

Commit e786c35

Browse files
authored
Merge pull request #123 from uc-cdis/feat/VADC-900
VADC-900
2 parents a810d32 + 6581e6a commit e786c35

File tree

3 files changed

+217
-52
lines changed

3 files changed

+217
-52
lines changed

src/argowrapper/auth/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import requests
2+
from argowrapper import logger
3+
4+
5+
def get_cohort_ids_for_team_project(token, source_id, team_project):
6+
header = {"Authorization": token, "cookie": "fence={}".format(token)}
7+
url = "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={}".format(
8+
source_id, team_project
9+
)
10+
11+
try:
12+
r = requests.get(url=url, headers=header)
13+
r.raise_for_status()
14+
team_cohort_info = r.json()
15+
team_cohort_id_set = set()
16+
if "cohort_definitions_and_stats" in team_cohort_info:
17+
for t in team_cohort_info["cohort_definitions_and_stats"]:
18+
if "cohort_definition_id" in t:
19+
team_cohort_id_set.add(t["cohort_definition_id"])
20+
return team_cohort_id_set
21+
except Exception as e:
22+
exception = Exception("Could not get team project cohort ids", e)
23+
logger.error(exception)
24+
raise exception

src/argowrapper/routes/routes.py

+66
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
HTTP_401_UNAUTHORIZED,
1010
HTTP_403_FORBIDDEN,
1111
HTTP_500_INTERNAL_SERVER_ERROR,
12+
HTTP_400_BAD_REQUEST,
1213
)
1314
from argowrapper.constants import (
1415
TEAM_PROJECT_FIELD_NAME,
@@ -21,6 +22,8 @@
2122
from argowrapper import logger
2223
from argowrapper.auth import Auth
2324
from argowrapper.engine.argo_engine import ArgoEngine
25+
from argowrapper.auth.utils import get_cohort_ids_for_team_project
26+
2427
import argowrapper.engine.helpers.argo_engine_helper as argo_engine_helper
2528

2629
import requests
@@ -144,6 +147,68 @@ def wrapper(*args, **kwargs):
144147
return wrapper
145148

146149

150+
def check_team_projects_and_cohorts(fn):
151+
"""custom annotation to make sure cohort in request belong to user's team project"""
152+
153+
@wraps(fn)
154+
def wrapper(*args, **kwargs):
155+
156+
token = kwargs["request"].headers.get("Authorization")
157+
request_body = kwargs["request_body"]
158+
team_project = request_body[TEAM_PROJECT_FIELD_NAME]
159+
source_id = request_body["source_id"]
160+
161+
# Construct set with all cohort ids requested
162+
cohort_ids = []
163+
if "cohort_ids" in request_body["outcome"]:
164+
cohort_ids.extend(request_body["outcome"]["cohort_ids"])
165+
166+
variables = request_body["variables"]
167+
for v in variables:
168+
if "cohort_ids" in v:
169+
cohort_ids.extend(v["cohort_ids"])
170+
171+
if "source_population_cohort" in request_body:
172+
cohort_ids.append(request_body["source_population_cohort"])
173+
174+
cohort_id_set = set(cohort_ids)
175+
176+
if team_project and source_id and len(team_project) > 0 and len(cohort_ids) > 0:
177+
# Get team project cohort ids
178+
team_cohort_id_set = get_cohort_ids_for_team_project(
179+
token, source_id, team_project
180+
)
181+
182+
logger.debug("cohort ids are " + " ".join(str(c) for c in cohort_ids))
183+
logger.debug(
184+
"team cohort ids are " + " ".join(str(c) for c in team_cohort_id_set)
185+
)
186+
187+
# Compare the two sets
188+
if cohort_id_set.issubset(team_cohort_id_set):
189+
logger.debug(
190+
"cohort ids submitted all belong to the same team project. Continue.."
191+
)
192+
return fn(*args, **kwargs)
193+
else:
194+
logger.error(
195+
"Cohort ids submitted do NOT all belong to the same team project."
196+
)
197+
return HTMLResponse(
198+
content="Cohort ids submitted do NOT all belong to the same team project.",
199+
status_code=HTTP_400_BAD_REQUEST,
200+
)
201+
202+
else:
203+
# some required parameters is missing, return bad request:
204+
return HTMLResponse(
205+
content="Missing required parameters",
206+
status_code=HTTP_400_BAD_REQUEST,
207+
)
208+
209+
return wrapper
210+
211+
147212
def check_user_billing_id(request):
148213
"""
149214
Check whether user is non-VA user
@@ -217,6 +282,7 @@ def test():
217282
# submit argo workflow
218283
@router.post("/submit", status_code=HTTP_200_OK)
219284
@check_auth_and_team_project
285+
@check_team_projects_and_cohorts
220286
def submit_workflow(
221287
request_body: Dict[Any, Any],
222288
request: Request, # pylint: disable=unused-argument

test/test_routes.py

+127-52
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,30 @@
2222
"variables": variables,
2323
"hare_population": "hare",
2424
"out_prefix": "vadc_genesis",
25-
"outcome": 1,
25+
"outcome": {
26+
"variable_type": "custom_dichotomous",
27+
"cohort_ids": [2],
28+
"provided_name": "test Pheno",
29+
},
2630
"maf_threshold": 0.01,
2731
"imputation_score_cutoff": 0.3,
2832
"template_version": "gwas-template-latest",
2933
"source_id": 4,
3034
"case_cohort_definition_id": 70,
3135
"control_cohort_definition_id": -1,
36+
"source_population_cohort": 4,
3237
"workflow_name": "wf_name",
3338
TEAM_PROJECT_FIELD_NAME: "dummy-team-project",
39+
"user_tags": None, # For testing purpose
40+
}
41+
42+
cohort_definition_data = {
43+
"cohort_definitions_and_stats": [
44+
{"cohort_definition_id": 1, "cohort_name": "Cohort 1", "size": 1},
45+
{"cohort_definition_id": 2, "cohort_name": "Cohort 2", "size": 2},
46+
{"cohort_definition_id": 3, "cohort_name": "Cohort 3", "size": 3},
47+
{"cohort_definition_id": 4, "cohort_name": "Cohort 4", "size": 4},
48+
]
3449
}
3550

3651

@@ -55,17 +70,50 @@ def client(app: FastAPI) -> Generator[TestClient, Any, None]:
5570
yield client
5671

5772

73+
def mocked_requests_get(*args, **kwargs):
74+
class MockResponse:
75+
def __init__(self, json_data, status_code):
76+
self.json_data = json_data
77+
self.status_code = status_code
78+
79+
def json(self):
80+
return self.json_data
81+
82+
def raise_for_status(self):
83+
if self.status_code == 500:
84+
raise Exception("fence is down")
85+
if self.status_code != 200:
86+
raise Exception()
87+
88+
if (
89+
kwargs["url"]
90+
== "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/4/by-team-project?team-project=dummy-team-project"
91+
):
92+
return MockResponse(cohort_definition_data, 200)
93+
94+
if kwargs["url"] == "http://fence-service/user":
95+
if data["user_tags"] != 500:
96+
return MockResponse(data["user_tags"], 200)
97+
else:
98+
return MockResponse({}, 500)
99+
100+
return None
101+
102+
58103
def test_submit_workflow(client):
59104
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
60105
"argowrapper.routes.routes.argo_engine.workflow_submission"
61106
) as mock_engine, patch(
62107
"argowrapper.routes.routes.log_auth_check_type"
63108
) as mock_log, patch(
64109
"argowrapper.routes.routes.check_user_billing_id"
65-
) as mock_check_billing_id:
110+
) as mock_check_billing_id, patch(
111+
"requests.get"
112+
) as mock_requests:
66113
mock_auth.return_value = True
67114
mock_engine.return_value = "workflow_123"
68115
mock_check_billing_id.return_value = None
116+
mock_requests.side_effect = mocked_requests_get
69117

70118
response = client.post(
71119
"/submit",
@@ -466,59 +514,44 @@ def test_submit_workflow_with_user_billing_id(client):
466514
"argowrapper.routes.routes.argo_engine.workflow_submission"
467515
) as mock_engine, patch(
468516
"argowrapper.routes.routes.log_auth_check_type"
469-
) as mock_log:
517+
) as mock_log, patch(
518+
"requests.get"
519+
) as mock_requests:
470520
mock_auth.return_value = True
471521
mock_engine.return_value = "workflow_123"
472-
with patch("requests.get") as mock_request:
473-
mock_resp = mock.Mock()
474-
mock_resp.status_code = 200
475-
mock_resp.raise_for_status = mock.Mock()
476-
mock_resp.json = mock.Mock(return_value={"tags": {}})
477-
mock_request.return_value = mock_resp
522+
mock_requests.side_effect = mocked_requests_get
478523

479-
response = client.post(
480-
"/submit",
481-
data=json.dumps(data),
482-
headers={
483-
"Content-Type": "application/json",
484-
"Authorization": EXAMPLE_AUTH_HEADER,
485-
},
486-
)
487-
assert response.status_code == 200
488-
assert mock_engine.call_args.args[2] == None
524+
data["user_tags"] = {"tags": {}}
525+
response = client.post(
526+
"/submit",
527+
data=json.dumps(data),
528+
headers={
529+
"Content-Type": "application/json",
530+
"Authorization": EXAMPLE_AUTH_HEADER,
531+
},
532+
)
533+
assert response.status_code == 200
534+
assert mock_engine.call_args.args[2] == None
489535

490-
mock_resp.json = mock.Mock(return_value={"tags": {"othertag1": "tag1"}})
536+
data["user_tags"] = {"tags": {"othertag1": "tag1"}}
491537

492-
response = client.post(
493-
"/submit",
494-
data=json.dumps(data),
495-
headers={
496-
"Content-Type": "application/json",
497-
"Authorization": EXAMPLE_AUTH_HEADER,
498-
},
499-
)
500-
assert response.status_code == 200
501-
assert mock_engine.call_args.args[2] == None
538+
response = client.post(
539+
"/submit",
540+
data=json.dumps(data),
541+
headers={
542+
"Content-Type": "application/json",
543+
"Authorization": EXAMPLE_AUTH_HEADER,
544+
},
545+
)
546+
assert response.status_code == 200
547+
assert mock_engine.call_args.args[2] == None
502548

503-
mock_resp.json = mock.Mock(
504-
return_value={"tags": {"othertag1": "tag1", "billing_id": "1234"}}
505-
)
506-
with patch(
507-
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
508-
) as mock_check_monthly_cap:
509-
mock_check_monthly_cap.return_value = False
510-
response = client.post(
511-
"/submit",
512-
data=json.dumps(data),
513-
headers={
514-
"Content-Type": "application/json",
515-
"Authorization": EXAMPLE_AUTH_HEADER,
516-
},
517-
)
518-
assert mock_engine.call_args.args[2] == "1234"
519-
520-
mock_resp.status_code == 500
521-
mock_resp.raise_for_status.side_effect = Exception("fence is down")
549+
data["user_tags"] = {"tags": {"othertag1": "tag1", "billing_id": "1234"}}
550+
551+
with patch(
552+
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
553+
) as mock_check_monthly_cap:
554+
mock_check_monthly_cap.return_value = False
522555
response = client.post(
523556
"/submit",
524557
data=json.dumps(data),
@@ -527,8 +560,19 @@ def test_submit_workflow_with_user_billing_id(client):
527560
"Authorization": EXAMPLE_AUTH_HEADER,
528561
},
529562
)
530-
assert response.status_code == 500
531-
assert "fence is down" in str(response.content)
563+
assert mock_engine.call_args.args[2] == "1234"
564+
565+
data["user_tags"] == 500
566+
567+
response = client.post(
568+
"/submit",
569+
data=json.dumps(data),
570+
headers={
571+
"Content-Type": "application/json",
572+
"Authorization": EXAMPLE_AUTH_HEADER,
573+
},
574+
)
575+
assert response.status_code == 500
532576

533577

534578
def test_check_user_reached_monthly_workflow_cap():
@@ -566,11 +610,14 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
566610
"argowrapper.routes.routes.check_user_billing_id"
567611
) as mock_check_billing_id, patch(
568612
"argowrapper.routes.routes.check_user_reached_monthly_workflow_cap"
569-
) as mock_check_monthly_cap:
613+
) as mock_check_monthly_cap, patch(
614+
"requests.get"
615+
) as mock_requests:
570616
mock_auth.return_value = True
571617
mock_engine.return_value = "workflow_123"
572618
mock_check_billing_id.return_value = "1234"
573619
mock_check_monthly_cap.return_value = True
620+
mock_requests.side_effect = mocked_requests_get
574621

575622
response = client.post(
576623
"/submit",
@@ -581,3 +628,31 @@ def test_submit_workflow_with_billing_id_and_over_monthly_cap(client):
581628
},
582629
)
583630
assert response.status_code == 403
631+
632+
633+
def test_submit_workflow_with_non_team_project_cohort(client):
634+
with patch("argowrapper.routes.routes.auth.authenticate") as mock_auth, patch(
635+
"argowrapper.routes.routes.argo_engine.workflow_submission"
636+
) as mock_engine, patch(
637+
"argowrapper.routes.routes.log_auth_check_type"
638+
) as mock_log, patch(
639+
"argowrapper.routes.routes.check_user_billing_id"
640+
) as mock_check_billing_id, patch(
641+
"requests.get"
642+
) as mock_requests:
643+
mock_auth.return_value = True
644+
mock_engine.return_value = "workflow_123"
645+
mock_check_billing_id.return_value = None
646+
mock_requests.side_effect = mocked_requests_get
647+
648+
data["outcome"]["cohort_ids"] = [400]
649+
650+
response = client.post(
651+
"/submit",
652+
data=json.dumps(data),
653+
headers={
654+
"Content-Type": "application/json",
655+
"Authorization": EXAMPLE_AUTH_HEADER,
656+
},
657+
)
658+
assert response.status_code == 400

0 commit comments

Comments
 (0)