Skip to content

Commit c4afe8c

Browse files
committed
address comments
1 parent 6f40589 commit c4afe8c

File tree

2 files changed

+56
-53
lines changed

2 files changed

+56
-53
lines changed

src/argowrapper/routes/routes.py

+32-53
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from argowrapper import logger
2323
from argowrapper.auth import Auth
2424
from argowrapper.engine.argo_engine import ArgoEngine
25+
from argowrapper.utils import get_team_cohort_id
26+
2527
import argowrapper.engine.helpers.argo_engine_helper as argo_engine_helper
2628

2729
import requests
@@ -150,74 +152,51 @@ def check_team_projects_and_cohorts(fn):
150152

151153
@wraps(fn)
152154
def wrapper(*args, **kwargs):
153-
import json
154155

155-
logger.info("arguments:")
156-
logger.info(json.dumps(kwargs["request_body"]))
157-
request = kwargs["request"]
158-
token = request.headers.get("Authorization")
156+
token = kwargs["request"].headers.get("Authorization")
159157
request_body = kwargs["request_body"]
160-
161158
team_project = request_body[TEAM_PROJECT_FIELD_NAME]
162-
cohort_ids = []
163159
source_id = request_body["source_id"]
164-
if "outcome" in request_body and "cohort_ids" in request_body["outcome"]:
160+
161+
# Construct set with all cohort ids requested
162+
cohort_ids = []
163+
if "cohort_ids" in request_body["outcome"]:
165164
cohort_ids.extend(request_body["outcome"]["cohort_ids"])
166165

167-
if "variables" in request_body:
168-
variables = request_body["variables"]
169-
for v in variables:
170-
if "cohort_ids" in v:
171-
cohort_ids.extend(v["cohort_ids"])
166+
variables = request_body["variables"]
167+
for v in variables:
168+
if "cohort_ids" in v:
169+
cohort_ids.extend(v["cohort_ids"])
172170

173171
if "source_population_cohort" in request_body:
174172
cohort_ids.append(request_body["source_population_cohort"])
175173

174+
cohort_id_set = set(cohort_ids)
175+
176176
if team_project and source_id and len(team_project) > 0 and len(cohort_ids) > 0:
177-
header = {"Authorization": token, "cookie": "fence={}".format(token)}
178-
url = "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={}".format(
179-
source_id, team_project
177+
# Get team project cohort ids
178+
team_cohort_id_set = get_team_cohort_id(token, source_id, team_project)
179+
180+
logger.debug("cohort ids are " + " ".join(str(c) for c in cohort_ids))
181+
logger.debug(
182+
"team cohort ids are " + " ".join(str(c) for c in team_cohort_id_set)
180183
)
181184

182-
logger.info("team project is " + team_project)
183-
logger.info("source_id is " + str(source_id))
184-
logger.info("request url is " + url)
185-
186-
try:
187-
r = requests.get(url=url, headers=header)
188-
r.raise_for_status()
189-
team_cohort_info = r.json()
190-
team_cohort_id_set = set()
191-
if "cohort_definitions_and_stats" in team_cohort_info:
192-
for t in team_cohort_info["cohort_definitions_and_stats"]:
193-
if "cohort_definition_id" in t:
194-
team_cohort_id_set.add(t["cohort_definition_id"])
195-
cohort_id_set = set(cohort_ids)
196-
197-
logger.info("cohort ids are " + " ".join(str(c) for c in cohort_ids))
198-
logger.info(
199-
"team cohort ids are "
200-
+ " ".join(str(c) for c in team_cohort_id_set)
185+
# Compare the two sets
186+
if cohort_id_set.issubset(team_cohort_id_set):
187+
logger.debug(
188+
"cohort ids submitted all belong to the same team project. Continue.."
189+
)
190+
return fn(*args, **kwargs)
191+
else:
192+
logger.error(
193+
"Cohort ids submitted do NOT all belong to the same team project."
194+
)
195+
return HTMLResponse(
196+
content="Cohort ids submitted do NOT all belong to the same team project.",
197+
status_code=HTTP_401_UNAUTHORIZED,
201198
)
202199

203-
if cohort_id_set.issubset(team_cohort_id_set):
204-
logger.info(
205-
"cohort ids submitted belong to user's team project. Continue.."
206-
)
207-
return fn(*args, **kwargs)
208-
else:
209-
logger.error(
210-
"cohort ids submitted do NOT belong to user's team project."
211-
)
212-
return HTMLResponse(
213-
content="Cohort ids submitted do NOT belong to user's team project.",
214-
status_code=HTTP_401_UNAUTHORIZED,
215-
)
216-
except Exception as e:
217-
exception = Exception("Something went wrong", e)
218-
logger.error(exception)
219-
traceback.print_exc()
220-
raise exception
221200
else:
222201
# some required parameters is missing, return bad request:
223202
return HTMLResponse(

src/argowrapper/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import requests
2+
from argowrapper import logger
3+
4+
5+
def get_team_cohort_id(token, source_id, team_project):
6+
header = {"Authorization": token, "cookie": "fence={}".format(token)}
7+
url = "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={}".format(
8+
source_id, team_project
9+
)
10+
11+
try:
12+
r = requests.get(url=url, headers=header)
13+
r.raise_for_status()
14+
team_cohort_info = r.json()
15+
team_cohort_id_set = set()
16+
if "cohort_definitions_and_stats" in team_cohort_info:
17+
for t in team_cohort_info["cohort_definitions_and_stats"]:
18+
if "cohort_definition_id" in t:
19+
team_cohort_id_set.add(t["cohort_definition_id"])
20+
return team_cohort_id_set
21+
except Exception as e:
22+
exception = Exception("Could not get team project cohort ids", e)
23+
logger.error(exception)
24+
raise exception

0 commit comments

Comments
 (0)