Skip to content

Commit 09cbf73

Browse files
author
Julien Ruaux
committed
Added support for aggregations
1 parent d32f9dd commit 09cbf73

15 files changed

+814
-495
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package com.redis.trino;
2+
3+
import static io.trino.spi.type.BigintType.BIGINT;
4+
import static io.trino.spi.type.DoubleType.DOUBLE;
5+
import static io.trino.spi.type.IntegerType.INTEGER;
6+
import static io.trino.spi.type.RealType.REAL;
7+
import static io.trino.spi.type.SmallintType.SMALLINT;
8+
import static io.trino.spi.type.TinyintType.TINYINT;
9+
10+
import java.util.Arrays;
11+
import java.util.List;
12+
import java.util.Map;
13+
import java.util.Objects;
14+
import java.util.Optional;
15+
16+
import com.fasterxml.jackson.annotation.JsonCreator;
17+
import com.fasterxml.jackson.annotation.JsonProperty;
18+
19+
import io.trino.spi.connector.AggregateFunction;
20+
import io.trino.spi.connector.ColumnHandle;
21+
import io.trino.spi.expression.Variable;
22+
import io.trino.spi.type.Type;
23+
24+
public class MetricAggregation {
25+
public static final String MAX = "max";
26+
public static final String MIN = "min";
27+
public static final String AVG = "avg";
28+
public static final String SUM = "sum";
29+
public static final String COUNT = "count";
30+
private static final List<String> SUPPORTED_AGGREGATION_FUNCTIONS = Arrays.asList(MAX, MIN, AVG, SUM, COUNT);
31+
private static final List<Type> NUMERIC_TYPES = Arrays.asList(REAL, DOUBLE, TINYINT, SMALLINT, INTEGER, BIGINT);
32+
private final String functionName;
33+
private final Type outputType;
34+
private final Optional<RediSearchColumnHandle> columnHandle;
35+
private final String alias;
36+
37+
@JsonCreator
38+
public MetricAggregation(@JsonProperty("functionName") String functionName,
39+
@JsonProperty("outputType") Type outputType,
40+
@JsonProperty("columnHandle") Optional<RediSearchColumnHandle> columnHandle,
41+
@JsonProperty("alias") String alias) {
42+
this.functionName = functionName;
43+
this.outputType = outputType;
44+
this.columnHandle = columnHandle;
45+
this.alias = alias;
46+
}
47+
48+
@JsonProperty
49+
public String getFunctionName() {
50+
return functionName;
51+
}
52+
53+
@JsonProperty
54+
public Type getOutputType() {
55+
return outputType;
56+
}
57+
58+
@JsonProperty
59+
public Optional<RediSearchColumnHandle> getColumnHandle() {
60+
return columnHandle;
61+
}
62+
63+
@JsonProperty
64+
public String getAlias() {
65+
return alias;
66+
}
67+
68+
public static boolean isNumericType(Type type) {
69+
return NUMERIC_TYPES.contains(type);
70+
}
71+
72+
public static Optional<MetricAggregation> handleAggregation(AggregateFunction function,
73+
Map<String, ColumnHandle> assignments, String alias) {
74+
if (!SUPPORTED_AGGREGATION_FUNCTIONS.contains(function.getFunctionName())) {
75+
return Optional.empty();
76+
}
77+
// check
78+
// 1. Function input can be found in assignments
79+
// 2. Target type of column being aggregate must be numeric type
80+
// 3. ColumnHandle support predicates(since text treats as VARCHAR, but text can
81+
// not be treats as term in es by default
82+
Optional<RediSearchColumnHandle> parameterColumnHandle = function.getArguments().stream()
83+
.filter(Variable.class::isInstance).map(Variable.class::cast).map(Variable::getName)
84+
.filter(assignments::containsKey).findFirst().map(assignments::get)
85+
.map(RediSearchColumnHandle.class::cast)
86+
.filter(column -> MetricAggregation.isNumericType(column.getType()));
87+
// only count can accept empty ElasticsearchColumnHandle
88+
if (!COUNT.equals(function.getFunctionName()) && parameterColumnHandle.isEmpty()) {
89+
return Optional.empty();
90+
}
91+
return Optional.of(new MetricAggregation(function.getFunctionName(), function.getOutputType(),
92+
parameterColumnHandle, alias));
93+
}
94+
95+
@Override
96+
public boolean equals(Object o) {
97+
if (this == o) {
98+
return true;
99+
}
100+
if (o == null || getClass() != o.getClass()) {
101+
return false;
102+
}
103+
MetricAggregation that = (MetricAggregation) o;
104+
return Objects.equals(functionName, that.functionName) && Objects.equals(outputType, that.outputType)
105+
&& Objects.equals(columnHandle, that.columnHandle) && Objects.equals(alias, that.alias);
106+
}
107+
108+
@Override
109+
public int hashCode() {
110+
return Objects.hash(functionName, outputType, columnHandle, alias);
111+
}
112+
113+
@Override
114+
public String toString() {
115+
return String.format("%s(%s)", functionName, columnHandle.map(RediSearchColumnHandle::getName).orElse(""));
116+
}
117+
}

