Skip to content

Commit ba01e7b

Browse files
Merge pull request #69 from uc-cdis/add_timeout_to_queries
Adding a default timeout to all queries
2 parents 18121be + 9c96c6b commit ba01e7b

File tree

6 files changed

+123
-42
lines changed

6 files changed

+123
-42
lines changed

models/cohortdata.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@ func (h CohortData) RetrieveDataByOriginalCohortAndNewCohort(sourceId int, origi
4747
resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results)
4848
var personData []*PersonIdAndCohort
4949

50-
meta_result := resultsDataSource.Db.Model(&Cohort{}).
50+
query := resultsDataSource.Db.Model(&Cohort{}).
5151
Select("cohort.subject_id as person_id, cohort.cohort_definition_id as cohort_id").
5252
Joins("INNER JOIN "+resultsDataSource.Schema+".cohort as original_cohort ON cohort.subject_id = original_cohort.subject_id").
5353
Where("cohort.cohort_definition_id = ?", cohortDefinitionId).
54-
Where("original_cohort.cohort_definition_id = ?", originalCohortDefinitionId).
55-
Scan(&personData)
54+
Where("original_cohort.cohort_definition_id = ?", originalCohortDefinitionId)
55+
query, cancel := utils.AddTimeoutToQuery(query)
56+
defer cancel()
57+
meta_result := query.Scan(&personData)
5658
return personData, meta_result.Error
5759
}
5860

@@ -68,14 +70,16 @@ func (h CohortData) RetrieveDataBySourceIdAndCohortIdAndConceptIdsOrderedByPerso
6870

6971
// get the observations for the subjects and the concepts, to build up the data rows to return:
7072
var cohortData []*PersonConceptAndValue
71-
meta_result := omopDataSource.Db.Table(omopDataSource.Schema+".observation_continuous as observation"+omopDataSource.GetViewDirective()).
73+
query := omopDataSource.Db.Table(omopDataSource.Schema+".observation_continuous as observation"+omopDataSource.GetViewDirective()).
7274
Select("observation.person_id, observation.observation_concept_id as concept_id, concept.concept_class_id, observation.value_as_string as concept_value_as_string, observation.value_as_number as concept_value_as_number, observation.value_as_concept_id as concept_value_as_concept_id").
7375
Joins("INNER JOIN "+resultsDataSource.Schema+".cohort as cohort ON cohort.subject_id = observation.person_id").
7476
Joins("INNER JOIN "+omopDataSource.Schema+".concept as concept ON concept.concept_id = observation.observation_concept_id").
7577
Where("cohort.cohort_definition_id = ?", cohortDefinitionId).
7678
Where("observation.observation_concept_id in (?)", conceptIds).
77-
Order("observation.person_id asc"). // this order is important!
78-
Scan(&cohortData)
79+
Order("observation.person_id asc") // this order is important!
80+
query, cancel := utils.AddTimeoutToQuery(query)
81+
defer cancel()
82+
meta_result := query.Scan(&cohortData)
7983
return cohortData, meta_result.Error
8084
}
8185

@@ -95,6 +99,8 @@ func (h CohortData) RetrieveHistogramDataBySourceIdAndCohortIdAndConceptIdsAndCo
9599

96100
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")
97101

