Skip to content

Commit 9ab34c7

Browse files
committed
feat: improve tests for GetTeamProjectsThatMatchAllCohortDefinitionIds
1 parent 922c748 commit 9ab34c7

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

tests/models_tests/models_test.go

+34-5
Original file line numberDiff line numberDiff line change
@@ -566,39 +566,68 @@ func TestRetrieveBreakdownStatsBySourceIdAndCohortIdWithResultsWithOnePersonTwoH
566566

567567
func TestGetTeamProjectsThatMatchAllCohortDefinitionIdsOnlyDefaultMatch(t *testing.T) {
568568
setUp(t)
569-
cohortDefinitionId := 2
569+
cohortDefinitionId := 2 // 'Medium cohort' in test_data_atlas.sql
570570
filterCohortPairs := []utils.CustomDichotomousVariableDef{
571571
{
572572
CohortDefinitionId1: smallestCohort.Id,
573573
CohortDefinitionId2: largestCohort.Id,
574574
ProvidedName: "test"},
575575
}
576576
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs)
577+
if len(uniqueCohortDefinitionIdsList) != 3 {
578+
t.Errorf("Expected uniqueCohortDefinitionIdsList length to be 3")
579+
}
577580
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
578581
if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" {
579582
t.Errorf("Expected to find only defaultteamproject")
580583
}
584+
585+
// Should also hold true if the uniqueCohortDefinitionIdsList is length 2 (which matches teamprojectX's cohort
586+
// list length but not in contents):
587+
filterCohortPairs = []utils.CustomDichotomousVariableDef{
588+
{
589+
CohortDefinitionId1: 2,
590+
CohortDefinitionId2: largestCohort.Id,
591+
ProvidedName: "test"},
592+
}
593+
uniqueCohortDefinitionIdsList = utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs)
594+
if len(uniqueCohortDefinitionIdsList) != 2 {
595+
t.Errorf("Expected uniqueCohortDefinitionIdsList length to be 2")
596+
}
597+
teamProjects, _ = cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
598+
if len(teamProjects) != 1 || teamProjects[0] != "defaultteamproject" {
599+
t.Errorf("Expected to find only defaultteamproject")
600+
}
581601
}
582602

583603
func TestGetTeamProjectsThatMatchAllCohortDefinitionIds(t *testing.T) {
584604
setUp(t)
585-
cohortDefinitionId := 2
605+
cohortDefinitionId := 2 // 'Medium cohort' in test_data_atlas.sql
586606
filterCohortPairs := []utils.CustomDichotomousVariableDef{
587607
{
588608
CohortDefinitionId1: 2,
589-
CohortDefinitionId2: 2,
609+
CohortDefinitionId2: 32,
590610
ProvidedName: "test"},
591611
}
592612
uniqueCohortDefinitionIdsList := utils.GetUniqueCohortDefinitionIdsListFromRequest([]int{cohortDefinitionId}, filterCohortPairs)
613+
if len(uniqueCohortDefinitionIdsList) != 2 {
614+
t.Errorf("Expected uniqueCohortDefinitionIdsList length to be 2")
615+
}
593616
teamProjects, _ := cohortDefinitionModel.GetTeamProjectsThatMatchAllCohortDefinitionIds(uniqueCohortDefinitionIdsList)
594617
if len(teamProjects) != 2 {
595618
t.Errorf("Expected to find two 'team projects' matching the cohort list, found %s", teamProjects)
596619
}
620+
if !utils.ContainsString(teamProjects, "defaultteamproject") {
621+
t.Errorf("Expected to find 'defaultteamproject' in the results, found %s", teamProjects)
622+
}
623+
if !utils.ContainsString(teamProjects, "teamprojectX") {
624+
t.Errorf("Expected to find 'teamprojectX' in the results, found %s", teamProjects)
625+
}
597626
}
598627

599628
func TestGetCohortDefinitionIdsForTeamProject(t *testing.T) {
600629
setUp(t)
601-
testTeamProject := "teamprojectX"
630+
testTeamProject := "teamprojectY"
602631
allowedCohortDefinitionIds, _ := cohortDefinitionModel.GetCohortDefinitionIdsForTeamProject(testTeamProject)
603632
if len(allowedCohortDefinitionIds) != 1 {
604633
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",
@@ -633,7 +662,7 @@ func TestGetAllCohortDefinitionsAndStatsOrderBySizeDesc(t *testing.T) {
633662
}
634663

635664
// some extra tests to cover also the teamProject option for this method:
636-
testTeamProject := "teamprojectX"
665+
testTeamProject := "teamprojectY"
637666
allowedCohortDefinitions, _ := cohortDefinitionModel.GetAllCohortDefinitionsAndStatsOrderBySizeDesc(testSourceId, testTeamProject)
638667
if len(allowedCohortDefinitions) != 1 {
639668
t.Errorf("Expected teamProject '%s' to have one cohort, but found %d",

tests/setup_local_db/test_data_atlas.sql

+25-24
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,29 @@ values
6969
(1458, 1005, 1185),
7070
(1459, 1005, 1186),
7171
(1460, 1005, 1187),
72-
(1461, 1009, 1188),
73-
(1462, 1009, 1189),
74-
(1463, 1009, 1190),
75-
(1464, 1009, 1191),
76-
(1465, 1009, 1192),
77-
(1466, 1009, 1193),
78-
(1467, 1009, 1194),
79-
(2454, 4000, 1181),
80-
(2455, 4000, 1182),
81-
(2456, 4000, 1183),
82-
(2457, 4000, 1184),
83-
(2458, 4000, 1185),
84-
(2459, 4000, 1186),
85-
(2460, 4000, 1187),
86-
(2461, 4000, 1188),
87-
(2462, 4000, 1189),
88-
(2463, 4000, 1190),
89-
(2464, 4000, 1191),
90-
(2465, 4000, 1192),
91-
(2466, 4000, 1193),
92-
(2467, 4000, 1194),
93-
(2468, 4000, 2193),
94-
(2469, 4000, 3193),
95-
(2470, 4000, 4193)
72+
(1461, 1005, 4193), -- 1005 teamprojectX has access to cohorts 2 and 32
73+
(2461, 1009, 1188),
74+
(2462, 1009, 1189),
75+
(2463, 1009, 1190),
76+
(2464, 1009, 1191),
77+
(2465, 1009, 1192),
78+
(2466, 1009, 1193),
79+
(2467, 1009, 1194), -- 1009 teamprojectY has access to cohort 4
80+
(4454, 4000, 1181),
81+
(4455, 4000, 1182),
82+
(4456, 4000, 1183),
83+
(4457, 4000, 1184),
84+
(4458, 4000, 1185),
85+
(4459, 4000, 1186),
86+
(4460, 4000, 1187),
87+
(4461, 4000, 1188),
88+
(4462, 4000, 1189),
89+
(4463, 4000, 1190),
90+
(4464, 4000, 1191),
91+
(4465, 4000, 1192),
92+
(4466, 4000, 1193),
93+
(4467, 4000, 1194),
94+
(4468, 4000, 2193),
95+
(4469, 4000, 3193),
96+
(4470, 4000, 4193) -- 4000 defaultteamproject has access to all cohorts
9697
;

utils/parsing.go

+9
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ func Pos(value int64, list []int64) int {
4343
return -1
4444
}
4545

46+
func ContainsString(list []string, value string) bool {
47+
for _, item := range list {
48+
if item == value {
49+
return true
50+
}
51+
}
52+
return false
53+
}
54+
4655
func Contains(list []int, value int) bool {
4756
for _, item := range list {
4857
if item == value {

0 commit comments

Comments
 (0)