subprojects/trino-redisearch/src/main/java/com/redis/trino/RediSearchMetadata.java

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
11
package com.redis.trino;
22

33
import static com.google.common.base.Preconditions.checkState;
4+
import static com.google.common.base.Verify.verify;
45
import static com.google.common.collect.ImmutableList.toImmutableList;
5-
import static java.lang.Math.toIntExact;
66
import static java.util.Objects.requireNonNull;
77
import static java.util.stream.Collectors.toList;
88

99
import java.util.Collection;
1010
import java.util.List;
1111
import java.util.Map;
1212
import java.util.Optional;
13-
import java.util.OptionalInt;
13+
import java.util.OptionalLong;
1414
import java.util.concurrent.atomic.AtomicReference;
1515

1616
import com.google.common.collect.ImmutableList;
1717
import com.google.common.collect.ImmutableMap;
18+
import com.redis.trino.RediSearchTableHandle.Type;
1819

20+
import io.airlift.log.Logger;
1921
import io.airlift.slice.Slice;
22+
import io.trino.spi.connector.AggregateFunction;
23+
import io.trino.spi.connector.AggregationApplicationResult;
24+
import io.trino.spi.connector.Assignment;
2025
import io.trino.spi.connector.ColumnHandle;
2126
import io.trino.spi.connector.ColumnMetadata;
2227
import io.trino.spi.connector.ConnectorInsertTableHandle;
@@ -35,11 +40,17 @@
3540
import io.trino.spi.connector.SchemaTableName;
3641
import io.trino.spi.connector.SchemaTablePrefix;
3742
import io.trino.spi.connector.TableNotFoundException;
43+
import io.trino.spi.expression.ConnectorExpression;
44+
import io.trino.spi.expression.Variable;
3845
import io.trino.spi.predicate.TupleDomain;
3946
import io.trino.spi.statistics.ComputedStatistics;
4047

4148
public class RediSearchMetadata implements ConnectorMetadata {
4249

50+
private static final Logger log = Logger.get(RediSearchMetadata.class);
51+
52+
private static final String SYNTHETIC_COLUMN_NAME_PREFIX = "syntheticColumn";
53+
4354
private final RediSearchSession rediSearchSession;
4455
private final String schemaName;
4556
private final AtomicReference<Runnable> rollbackAction = new AtomicReference<>();
@@ -199,12 +210,14 @@ public Optional<LimitApplicationResult<ConnectorTableHandle>> applyLimit(Connect
199210
return Optional.empty();
200211
}
201212

202-
if (handle.getLimit().isPresent() && handle.getLimit().getAsInt() <= limit) {
213+
if (handle.getLimit().isPresent() && handle.getLimit().getAsLong() <= limit) {
203214
return Optional.empty();
204215
}
205216

206-
return Optional.of(new LimitApplicationResult<>(new RediSearchTableHandle(handle.getSchemaTableName(),
207-
handle.getConstraint(), OptionalInt.of(toIntExact(limit))), true, false));
217+
return Optional.of(new LimitApplicationResult<>(
218+
new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), handle.getConstraint(),
219+
OptionalLong.of(limit), handle.getTermAggregations(), handle.getMetricAggregations()),
220+
true, false));
208221
}
209222

210223
@Override
@@ -218,11 +231,57 @@ public Optional<ConstraintApplicationResult<ConnectorTableHandle>> applyFilter(C
218231
return Optional.empty();
219232
}
220233

221-
handle = new RediSearchTableHandle(handle.getSchemaTableName(), newDomain, handle.getLimit());
234+
handle = new RediSearchTableHandle(handle.getType(), handle.getSchemaTableName(), newDomain, handle.getLimit(),
235+
handle.getTermAggregations(), handle.getMetricAggregations());
222236

223237
return Optional.of(new ConstraintApplicationResult<>(handle, constraint.getSummary(), false));
224238
}
225239

