34
34
import graphql .schema .DataFetcher ;
35
35
import graphql .schema .DataFetchingEnvironment ;
36
36
import graphql .schema .GraphQLScalarType ;
37
- import java .util .ArrayList ;
38
37
import java .util .LinkedHashMap ;
39
38
import java .util .List ;
40
39
import java .util .Map ;
53
52
class GraphQLJpaQueryDataFetcher implements DataFetcher <PagedResult <Object >> {
54
53
55
54
private static final Logger logger = LoggerFactory .getLogger (GraphQLJpaQueryDataFetcher .class );
55
+ public static final String AGGREGATE_PARAM_NAME = "aggregate" ;
56
+ public static final String COUNT_FIELD_NAME = "count" ;
57
+ public static final String GROUP_FIELD_NAME = "group" ;
58
+ public static final String BY_FILED_NAME = "by" ;
59
+ public static final String FIELD_ARGUMENT_NAME = "field" ;
60
+ public static final String OF_ARGUMENT_NAME = "of" ;
56
61
57
62
private final int defaultMaxResults ;
58
63
private final int defaultPageLimitSize ;
@@ -76,7 +81,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
76
81
Optional <Field > pagesSelection = getSelectionField (rootNode , PAGE_PAGES_PARAM_NAME );
77
82
Optional <Field > totalSelection = getSelectionField (rootNode , PAGE_TOTAL_PARAM_NAME );
78
83
Optional <Field > recordsSelection = searchByFieldName (rootNode , QUERY_SELECT_PARAM_NAME );
79
- Optional <Field > aggregateSelection = getSelectionField (rootNode , "aggregate" );
84
+ Optional <Field > aggregateSelection = getSelectionField (rootNode , AGGREGATE_PARAM_NAME );
80
85
81
86
final int firstResult = page .getOffset ();
82
87
final int maxResults = Integer .min (page .getLimit (), defaultMaxResults ); // Limit max results to avoid OoM
@@ -85,22 +90,29 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
85
90
.builder ()
86
91
.withOffset (firstResult )
87
92
.withLimit (maxResults );
88
- Optional <List <Object >> restrictedKeys = queryFactory .getRestrictedKeys (environment );
93
+
94
+ final Optional <List <Object >> restrictedKeys = queryFactory .getRestrictedKeys (environment );
89
95
90
96
if (recordsSelection .isPresent ()) {
91
97
if (restrictedKeys .isPresent ()) {
92
- final List <Object > queryKeys = new ArrayList <>();
93
-
94
98
if (pageArgument .isPresent () || enableDefaultMaxResults ) {
95
- queryKeys .addAll (
96
- queryFactory .queryKeys (environment , firstResult , maxResults , restrictedKeys .get ())
99
+ final List <Object > queryKeys = queryFactory .queryKeys (
100
+ environment ,
101
+ firstResult ,
102
+ maxResults ,
103
+ restrictedKeys .get ()
97
104
);
105
+
106
+ if (!queryKeys .isEmpty ()) {
107
+ pagedResult .withSelect (
108
+ queryFactory .queryResultList (environment , maxResults , restrictedKeys .get ())
109
+ );
110
+ } else {
111
+ pagedResult .withSelect (List .of ());
112
+ }
98
113
} else {
99
- queryKeys . addAll ( restrictedKeys .get ());
114
+ pagedResult . withSelect ( queryFactory . queryResultList ( environment , maxResults , restrictedKeys .get () ));
100
115
}
101
-
102
- final List <Object > resultList = queryFactory .queryResultList (environment , maxResults , queryKeys );
103
- pagedResult .withSelect (resultList );
104
116
}
105
117
}
106
118
@@ -113,7 +125,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
113
125
aggregateSelection .ifPresent (aggregateField -> {
114
126
Map <String , Object > aggregate = new LinkedHashMap <>();
115
127
116
- getFields (aggregateField .getSelectionSet (), "count" )
128
+ getFields (aggregateField .getSelectionSet (), COUNT_FIELD_NAME )
117
129
.forEach (countField -> {
118
130
getCountOfArgument (countField )
119
131
.ifPresentOrElse (
@@ -130,16 +142,16 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
130
142
);
131
143
});
132
144
133
- getFields (aggregateField .getSelectionSet (), "group" )
145
+ getFields (aggregateField .getSelectionSet (), GROUP_FIELD_NAME )
134
146
.forEach (groupField -> {
135
- var countField = getFields (groupField .getSelectionSet (), "count" )
147
+ var countField = getFields (groupField .getSelectionSet (), COUNT_FIELD_NAME )
136
148
.stream ()
137
149
.findFirst ()
138
150
.orElseThrow (() -> new GraphQLException ("Missing aggregate count for group: " + groupField ));
139
151
140
152
var countOfArgumentValue = getCountOfArgument (countField );
141
153
142
- Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), "by" )
154
+ Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), BY_FILED_NAME )
143
155
.stream ()
144
156
.map (GraphQLJpaQueryDataFetcher ::groupByFieldEntry )
145
157
.toArray (Map .Entry []::new );
@@ -176,21 +188,21 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
176
188
aggregate .put (getAliasOrName (groupField ), resultList );
177
189
});
178
190
179
- getSelectionField (aggregateField , "by" )
191
+ getSelectionField (aggregateField , BY_FILED_NAME )
180
192
.map (byField -> byField .getSelectionSet ().getSelections ().stream ().map (Field .class ::cast ).toList ())
181
193
.filter (Predicate .not (List ::isEmpty ))
182
194
.ifPresent (aggregateBySelections -> {
183
195
var aggregatesBy = new LinkedHashMap <>();
184
- aggregate .put ("by" , aggregatesBy );
196
+ aggregate .put (BY_FILED_NAME , aggregatesBy );
185
197
186
198
aggregateBySelections .forEach (groupField -> {
187
- var countField = getFields (groupField .getSelectionSet (), "count" )
199
+ var countField = getFields (groupField .getSelectionSet (), COUNT_FIELD_NAME )
188
200
.stream ()
189
201
.findFirst ()
190
202
.orElseThrow (() -> new GraphQLException ("Missing aggregate count for group: " + groupField )
191
203
);
192
204
193
- Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), "by" )
205
+ Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), BY_FILED_NAME )
194
206
.stream ()
195
207
.map (GraphQLJpaQueryDataFetcher ::groupByFieldEntry )
196
208
.toArray (Map .Entry []::new );
@@ -239,7 +251,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
239
251
static Map .Entry <String , String > groupByFieldEntry (Field selectedField ) {
240
252
String key = Optional .ofNullable (selectedField .getAlias ()).orElse (selectedField .getName ());
241
253
242
- String value = findArgument (selectedField , "field" )
254
+ String value = findArgument (selectedField , FIELD_ARGUMENT_NAME )
243
255
.map (Argument ::getValue )
244
256
.map (EnumValue .class ::cast )
245
257
.map (EnumValue ::getName )
@@ -257,7 +269,7 @@ static Map.Entry<String, String> countFieldEntry(Field selectedField) {
257
269
}
258
270
259
271
static Optional <String > getCountOfArgument (Field selectedField ) {
260
- return findArgument (selectedField , "of" )
272
+ return findArgument (selectedField , OF_ARGUMENT_NAME )
261
273
.map (Argument ::getValue )
262
274
.map (EnumValue .class ::cast )
263
275
.map (EnumValue ::getName );
0 commit comments