Skip to content

Skip redundant fetch query for empty query keys results #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLScalarType;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -53,6 +53,12 @@
class GraphQLJpaQueryDataFetcher implements DataFetcher<PagedResult<Object>> {

private static final Logger logger = LoggerFactory.getLogger(GraphQLJpaQueryDataFetcher.class);
public static final String AGGREGATE_PARAM_NAME = "aggregate";
public static final String COUNT_FIELD_NAME = "count";
public static final String GROUP_FIELD_NAME = "group";
public static final String BY_FILED_NAME = "by";
public static final String FIELD_ARGUMENT_NAME = "field";
public static final String OF_ARGUMENT_NAME = "of";

private final int defaultMaxResults;
private final int defaultPageLimitSize;
Expand All @@ -76,7 +82,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
Optional<Field> pagesSelection = getSelectionField(rootNode, PAGE_PAGES_PARAM_NAME);
Optional<Field> totalSelection = getSelectionField(rootNode, PAGE_TOTAL_PARAM_NAME);
Optional<Field> recordsSelection = searchByFieldName(rootNode, QUERY_SELECT_PARAM_NAME);
Optional<Field> aggregateSelection = getSelectionField(rootNode, "aggregate");
Optional<Field> aggregateSelection = getSelectionField(rootNode, AGGREGATE_PARAM_NAME);

final int firstResult = page.getOffset();
final int maxResults = Integer.min(page.getLimit(), defaultMaxResults); // Limit max results to avoid OoM
Expand All @@ -85,35 +91,47 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
.builder()
.withOffset(firstResult)
.withLimit(maxResults);
Optional<List<Object>> restrictedKeys = queryFactory.getRestrictedKeys(environment);

final Optional<List<Object>> restrictedKeys = queryFactory.getRestrictedKeys(environment);

if (recordsSelection.isPresent()) {
if (restrictedKeys.isPresent()) {
final List<Object> queryKeys = new ArrayList<>();

if (pageArgument.isPresent() || enableDefaultMaxResults) {
queryKeys.addAll(
queryFactory.queryKeys(environment, firstResult, maxResults, restrictedKeys.get())
final List<Object> queryKeys = queryFactory.queryKeys(
environment,
firstResult,
maxResults,
restrictedKeys.get()
);

if (!queryKeys.isEmpty()) {
pagedResult.withSelect(
queryFactory.queryResultList(environment, maxResults, restrictedKeys.get())
);
} else {
pagedResult.withSelect(List.of());
}
} else {
queryKeys.addAll(restrictedKeys.get());
pagedResult.withSelect(queryFactory.queryResultList(environment, maxResults, restrictedKeys.get()));
}

final List<Object> resultList = queryFactory.queryResultList(environment, maxResults, queryKeys);
pagedResult.withSelect(resultList);
}
}

if (totalSelection.isPresent() || pagesSelection.isPresent()) {
final Long total = queryFactory.queryTotalCount(environment, restrictedKeys);
final var selectResult = pagedResult.getSelect();

final long total = recordsSelection.isEmpty() ||
selectResult.filter(Predicate.not(Collection::isEmpty)).isPresent()
? queryFactory.queryTotalCount(environment, restrictedKeys)
: 0L;

pagedResult.withTotal(total);
}

aggregateSelection.ifPresent(aggregateField -> {
Map<String, Object> aggregate = new LinkedHashMap<>();

getFields(aggregateField.getSelectionSet(), "count")
getFields(aggregateField.getSelectionSet(), COUNT_FIELD_NAME)
.forEach(countField -> {
getCountOfArgument(countField)
.ifPresentOrElse(
Expand All @@ -130,16 +148,16 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
);
});

getFields(aggregateField.getSelectionSet(), "group")
getFields(aggregateField.getSelectionSet(), GROUP_FIELD_NAME)
.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
var countField = getFields(groupField.getSelectionSet(), COUNT_FIELD_NAME)
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField));

var countOfArgumentValue = getCountOfArgument(countField);

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), BY_FILED_NAME)
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);
Expand Down Expand Up @@ -176,21 +194,21 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
aggregate.put(getAliasOrName(groupField), resultList);
});

getSelectionField(aggregateField, "by")
getSelectionField(aggregateField, BY_FILED_NAME)
.map(byField -> byField.getSelectionSet().getSelections().stream().map(Field.class::cast).toList())
.filter(Predicate.not(List::isEmpty))
.ifPresent(aggregateBySelections -> {
var aggregatesBy = new LinkedHashMap<>();
aggregate.put("by", aggregatesBy);
aggregate.put(BY_FILED_NAME, aggregatesBy);

aggregateBySelections.forEach(groupField -> {
var countField = getFields(groupField.getSelectionSet(), "count")
var countField = getFields(groupField.getSelectionSet(), COUNT_FIELD_NAME)
.stream()
.findFirst()
.orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)
);

Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), "by")
Map.Entry<String, String>[] groupings = getFields(groupField.getSelectionSet(), BY_FILED_NAME)
.stream()
.map(GraphQLJpaQueryDataFetcher::groupByFieldEntry)
.toArray(Map.Entry[]::new);
Expand Down Expand Up @@ -239,7 +257,7 @@ public PagedResult<Object> get(DataFetchingEnvironment environment) {
static Map.Entry<String, String> groupByFieldEntry(Field selectedField) {
String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName());

String value = findArgument(selectedField, "field")
String value = findArgument(selectedField, FIELD_ARGUMENT_NAME)
.map(Argument::getValue)
.map(EnumValue.class::cast)
.map(EnumValue::getName)
Expand All @@ -257,7 +275,7 @@ static Map.Entry<String, String> countFieldEntry(Field selectedField) {
}

static Optional<String> getCountOfArgument(Field selectedField) {
return findArgument(selectedField, "of")
return findArgument(selectedField, OF_ARGUMENT_NAME)
.map(Argument::getValue)
.map(EnumValue.class::cast)
.map(EnumValue::getName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class PagedResult<T> {

Expand Down Expand Up @@ -135,6 +136,10 @@ public Builder<T> withSelect(List<T> select) {
return this;
}

public Optional<List<T>> getSelect() {
return Optional.ofNullable(select);
}

/**
* Builder method for select parameter.
* @param select field to set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl
return System.identityHashCode(proxy);
} else if ("spliterator".equals(method.getName())) {
return stream.spliterator();
} else if ("isEmpty".equals(method.getName())) {
return size == 0;
}
throw new UnsupportedOperationException(method + " is not supported");
}
Expand Down
Loading