Skip to content

Commit 30fb6b4

Browse files
Merge pull request #83 from uc-cdis/feat/integrate_arborist_validation_for_team_project_for_cohort_data_endpoints
Feat: integrate Arborist validation for team project for cohort data endpoints AND remove unused endpoints
2 parents 198efbf + f7b7fc7 commit 30fb6b4

File tree

10 files changed

+117
-122
lines changed

10 files changed

+117
-122
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ cd tests/setup_local_db/
118118
JSON summary data endpoints:
119119
```bash
120120
curl http://localhost:8080/sources | python -m json.tool
121-
curl http://localhost:8080/cohortdefinition-stats/by-source-id/1 | python -m json.tool
121+
curl "http://localhost:8080/cohortdefinition-stats/by-source-id/1/by-team-project?team-project=test" | python -m json.tool
122122
curl http://localhost:8080/concept/by-source-id/1 | python -m json.tool
123123
curl -d '{"ConceptIds":[2000000324,2000006885]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1 | python -m json.tool
124124
curl -d '{"ConceptTypes":["Measurement","Person"]}' -H "Content-Type: application/json" -X POST http://localhost:8080/concept/by-source-id/1/by-type | python -m json.tool

controllers/cohortdata.go

+32-3
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@ import (
99
"strconv"
1010

1111
"github.com/gin-gonic/gin"
12+
"github.com/uc-cdis/cohort-middleware/middlewares"
1213
"github.com/uc-cdis/cohort-middleware/models"
1314
"github.com/uc-cdis/cohort-middleware/utils"
1415
)
1516

1617
type CohortDataController struct {
17-
cohortDataModel models.CohortDataI
18+
cohortDataModel models.CohortDataI
19+
teamProjectAuthz middlewares.TeamProjectAuthzI
1820
}
1921

20-
func NewCohortDataController(cohortDataModel models.CohortDataI) CohortDataController {
21-
return CohortDataController{cohortDataModel: cohortDataModel}
22+
func NewCohortDataController(cohortDataModel models.CohortDataI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDataController {
23+
return CohortDataController{
24+
cohortDataModel: cohortDataModel,
25+
teamProjectAuthz: teamProjectAuthz,
26+
}
2227
}
2328

2429
func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Context) {
@@ -44,6 +49,14 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
4449
cohortId, _ := strconv.Atoi(cohortIdStr)
4550
histogramConceptId, _ := strconv.ParseInt(histogramIdStr, 10, 64)
4651

52+
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
53+
if !validAccessRequest {
54+
log.Printf("Error: invalid request")
55+
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
56+
c.Abort()
57+
return
58+
}
59+
4760
cohortData, err := u.cohortDataModel.RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCohortPairs(sourceId, cohortId, histogramConceptId, filterConceptIds, cohortPairs)
4861
if err != nil {
4962
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving concept details", "error": err.Error()})
@@ -85,6 +98,14 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
8598
sourceId, _ := strconv.Atoi(sourceIdStr)
8699
cohortId, _ := strconv.Atoi(cohortIdStr)
87100

101+
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
102+
if !validAccessRequest {
103+
log.Printf("Error: invalid request")
104+
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
105+
c.Abort()
106+
return
107+
}
108+
88109
// call model method:
89110
cohortData, err := u.cohortDataModel.RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPersonId(sourceId, cohortId, conceptIds)
90111
if err != nil {
@@ -230,6 +251,14 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
230251
controlCohortId, errors[2] = utils.ParseNumericArg(c, "controlcohortid")
231252
conceptIds, cohortPairs, errors[3] = utils.ParseConceptIdsAndDichotomousDefs(c)
232253

254+
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs)
255+
if !validAccessRequest {
256+
log.Printf("Error: invalid request")
257+
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
258+
c.Abort()
259+
return
260+
}
261+
233262
if utils.ContainsNonNil(errors) {
234263
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
235264
c.Abort()

controllers/cohortdefinition.go

+1-48
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package controllers
22

33
import (
44
"net/http"
5-
"strconv"
65

76
"github.com/gin-gonic/gin"
87
"github.com/uc-cdis/cohort-middleware/models"
@@ -17,56 +16,10 @@ func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinition
1716
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
1817
}
1918

20-
func (u CohortDefinitionController) RetriveById(c *gin.Context) {
21-
cohortDefinitionId := c.Param("id")
22-
23-
if cohortDefinitionId != "" {
24-
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
25-
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
26-
if err != nil {
27-
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
28-
c.Abort()
29-
return
30-
}
31-
c.JSON(http.StatusOK, gin.H{"cohort_definition": cohortDefinition})
32-
return
33-
}
34-
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
35-
c.Abort()
36-
}
37-
38-
func (u CohortDefinitionController) RetriveByName(c *gin.Context) {
39-
cohortDefinitionName := c.Param("name")
40-
41-
if cohortDefinitionName != "" {
42-
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionByName(cohortDefinitionName)
43-
if err != nil {
44-
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
45-
c.Abort()
46-
return
47-
}
48-
c.JSON(http.StatusOK, gin.H{"CohortDefinition": cohortDefinition})
49-
return
50-
}
51-
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
52-
c.Abort()
53-
}
54-
55-
func (u CohortDefinitionController) RetriveAll(c *gin.Context) {
56-
cohortDefinitions, err := u.cohortDefinitionModel.GetAllCohortDefinitions()
57-
if err != nil {
58-
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
59-
c.Abort()
60-
return
61-
}
62-
c.JSON(http.StatusOK, gin.H{"cohort_definitions": cohortDefinitions})
63-
}
64-
6519
func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.Context) {
6620
// This method returns ALL cohortdefinition entries with cohort size statistics (for a given source)
67-
6821
sourceId, err1 := utils.ParseNumericArg(c, "sourceid")
69-
teamProject := c.Param("teamproject")
22+
teamProject := c.Query("team-project")
7023
if teamProject == "" {
7124
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error while parsing request", "error": "team-project is a mandatory parameter but was found to be empty!"})
7225
c.Abort()

controllers/concept.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
132132
c.Abort()
133133
return
134134
}
135-
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
135+
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
136136
if !validAccessRequest {
137137
log.Printf("Error: invalid request")
138138
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
@@ -198,7 +198,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
198198
return
199199
}
200200
_, cohortPairs := utils.GetConceptIdsAndCohortPairsAsSeparateLists(conceptIdsAndCohortPairs)
201-
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, cohortId, cohortPairs)
201+
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
202202
if !validAccessRequest {
203203
log.Printf("Error: invalid request")
204204
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})

middlewares/teamprojectauthz.go

+11-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ import (
1111

1212
type TeamProjectAuthzI interface {
1313
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
14-
TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
14+
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
15+
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
1516
}
1617

1718
type HttpClientI interface {
@@ -58,16 +59,20 @@ func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects [
5859

5960
func (u TeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
6061
filterCohortPairs := []utils.CustomDichotomousVariableDef{}
61-
return u.TeamProjectValidation(ctx, cohortDefinitionId, filterCohortPairs)
62+
return u.TeamProjectValidation(ctx, []int{cohortDefinitionId}, filterCohortPairs)
63+
}
64+
65+
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {
66+
67+
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionIds, filterCohortPairs)
68+
return u.TeamProjectValidationForCohortIdsList(ctx, uniqueCohortDefinitionIdsList)
6269
}
6370

6471
// "team project" related checks:
65-
// (1) check if the request contains any cohorts and if all cohorts belong to the same "team project"
72+
// (1) check if all cohorts belong to the same "team project"
6673
// (2) check if the user has permission in the "team project"
6774
// Returns true if both checks above pass, false otherwise.
68-
func (u TeamProjectAuthz) TeamProjectValidation(ctx *gin.Context, cohortDefinitionId int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool {
69-
70-
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest(cohortDefinitionId, filterCohortPairs)
75+
func (u TeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool {
7176
teamProjects, _ := u.cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
7277
if len(teamProjects) == 0 {
7378
log.Printf("Invalid request error: could not find a 'team project' that is associated to ALL the cohorts present in this request")

server/router.go

+2-5
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,7 @@ func NewRouter() *gin.Engine {
2929
authorized.GET("/sources", source.RetriveAll)
3030

3131
cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
32-
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
33-
authorized.GET("/cohortdefinition/by-name/:name", cohortdefinitions.RetriveByName)
34-
authorized.GET("/cohortdefinitions", cohortdefinitions.RetriveAll)
35-
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project/:teamproject", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
32+
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)
3633

3734
// concept endpoints:
3835
concepts := controllers.NewConceptController(*new(models.Concept), *new(models.CohortDefinition),
@@ -46,7 +43,7 @@ func NewRouter() *gin.Engine {
4643
authorized.POST("/concept-stats/by-source-id/:sourceid/by-cohort-definition-id/:cohortid/breakdown-by-concept-id/:breakdownconceptid/csv", concepts.RetrieveAttritionTable)
4744

4845
// cohort stats and checks:
49-
cohortData := controllers.NewCohortDataController(*new(models.CohortData))
46+
cohortData := controllers.NewCohortDataController(*new(models.CohortData), middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
5047
// :casecohortid/:controlcohortid are just labels here and have no special meaning. Could also just be :cohortAId/:cohortBId here:
5148
authorized.POST("/cohort-stats/check-overlap/by-source-id/:sourceid/by-cohort-definition-ids/:casecohortid/:controlcohortid", cohortData.RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue)
5249

0 commit comments

Comments
 (0)