102+
query, cancel := utils.AddTimeoutToQuery(query)
103+
defer cancel()
98104
meta_result := query.Scan(&cohortData)
99105
return cohortData, meta_result.Error
100106
}
@@ -118,6 +124,8 @@ func (h CohortData) RetrieveCohortOverlapStatsWithoutFilteringOnConceptValue(sou
118124
query = QueryFilterByConceptIdsHelper(query, sourceId, otherFilterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")
119125
}
120126
query = query.Where("control_cohort.cohort_definition_id = ?", controlCohortId)
127+
query, cancel := utils.AddTimeoutToQuery(query)
128+
defer cancel()
121129
meta_result := query.Scan(&cohortOverlapStats)
122130
return cohortOverlapStats, meta_result.Error
123131
}
@@ -152,6 +160,8 @@ func (h CohortData) ValidateObservationData(observationConceptIdsToCheck []int64
152160
Group("observation.person_id, observation.observation_concept_id").
153161
Having("count(*) > 1")
154162

163+
query, cancel := utils.AddTimeoutToQuery(query)
164+
defer cancel()
155165
meta_result := query.Scan(&personConceptAndCount)
156166
if meta_result.Error != nil {
157167
return -1, meta_result.Error

models/cohortdefinition.go

+21-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55

66
"github.com/uc-cdis/cohort-middleware/db"
7+
"github.com/uc-cdis/cohort-middleware/utils"
78
)
89

910
type CohortDefinitionI interface {
@@ -32,29 +33,35 @@ type CohortDefinitionStats struct {
3233
func (h CohortDefinition) GetCohortDefinitionById(id int) (*CohortDefinition, error) {
3334
db2 := db.GetAtlasDB().Db
3435
var cohortDefinition *CohortDefinition
35-
meta_result := db2.Model(&CohortDefinition{}).
36+
query := db2.Model(&CohortDefinition{}).
3637
Select("id, name, description").
37-
Where("id = ?", id).
38-
Scan(&cohortDefinition)
38+
Where("id = ?", id)
39+
query, cancel := utils.AddTimeoutToQuery(query)
40+
defer cancel()
41+
meta_result := query.Scan(&cohortDefinition)
3942
return cohortDefinition, meta_result.Error
4043
}
4144

4245
func (h CohortDefinition) GetCohortDefinitionByName(name string) (*CohortDefinition, error) {
4346
db2 := db.GetAtlasDB().Db
4447
var cohortDefinition *CohortDefinition
45-
meta_result := db2.Model(&CohortDefinition{}).
48+
query := db2.Model(&CohortDefinition{}).
4649
Select("id, name, description").
47-
Where("name = ?", name).
48-
Scan(&cohortDefinition)
50+
Where("name = ?", name)
51+
query, cancel := utils.AddTimeoutToQuery(query)
52+
defer cancel()
53+
meta_result := query.Scan(&cohortDefinition)
4954
return cohortDefinition, meta_result.Error
5055
}
5156

5257
func (h CohortDefinition) GetAllCohortDefinitions() ([]*CohortDefinition, error) {
5358
db2 := db.GetAtlasDB().Db
5459
var cohortDefinition []*CohortDefinition
55-
meta_result := db2.Model(&CohortDefinition{}).
56-
Select("id, name, description").
57-
Scan(&cohortDefinition)
60+
query := db2.Model(&CohortDefinition{}).
61+
Select("id, name, description")
62+
query, cancel := utils.AddTimeoutToQuery(query)
63+
defer cancel()
64+
meta_result := query.Scan(&cohortDefinition)
5865
return cohortDefinition, meta_result.Error
5966
}
6067

@@ -64,11 +71,13 @@ func (h CohortDefinition) GetAllCohortDefinitionsAndStatsOrderBySizeDesc(sourceI
6471
var dataSourceModel = new(Source)
6572
resultsDataSource := dataSourceModel.GetDataSource(sourceId, Results)
6673
var cohortDefinitionStats []*CohortDefinitionStats
67-
meta_result := resultsDataSource.Db.Model(&Cohort{}).
74+
query := resultsDataSource.Db.Model(&Cohort{}).
6875
Select("cohort_definition_id as id, '' as name, count(*) as cohort_size").
6976
Group("cohort_definition_id").
70-
Order("count(*) desc").
71-
Scan(&cohortDefinitionStats)
77+
Order("count(*) desc")
78+
query, cancel := utils.AddTimeoutToQuery(query)
79+
defer cancel()
80+
meta_result := query.Scan(&cohortDefinitionStats)
7281

7382
// add name details:
7483
for _, cohortDefinitionStat := range cohortDefinitionStats {

models/concept.go

+19-9
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ func (h Concept) RetriveAllBySourceId(sourceId int) ([]*Concept, error) {
5858
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
5959

6060
var concepts []*Concept
61-
meta_result := omopDataSource.Db.Model(&Concept{}).
61+
query := omopDataSource.Db.Model(&Concept{}).
6262
Select("concept_id, concept_name, concept_class_id as concept_type").
63-
Order("concept_name").
64-
Scan(&concepts)
63+
Order("concept_name")
64+
query, cancel := utils.AddTimeoutToQuery(query)
65+
defer cancel()
66+
meta_result := query.Scan(&concepts)
6567
return concepts, meta_result.Error
6668
}
6769

@@ -95,11 +97,13 @@ func (h Concept) RetrieveInfoBySourceIdAndConceptIds(sourceId int, conceptIds []
9597
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
9698

9799
var conceptItems []*ConceptSimple
98-
meta_result := omopDataSource.Db.Model(&Concept{}).
100+
query := omopDataSource.Db.Model(&Concept{}).
99101
Select("concept_id, concept_name, concept_code, concept_class_id as concept_type").
100102
Where("concept_id in (?)", conceptIds).
101-
Order("concept_name").
102-
Scan(&conceptItems)
103+
Order("concept_name")
104+
query, cancel := utils.AddTimeoutToQuery(query)
105+
defer cancel()
106+
meta_result := query.Scan(&conceptItems)
103107
if meta_result.Error != nil {
104108
return nil, meta_result.Error
105109
}
@@ -118,11 +122,14 @@ func (h Concept) RetrieveInfoBySourceIdAndConceptTypes(sourceId int, conceptType
118122
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
119123

120124
var conceptItems []*ConceptSimple
121-
meta_result := omopDataSource.Db.Model(&Concept{}).
125+
query := omopDataSource.Db.Model(&Concept{}).
122126
Select("concept_id, concept_name, concept_class_id as concept_type").
123127
Where("concept_class_id in (?)", conceptTypes).
124-
Order("concept_name").
125-
Scan(&conceptItems)
128+
Order("concept_name")
129+
130+
query, cancel := utils.AddTimeoutToQuery(query)
131+
defer cancel()
132+
meta_result := query.Scan(&conceptItems)
126133
if meta_result.Error != nil {
127134
return nil, meta_result.Error
128135
}
@@ -135,6 +142,7 @@ func (h Concept) RetrieveInfoBySourceIdAndConceptTypes(sourceId int, conceptType
135142

136143
// Retrieve concept name, type and missing ratio statistics for given list of conceptIds.
137144
// Assumption is that both OMOP and RESULTS schemas are on same DB.
145+
// TODO - remove this code as it is NOT used anymore by the frontend
138146
func (h Concept) RetrieveStatsBySourceIdAndCohortIdAndConceptIds(sourceId int, cohortDefinitionId int, conceptIds []int64) ([]*ConceptStats, error) {
139147
var dataSourceModel = new(Source)
140148
omopDataSource := dataSourceModel.GetDataSource(sourceId, Omop)
@@ -227,6 +235,8 @@ func (h Concept) RetrieveBreakdownStatsBySourceIdAndCohortIdAndConceptIdsAndCoho
227235
// which is a better performing SQL in this particular scenario:
228236
query = QueryFilterByConceptIdsHelper(query, sourceId, filterConceptIds, omopDataSource, resultsDataSource.Schema, "observation")
229237

238+
query, cancel := utils.AddTimeoutToQuery(query)
239+
defer cancel()
230240
meta_result := query.Group("observation.value_as_concept_id").
231241
Scan(&conceptBreakdownList)
232242

models/source.go

+25-15
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@ type Source struct {
1717
func (h Source) GetSourceById(id int) (*Source, error) {
1818
db2 := db.GetAtlasDB().Db
1919
var dataSource *Source
20-
db2.Model(&Source{}).
20+
query := db2.Model(&Source{}).
2121
Select("source_id, source_name").
22-
Where("source_id = ?", id).
23-
Scan(&dataSource)
22+
Where("source_id = ?", id)
23+
query, cancel := utils.AddTimeoutToQuery(query)
24+
defer cancel()
25+
query.Scan(&dataSource)
2426
return dataSource, nil
2527
}
2628

2729
func (h Source) GetSourceByIdWithConnection(id int) (*Source, error) {
2830
db2 := db.GetAtlasDB().Db
2931
var dataSource *Source
30-
db2.Model(&Source{}).
32+
query := db2.Model(&Source{}).
3133
Select("source_id, source_name, source_connection, source_dialect, username, password").
32-
Where("source_id = ?", id).
33-
Scan(&dataSource)
34+
Where("source_id = ?", id)
35+
query, cancel := utils.AddTimeoutToQuery(query)
36+
defer cancel()
37+
query.Scan(&dataSource)
3438
return dataSource, nil
3539
}
3640

@@ -42,12 +46,14 @@ func (h Source) GetSourceSchemaNameBySourceIdAndSourceType(id int, sourceType So
4246
atlasDb := db.GetAtlasDB()
4347
db2 := atlasDb.Db
4448
var sourceSchema *SourceSchema
45-
db2.Model(&Source{}).
49+
query := db2.Model(&Source{}).
4650
Select("source_daimon.table_qualifier as schema_name").
4751
Joins("INNER JOIN "+atlasDb.Schema+".source_daimon ON source.source_id = source_daimon.source_id").
4852
Where("source.source_id = ?", id).
49-
Where("source_daimon.daimon_type = ?", sourceType).
50-
Scan(&sourceSchema)
53+
Where("source_daimon.daimon_type = ?", sourceType)
54+
query, cancel := utils.AddTimeoutToQuery(query)
55+
defer cancel()
56+
query.Scan(&sourceSchema)
5157
return sourceSchema, nil
5258
}
5359

@@ -74,18 +80,22 @@ func (h Source) GetDataSource(sourceId int, sourceType SourceType) *utils.DbAndS
7480
func (h Source) GetSourceByName(name string) (*Source, error) {
7581
db2 := db.GetAtlasDB().Db
7682
var dataSource *Source
77-
db2.Model(&Source{}).
83+
query := db2.Model(&Source{}).
7884
Select("source_id, source_name").
79-
Where("source_name = ?", name).
80-
Scan(&dataSource)
85+
Where("source_name = ?", name)
86+
query, cancel := utils.AddTimeoutToQuery(query)
87+
defer cancel()
88+
query.Scan(&dataSource)
8189
return dataSource, nil
8290
}
8391

8492
func (h Source) GetAllSources() ([]*Source, error) {
8593
db2 := db.GetAtlasDB().Db
8694
var dataSource []*Source
87-
db2.Model(&Source{}).
88-
Select("source_id, source_name").
89-
Scan(&dataSource)
95+
query := db2.Model(&Source{}).
96+
Select("source_id, source_name")
97+
query, cancel := utils.AddTimeoutToQuery(query)
98+
defer cancel()
99+
query.Scan(&dataSource)
90100
return dataSource, nil
91101
}

tests/models_tests/models_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"os"
66
"strings"
77
"testing"
8+
"time"
89

910
"github.com/uc-cdis/cohort-middleware/config"
1011
"github.com/uc-cdis/cohort-middleware/db"
@@ -889,3 +890,28 @@ func TestRetrieveDataByOriginalCohortAndNewCohort(t *testing.T) {
889890
}
890891
}
891892
}
893+
894+
func TestAddTimeoutToQuery(t *testing.T) {
895+
setUp(t)
896+
897+
// take a simple query, run with short timeout, and expect error:
898+
db2 := db.GetAtlasDB().Db
899+
var dataSource []*models.Source
900+
query := db2.Model(&models.Source{}).
901+
Select("source_id, source_name")
902+
query, cancel := utils.AddSpecificTimeoutToQuery(query, 2*time.Nanosecond)
903+
defer cancel()
904+
meta_result := query.Scan(&dataSource)
905+
if meta_result.Error == nil || len(dataSource) > 0 {
906+
t.Errorf("Expected timeout error and NO data")
907+
}
908+
909+
// then switch to default (longer) timeout and expect a result:
910+
query2, cancel2 := utils.AddTimeoutToQuery(query)
911+
defer cancel2()
912+
meta_result2 := query2.Scan(&dataSource)
913+
914+
if meta_result2.Error != nil || len(dataSource) == 0 {
915+
t.Errorf("Expected result and NO error")
916+
}
917+
}

utils/db.go

+16
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package utils
22

33
import (
4+
"context"
45
"log"
56
"strings"
7+
"time"
68

79
"gorm.io/driver/postgres"
810
"gorm.io/driver/sqlserver"
@@ -56,6 +58,20 @@ func GetDataSourceDB(sourceConnectionString string, dbSchema string) *DbAndSchem
5658
return dataSourceDb
5759
}
5860

61+
// Adds a default timeout to a query
62+
func AddTimeoutToQuery(query *gorm.DB) (*gorm.DB, context.CancelFunc) {
63+
// default timeout of 3 minutes:
64+
query, cancel := AddSpecificTimeoutToQuery(query, 180*time.Second)
65+
return query, cancel
66+
}
67+
68+
// Adds a specific timeout to a query
69+
func AddSpecificTimeoutToQuery(query *gorm.DB, timeout time.Duration) (*gorm.DB, context.CancelFunc) {
70+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
71+
query = query.WithContext(ctx)
72+
return query, cancel
73+
}
74+
5975
// Returns extra DB dialect specific directives to optimize performance when using views:
6076
func (h DbAndSchema) GetViewDirective() string {
6177
if h.Vendor == "sqlserver" {

0 commit comments

Comments
 (0)