diff --git a/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaQueryDataFetcher.java b/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaQueryDataFetcher.java index 97eb4943..f305cb7f 100644 --- a/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaQueryDataFetcher.java +++ b/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/GraphQLJpaQueryDataFetcher.java @@ -34,7 +34,7 @@ import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import graphql.schema.GraphQLScalarType; -import java.util.ArrayList; +import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -53,6 +53,12 @@ class GraphQLJpaQueryDataFetcher implements DataFetcher> { private static final Logger logger = LoggerFactory.getLogger(GraphQLJpaQueryDataFetcher.class); + public static final String AGGREGATE_PARAM_NAME = "aggregate"; + public static final String COUNT_FIELD_NAME = "count"; + public static final String GROUP_FIELD_NAME = "group"; + public static final String BY_FILED_NAME = "by"; + public static final String FIELD_ARGUMENT_NAME = "field"; + public static final String OF_ARGUMENT_NAME = "of"; private final int defaultMaxResults; private final int defaultPageLimitSize; @@ -76,7 +82,7 @@ public PagedResult get(DataFetchingEnvironment environment) { Optional pagesSelection = getSelectionField(rootNode, PAGE_PAGES_PARAM_NAME); Optional totalSelection = getSelectionField(rootNode, PAGE_TOTAL_PARAM_NAME); Optional recordsSelection = searchByFieldName(rootNode, QUERY_SELECT_PARAM_NAME); - Optional aggregateSelection = getSelectionField(rootNode, "aggregate"); + Optional aggregateSelection = getSelectionField(rootNode, AGGREGATE_PARAM_NAME); final int firstResult = page.getOffset(); final int maxResults = Integer.min(page.getLimit(), defaultMaxResults); // Limit max results to avoid OoM @@ -85,27 +91,39 @@ public PagedResult get(DataFetchingEnvironment environment) { .builder() .withOffset(firstResult) .withLimit(maxResults); - Optional> restrictedKeys = queryFactory.getRestrictedKeys(environment); + + final Optional> restrictedKeys = queryFactory.getRestrictedKeys(environment); if (recordsSelection.isPresent()) { if (restrictedKeys.isPresent()) { - final List queryKeys = new ArrayList<>(); - if (pageArgument.isPresent() || enableDefaultMaxResults) { - queryKeys.addAll( - queryFactory.queryKeys(environment, firstResult, maxResults, restrictedKeys.get()) + final List queryKeys = queryFactory.queryKeys( + environment, + firstResult, + maxResults, + restrictedKeys.get() ); + + if (!queryKeys.isEmpty()) { + pagedResult.withSelect( + queryFactory.queryResultList(environment, maxResults, restrictedKeys.get()) + ); + } else { + pagedResult.withSelect(List.of()); + } } else { - queryKeys.addAll(restrictedKeys.get()); + pagedResult.withSelect(queryFactory.queryResultList(environment, maxResults, restrictedKeys.get())); } - - final List resultList = queryFactory.queryResultList(environment, maxResults, queryKeys); - pagedResult.withSelect(resultList); } } if (totalSelection.isPresent() || pagesSelection.isPresent()) { - final Long total = queryFactory.queryTotalCount(environment, restrictedKeys); + final var selectResult = pagedResult.getSelect(); + + final long total = recordsSelection.isEmpty() || + selectResult.filter(Predicate.not(Collection::isEmpty)).isPresent() + ? queryFactory.queryTotalCount(environment, restrictedKeys) + : 0L; pagedResult.withTotal(total); } @@ -113,7 +131,7 @@ public PagedResult get(DataFetchingEnvironment environment) { aggregateSelection.ifPresent(aggregateField -> { Map aggregate = new LinkedHashMap<>(); - getFields(aggregateField.getSelectionSet(), "count") + getFields(aggregateField.getSelectionSet(), COUNT_FIELD_NAME) .forEach(countField -> { getCountOfArgument(countField) .ifPresentOrElse( @@ -130,16 +148,16 @@ public PagedResult get(DataFetchingEnvironment environment) { ); }); - getFields(aggregateField.getSelectionSet(), "group") + getFields(aggregateField.getSelectionSet(), GROUP_FIELD_NAME) .forEach(groupField -> { - var countField = getFields(groupField.getSelectionSet(), "count") + var countField = getFields(groupField.getSelectionSet(), COUNT_FIELD_NAME) .stream() .findFirst() .orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField)); var countOfArgumentValue = getCountOfArgument(countField); - Map.Entry[] groupings = getFields(groupField.getSelectionSet(), "by") + Map.Entry[] groupings = getFields(groupField.getSelectionSet(), BY_FILED_NAME) .stream() .map(GraphQLJpaQueryDataFetcher::groupByFieldEntry) .toArray(Map.Entry[]::new); @@ -176,21 +194,21 @@ public PagedResult get(DataFetchingEnvironment environment) { aggregate.put(getAliasOrName(groupField), resultList); }); - getSelectionField(aggregateField, "by") + getSelectionField(aggregateField, BY_FILED_NAME) .map(byField -> byField.getSelectionSet().getSelections().stream().map(Field.class::cast).toList()) .filter(Predicate.not(List::isEmpty)) .ifPresent(aggregateBySelections -> { var aggregatesBy = new LinkedHashMap<>(); - aggregate.put("by", aggregatesBy); + aggregate.put(BY_FILED_NAME, aggregatesBy); aggregateBySelections.forEach(groupField -> { - var countField = getFields(groupField.getSelectionSet(), "count") + var countField = getFields(groupField.getSelectionSet(), COUNT_FIELD_NAME) .stream() .findFirst() .orElseThrow(() -> new GraphQLException("Missing aggregate count for group: " + groupField) ); - Map.Entry[] groupings = getFields(groupField.getSelectionSet(), "by") + Map.Entry[] groupings = getFields(groupField.getSelectionSet(), BY_FILED_NAME) .stream() .map(GraphQLJpaQueryDataFetcher::groupByFieldEntry) .toArray(Map.Entry[]::new); @@ -239,7 +257,7 @@ public PagedResult get(DataFetchingEnvironment environment) { static Map.Entry groupByFieldEntry(Field selectedField) { String key = Optional.ofNullable(selectedField.getAlias()).orElse(selectedField.getName()); - String value = findArgument(selectedField, "field") + String value = findArgument(selectedField, FIELD_ARGUMENT_NAME) .map(Argument::getValue) .map(EnumValue.class::cast) .map(EnumValue::getName) @@ -257,7 +275,7 @@ static Map.Entry countFieldEntry(Field selectedField) { } static Optional getCountOfArgument(Field selectedField) { - return findArgument(selectedField, "of") + return findArgument(selectedField, OF_ARGUMENT_NAME) .map(Argument::getValue) .map(EnumValue.class::cast) .map(EnumValue::getName); diff --git a/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/PagedResult.java b/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/PagedResult.java index e41c9202..abba0b1c 100644 --- a/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/PagedResult.java +++ b/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/PagedResult.java @@ -20,6 +20,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; public class PagedResult { @@ -135,6 +136,10 @@ public Builder withSelect(List select) { return this; } + public Optional> getSelect() { + return Optional.ofNullable(select); + } + /** * Builder method for select parameter. * @param select field to set diff --git a/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/ResultStreamWrapper.java b/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/ResultStreamWrapper.java index 1196bb9a..4ab11f06 100644 --- a/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/ResultStreamWrapper.java +++ b/schema/src/main/java/com/introproventures/graphql/jpa/query/schema/impl/ResultStreamWrapper.java @@ -63,6 +63,8 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl return System.identityHashCode(proxy); } else if ("spliterator".equals(method.getName())) { return stream.spliterator(); + } else if ("isEmpty".equals(method.getName())) { + return size == 0; } throw new UnsupportedOperationException(method + " is not supported"); }