240+
@Override
241+
public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggregation(ConnectorSession session,
242+
ConnectorTableHandle handle, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments,
243+
List<List<ColumnHandle>> groupingSets) {
244+
log.info("applyAggregation aggregates=%s groupingSets=%s", aggregates, groupingSets);
245+
RediSearchTableHandle table = (RediSearchTableHandle) handle;
246+
// Global aggregation is represented by [[]]
247+
verify(!groupingSets.isEmpty(), "No grouping sets provided");
248+
if (!table.getTermAggregations().isEmpty()) {
249+
return Optional.empty();
250+
}
251+
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
252+
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
253+
ImmutableList.Builder<MetricAggregation> metricAggregations = ImmutableList.builder();
254+
ImmutableList.Builder<TermAggregation> termAggregations = ImmutableList.builder();
255+
for (int i = 0; i < aggregates.size(); i++) {
256+
AggregateFunction function = aggregates.get(i);
257+
String colName = SYNTHETIC_COLUMN_NAME_PREFIX + i;
258+
Optional<MetricAggregation> metricAggregation = MetricAggregation.handleAggregation(function, assignments,
259+
colName);
260+
if (metricAggregation.isEmpty()) {
261+
return Optional.empty();
262+
}
263+
RediSearchColumnHandle newColumn = new RediSearchColumnHandle(colName, function.getOutputType(), false);
264+
projections.add(new Variable(colName, function.getOutputType()));
265+
resultAssignments.add(new Assignment(colName, newColumn, function.getOutputType()));
266+
metricAggregations.add(metricAggregation.get());
267+
}
268+
for (ColumnHandle columnHandle : groupingSets.get(0)) {
269+
Optional<TermAggregation> termAggregation = TermAggregation.fromColumnHandle(columnHandle);
270+
if (termAggregation.isEmpty()) {
271+
return Optional.empty();
272+
}
273+
termAggregations.add(termAggregation.get());
274+
}
275+
ImmutableList<MetricAggregation> metrics = metricAggregations.build();
276+
if (metrics.isEmpty()) {
277+
return Optional.empty();
278+
}
279+
RediSearchTableHandle tableHandle = new RediSearchTableHandle(Type.AGGREGATE, table.getSchemaTableName(),
280+
table.getConstraint(), table.getLimit(), termAggregations.build(), metrics);
281+
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(),
282+
resultAssignments.build(), ImmutableMap.of(), false));
283+
}
284+
226285
private void setRollback(Runnable action) {
227286
checkState(rollbackAction.compareAndSet(null, action), "rollback action is already set");
228287
}
Lines changed: 10 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,29 @@
11
package com.redis.trino;
22

33
import static com.google.common.base.Verify.verify;
4-
import static com.redis.trino.TypeUtils.isJsonType;
5-
import static io.airlift.slice.Slices.utf8Slice;
6-
import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse;
7-
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
8-
import static io.trino.spi.type.BigintType.BIGINT;
9-
import static io.trino.spi.type.Chars.truncateToLengthAndTrimSpaces;
10-
import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone;
11-
import static io.trino.spi.type.DateType.DATE;
12-
import static io.trino.spi.type.Decimals.encodeScaledValue;
13-
import static io.trino.spi.type.Decimals.encodeShortScaledValue;
14-
import static io.trino.spi.type.IntegerType.INTEGER;
15-
import static io.trino.spi.type.RealType.REAL;
16-
import static io.trino.spi.type.SmallintType.SMALLINT;
17-
import static io.trino.spi.type.TimeZoneKey.UTC_KEY;
18-
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
19-
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
20-
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND;
21-
import static io.trino.spi.type.TinyintType.TINYINT;
22-
import static java.lang.Float.floatToIntBits;
234
import static java.util.stream.Collectors.toList;
245

256
import java.io.IOException;
267
import java.io.OutputStream;
27-
import java.math.BigDecimal;
28-
import java.time.LocalDate;
29-
import java.time.format.DateTimeFormatter;
308
import java.util.Iterator;
319
import java.util.List;
3210

3311
import com.fasterxml.jackson.core.JsonFactory;
3412
import com.fasterxml.jackson.core.JsonGenerator;
35-
import com.google.common.primitives.SignedBytes;
3613
import com.redis.lettucemod.search.Document;
3714

38-
import io.airlift.slice.Slice;
3915
import io.airlift.slice.SliceOutput;
4016
import io.trino.spi.Page;
4117
import io.trino.spi.PageBuilder;
42-
import io.trino.spi.TrinoException;
4318
import io.trino.spi.block.BlockBuilder;
4419
import io.trino.spi.connector.ConnectorPageSource;
45-
import io.trino.spi.type.CharType;
46-
import io.trino.spi.type.DecimalType;
4720
import io.trino.spi.type.Type;
48-
import io.trino.spi.type.VarcharType;
4921

