From ab24d83e0b0cf21c7278f237d44fd1f3194ae727 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 27 Feb 2025 00:46:44 +0800 Subject: [PATCH 1/2] Add more aggregation tests Signed-off-by: Lantao Jin --- .../sql/calcite/CalcitePlanContext.java | 15 +- .../sql/calcite/CalciteRelNodeVisitor.java | 6 - .../sql/calcite/CalciteRexNodeVisitor.java | 21 +- .../sql/calcite/ExtendedRexBuilder.java | 49 +++ .../calcite/utils/DataTypeTransformer.java | 58 ---- .../calcite/utils/OpenSearchTypeFactory.java | 8 + .../opensearch/sql/executor/QueryService.java | 5 +- .../standalone/CalcitePPLAggregationIT.java | 242 +++++++++++++- .../scan/OpenSearchIndexEnumerator.java | 6 +- .../calcite/CalcitePPLAggregationTest.java | 306 +++++++++++++++++- 10 files changed, 614 insertions(+), 102 deletions(-) delete mode 100644 core/src/main/java/org/opensearch/sql/calcite/utils/DataTypeTransformer.java diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java index 775a0b8e9f..c1ab9869a0 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java @@ -5,10 +5,11 @@ package org.opensearch.sql.calcite; +import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; + import java.sql.Connection; import java.util.function.BiFunction; import lombok.Getter; -import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.rex.RexNode; import org.apache.calcite.tools.FrameworkConfig; import org.apache.calcite.tools.RelBuilder; @@ -24,10 +25,10 @@ public class CalcitePlanContext { @Getter private boolean isResolvingJoinCondition = false; - private CalcitePlanContext(FrameworkConfig config, JavaTypeFactory typeFactory) { + private CalcitePlanContext(FrameworkConfig config) { this.config = config; - this.connection = CalciteToolsHelper.connect(config, typeFactory); - this.relBuilder = CalciteToolsHelper.create(config, typeFactory, connection); + this.connection = CalciteToolsHelper.connect(config, TYPE_FACTORY); + this.relBuilder = CalciteToolsHelper.create(config, TYPE_FACTORY, connection); this.rexBuilder = new ExtendedRexBuilder(relBuilder.getRexBuilder()); } @@ -41,10 +42,6 @@ public RexNode resolveJoinCondition( } public static CalcitePlanContext create(FrameworkConfig config) { - return new CalcitePlanContext(config, null); - } - - public static CalcitePlanContext create(FrameworkConfig config, JavaTypeFactory typeFactory) { - return new CalcitePlanContext(config, typeFactory); + return new CalcitePlanContext(config); } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 7a5646996e..15ae04db2d 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -196,12 +196,6 @@ public RelNode visitAggregation(Aggregation node, CalcitePlanContext context) { groupByList.add(spanRex); // add span's group alias field (most recent added expression) } - // List aggList = node.getAggExprList().stream() - // .map(expr -> rexVisitor.analyze(expr, context)) - // .collect(Collectors.toList()); - // relBuilder.aggregate(relBuilder.groupKey(groupByList), - // aggList.stream().map(rex -> (MyAggregateCall) rex) - // .map(MyAggregateCall::getCall).collect(Collectors.toList())); context.relBuilder.aggregate(context.relBuilder.groupKey(groupByList), aggList); return context.relBuilder.peek(); } diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index 47fc0babbc..c75822a455 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -19,6 +19,7 @@ import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.parser.SqlParserUtil; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.DateString; import org.apache.calcite.util.TimeString; @@ -29,6 +30,7 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -39,7 +41,6 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.calcite.utils.BuiltinFunctionUtils; -import org.opensearch.sql.calcite.utils.DataTypeTransformer; public class CalciteRexNodeVisitor extends AbstractNodeVisitor { @@ -198,13 +199,9 @@ public RexNode visitSpan(Span node, CalcitePlanContext context) { RelDataTypeFactory typeFactory = context.rexBuilder.getTypeFactory(); SpanUnit unit = node.getUnit(); if (isTimeBased(unit)) { - String datetimeUnitString = DataTypeTransformer.translate(unit); - RexNode interval = - context.rexBuilder.makeIntervalLiteral( - new BigDecimal(value.toString()), - new SqlIntervalQualifier(datetimeUnitString, SqlParserPos.ZERO)); - // TODO not supported yet - return interval; + SqlIntervalQualifier intervalQualifier = context.rexBuilder.createIntervalUntil(unit); + long millis = SqlParserUtil.intervalToMillis(value.toString(), intervalQualifier); + return context.rexBuilder.makeIntervalLiteral(new BigDecimal(millis), intervalQualifier); } else { // if the unit is not time base - create a math expression to bucket the span partitions return context.rexBuilder.makeCall( @@ -247,4 +244,12 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) { return context.rexBuilder.makeCall( BuiltinFunctionUtils.translate(node.getFuncName()), arguments); } + + @Override + public RexNode visitInterval(Interval node, CalcitePlanContext context) { + RexNode field = analyze(node.getValue(), context); + return context.rexBuilder.makeIntervalLiteral( + new BigDecimal(field.toString()), + new SqlIntervalQualifier(node.getUnit().name(), SqlParserPos.ZERO)); + } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java b/core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java index 68498d83a3..5e0c78c0a0 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java +++ b/core/src/main/java/org/opensearch/sql/calcite/ExtendedRexBuilder.java @@ -5,9 +5,13 @@ package org.opensearch.sql.calcite; +import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.opensearch.sql.ast.expression.SpanUnit; public class ExtendedRexBuilder extends RexBuilder { @@ -22,4 +26,49 @@ public RexNode coalesce(RexNode... nodes) { public RexNode equals(RexNode n1, RexNode n2) { return this.makeCall(SqlStdOperatorTable.EQUALS, n1, n2); } + + public SqlIntervalQualifier createIntervalUntil(SpanUnit unit) { + TimeUnit timeUnit; + switch (unit) { + case MILLISECOND: + case MS: + timeUnit = TimeUnit.MILLISECOND; + break; + case SECOND: + case S: + timeUnit = TimeUnit.SECOND; + break; + case MINUTE: + case m: + timeUnit = TimeUnit.MINUTE; + break; + case HOUR: + case H: + timeUnit = TimeUnit.HOUR; + break; + case DAY: + case D: + timeUnit = TimeUnit.DAY; + break; + case WEEK: + case W: + timeUnit = TimeUnit.WEEK; + break; + case MONTH: + case M: + timeUnit = TimeUnit.MONTH; + break; + case QUARTER: + case Q: + timeUnit = TimeUnit.QUARTER; + break; + case YEAR: + case Y: + timeUnit = TimeUnit.YEAR; + break; + default: + timeUnit = TimeUnit.EPOCH; + } + return new SqlIntervalQualifier(timeUnit, timeUnit, SqlParserPos.ZERO); + } } diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/DataTypeTransformer.java b/core/src/main/java/org/opensearch/sql/calcite/utils/DataTypeTransformer.java deleted file mode 100644 index dea36f2eb8..0000000000 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/DataTypeTransformer.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.calcite.utils; - -import static org.opensearch.sql.ast.expression.SpanUnit.DAY; -import static org.opensearch.sql.ast.expression.SpanUnit.HOUR; -import static org.opensearch.sql.ast.expression.SpanUnit.MILLISECOND; -import static org.opensearch.sql.ast.expression.SpanUnit.MINUTE; -import static org.opensearch.sql.ast.expression.SpanUnit.MONTH; -import static org.opensearch.sql.ast.expression.SpanUnit.NONE; -import static org.opensearch.sql.ast.expression.SpanUnit.QUARTER; -import static org.opensearch.sql.ast.expression.SpanUnit.SECOND; -import static org.opensearch.sql.ast.expression.SpanUnit.WEEK; -import static org.opensearch.sql.ast.expression.SpanUnit.YEAR; - -import org.opensearch.sql.ast.expression.SpanUnit; - -public interface DataTypeTransformer { - - static String translate(SpanUnit unit) { - switch (unit) { - case UNKNOWN: - case NONE: - return NONE.name(); - case MILLISECOND: - case MS: - return MILLISECOND.name(); - case SECOND: - case S: - return SECOND.name(); - case MINUTE: - case m: - return MINUTE.name(); - case HOUR: - case H: - return HOUR.name(); - case DAY: - case D: - return DAY.name(); - case WEEK: - case W: - return WEEK.name(); - case MONTH: - case M: - return MONTH.name(); - case QUARTER: - case Q: - return QUARTER.name(); - case YEAR: - case Y: - return YEAR.name(); - } - return ""; - } -} diff --git a/core/src/main/java/org/opensearch/sql/calcite/utils/OpenSearchTypeFactory.java b/core/src/main/java/org/opensearch/sql/calcite/utils/OpenSearchTypeFactory.java index a60408ccda..4a992c885a 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/utils/OpenSearchTypeFactory.java +++ b/core/src/main/java/org/opensearch/sql/calcite/utils/OpenSearchTypeFactory.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.FLOAT; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.INTERVAL; import static org.opensearch.sql.data.type.ExprCoreType.IP; import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.SHORT; @@ -151,6 +152,13 @@ public static ExprType convertRelDataTypeToExprType(RelDataType type) { return TIMESTAMP; case GEOMETRY: return IP; + case INTERVAL_YEAR: + case INTERVAL_MONTH: + case INTERVAL_DAY: + case INTERVAL_HOUR: + case INTERVAL_MINUTE: + case INTERVAL_SECOND: + return INTERVAL; case ARRAY: return ARRAY; case MAP: diff --git a/core/src/main/java/org/opensearch/sql/executor/QueryService.java b/core/src/main/java/org/opensearch/sql/executor/QueryService.java index 7a04ad7def..16845b4abb 100644 --- a/core/src/main/java/org/opensearch/sql/executor/QueryService.java +++ b/core/src/main/java/org/opensearch/sql/executor/QueryService.java @@ -8,8 +8,6 @@ package org.opensearch.sql.executor; -import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.TYPE_FACTORY; - import java.security.AccessController; import java.security.PrivilegedAction; import java.util.List; @@ -80,8 +78,7 @@ public void execute( (PrivilegedAction) () -> { final FrameworkConfig config = buildFrameworkConfig(); - final CalcitePlanContext context = - CalcitePlanContext.create(config, TYPE_FACTORY); + final CalcitePlanContext context = CalcitePlanContext.create(config); executePlanByCalcite(analyze(plan, context), context, listener); return null; }); diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java index ce9dbe8357..73c7d2bd36 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/standalone/CalcitePPLAggregationIT.java @@ -158,6 +158,60 @@ public void testMultipleAggregatesWithAliases() { actual); } + @Test + public void testMultipleAggregatesWithAliasesByClause() { + String actual = + execute( + String.format( + "source=%s | stats avg(balance) as avg, max(balance) as max, min(balance) as min," + + " count() as cnt by gender", + TEST_INDEX_BANK)); + assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"gender\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"avg\",\n" + + " \"type\": \"double\"\n" + + " },\n" + + " {\n" + + " \"name\": \"max\",\n" + + " \"type\": \"long\"\n" + + " },\n" + + " {\n" + + " \"name\": \"min\",\n" + + " \"type\": \"long\"\n" + + " },\n" + + " {\n" + + " \"name\": \"cnt\",\n" + + " \"type\": \"long\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " \"F\",\n" + + " 40488.0,\n" + + " 48086,\n" + + " 32838,\n" + + " 3\n" + + " ],\n" + + " [\n" + + " \"M\",\n" + + " 16377.25,\n" + + " 39225,\n" + + " 4180,\n" + + " 4\n" + + " ]\n" + + " ],\n" + + " \"total\": 2,\n" + + " \"size\": 2\n" + + "}", + actual); + } + @Test public void testAvgByField() { String actual = @@ -268,8 +322,21 @@ public void testAvgBySpanAndFields() { actual); } - // TODO fallback to V2 because missing conversion LogicalAggregate[convention: NONE -> ENUMERABLE] + /** + * TODO Calcite doesn't support group by window, but it support Tumble table function. See + * `SqlToRelConverterTest` + */ @Ignore + public void testAvgByTimeSpanAndFields() { + String actual = + execute( + String.format( + "source=%s | stats avg(balance) by span(birthdate, 1 day) as age_balance", + TEST_INDEX_BANK)); + assertEquals("", actual); + } + + @Test public void testCountDistinct() { String actual = execute( @@ -278,22 +345,105 @@ public void testCountDistinct() { "{\n" + " \"schema\": [\n" + " {\n" + + " \"name\": \"gender\",\n" + + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + " \"name\": \"distinct_count(state)\",\n" - + " \"type\": \"integer\"\n" + + " \"type\": \"long\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " \"F\",\n" + + " 3\n" + + " ],\n" + + " [\n" + + " \"M\",\n" + + " 4\n" + + " ]\n" + + " ],\n" + + " \"total\": 2,\n" + + " \"size\": 2\n" + + "}", + actual); + } + + @Test + public void testCountDistinctWithAlias() { + String actual = + execute( + String.format( + "source=%s | stats distinct_count(state) as dc by gender", TEST_INDEX_BANK)); + assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"gender\",\n" + + " \"type\": \"string\"\n" + " },\n" + " {\n" + + " \"name\": \"dc\",\n" + + " \"type\": \"long\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " \"F\",\n" + + " 3\n" + + " ],\n" + + " [\n" + + " \"M\",\n" + + " 4\n" + + " ]\n" + + " ],\n" + + " \"total\": 2,\n" + + " \"size\": 2\n" + + "}", + actual); + } + + @Ignore + public void testApproxCountDistinct() { + String actual = + execute( + String.format( + "source=%s | stats distinct_count_approx(state) by gender", TEST_INDEX_BANK)); + } + + @Test + public void testStddevSampStddevPop() { + String actual = + execute( + String.format( + "source=%s | stats stddev_samp(balance) as ss, stddev_pop(balance) as sp by gender", + TEST_INDEX_BANK)); + assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + " \"name\": \"gender\",\n" + " \"type\": \"string\"\n" + + " },\n" + + " {\n" + + " \"name\": \"ss\",\n" + + " \"type\": \"double\"\n" + + " },\n" + + " {\n" + + " \"name\": \"sp\",\n" + + " \"type\": \"double\"\n" + " }\n" + " ],\n" + " \"datarows\": [\n" + " [\n" - + " 3,\n" - + " \"f\"\n" + + " \"F\",\n" + + " 7624.132999889233,\n" + + " 6225.078526947806\n" + " ],\n" + " [\n" - + " 4,\n" - + " \"m\"\n" + + " \"M\",\n" + + " 16177.114233282358,\n" + + " 14009.791885945344\n" + " ]\n" + " ],\n" + " \"total\": 2,\n" @@ -301,4 +451,84 @@ public void testCountDistinct() { + "}", actual); } + + @Test + public void testAggWithEval() { + String actual = + execute( + String.format( + "source=%s | eval a = 1, b = a | stats avg(a) as avg_a by b", TEST_INDEX_BANK)); + assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"b\",\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " {\n" + + " \"name\": \"avg_a\",\n" + + " \"type\": \"double\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " 1,\n" + + " 1.0\n" + + " ]\n" + + " ],\n" + + " \"total\": 1,\n" + + " \"size\": 1\n" + + "}", + actual); + } + + @Test + public void testAggWithBackticksAlias() { + String actual = + execute(String.format("source=%s | stats sum(`balance`) as `sum_b`", TEST_INDEX_BANK)); + assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"sum_b\",\n" + + " \"type\": \"long\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " 186973\n" + + " ]\n" + + " ],\n" + + " \"total\": 1,\n" + + " \"size\": 1\n" + + "}", + actual); + } + + @Test + public void testSimpleTwoLevelStats() { + String actual = + execute( + String.format( + "source=%s | stats avg(balance) as avg_by_gender by gender | stats" + + " avg(avg_by_gender) as avg_avg", + TEST_INDEX_BANK)); + assertEquals( + "{\n" + + " \"schema\": [\n" + + " {\n" + + " \"name\": \"avg_avg\",\n" + + " \"type\": \"double\"\n" + + " }\n" + + " ],\n" + + " \"datarows\": [\n" + + " [\n" + + " 28432.625\n" + + " ]\n" + + " ],\n" + + " \"total\": 1,\n" + + " \"size\": 1\n" + + "}", + actual); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java index 518e67d49f..e4db98ac4b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexEnumerator.java @@ -11,6 +11,7 @@ import lombok.EqualsAndHashCode; import lombok.ToString; import org.apache.calcite.linq4j.Enumerator; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.request.OpenSearchRequest; @@ -61,7 +62,10 @@ private void fetchNextBatch() { @Override public Object current() { - Object[] p = fields.stream().map(k -> current.tupleValue().get(k).valueForCalcite()).toArray(); + Object[] p = + fields.stream() + .map(k -> current.tupleValue().getOrDefault(k, ExprNullValue.of()).valueForCalcite()) + .toArray(); return p; } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java index 12c4d5004b..22b360094d 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java @@ -68,6 +68,32 @@ public void testMultipleAggregatesWithAliases() { verifyPPLToSparkSQL(root, expectedSparkSql); } + @Test + public void testMultipleAggregatesWithAliasesByClause() { + String ppl = + "source=EMP | stats avg(SAL) as avg_sal, max(SAL) as max_sal, min(SAL) as min_sal, count()" + + " as cnt by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalAggregate(group=[{7}], avg_sal=[AVG($5)], max_sal=[MAX($5)], min_sal=[MIN($5)]," + + " cnt=[COUNT()])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + + "DEPTNO=20; avg_sal=2175.00; max_sal=3000.00; min_sal=800.00; cnt=5\n" + + "DEPTNO=10; avg_sal=2916.66; max_sal=5000.00; min_sal=1300.00; cnt=3\n" + + "DEPTNO=30; avg_sal=1566.66; max_sal=2850.00; min_sal=950.00; cnt=6\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT `DEPTNO`, AVG(`SAL`) `avg_sal`, MAX(`SAL`) `max_sal`, MIN(`SAL`) `min_sal`," + + " COUNT(*) `cnt`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + @Test public void testAvgByField() { String ppl = "source=EMP | stats avg(SAL) by DEPTNO"; @@ -145,16 +171,38 @@ public void testAvgBySpanAndFields() { verifyPPLToSparkSQL(root, expectedSparkSql); } - @Ignore + /** + * TODO Calcite doesn't support group by window, but it support Tumble table function. See + * `SqlToRelConverterTest` + */ + @Test public void testAvgByTimeSpanAndFields() { String ppl = - "source=EMP | stats avg(SAL) by span(HIREDATE, 1y) as hiredate_span, DEPTNO | sort DEPTNO," - + " hiredate_span"; + "source=EMP | stats avg(SAL) by span(HIREDATE, 1 day) as hiredate_span, DEPTNO | sort" + + " DEPTNO, hiredate_span"; RelNode root = getRelNode(ppl); - String expectedLogical = ""; + String expectedLogical = + "" + + "LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[ASC])\n" + + " LogicalAggregate(group=[{1, 2}], avg(SAL)=[AVG($0)])\n" + + " LogicalProject(SAL=[$5], DEPTNO=[$7], hiredate_span=[86400000:INTERVAL DAY])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); - String expectedResult = ""; + + String expectedResult = + "" + + "DEPTNO=10; hiredate_span=+1; avg(SAL)=2916.66\n" + + "DEPTNO=20; hiredate_span=+1; avg(SAL)=2175.00\n" + + "DEPTNO=30; hiredate_span=+1; avg(SAL)=1566.66\n"; verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `DEPTNO`, INTERVAL '1' DAY `hiredate_span`, AVG(`SAL`) `avg(SAL)`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`, INTERVAL '1' DAY\n" + + "ORDER BY `DEPTNO` NULLS LAST, INTERVAL '1' DAY NULLS LAST"; + verifyPPLToSparkSQL(root, expectedSparkSql); } @Test @@ -181,14 +229,252 @@ public void testCountDistinct() { verifyPPLToSparkSQL(root, expectedSparkSql); } + @Test + public void testCountDistinctWithAlias() { + String ppl = "source=EMP | stats distinct_count(JOB) as dc by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{7}], dc=[COUNT(DISTINCT $2)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = "" + "DEPTNO=20; dc=3\n" + "DEPTNO=10; dc=3\n" + "DEPTNO=30; dc=3\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `DEPTNO`, COUNT(DISTINCT `JOB`) `dc`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + @Ignore - public void testMultipleLevelStats() { - // TODO unsupported - String ppl = "source=EMP | stats avg(SAL) as avg_sal | stats avg(COMM) as avg_comm"; + public void testApproxCountDistinct() { + String ppl = "source=EMP | stats distinct_count_approx(JOB) by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{7}], distinct_count_approx(JOB)=[COUNT(DISTINCT $2)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + + "DEPTNO=20; distinct_count(JOB)=3\n" + + "DEPTNO=10; distinct_count(JOB)=3\n" + + "DEPTNO=30; distinct_count(JOB)=3\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `DEPTNO`, COUNT(DISTINCT `JOB`) `distinct_count(JOB)`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testStddevSampByField() { + String ppl = "source=EMP | stats stddev_samp(SAL) by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{7}], stddev_samp(SAL)=[STDDEV_SAMP($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + + "DEPTNO=20; stddev_samp(SAL)=1123.33\n" + + "DEPTNO=10; stddev_samp(SAL)=1893.62\n" + + "DEPTNO=30; stddev_samp(SAL)=668.33\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `DEPTNO`, STDDEV_SAMP(`SAL`) `stddev_samp(SAL)`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testStddevSampByFieldWithAlias() { + String ppl = "source=EMP | stats stddev_samp(SAL) as samp by span(EMPNO, 100) as empno_span"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{1}], samp=[STDDEV_SAMP($0)])\n" + + " LogicalProject(SAL=[$5], empno_span=[*(FLOOR(/($0, 100)), 100)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + + "empno_span=7300.0; samp=null\n" + + "empno_span=7400.0; samp=null\n" + + "empno_span=7500.0; samp=1219.75\n" + + "empno_span=7600.0; samp=1131.37\n" + + "empno_span=7700.0; samp=388.90\n" + + "empno_span=7800.0; samp=2145.53\n" + + "empno_span=7900.0; samp=1096.58\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT FLOOR(`EMPNO` / 100) * 100 `empno_span`, STDDEV_SAMP(`SAL`) `samp`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY FLOOR(`EMPNO` / 100) * 100"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testStddevPopByField() { + String ppl = "source=EMP | stats stddev_pop(SAL) by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{7}], stddev_pop(SAL)=[STDDEV_POP($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + + "DEPTNO=20; stddev_pop(SAL)=1004.73\n" + + "DEPTNO=10; stddev_pop(SAL)=1546.14\n" + + "DEPTNO=30; stddev_pop(SAL)=610.10\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `DEPTNO`, STDDEV_POP(`SAL`) `stddev_pop(SAL)`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testStddevPopByFieldWithAlias() { + String ppl = "source=EMP | stats stddev_pop(SAL) as pop by DEPTNO"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{7}], pop=[STDDEV_POP($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = + "" + "DEPTNO=20; pop=1004.73\n" + "DEPTNO=10; pop=1546.14\n" + "DEPTNO=30; pop=610.10\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `DEPTNO`, STDDEV_POP(`SAL`) `pop`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `DEPTNO`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAggWithEval() { + String ppl = "source=EMP | eval a = 1 | stats avg(a) as avg_a"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{}], avg_a=[AVG($0)])\n" + + " LogicalProject(a=[1])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = "avg_a=1.0\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = "" + "SELECT AVG(1) `avg_a`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAggByWithEval() { + String ppl = "source=EMP | eval a = 1, b = a | stats avg(a) as avg_a by b"; RelNode root = getRelNode(ppl); - String expectedLogical = ""; + String expectedLogical = + "" + + "LogicalAggregate(group=[{1}], avg_a=[AVG($0)])\n" + + " LogicalProject(a=[1], b=[1])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; verifyLogical(root, expectedLogical); - String expectedResult = ""; + String expectedResult = "b=1; avg_a=1.0\n"; verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + "SELECT 1 `b`, AVG(1) `avg_a`\n" + "FROM `scott`.`EMP`\n" + "GROUP BY 1"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testAggWithBackticksAlias() { + String ppl = "source=EMP | stats avg(`SAL`) as `avg_sal`"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{}], avg_sal=[AVG($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = "avg_sal=2073.21\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = "" + "SELECT AVG(`SAL`) `avg_sal`\n" + "FROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testSimpleTwoLevelStats() { + String ppl = "source=EMP | stats avg(SAL) as avg_sal | stats avg(avg_sal) as avg_avg_sal"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{}], avg_avg_sal=[AVG($0)])\n" + + " LogicalAggregate(group=[{}], avg_sal=[AVG($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedResult = "avg_avg_sal=2073.21\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT AVG(`avg_sal`) `avg_avg_sal`\n" + + "FROM (SELECT AVG(`SAL`) `avg_sal`\n" + + "FROM `scott`.`EMP`) `t`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testTwoLevelStats() { + String ppl = + "source=EMP | stats avg(SAL) as avg_sal by DEPTNO, MGR | stats avg(avg_sal) as avg_avg_sal" + + " by MGR"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "" + + "LogicalAggregate(group=[{0}], avg_avg_sal=[AVG($2)])\n" + + " LogicalAggregate(group=[{3, 7}], avg_sal=[AVG($5)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedResult = + "" + + "MGR=null; avg_avg_sal=5000.00\n" + + "MGR=7698; avg_avg_sal=1310.00\n" + + "MGR=7782; avg_avg_sal=1300.00\n" + + "MGR=7788; avg_avg_sal=1100.00\n" + + "MGR=7902; avg_avg_sal=800.00\n" + + "MGR=7566; avg_avg_sal=3000.00\n" + + "MGR=7839; avg_avg_sal=2758.33\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "" + + "SELECT `MGR`, AVG(`avg_sal`) `avg_avg_sal`\n" + + "FROM (SELECT `MGR`, `DEPTNO`, AVG(`SAL`) `avg_sal`\n" + + "FROM `scott`.`EMP`\n" + + "GROUP BY `MGR`, `DEPTNO`) `t`\n" + + "GROUP BY `MGR`"; + verifyPPLToSparkSQL(root, expectedSparkSql); } } From 60e10db02b4ef294e5f3dfb03bc24d47782e7e47 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 27 Feb 2025 10:08:22 +0800 Subject: [PATCH 2/2] delete unrelavant code Signed-off-by: Lantao Jin --- .../opensearch/sql/calcite/CalciteRexNodeVisitor.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java index c75822a455..fd67d08bf1 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java @@ -18,7 +18,6 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlIntervalQualifier; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.parser.SqlParserUtil; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.DateString; @@ -30,7 +29,6 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; @@ -244,12 +242,4 @@ public RexNode visitFunction(Function node, CalcitePlanContext context) { return context.rexBuilder.makeCall( BuiltinFunctionUtils.translate(node.getFuncName()), arguments); } - - @Override - public RexNode visitInterval(Interval node, CalcitePlanContext context) { - RexNode field = analyze(node.getValue(), context); - return context.rexBuilder.makeIntervalLiteral( - new BigDecimal(field.toString()), - new SqlIntervalQualifier(node.getUnit().name(), SqlParserPos.ZERO)); - } }