Skip to content

Commit ca14e8a

Browse files
authored
fix: DH-18472 correct the ordering of result columns in update_by (#6630)
Previously: ![image](https://github.com/user-attachments/assets/1a3b21d3-5d2c-4a5b-9d21-2d2972c23dd9) After the change: <img width="1817" alt="image" src="https://github.com/user-attachments/assets/8456e9a6-f0d9-4854-a0e6-6ed71f64a37b" />
1 parent b6acfcd commit ca14e8a

File tree

3 files changed

+124
-26
lines changed

3 files changed

+124
-26
lines changed

engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateBy.java

+14-3
Original file line numberDiff line numberDiff line change
@@ -1392,15 +1392,26 @@ public static Table updateBy(@NotNull final QueryTable source,
13921392

13931393
final Map<String, ColumnSource<?>> resultSources = new LinkedHashMap<>(source.getColumnSourceMap());
13941394

1395-
// We have the source table and the row redirection; we can initialize the operators and add the output
1396-
// columns to the result sources
1395+
final Map<String, ColumnSource<?>> unorderedResultSources = new HashMap<>();
1396+
// We have the source table and the row redirection; we can initialize the operators and collect the output
1397+
// columns.
13971398
for (UpdateByWindow win : operatorCollection.windowArr) {
13981399
for (UpdateByOperator op : win.operators) {
13991400
op.initializeSources(source, rowRedirection);
1400-
resultSources.putAll(op.getOutputColumns());
1401+
unorderedResultSources.putAll(op.getOutputColumns());
14011402
}
14021403
}
14031404

1405+
// Add the output result sources to the table column map in the order specified by the updateBy call.
1406+
for (String outputColumnName : operatorCollection.outputColumnNames) {
1407+
final ColumnSource<?> cs = unorderedResultSources.get(outputColumnName);
1408+
if (cs == null) {
1409+
throw new IllegalStateException(
1410+
"Requested output column '" + outputColumnName + "' was not found in operator output");
1411+
}
1412+
resultSources.put(outputColumnName, cs);
1413+
}
1414+
14041415
if (operatorCollection.byColumnNames.length == 0) {
14051416
return LivenessScopeStack.computeEnclosed(() -> {
14061417
final ZeroKeyUpdateByManager zkm = new ZeroKeyUpdateByManager(

engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/UpdateByOperatorFactory.java

+15
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,21 @@ private class OutputColumnVisitor implements UpdateByOperation.Visitor<Void> {
212212
@Override
213213
public Void visit(@NotNull final ColumnUpdateOperation clause) {
214214
final UpdateBySpec spec = clause.spec();
215+
// Need to handle some specs uniquely
216+
if (spec instanceof CumCountWhereSpec) {
217+
outputColumns.add(((CumCountWhereSpec) spec).column().name());
218+
return null;
219+
}
220+
if (spec instanceof RollingCountWhereSpec) {
221+
outputColumns.add(((RollingCountWhereSpec) spec).column().name());
222+
return null;
223+
}
224+
if (spec instanceof RollingFormulaSpec && ((RollingFormulaSpec) spec).paramToken().isEmpty()) {
225+
// The presence of the paramToken indicates that this is a multi-column formula and we have a single
226+
// output column in #selectable()
227+
outputColumns.add(((RollingFormulaSpec) spec).selectable().newColumn().name());
228+
return null;
229+
}
215230
final MatchPair[] pairs =
216231
createColumnsToAddIfMissing(tableDef, parseMatchPairs(clause.columns()), spec, groupByColumns);
217232
for (MatchPair pair : pairs) {

engine/table/src/test/java/io/deephaven/engine/table/impl/updateby/TestUpdateByGeneral.java

+95-23
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import io.deephaven.engine.table.impl.QueryTable;
1313
import io.deephaven.engine.table.impl.UpdateErrorReporter;
1414
import io.deephaven.engine.table.impl.util.AsyncClientErrorNotifier;
15+
import io.deephaven.engine.table.impl.util.ColumnHolder;
1516
import io.deephaven.engine.testutil.ControlledUpdateGraph;
1617
import io.deephaven.engine.testutil.EvalNugget;
1718
import io.deephaven.engine.table.impl.TableDefaults;
@@ -29,18 +30,19 @@
2930
import junit.framework.TestCase;
3031
import org.jetbrains.annotations.NotNull;
3132
import org.junit.After;
33+
import org.junit.Assert;
3234
import org.junit.Before;
3335
import org.junit.Test;
3436
import org.junit.experimental.categories.Category;
3537

3638
import java.time.Duration;
3739
import java.util.*;
3840

41+
import static io.deephaven.api.updateby.UpdateByOperation.*;
3942
import static io.deephaven.engine.testutil.GenerateTableUpdates.generateAppends;
4043
import static io.deephaven.engine.testutil.TstUtils.*;
4144
import static io.deephaven.engine.testutil.testcase.RefreshingTableTestCase.simulateShiftAwareStep;
42-
import static io.deephaven.engine.util.TableTools.col;
43-
import static io.deephaven.engine.util.TableTools.intCol;
45+
import static io.deephaven.engine.util.TableTools.*;
4446
import static io.deephaven.time.DateTimeUtils.MINUTE;
4547

4648
@Category(OutOfBandTest.class)
@@ -138,9 +140,9 @@ protected Table e() {
138140
UpdateByOperation.RollingMin("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
139141
makeOpColNames(columnNamesArray, "_rollmintimerev", "Sym", "ts", "boolCol")),
140142

141-
UpdateByOperation.RollingMax(50, 50,
143+
RollingMax(50, 50,
142144
makeOpColNames(columnNamesArray, "_rollmaxticksrev", "Sym", "ts", "boolCol")),
143-
UpdateByOperation.RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
145+
RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
144146
makeOpColNames(columnNamesArray, "_rollmaxtimerev", "Sym", "ts", "boolCol")),
145147

146148
// Excluding 'bigDecimalCol' because we need fuzzy matching which doesn't exist for BD
@@ -154,8 +156,8 @@ protected Table e() {
154156
UpdateByOperation.Ema(skipControl, "ts", 10 * MINUTE,
155157
makeOpColNames(columnNamesArray, "_ema", "Sym", "ts", "boolCol")),
156158
UpdateByOperation.CumSum(makeOpColNames(columnNamesArray, "_sum", "Sym", "ts")),
157-
UpdateByOperation.CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
158-
UpdateByOperation.CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
159+
CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
160+
CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
159161
UpdateByOperation
160162
.CumProd(makeOpColNames(columnNamesArray, "_prod", "Sym", "ts", "boolCol")));
161163
final UpdateByControl control = UpdateByControl.builder().useRedirection(redirected).build();
@@ -272,38 +274,36 @@ public void testInMemoryColumn() {
272274
final Collection<? extends UpdateByOperation> clauses = List.of(
273275
UpdateByOperation.Fill(),
274276

275-
UpdateByOperation.RollingGroup(50, 50,
277+
RollingGroup(50, 50,
276278
makeOpColNames(columnNamesArray, "_rollgroupfwdrev", "Sym", "ts")),
277-
UpdateByOperation.RollingGroup("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
279+
RollingGroup("ts", Duration.ofMinutes(5), Duration.ofMinutes(5),
278280
makeOpColNames(columnNamesArray, "_rollgrouptimefwdrev", "Sym", "ts")),
279281

280-
UpdateByOperation.RollingSum(100, 0,
282+
RollingSum(100, 0,
281283
makeOpColNames(columnNamesArray, "_rollsumticksrev", "Sym", "ts", "boolCol")),
282-
UpdateByOperation.RollingSum("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
284+
RollingSum("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
283285
makeOpColNames(columnNamesArray, "_rollsumtimerev", "Sym", "ts", "boolCol")),
284286

285-
UpdateByOperation.RollingAvg(100, 0,
287+
RollingAvg(100, 0,
286288
makeOpColNames(columnNamesArray, "_rollavgticksrev", "Sym", "ts", "boolCol")),
287-
UpdateByOperation.RollingAvg("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
289+
RollingAvg("ts", Duration.ofMinutes(15), Duration.ofMinutes(0),
288290
makeOpColNames(columnNamesArray, "_rollavgtimerev", "Sym", "ts", "boolCol")),
289291

290-
UpdateByOperation.RollingMin(100, 0,
292+
RollingMin(100, 0,
291293
makeOpColNames(columnNamesArray, "_rollminticksrev", "Sym", "ts", "boolCol")),
292-
UpdateByOperation.RollingMin("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
294+
RollingMin("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
293295
makeOpColNames(columnNamesArray, "_rollmintimerev", "Sym", "ts", "boolCol")),
294296

295-
UpdateByOperation.RollingMax(100, 0,
297+
RollingMax(100, 0,
296298
makeOpColNames(columnNamesArray, "_rollmaxticksrev", "Sym", "ts", "boolCol")),
297-
UpdateByOperation.RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
299+
RollingMax("ts", Duration.ofMinutes(5), Duration.ofMinutes(0),
298300
makeOpColNames(columnNamesArray, "_rollmaxtimerev", "Sym", "ts", "boolCol")),
299301

300-
UpdateByOperation.Ema(skipControl, "ts", 10 * MINUTE,
301-
makeOpColNames(columnNamesArray, "_ema", "Sym", "ts", "boolCol")),
302-
UpdateByOperation.CumSum(makeOpColNames(columnNamesArray, "_sum", "Sym", "ts")),
303-
UpdateByOperation.CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
304-
UpdateByOperation.CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
305-
UpdateByOperation
306-
.CumProd(makeOpColNames(columnNamesArray, "_prod", "Sym", "ts", "boolCol")));
302+
Ema(skipControl, "ts", 10 * MINUTE, makeOpColNames(columnNamesArray, "_ema", "Sym", "ts", "boolCol")),
303+
CumSum(makeOpColNames(columnNamesArray, "_sum", "Sym", "ts")),
304+
CumMin(makeOpColNames(columnNamesArray, "_min", "boolCol")),
305+
CumMax(makeOpColNames(columnNamesArray, "_max", "boolCol")),
306+
CumProd(makeOpColNames(columnNamesArray, "_prod", "Sym", "ts", "boolCol")));
307307
final UpdateByControl control = UpdateByControl.builder().useRedirection(false).build();
308308

309309
final Table table = result.t.updateBy(control, clauses, ColumnName.from("Sym"));
@@ -322,4 +322,76 @@ public void run() {
322322
}
323323
});
324324
}
325+
326+
@Test
327+
public void testResultColumnOrdering() {
328+
final Table source = emptyTable(5).update("X=ii");
329+
330+
final ColumnHolder<?> x = longCol("X", 0, 1, 2, 3, 4);
331+
final ColumnHolder<?> cumMin = longCol("cumMin", 0, 0, 0, 0, 0);
332+
final ColumnHolder<?> cumMax = longCol("cumMax", 0, 1, 2, 3, 4);
333+
final ColumnHolder<?> rollingMin = longCol("rollingMin", 0, 0, 1, 2, 3);
334+
final ColumnHolder<?> rollingMax = longCol("rollingMax", 0, 1, 2, 3, 4);
335+
336+
final Table result_1 = source.updateBy(List.of(
337+
CumMin("cumMin=X"),
338+
CumMax("cumMax=X"),
339+
RollingMin(2, "rollingMin=X"),
340+
RollingMax(2, "rollingMax=X")));
341+
final Table expected_1 = TableTools.newTable(x, cumMin, cumMax, rollingMin, rollingMax);
342+
Assert.assertEquals("", diff(result_1, expected_1, 10));
343+
344+
final Table result_2 = source.updateBy(List.of(
345+
CumMax("cumMax=X"),
346+
CumMin("cumMin=X"),
347+
RollingMax(2, "rollingMax=X"),
348+
RollingMin(2, "rollingMin=X")));
349+
final Table expected_2 = TableTools.newTable(x, cumMax, cumMin, rollingMax, rollingMin);
350+
Assert.assertEquals("", diff(result_2, expected_2, 10));
351+
352+
final Table result_3 = source.updateBy(List.of(
353+
RollingMin(2, "rollingMin=X"),
354+
RollingMax(2, "rollingMax=X"),
355+
CumMin("cumMin=X"),
356+
CumMax("cumMax=X")));
357+
final Table expected_3 = TableTools.newTable(x, rollingMin, rollingMax, cumMin, cumMax);
358+
Assert.assertEquals("", diff(result_3, expected_3, 10));
359+
360+
final Table result_4 = source.updateBy(List.of(
361+
RollingMax(2, "rollingMax=X"),
362+
RollingMin(2, "rollingMin=X"),
363+
CumMax("cumMax=X"),
364+
CumMin("cumMin=X")));
365+
final Table expected_4 = TableTools.newTable(x, rollingMax, rollingMin, cumMax, cumMin);
366+
Assert.assertEquals("", diff(result_4, expected_4, 10));
367+
368+
final Table result_5 = source.updateBy(List.of(
369+
CumMin("cumMin=X"),
370+
RollingMin(2, "rollingMin=X"),
371+
CumMax("cumMax=X"),
372+
RollingMax(2, "rollingMax=X")));
373+
final Table expected_5 = TableTools.newTable(x, cumMin, rollingMin, cumMax, rollingMax);
374+
Assert.assertEquals("", diff(result_5, expected_5, 10));
375+
376+
// Trickiest one, since we internally combine groupBy operations.
377+
final Table source_2 = source.update("Y=ii % 2");
378+
final Table result_6 = source_2.updateBy(List.of(
379+
CumMin("cumMin=X"),
380+
RollingGroup(2, "rollingGroupY=Y"),
381+
RollingMin(2, "rollingMin=X"),
382+
CumMax("cumMax=X"),
383+
RollingGroup(2, "rollingGroupX=X"),
384+
RollingMax(2, "rollingMax=X")));
385+
386+
Assert.assertArrayEquals(result_6.getDefinition().getColumnNamesArray(),
387+
new String[] {
388+
"X",
389+
"Y",
390+
"cumMin",
391+
"rollingGroupY",
392+
"rollingMin",
393+
"cumMax",
394+
"rollingGroupX",
395+
"rollingMax"});
396+
}
325397
}

0 commit comments

Comments
 (0)