|
22 | 22 | from argowrapper import logger
|
23 | 23 | from argowrapper.auth import Auth
|
24 | 24 | from argowrapper.engine.argo_engine import ArgoEngine
|
| 25 | +from argowrapper.utils import get_team_cohort_id |
| 26 | + |
25 | 27 | import argowrapper.engine.helpers.argo_engine_helper as argo_engine_helper
|
26 | 28 |
|
27 | 29 | import requests
|
@@ -150,74 +152,51 @@ def check_team_projects_and_cohorts(fn):
|
150 | 152 |
|
151 | 153 | @wraps(fn)
|
152 | 154 | def wrapper(*args, **kwargs):
|
153 |
| - import json |
154 | 155 |
|
155 |
| - logger.info("arguments:") |
156 |
| - logger.info(json.dumps(kwargs["request_body"])) |
157 |
| - request = kwargs["request"] |
158 |
| - token = request.headers.get("Authorization") |
| 156 | + token = kwargs["request"].headers.get("Authorization") |
159 | 157 | request_body = kwargs["request_body"]
|
160 |
| - |
161 | 158 | team_project = request_body[TEAM_PROJECT_FIELD_NAME]
|
162 |
| - cohort_ids = [] |
163 | 159 | source_id = request_body["source_id"]
|
164 |
| - if "outcome" in request_body and "cohort_ids" in request_body["outcome"]: |
| 160 | + |
| 161 | + # Construct set with all cohort ids requested |
| 162 | + cohort_ids = [] |
| 163 | + if "cohort_ids" in request_body["outcome"]: |
165 | 164 | cohort_ids.extend(request_body["outcome"]["cohort_ids"])
|
166 | 165 |
|
167 |
| - if "variables" in request_body: |
168 |
| - variables = request_body["variables"] |
169 |
| - for v in variables: |
170 |
| - if "cohort_ids" in v: |
171 |
| - cohort_ids.extend(v["cohort_ids"]) |
| 166 | + variables = request_body["variables"] |
| 167 | + for v in variables: |
| 168 | + if "cohort_ids" in v: |
| 169 | + cohort_ids.extend(v["cohort_ids"]) |
172 | 170 |
|
173 | 171 | if "source_population_cohort" in request_body:
|
174 | 172 | cohort_ids.append(request_body["source_population_cohort"])
|
175 | 173 |
|
| 174 | + cohort_id_set = set(cohort_ids) |
| 175 | + |
176 | 176 | if team_project and source_id and len(team_project) > 0 and len(cohort_ids) > 0:
|
177 |
| - header = {"Authorization": token, "cookie": "fence={}".format(token)} |
178 |
| - url = "http://cohort-middleware-service/cohortdefinition-stats/by-source-id/{}/by-team-project?team-project={}".format( |
179 |
| - source_id, team_project |
| 177 | + # Get team project cohort ids |
| 178 | + team_cohort_id_set = get_team_cohort_id(token, source_id, team_project) |
| 179 | + |
| 180 | + logger.debug("cohort ids are " + " ".join(str(c) for c in cohort_ids)) |
| 181 | + logger.debug( |
| 182 | + "team cohort ids are " + " ".join(str(c) for c in team_cohort_id_set) |
180 | 183 | )
|
181 | 184 |
|
182 |
| - logger.info("team project is " + team_project) |
183 |
| - logger.info("source_id is " + str(source_id)) |
184 |
| - logger.info("request url is " + url) |
185 |
| - |
186 |
| - try: |
187 |
| - r = requests.get(url=url, headers=header) |
188 |
| - r.raise_for_status() |
189 |
| - team_cohort_info = r.json() |
190 |
| - team_cohort_id_set = set() |
191 |
| - if "cohort_definitions_and_stats" in team_cohort_info: |
192 |
| - for t in team_cohort_info["cohort_definitions_and_stats"]: |
193 |
| - if "cohort_definition_id" in t: |
194 |
| - team_cohort_id_set.add(t["cohort_definition_id"]) |
195 |
| - cohort_id_set = set(cohort_ids) |
196 |
| - |
197 |
| - logger.info("cohort ids are " + " ".join(str(c) for c in cohort_ids)) |
198 |
| - logger.info( |
199 |
| - "team cohort ids are " |
200 |
| - + " ".join(str(c) for c in team_cohort_id_set) |
| 185 | + # Compare the two sets |
| 186 | + if cohort_id_set.issubset(team_cohort_id_set): |
| 187 | + logger.debug( |
| 188 | + "cohort ids submitted all belong to the same team project. Continue.." |
| 189 | + ) |
| 190 | + return fn(*args, **kwargs) |
| 191 | + else: |
| 192 | + logger.error( |
| 193 | + "Cohort ids submitted do NOT all belong to the same team project." |
| 194 | + ) |
| 195 | + return HTMLResponse( |
| 196 | + content="Cohort ids submitted do NOT all belong to the same team project.", |
| 197 | + status_code=HTTP_401_UNAUTHORIZED, |
201 | 198 | )
|
202 | 199 |
|
203 |
| - if cohort_id_set.issubset(team_cohort_id_set): |
204 |
| - logger.info( |
205 |
| - "cohort ids submitted belong to user's team project. Continue.." |
206 |
| - ) |
207 |
| - return fn(*args, **kwargs) |
208 |
| - else: |
209 |
| - logger.error( |
210 |
| - "cohort ids submitted do NOT belong to user's team project." |
211 |
| - ) |
212 |
| - return HTMLResponse( |
213 |
| - content="Cohort ids submitted do NOT belong to user's team project.", |
214 |
| - status_code=HTTP_401_UNAUTHORIZED, |
215 |
| - ) |
216 |
| - except Exception as e: |
217 |
| - exception = Exception("Something went wrong", e) |
218 |
| - logger.error(exception) |
219 |
| - traceback.print_exc() |
220 |
| - raise exception |
221 | 200 | else:
|
222 | 201 | # some required parameters is missing, return bad request:
|
223 | 202 | return HTMLResponse(
|
|
0 commit comments