Skip to content

Commit 4e3cfca

Browse files
committed
Skip redundant fetch query for empty query keys results
1 parent 57d484b commit 4e3cfca

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaQueryDataFetcher.java

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import graphql.schema.DataFetcher;
3535
import graphql.schema.DataFetchingEnvironment;
3636
import graphql.schema.GraphQLScalarType;
37-
import java.util.ArrayList;
3837
import java.util.LinkedHashMap;
3938
import java.util.List;
4039
import java.util.Map;
@@ -53,6 +52,12 @@
5352
class GraphQLJpaQueryDataFetcher implements DataFetcher<PagedResult<Object>> {
5453

5554
private static final Logger logger = LoggerFactory.getLogger(GraphQLJpaQueryDataFetcher.class);
55+
public static final String AGGREGATE_PARAM_NAME = "aggregate";
56+
public static final String COUNT_FIELD_NAME = "count";
57+
public static final String GROUP_FIELD_NAME = "group";
58+
public static final String BY_FILED_NAME = "by";
59+
public static final String FIELD_ARGUMENT_NAME = "field";
60+
public static final String OF_ARGUMENT_NAME = "of";
5661

5762
private final int defaultMaxResults;
5863
private final int defaultPageLimitSize;
@@ -76,7 +81,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
7681
Optional<Field> pagesSelection = getSelectionField(rootNode, PAGE_PAGES_PARAM_NAME);
7782
Optional<Field> totalSelection = getSelectionField(rootNode, PAGE_TOTAL_PARAM_NAME);
7883
Optional<Field> recordsSelection = searchByFieldName(rootNode, QUERY_SELECT_PARAM_NAME);
79-
Optional<Field> aggregateSelection = getSelectionField(rootNode, "aggregate");
84+
Optional<Field> aggregateSelection = getSelectionField(rootNode, AGGREGATE_PARAM_NAME);
8085

8186
final int firstResult = page.getOffset();
8287
final int maxResults = Integer.min(page.getLimit(), defaultMaxResults); // Limit max results to avoid OoM
@@ -85,22 +90,29 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
8590
.builder()
8691
.withOffset(firstResult)
8792
.withLimit(maxResults);
88-
Optional<List<Object>> restrictedKeys = queryFactory.getRestrictedKeys(environment);
93+
94+
final Optional<List<Object>> restrictedKeys = queryFactory.getRestrictedKeys(environment);
8995

9096
if (recordsSelection.isPresent()) {
9197
if (restrictedKeys.isPresent()) {
92-
final List<Object> queryKeys = new ArrayList<>();
93-
9498
if (pageArgument.isPresent() || enableDefaultMaxResults) {
95-
queryKeys.addAll(
96-
queryFactory.queryKeys(environment, firstResult, maxResults, restrictedKeys.get())
99+
final List<Object> queryKeys = queryFactory.queryKeys(
100+
environment,
101+
firstResult,
102+
maxResults,
103+
restrictedKeys.get()
97104
);
105+
106+
if (!queryKeys.isEmpty()) {
107+
pagedResult.withSelect(
108+
queryFactory.queryResultList(environment, maxResults, restrictedKeys.get())
109+
);
110+
} else {
111+
pagedResult.withSelect(List.of());
112+
}
98113
} else {
99-
queryKeys.addAll(restrictedKeys.get());
114+
pagedResult.withSelect(queryFactory.queryResultList(environment, maxResults, restrictedKeys.get()));
100115
}
101-
102-
final List<Object> resultList = queryFactory.queryResultList(environment, maxResults, queryKeys);
103-
pagedResult.withSelect(resultList);
104116
}
105117
}
106118

@@ -113,7 +125,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
113125
aggregateSelection.ifPresent(aggregateField -> {
114126
Map<String, Object> aggregate = new LinkedHashMap<>();
115127

116-
getFields(aggregateField.getSelectionSet(), "count")
128+
getFields(aggregateField.getSelectionSet(), COUNT_FIELD_NAME)
117129
.forEach(countField -> {
118130
getCountOfArgument(countField)
119131
.ifPresentOrElse(
@@ -130,16 +142,16 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
130142
);
131143
});
132144

133-
getFields(aggregateField.getSelectionSet(), "group")
145+
getFields(aggregateField.getSelectionSet(), GROUP_FIELD_NAME)
134146
.forEach(groupField -> {
135-
var countField = getFields(groupField.getSelectionSet(), "count")
147+
var countField = getFields(groupField.getSelectionSet(), COUNT_FIELD_NAME)
136148
.stream()
137149
.findFirst()
138150
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));
139151

140152
var countOfArgumentValue = getCountOfArgument(countField);
141153

142-
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
154+
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), BY_FILED_NAME)
143155
.stream()
144156
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
145157
.toArray(Map.Entry[]::new);
@@ -176,21 +188,21 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
176188
aggregate.put(getAliasOrName(groupField), resultList);
177189
});
178190

179-
getSelectionField(aggregateField, "by")
191+
getSelectionField(aggregateField, BY_FILED_NAME)
180192
.map(byField -> byField.getSelectionSet().getSelections().stream().map(Field.class::cast).toList())
181193
.filter(Predicate.not(List::isEmpty))
182194
.ifPresent(aggregateBySelections -> {
183195
var aggregatesBy = new LinkedHashMap<>();
184-
aggregate.put("by", aggregatesBy);
196+
aggregate.put(BY_FILED_NAME, aggregatesBy);
185197

186198
aggregateBySelections.forEach(groupField -> {
187-
var countField = getFields(groupField.getSelectionSet(), "count")
199+
var countField = getFields(groupField.getSelectionSet(), COUNT_FIELD_NAME)
188200
.stream()
189201
.findFirst()
190202
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)
191203
);
192204

193-
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
205+
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), BY_FILED_NAME)
194206
.stream()
195207
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
196208
.toArray(Map.Entry[]::new);
@@ -239,7 +251,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
239251
static Map.Entry<String, String> groupByFieldEntry(Field selectedField) {
240252
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());
241253

242-
String value = findArgument(selectedField, "field")
254+
String value = findArgument(selectedField, FIELD_ARGUMENT_NAME)
243255
.map(Argument::getValue)
244256
.map(EnumValue.class::cast)
245257
.map(EnumValue::getName)
@@ -257,7 +269,7 @@ static Map.Entry<String, String> countFieldEntry(Field selectedField) {
257269
}
258270

259271
static Optional<String> getCountOfArgument(Field selectedField) {
260-
return findArgument(selectedField, "of")
272+
return findArgument(selectedField, OF_ARGUMENT_NAME)
261273
.map(Argument::getValue)
262274
.map(EnumValue.class::cast)
263275
.map(EnumValue::getName);

0 commit comments

Comments
 (0)