34
34
import graphql .schema .DataFetcher ;
35
35
import graphql .schema .DataFetchingEnvironment ;
36
36
import graphql .schema .GraphQLScalarType ;
37
- import java .util .ArrayList ;
37
+ import java .util .Collection ;
38
38
import java .util .LinkedHashMap ;
39
39
import java .util .List ;
40
40
import java .util .Map ;
53
53
class GraphQLJpaQueryDataFetcher implements DataFetcher <PagedResult <Object >> {
54
54
55
55
private static final Logger logger = LoggerFactory .getLogger (GraphQLJpaQueryDataFetcher .class );
56
+ public static final String AGGREGATE_PARAM_NAME = "aggregate" ;
57
+ public static final String COUNT_FIELD_NAME = "count" ;
58
+ public static final String GROUP_FIELD_NAME = "group" ;
59
+ public static final String BY_FILED_NAME = "by" ;
60
+ public static final String FIELD_ARGUMENT_NAME = "field" ;
61
+ public static final String OF_ARGUMENT_NAME = "of" ;
56
62
57
63
private final int defaultMaxResults ;
58
64
private final int defaultPageLimitSize ;
@@ -76,7 +82,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
76
82
Optional <Field > pagesSelection = getSelectionField (rootNode , PAGE_PAGES_PARAM_NAME );
77
83
Optional <Field > totalSelection = getSelectionField (rootNode , PAGE_TOTAL_PARAM_NAME );
78
84
Optional <Field > recordsSelection = searchByFieldName (rootNode , QUERY_SELECT_PARAM_NAME );
79
- Optional <Field > aggregateSelection = getSelectionField (rootNode , "aggregate" );
85
+ Optional <Field > aggregateSelection = getSelectionField (rootNode , AGGREGATE_PARAM_NAME );
80
86
81
87
final int firstResult = page .getOffset ();
82
88
final int maxResults = Integer .min (page .getLimit (), defaultMaxResults ); // Limit max results to avoid OoM
@@ -85,35 +91,47 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
85
91
.builder ()
86
92
.withOffset (firstResult )
87
93
.withLimit (maxResults );
88
- Optional <List <Object >> restrictedKeys = queryFactory .getRestrictedKeys (environment );
94
+
95
+ final Optional <List <Object >> restrictedKeys = queryFactory .getRestrictedKeys (environment );
89
96
90
97
if (recordsSelection .isPresent ()) {
91
98
if (restrictedKeys .isPresent ()) {
92
- final List <Object > queryKeys = new ArrayList <>();
93
-
94
99
if (pageArgument .isPresent () || enableDefaultMaxResults ) {
95
- queryKeys .addAll (
96
- queryFactory .queryKeys (environment , firstResult , maxResults , restrictedKeys .get ())
100
+ final List <Object > queryKeys = queryFactory .queryKeys (
101
+ environment ,
102
+ firstResult ,
103
+ maxResults ,
104
+ restrictedKeys .get ()
97
105
);
106
+
107
+ if (!queryKeys .isEmpty ()) {
108
+ pagedResult .withSelect (
109
+ queryFactory .queryResultList (environment , maxResults , restrictedKeys .get ())
110
+ );
111
+ } else {
112
+ pagedResult .withSelect (List .of ());
113
+ }
98
114
} else {
99
- queryKeys . addAll ( restrictedKeys .get ());
115
+ pagedResult . withSelect ( queryFactory . queryResultList ( environment , maxResults , restrictedKeys .get () ));
100
116
}
101
-
102
- final List <Object > resultList = queryFactory .queryResultList (environment , maxResults , queryKeys );
103
- pagedResult .withSelect (resultList );
104
117
}
105
118
}
106
119
107
120
if (totalSelection .isPresent () || pagesSelection .isPresent ()) {
108
- final Long total = queryFactory .queryTotalCount (environment , restrictedKeys );
121
+ final var selectResult = pagedResult .getSelect ();
122
+
123
+ final long total = recordsSelection .isEmpty () ||
124
+ selectResult .filter (Predicate .not (Collection ::isEmpty )).isPresent ()
125
+ ? queryFactory .queryTotalCount (environment , restrictedKeys )
126
+ : 0L ;
109
127
110
128
pagedResult .withTotal (total );
111
129
}
112
130
113
131
aggregateSelection .ifPresent (aggregateField -> {
114
132
Map <String , Object > aggregate = new LinkedHashMap <>();
115
133
116
- getFields (aggregateField .getSelectionSet (), "count" )
134
+ getFields (aggregateField .getSelectionSet (), COUNT_FIELD_NAME )
117
135
.forEach (countField -> {
118
136
getCountOfArgument (countField )
119
137
.ifPresentOrElse (
@@ -130,16 +148,16 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
130
148
);
131
149
});
132
150
133
- getFields (aggregateField .getSelectionSet (), "group" )
151
+ getFields (aggregateField .getSelectionSet (), GROUP_FIELD_NAME )
134
152
.forEach (groupField -> {
135
- var countField = getFields (groupField .getSelectionSet (), "count" )
153
+ var countField = getFields (groupField .getSelectionSet (), COUNT_FIELD_NAME )
136
154
.stream ()
137
155
.findFirst ()
138
156
.orElseThrow (() -> new GraphQLException ("Missing aggregate count for group: " + groupField ));
139
157
140
158
var countOfArgumentValue = getCountOfArgument (countField );
141
159
142
- Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), "by" )
160
+ Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), BY_FILED_NAME )
143
161
.stream ()
144
162
.map (GraphQLJpaQueryDataFetcher ::groupByFieldEntry )
145
163
.toArray (Map .Entry []::new );
@@ -176,21 +194,21 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
176
194
aggregate .put (getAliasOrName (groupField ), resultList );
177
195
});
178
196
179
- getSelectionField (aggregateField , "by" )
197
+ getSelectionField (aggregateField , BY_FILED_NAME )
180
198
.map (byField -> byField .getSelectionSet ().getSelections ().stream ().map (Field .class ::cast ).toList ())
181
199
.filter (Predicate .not (List ::isEmpty ))
182
200
.ifPresent (aggregateBySelections -> {
183
201
var aggregatesBy = new LinkedHashMap <>();
184
- aggregate .put ("by" , aggregatesBy );
202
+ aggregate .put (BY_FILED_NAME , aggregatesBy );
185
203
186
204
aggregateBySelections .forEach (groupField -> {
187
- var countField = getFields (groupField .getSelectionSet (), "count" )
205
+ var countField = getFields (groupField .getSelectionSet (), COUNT_FIELD_NAME )
188
206
.stream ()
189
207
.findFirst ()
190
208
.orElseThrow (() -> new GraphQLException ("Missing aggregate count for group: " + groupField )
191
209
);
192
210
193
- Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), "by" )
211
+ Map .Entry <String , String >[] groupings = getFields (groupField .getSelectionSet (), BY_FILED_NAME )
194
212
.stream ()
195
213
.map (GraphQLJpaQueryDataFetcher ::groupByFieldEntry )
196
214
.toArray (Map .Entry []::new );
@@ -239,7 +257,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
239
257
static Map .Entry <String , String > groupByFieldEntry (Field selectedField ) {
240
258
String key = Optional .ofNullable (selectedField .getAlias ()).orElse (selectedField .getName ());
241
259
242
- String value = findArgument (selectedField , "field" )
260
+ String value = findArgument (selectedField , FIELD_ARGUMENT_NAME )
243
261
.map (Argument ::getValue )
244
262
.map (EnumValue .class ::cast )
245
263
.map (EnumValue ::getName )
@@ -257,7 +275,7 @@ static Map.Entry<String, String> countFieldEntry(Field selectedField) {
257
275
}
258
276
259
277
static Optional <String > getCountOfArgument (Field selectedField ) {
260
- return findArgument (selectedField , "of" )
278
+ return findArgument (selectedField , OF_ARGUMENT_NAME )
261
279
.map (Argument ::getValue )
262
280
.map (EnumValue .class ::cast )
263
281
.map (EnumValue ::getName );
0 commit comments