Skip to content

Commit 5e1839b

Browse files
Merge pull request #87 from uc-cdis/feat/add_missing_arborist_checks
Feat/add missing arborist checks
2 parents ecb2237 + aa75cb3 commit 5e1839b

File tree

8 files changed

+191
-46
lines changed

8 files changed

+191
-46
lines changed

controllers/cohortdata.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func (u CohortDataController) RetrieveHistogramForCohortIdAndConceptId(c *gin.Co
5252
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
5353
if !validAccessRequest {
5454
log.Printf("Error: invalid request")
55-
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
55+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
5656
c.Abort()
5757
return
5858
}
@@ -101,7 +101,7 @@ func (u CohortDataController) RetrieveDataBySourceIdAndCohortIdAndVariables(c *g
101101
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
102102
if !validAccessRequest {
103103
log.Printf("Error: invalid request")
104-
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
104+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
105105
c.Abort()
106106
return
107107
}
@@ -254,7 +254,7 @@ func (u CohortDataController) RetrieveCohortOverlapStatsWithoutFilteringOnConcep
254254
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{caseCohortId, controlCohortId}, cohortPairs)
255255
if !validAccessRequest {
256256
log.Printf("Error: invalid request")
257-
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
257+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
258258
c.Abort()
259259
return
260260
}

controllers/cohortdefinition.go

+25-4
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,41 @@
11
package controllers
22

33
import (
4+
"log"
45
"net/http"
56
"strconv"
67

78
"github.com/gin-gonic/gin"
9+
"github.com/uc-cdis/cohort-middleware/middlewares"
810
"github.com/uc-cdis/cohort-middleware/models"
911
"github.com/uc-cdis/cohort-middleware/utils"
1012
)
1113

1214
type CohortDefinitionController struct {
1315
cohortDefinitionModel models.CohortDefinitionI
16+
teamProjectAuthz middlewares.TeamProjectAuthzI
1417
}
1518

16-
func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI) CohortDefinitionController {
17-
return CohortDefinitionController{cohortDefinitionModel: cohortDefinitionModel}
19+
func NewCohortDefinitionController(cohortDefinitionModel models.CohortDefinitionI, teamProjectAuthz middlewares.TeamProjectAuthzI) CohortDefinitionController {
20+
return CohortDefinitionController{
21+
cohortDefinitionModel: cohortDefinitionModel,
22+
teamProjectAuthz: teamProjectAuthz,
23+
}
1824
}
1925

2026
func (u CohortDefinitionController) RetriveById(c *gin.Context) {
21-
// TODO - add teamproject validation - check if user has the necessary atlas and arborist permissions
2227
cohortDefinitionId := c.Param("id")
2328

2429
if cohortDefinitionId != "" {
2530
cohortDefinitionId, _ := strconv.Atoi(cohortDefinitionId)
31+
// validate teamproject access permission for cohort:
32+
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortDefinitionId)
33+
if !validAccessRequest {
34+
log.Printf("Error: invalid request")
35+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
36+
c.Abort()
37+
return
38+
}
2639
cohortDefinition, err := u.cohortDefinitionModel.GetCohortDefinitionById(cohortDefinitionId)
2740
if err != nil {
2841
c.JSON(http.StatusInternalServerError, gin.H{"message": "Error retrieving cohortDefinition", "error": err.Error()})
@@ -45,7 +58,15 @@ func (u CohortDefinitionController) RetriveStatsBySourceIdAndTeamProject(c *gin.
4558
c.Abort()
4659
return
4760
}
48-
// TODO - validate teamproject against arborist
61+
// validate teamproject access permission:
62+
validAccessRequest := u.teamProjectAuthz.HasAccessToTeamProject(c, teamProject)
63+
if !validAccessRequest {
64+
log.Printf("Error: invalid request")
65+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
66+
c.Abort()
67+
return
68+
}
69+
4970
if err1 == nil {
5071
cohortDefinitionsAndStats, err := u.cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceId, teamProject)
5172
if err != nil {

controllers/concept.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortId(c *gin.Co
102102
validAccessRequest := u.teamProjectAuthz.TeamProjectValidationForCohort(c, cohortId)
103103
if !validAccessRequest {
104104
log.Printf("Error: invalid request")
105-
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
105+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
106106
c.Abort()
107107
return
108108
}
@@ -135,7 +135,7 @@ func (u ConceptController) RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariabl
135135
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
136136
if !validAccessRequest {
137137
log.Printf("Error: invalid request")
138-
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
138+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
139139
c.Abort()
140140
return
141141
}
@@ -201,7 +201,7 @@ func (u ConceptController) RetrieveAttritionTable(c *gin.Context) {
201201
validAccessRequest := u.teamProjectAuthz.TeamProjectValidation(c, []int{cohortId}, cohortPairs)
202202
if !validAccessRequest {
203203
log.Printf("Error: invalid request")
204-
c.JSON(http.StatusBadRequest, gin.H{"message": "access denied"})
204+
c.JSON(http.StatusForbidden, gin.H{"message": "access denied"})
205205
c.Abort()
206206
return
207207
}

middlewares/teamprojectauthz.go

+28-17
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ type TeamProjectAuthzI interface {
1313
TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool
1414
TeamProjectValidation(ctx *gin.Context, cohortDefinitionIds []int, filterCohortPairs []utils.CustomDichotomousVariableDef) bool
1515
TeamProjectValidationForCohortIdsList(ctx *gin.Context, uniqueCohortDefinitionIdsList []int) bool
16+
HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool
1617
}
1718

1819
type HttpClientI interface {
@@ -30,30 +31,40 @@ func NewTeamProjectAuthz(cohortDefinitionModel models.CohortDefinitionI, httpCli
3031
httpClient: httpClient,
3132
}
3233
}
33-
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
3434

35-
// query Arborist and return as soon as one of the teamProjects access check returns 200:
36-
for _, teamProject := range teamProjects {
37-
teamProjectAsResourcePath := teamProject
38-
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"
35+
func (u TeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
36+
teamProjectAsResourcePath := teamProject
37+
teamProjectAccessService := "atlas-argo-wrapper-and-cohort-middleware"
3938

40-
req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
41-
if err != nil {
42-
ctx.AbortWithStatus(500)
43-
panic("Error while preparing Arborist request")
44-
}
45-
// send the request to Arborist:
46-
resp, _ := u.httpClient.Do(req)
47-
log.Printf("Got response status %d from Arborist...", resp.StatusCode)
39+
req, err := PrepareNewArboristRequestForResourceAndService(ctx, teamProjectAsResourcePath, teamProjectAccessService)
40+
if err != nil {
41+
ctx.AbortWithStatus(500)
42+
panic("Error while preparing Arborist request")
43+
}
44+
// send the request to Arborist:
45+
resp, _ := u.httpClient.Do(req)
46+
log.Printf("Got response status %d from Arborist...", resp.StatusCode)
47+
48+
// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
49+
if resp.StatusCode == 200 {
50+
return true
51+
} else {
52+
// unauthorized or otherwise:
53+
log.Printf("Authorization check for team project failed with status %d ...", resp.StatusCode)
54+
return false
55+
}
56+
}
4857

49-
// arborist will return with 200 if the user has been granted access to the cohort-middleware URL in ctx:
50-
if resp.StatusCode == 200 {
58+
func (u TeamProjectAuthz) hasAccessToAtLeastOne(ctx *gin.Context, teamProjects []string) bool {
59+
for _, teamProject := range teamProjects {
60+
if u.HasAccessToTeamProject(ctx, teamProject) {
5161
return true
5262
} else {
53-
// unauthorized or otherwise:
54-
log.Printf("Status %d does NOT give access to team project...", resp.StatusCode)
63+
// unauthorized:
64+
log.Printf("NO access to team project...checking next one (if any)...")
5565
}
5666
}
67+
log.Printf("NO access to any of the team projects queried...")
5768
return false
5869
}
5970

server/router.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ func NewRouter() *gin.Engine {
2828
authorized.GET("/source/by-name/:name", source.RetriveByName)
2929
authorized.GET("/sources", source.RetriveAll)
3030

31-
cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
31+
cohortdefinitions := controllers.NewCohortDefinitionController(*new(models.CohortDefinition),
32+
middlewares.NewTeamProjectAuthz(*new(models.CohortDefinition), &http.Client{}))
3233
authorized.GET("/cohortdefinition/by-id/:id", cohortdefinitions.RetriveById)
3334

3435
authorized.GET("/cohortdefinition-stats/by-source-id/:sourceid/by-team-project", cohortdefinitions.RetriveStatsBySourceIdAndTeamProject)

tests/controllers_tests/controllers_test.go

+52-13
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ var cohortDataController = controllers.NewCohortDataController(*new(dummyCohortD
5454
var cohortDataControllerWithFailingTeamProjectAuthz = controllers.NewCohortDataController(*new(dummyCohortDataModel), *new(dummyFailingTeamProjectAuthz))
5555

5656
// instance of the controller that talks to the regular model implementation (that needs a real DB):
57-
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition))
57+
var cohortDefinitionControllerNeedsDb = controllers.NewCohortDefinitionController(*new(models.CohortDefinition), *new(dummyTeamProjectAuthz))
5858

5959
// instance of the controller that talks to a mock implementation of the model:
60-
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel))
60+
var cohortDefinitionController = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
61+
var cohortDefinitionControllerWithFailingTeamProjectAuthz = controllers.NewCohortDefinitionController(*new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))
6162

6263
type dummyCohortDataModel struct{}
6364

@@ -151,6 +152,10 @@ func (h dummyTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx *gin.Co
151152
return true
152153
}
153154

155+
func (h dummyTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
156+
return true
157+
}
158+
154159
type dummyFailingTeamProjectAuthz struct{}
155160

156161
func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohort(ctx *gin.Context, cohortDefinitionId int) bool {
@@ -165,6 +170,10 @@ func (h dummyFailingTeamProjectAuthz) TeamProjectValidationForCohortIdsList(ctx
165170
return false
166171
}
167172

173+
func (h dummyFailingTeamProjectAuthz) HasAccessToTeamProject(ctx *gin.Context, teamProject string) bool {
174+
return false
175+
}
176+
168177
var conceptController = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyTeamProjectAuthz))
169178
var conceptControllerWithFailingTeamProjectAuthz = controllers.NewConceptController(*new(dummyConceptDataModel), *new(dummyCohortDefinitionDataModel), *new(dummyFailingTeamProjectAuthz))
170179

@@ -463,18 +472,37 @@ func TestRetriveStatsBySourceIdAndTeamProjectCheckMandatoryTeamProject(t *testin
463472
}
464473
}
465474

475+
func TestRetriveStatsBySourceIdAndTeamProjectAuthorizationError(t *testing.T) {
476+
setUp(t)
477+
requestContext := new(gin.Context)
478+
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
479+
requestContext.Request = &http.Request{URL: &url.URL{}}
480+
teamProject := "/test/dummyname/dummy-team-project"
481+
requestContext.Request.URL.RawQuery = "team-project=" + teamProject
482+
requestContext.Writer = new(tests.CustomResponseWriter)
483+
cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveStatsBySourceIdAndTeamProject(requestContext)
484+
result := requestContext.Writer.(*tests.CustomResponseWriter)
485+
if !requestContext.IsAborted() {
486+
t.Errorf("Expected aborted request")
487+
}
488+
if result.Status() != http.StatusForbidden {
489+
t.Errorf("Expected StatusForbidden, got %d", result.Status())
490+
}
491+
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
492+
t.Errorf("Expected 'access denied' in response")
493+
}
494+
}
495+
466496
func TestRetriveStatsBySourceIdAndTeamProject(t *testing.T) {
467497
setUp(t)
468498
requestContext := new(gin.Context)
469499
requestContext.Params = append(requestContext.Params, gin.Param{Key: "sourceid", Value: strconv.Itoa(tests.GetTestSourceId())})
470-
//requestContext.Params = append(requestContext.Params, gin.Param{Key: "teamproject", Value: "dummy-team-project"})
471500
requestContext.Request = &http.Request{URL: &url.URL{}}
472501
teamProject := "/test/dummyname/dummy-team-project"
473502
requestContext.Request.URL.RawQuery = "team-project=" + teamProject
474503
requestContext.Writer = new(tests.CustomResponseWriter)
475504
cohortDefinitionController.RetriveStatsBySourceIdAndTeamProject(requestContext)
476505
result := requestContext.Writer.(*tests.CustomResponseWriter)
477-
log.Printf("result: %s", result)
478506
// expect result with all of the dummy data:
479507
if !strings.Contains(result.CustomResponseWriterOut, "name1_"+teamProject) ||
480508
!strings.Contains(result.CustomResponseWriterOut, "name2_"+teamProject) ||
@@ -502,7 +530,6 @@ func TestRetriveById(t *testing.T) {
502530
requestContext.Writer = new(tests.CustomResponseWriter)
503531
cohortDefinitionController.RetriveById(requestContext)
504532
result := requestContext.Writer.(*tests.CustomResponseWriter)
505-
log.Printf("result: %s", result)
506533
// expect result with dummy data:
507534
if !strings.Contains(result.CustomResponseWriterOut, "test 1") {
508535
t.Errorf("Expected data in result")
@@ -522,6 +549,26 @@ func TestRetriveByIdModelError(t *testing.T) {
522549
}
523550
}
524551

552+
func TestRetriveByIdAuthorizationError(t *testing.T) {
553+
setUp(t)
554+
requestContext := new(gin.Context)
555+
requestContext.Params = append(requestContext.Params, gin.Param{Key: "id", Value: "1"})
556+
requestContext.Writer = new(tests.CustomResponseWriter)
557+
cohortDefinitionControllerWithFailingTeamProjectAuthz.RetriveById(requestContext)
558+
result := requestContext.Writer.(*tests.CustomResponseWriter)
559+
if !requestContext.IsAborted() {
560+
t.Errorf("Expected aborted request")
561+
}
562+
if result.Status() != http.StatusForbidden {
563+
t.Errorf("Expected StatusForbidden, got %d", result.Status())
564+
}
565+
// expect result with dummy data:
566+
if !strings.Contains(result.CustomResponseWriterOut, "access denied") {
567+
t.Errorf("Expected 'access denied' in response")
568+
}
569+
570+
}
571+
525572
func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
526573
setUp(t)
527574
requestContext := new(gin.Context)
@@ -532,7 +579,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortId(t *testing.T) {
532579
requestContext.Writer = new(tests.CustomResponseWriter)
533580
conceptController.RetrieveBreakdownStatsBySourceIdAndCohortId(requestContext)
534581
result := requestContext.Writer.(*tests.CustomResponseWriter)
535-
log.Printf("result: %s", result)
536582
// expect result with dummy data:
537583
if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") {
538584
t.Errorf("Expected data in result")
@@ -563,7 +609,6 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(t *testing.T) {
563609
requestContext.Writer = new(tests.CustomResponseWriter)
564610
conceptController.RetrieveBreakdownStatsBySourceIdAndCohortIdAndVariables(requestContext)
565611
result := requestContext.Writer.(*tests.CustomResponseWriter)
566-
log.Printf("result: %s", result)
567612
// expect result with dummy data:
568613
if !strings.Contains(result.CustomResponseWriterOut, "persons_in_cohort_with_value") {
569614
t.Errorf("Expected data in result")
@@ -608,7 +653,6 @@ func TestRetrieveInfoBySourceIdAndConceptIds(t *testing.T) {
608653
requestContext.Writer = new(tests.CustomResponseWriter)
609654
conceptController.RetrieveInfoBySourceIdAndConceptIds(requestContext)
610655
result := requestContext.Writer.(*tests.CustomResponseWriter)
611-
log.Printf("result: %s", result)
612656
// expect result with dummy data:
613657
if !strings.Contains(result.CustomResponseWriterOut, "Concept A") ||
614658
!strings.Contains(result.CustomResponseWriterOut, "Concept B") {
@@ -625,7 +669,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypes(t *testing.T) {
625669
requestContext.Writer = new(tests.CustomResponseWriter)
626670
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
627671
result := requestContext.Writer.(*tests.CustomResponseWriter)
628-
log.Printf("result: %s", result)
629672
// expect result with dummy data:
630673
if !strings.Contains(result.CustomResponseWriterOut, "Concept A") ||
631674
!strings.Contains(result.CustomResponseWriterOut, "Concept B") {
@@ -644,7 +687,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesModelError(t *testing.T) {
644687
dummyModelReturnError = true
645688
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
646689
result := requestContext.Writer.(*tests.CustomResponseWriter)
647-
log.Printf("result: %s", result)
648690
if !requestContext.IsAborted() {
649691
t.Errorf("Expected aborted request")
650692
}
@@ -662,7 +704,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesArgsError(t *testing.T) {
662704
dummyModelReturnError = true
663705
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
664706
result := requestContext.Writer.(*tests.CustomResponseWriter)
665-
log.Printf("result: %s", result)
666707
if !requestContext.IsAborted() {
667708
t.Errorf("Expected aborted request")
668709
}
@@ -680,7 +721,6 @@ func TestRetrieveInfoBySourceIdAndConceptTypesMissingBody(t *testing.T) {
680721
dummyModelReturnError = true
681722
conceptController.RetrieveInfoBySourceIdAndConceptTypes(requestContext)
682723
result := requestContext.Writer.(*tests.CustomResponseWriter)
683-
log.Printf("result: %s", result)
684724
if !requestContext.IsAborted() {
685725
t.Errorf("Expected aborted request")
686726
}
@@ -982,7 +1022,6 @@ func TestRetrieveAttritionTable(t *testing.T) {
9821022
requestContext.Writer = new(tests.CustomResponseWriter)
9831023
conceptController.RetrieveAttritionTable(requestContext)
9841024
result := requestContext.Writer.(*tests.CustomResponseWriter)
985-
log.Printf("result: %s", result.CustomResponseWriterOut)
9861025
// check result vs expect result:
9871026
csvLines := strings.Split(strings.TrimRight(result.CustomResponseWriterOut, "\n"), "\n")
9881027
expectedLines := []string{

0 commit comments

Comments
 (0)