5022
public class RediSearchPageSource implements ConnectorPageSource {
23+
5124
private static final int ROWS_PER_REQUEST = 1024;
5225

26+
private final RediSearchPageSourceResultWriter writer = new RediSearchPageSourceResultWriter();
5327
private final Iterator<Document<String, String>> cursor;
5428
private final List<String> columnNames;
5529
private final List<Type> columnTypes;
@@ -63,7 +37,7 @@ public RediSearchPageSource(RediSearchSession rediSearchSession, RediSearchTable
6337
List<RediSearchColumnHandle> columns) {
6438
this.columnNames = columns.stream().map(RediSearchColumnHandle::getName).collect(toList());
6539
this.columnTypes = columns.stream().map(RediSearchColumnHandle::getType).collect(toList());
66-
this.cursor = rediSearchSession.execute(tableHandle).iterator();
40+
this.cursor = rediSearchSession.search(tableHandle).iterator();
6741
this.currentDoc = null;
6842
this.pageBuilder = new PageBuilder(columnTypes);
6943
}
@@ -103,7 +77,12 @@ public Page getNextPage() {
10377
pageBuilder.declarePosition();
10478
for (int column = 0; column < columnTypes.size(); column++) {
10579
BlockBuilder output = pageBuilder.getBlockBuilder(column);
106-
appendTo(columnTypes.get(column), currentDoc.get(columnNames.get(column)), output);
80+
String value = currentDoc.get(columnNames.get(column));
81+
if (value == null) {
82+
output.appendNull();
83+
} else {
84+
writer.appendTo(columnTypes.get(column), value, output);
85+
}
10786
}
10887
}
10988

@@ -112,67 +91,12 @@ public Page getNextPage() {
11291
return page;
11392
}
11493

115-
private void appendTo(Type type, String value, BlockBuilder output) {
116-
if (value == null) {
117-
output.appendNull();
118-
return;
119-
}
120-
Class<?> javaType = type.getJavaType();
121-
if (javaType == boolean.class) {
122-
type.writeBoolean(output, Boolean.parseBoolean(value));
123-
} else if (javaType == long.class) {
124-
if (type.equals(BIGINT)) {
125-
type.writeLong(output, Long.parseLong(value));
126-
} else if (type.equals(INTEGER)) {
127-
type.writeLong(output, Integer.parseInt(value));
128-
} else if (type.equals(SMALLINT)) {
129-
type.writeLong(output, Short.parseShort(value));
130-
} else if (type.equals(TINYINT)) {
131-
type.writeLong(output, SignedBytes.checkedCast(Long.parseLong(value)));
132-
} else if (type.equals(REAL)) {
133-
type.writeLong(output, floatToIntBits((Float.parseFloat(value))));
134-
} else if (type instanceof DecimalType) {
135-
type.writeLong(output, encodeShortScaledValue(new BigDecimal(value), ((DecimalType) type).getScale()));
136-
} else if (type.equals(DATE)) {
137-
type.writeLong(output, LocalDate.from(DateTimeFormatter.ISO_DATE.parse(value)).toEpochDay());
138-
} else if (type.equals(TIMESTAMP_MILLIS)) {
139-
type.writeLong(output, Long.parseLong(value) * MICROSECONDS_PER_MILLISECOND);
140-
} else if (type.equals(TIMESTAMP_TZ_MILLIS)) {
141-
type.writeLong(output, packDateTimeWithZone(Long.parseLong(value), UTC_KEY));
142-
} else {
143-
throw new TrinoException(GENERIC_INTERNAL_ERROR,
144-
"Unhandled type for " + javaType.getSimpleName() + ":" + type.getTypeSignature());
145-
}
146-
} else if (javaType == double.class) {
147-
type.writeDouble(output, Double.parseDouble(value));
148-
} else if (javaType == Slice.class) {
149-
writeSlice(output, type, value);
150-
} else {
151-
throw new TrinoException(GENERIC_INTERNAL_ERROR,
152-
"Unhandled type for " + javaType.getSimpleName() + ":" + type.getTypeSignature());
153-
}
154-
}
155-
156-
private void writeSlice(BlockBuilder output, Type type, String value) {
157-
if (type instanceof VarcharType) {
158-
type.writeSlice(output, utf8Slice(value));
159-
} else if (type instanceof CharType) {
160-
type.writeSlice(output, truncateToLengthAndTrimSpaces(utf8Slice(value), ((CharType) type)));
161-
} else if (type instanceof DecimalType) {
162-
type.writeObject(output, encodeScaledValue(new BigDecimal(value), ((DecimalType) type).getScale()));
163-
} else if (isJsonType(type)) {
164-
type.writeSlice(output, jsonParse(utf8Slice(value)));
165-
} else {
166-
throw new TrinoException(GENERIC_INTERNAL_ERROR, "Unhandled type for Slice: " + type.getTypeSignature());
167-
}
168-
}
169-
17094
public static JsonGenerator createJsonGenerator(JsonFactory factory, SliceOutput output) throws IOException {
17195
return factory.createGenerator((OutputStream) output);
17296
}
17397

17498
@Override
17599
public void close() {
176-
100+
// nothing to do
177101
}
178102
}

0 commit comments

Comments
 (0)