diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 51fa468a822a8..d13e6208c005f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -591,10 +591,18 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) if (node.getType().equals(INNER)) { canonicalCriteria.stream() .filter(clause -> clause.getLeft().getType().equals(clause.getRight().getType()) && clause.getLeft().getType().equalValuesAreIdentical()) - .filter(clause -> node.getOutputVariables().contains(clause.getLeft())) + .filter(clause -> node.getOutputVariables().contains(clause.getRight())) .forEach(clause -> map(clause.getRight(), clause.getLeft())); } + List canonicalizedOutput = canonicalizeAndDistinct(node.getOutputVariables()); + if (canonicalCriteria.stream().map(x -> x.getRight().getName()).anyMatch(mapping::containsKey)) { + canonicalizedOutput = ImmutableList.builder() + .addAll(canonicalizedOutput.stream().filter(left.getOutputVariables()::contains).collect(Collectors.toList())) + .addAll(canonicalizedOutput.stream().filter(right.getOutputVariables()::contains).collect(Collectors.toList())) + .build(); + } + return new JoinNode( node.getSourceLocation(), node.getId(), @@ -602,7 +610,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) left, right, canonicalCriteria, - canonicalizeAndDistinct(node.getOutputVariables()), + canonicalizedOutput, canonicalFilter, canonicalLeftHashVariable, canonicalRightHashVariable, diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnaliasSymbolReferences.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnaliasSymbolReferences.java index a843f0b1b9a2b..de09cfd2623e5 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnaliasSymbolReferences.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnaliasSymbolReferences.java @@ -15,12 +15,15 @@ import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; import static com.facebook.presto.spi.plan.JoinType.INNER; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; public class TestUnaliasSymbolReferences @@ -68,4 +71,15 @@ public void testIdenticalValuesCollapseAssignments() .withNumberOfOutputColumns(1) .withExactOutputs("LEFT_BAR"))); } + + @Test + public void testJoinKeyOutput() + { + assertPlan("select o.orderkey, l.quantity from lineitem l join orders o on l.orderkey=o.orderkey", + output( + ImmutableList.of("orderkey", "quantity"), + join(INNER, ImmutableList.of(equiJoinClause("orderkey", "orderkey_0")), + anyTree(tableScan("lineitem", ImmutableMap.of("orderkey", "orderkey", "quantity", "quantity"))), + anyTree(tableScan("orders", ImmutableMap.of("orderkey_0", "orderkey")))))); + } }