From aa244bef115af4d4f3c72e70354299bbc0d2cf01 Mon Sep 17 00:00:00 2001 From: Heidi Han Date: Mon, 19 May 2025 09:59:23 -0700 Subject: [PATCH 1/4] Extend visitCast and visitFunctionCall for differeng arguments containing custom types --- .../rewrite/NativeExecutionTypeRewrite.java | 131 +++++++++++++++--- 1 file changed, 115 insertions(+), 16 deletions(-) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java index ebd9e3d31a543..8ec6a183820f0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java @@ -30,6 +30,7 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.ArrayConstructor; import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LongLiteral; @@ -37,10 +38,13 @@ import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.WhenClause; import com.google.common.collect.ImmutableList; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -121,28 +125,73 @@ protected Node visitCast(Cast node, Void context) catch (IllegalArgumentException | UnknownTypeException e) { throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType()); } - return node; + + Expression expression = node.getExpression(); + expression = peelIfDereferenceExpression(expression); + return new Cast(node.getLocation(), expression, node.getType(), node.isSafe(), node.isTypeOnly()); } @Override protected Node visitFunctionCall(FunctionCall node, Void context) { - if (isValidEnumKeyFunctionCall(node)) { - Cast argument = (Cast) node.getArguments().get(0); - Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(argument.getType())); - if (argumentType instanceof TypeWithName) { - // Peel user defined type name. - argumentType = ((TypeWithName) argumentType).getType(); + QualifiedName functionName = node.getName(); + List arguments = node.getArguments(); + if (node.getArguments().size() == 1) { + Expression argument = arguments.get(0); + if (isValidEnumKeyFunctionCall(node)) { + functionName = QualifiedName.of(FUNCTION_ELEMENT_AT); + Type argumentType; + if (argument instanceof Cast) { + argumentType = functionAndTypeResolver.getType(parseTypeSignature(((Cast) argument).getType())); + } + else if (argument instanceof DereferenceExpression) { + argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString())); + } + else { + // ENUM_KEY is only supported with Cast or DereferenceExpression for now, return node without rewriting + return node; + } + argument = rewriteIfCastOrDereferenceExpression(argument); + if (argumentType instanceof TypeWithName) { + // Peel user defined type name. + argumentType = ((TypeWithName) argumentType).getType(); + if (argumentType instanceof EnumType) { + arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument); + } + } } - if (argumentType instanceof EnumType) { - // Convert enum_key to element_at. - List arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument.getExpression()); - return node.getLocation().isPresent() - ? new FunctionCall(node.getLocation().get(), QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) - : new FunctionCall(QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); + else { + arguments.set(0, rewriteIfCastOrDereferenceExpression(argument)); } } - return super.visitFunctionCall(node, context); + return node.getLocation().isPresent() + ? new FunctionCall(node.getLocation().get(), functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) + : new FunctionCall(functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); + } + + @Override + protected Node visitSimpleCaseExpression(SimpleCaseExpression node, Void context) + { + // SimpleCaseExpression has 3 parts: operand, whenClauses, and defaultValue. + Expression operand = node.getOperand(); + List whenClauses = node.getWhenClauses(); + Optional defaultValue = node.getDefaultValue(); + + // Rewrite each component + operand = rewriteIfCastOrDereferenceExpression(node.getOperand()); + List newWhenClauses = new ArrayList<>(); + for (WhenClause when : whenClauses) { + Expression whenOperand = rewriteIfCastOrDereferenceExpression(when.getOperand()); + Expression result = rewriteIfCastOrDereferenceExpression(when.getResult()); + newWhenClauses.add(new WhenClause(whenOperand, result)); + } + if (defaultValue.isPresent()) { + defaultValue = Optional.of(rewriteIfCastOrDereferenceExpression(defaultValue.get())); + } + + return node.getLocation().isPresent() + ? new SimpleCaseExpression(node.getLocation().get(), operand, newWhenClauses, defaultValue) + : new SimpleCaseExpression(operand, newWhenClauses, defaultValue); } @Override @@ -154,8 +203,7 @@ protected Node visitExpression(Expression node, Void context) private boolean isValidEnumKeyFunctionCall(FunctionCall node) { return node.getName().equals(QualifiedName.of(FUNCTION_ENUM_KEY)) - && node.getArguments().size() == 1 - && node.getArguments().get(0) instanceof Cast; + && node.getArguments().size() == 1; } private Expression convertEnumTypeToMapExpression(Type type) @@ -183,5 +231,56 @@ private Expression convertEnumTypeToMapExpression(Type type) new ArrayConstructor(keys.build()), new ArrayConstructor(values.build()))); } + + private Expression convertEnumTypeToLiteral(DereferenceExpression key, Type type) + { + String enumKey = key.getField().getValue().toUpperCase(); + if (type instanceof BigintEnumType) { + Map enumMap = ((EnumType) type).getEnumMap(); + Long enumValue = enumMap.get(enumKey); + if (enumValue == null) { + throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum BigintEnum"); + } + return new LongLiteral(enumValue.toString()); + } + else if (type instanceof VarcharEnumType) { + Map enumMap = ((EnumType) type).getEnumMap(); + String enumValue = enumMap.get(enumKey); + if (enumValue == null) { + throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum VarcharEnum"); + } + return new StringLiteral(enumValue); + } + return key; + } + + private Expression rewriteIfCastOrDereferenceExpression(Expression argument) + { + if (argument instanceof Cast) { + argument = (Expression) visitCast((Cast) argument, null); + } + return peelIfDereferenceExpression(argument); + } + + private Expression peelIfDereferenceExpression(Expression argument) + { + if (argument instanceof DereferenceExpression) { + try { + DereferenceExpression arg = (DereferenceExpression) argument; + Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(arg.getBase().toString())); + if (argumentType instanceof TypeWithName) { + argumentType = ((TypeWithName) argumentType).getType(); + if (argumentType instanceof EnumType) { + return convertEnumTypeToLiteral(arg, argumentType); + } + } + } + catch (IllegalArgumentException | UnknownTypeException e) { + // return the original expression if rewrite fails + return argument; + } + } + return argument; + } } } From b86dbece876c38f9d20773f3e061711552be3bfc Mon Sep 17 00:00:00 2001 From: Heidi Han Date: Tue, 20 May 2025 13:47:08 -0700 Subject: [PATCH 2/4] Reformat comments --- .../presto/sql/rewrite/NativeExecutionTypeRewrite.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java index 8ec6a183820f0..30ec90f8ff245 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java @@ -148,7 +148,8 @@ else if (argument instanceof DereferenceExpression) { argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString())); } else { - // ENUM_KEY is only supported with Cast or DereferenceExpression for now, return node without rewriting + // ENUM_KEY is only supported with Cast or DereferenceExpression for now. + // Return node without rewriting. return node; } argument = rewriteIfCastOrDereferenceExpression(argument); @@ -177,7 +178,7 @@ protected Node visitSimpleCaseExpression(SimpleCaseExpression node, Void context List whenClauses = node.getWhenClauses(); Optional defaultValue = node.getDefaultValue(); - // Rewrite each component + // Rewrite each component. operand = rewriteIfCastOrDereferenceExpression(node.getOperand()); List newWhenClauses = new ArrayList<>(); for (WhenClause when : whenClauses) { @@ -276,7 +277,7 @@ private Expression peelIfDereferenceExpression(Expression argument) } } catch (IllegalArgumentException | UnknownTypeException e) { - // return the original expression if rewrite fails + // Returns the original expression if rewrite fails. return argument; } } From b52f141048277c57be2922293c3d4ff27587701e Mon Sep 17 00:00:00 2001 From: Heidi Han Date: Thu, 29 May 2025 10:13:33 -0700 Subject: [PATCH 3/4] move rewrite logic to ExpressionRewriter --- .../rewrite/NativeExecutionTypeRewrite.java | 231 +++++++++--------- 1 file changed, 110 insertions(+), 121 deletions(-) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java index 30ec90f8ff245..d10ea2de36a8d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.rewrite; +import com.facebook.airlift.log.Logger; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.UnknownTypeException; @@ -32,19 +33,18 @@ import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LongLiteral; import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.QualifiedName; -import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.StringLiteral; -import com.facebook.presto.sql.tree.WhenClause; import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -70,6 +70,8 @@ final class NativeExecutionTypeRewrite implements StatementRewrite.Rewrite { + private static final Logger LOG = Logger.get(ExpressionRewriter.class); + private static final String FUNCTION_ENUM_KEY = "enum_key"; private static final String FUNCTION_ELEMENT_AT = "element_at"; private static final String FUNCTION_MAP = "map"; @@ -94,19 +96,76 @@ public Statement rewrite( return node; } - private static final class Rewriter - extends DefaultTreeRewriter + public static Expression rewriteEnumExpressions(Expression expression, FunctionAndTypeResolver functionAndTypeResolver) + { + return ExpressionTreeRewriter.rewriteWith(new EnumExpressionRewriter(functionAndTypeResolver), expression); + } + + private static class EnumExpressionRewriter + extends ExpressionRewriter { private final FunctionAndTypeResolver functionAndTypeResolver; - public Rewriter(FunctionAndTypeResolver functionAndTypeResolver) + public EnumExpressionRewriter(FunctionAndTypeResolver functionAndTypeResolver) { - this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null"); + this.functionAndTypeResolver = functionAndTypeResolver; + } + + private Expression convertEnumTypeToLiteral(DereferenceExpression key, Type type) + { + String enumKey = key.getField().getValue().toUpperCase(); + if (type instanceof BigintEnumType) { + Map enumMap = ((EnumType) type).getEnumMap(); + Long enumValue = enumMap.get(enumKey); + if (enumValue == null) { + throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum BigintEnum"); + } + return new LongLiteral(enumValue.toString()); + } + else if (type instanceof VarcharEnumType) { + Map enumMap = ((EnumType) type).getEnumMap(); + String enumValue = enumMap.get(enumKey); + if (enumValue == null) { + throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum VarcharEnum"); + } + return new StringLiteral(enumValue); + } + return key; } @Override - protected Node visitCast(Cast node, Void context) + public Expression rewriteExpression(Expression expression, Void context, ExpressionTreeRewriter treeRewriter) { + return treeRewriter.defaultRewrite(expression, null); + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + try { + Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(node.getBase().toString())); + if (argumentType instanceof TypeWithName) { + argumentType = ((TypeWithName) argumentType).getType(); + if (argumentType instanceof EnumType) { + return convertEnumTypeToLiteral(node, argumentType); + } + } + } + catch (IllegalArgumentException | UnknownTypeException e) { + // Returns the original expression if rewrite fails. + LOG.warn(e.getMessage()); + return node; + } + return node; + } + + @Override + public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter treeRewriter) + { + // Rewrite any enum types to their values. + node = treeRewriter.defaultRewrite(node, null); + + // Rewrite type to base type. try { Type type = functionAndTypeResolver.getType(parseTypeSignature(node.getType())); if (type instanceof TypeWithName) { @@ -125,79 +184,6 @@ protected Node visitCast(Cast node, Void context) catch (IllegalArgumentException | UnknownTypeException e) { throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType()); } - - Expression expression = node.getExpression(); - expression = peelIfDereferenceExpression(expression); - return new Cast(node.getLocation(), expression, node.getType(), node.isSafe(), node.isTypeOnly()); - } - - @Override - protected Node visitFunctionCall(FunctionCall node, Void context) - { - QualifiedName functionName = node.getName(); - List arguments = node.getArguments(); - if (node.getArguments().size() == 1) { - Expression argument = arguments.get(0); - if (isValidEnumKeyFunctionCall(node)) { - functionName = QualifiedName.of(FUNCTION_ELEMENT_AT); - Type argumentType; - if (argument instanceof Cast) { - argumentType = functionAndTypeResolver.getType(parseTypeSignature(((Cast) argument).getType())); - } - else if (argument instanceof DereferenceExpression) { - argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString())); - } - else { - // ENUM_KEY is only supported with Cast or DereferenceExpression for now. - // Return node without rewriting. - return node; - } - argument = rewriteIfCastOrDereferenceExpression(argument); - if (argumentType instanceof TypeWithName) { - // Peel user defined type name. - argumentType = ((TypeWithName) argumentType).getType(); - if (argumentType instanceof EnumType) { - arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument); - } - } - } - else { - arguments.set(0, rewriteIfCastOrDereferenceExpression(argument)); - } - } - return node.getLocation().isPresent() - ? new FunctionCall(node.getLocation().get(), functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) - : new FunctionCall(functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); - } - - @Override - protected Node visitSimpleCaseExpression(SimpleCaseExpression node, Void context) - { - // SimpleCaseExpression has 3 parts: operand, whenClauses, and defaultValue. - Expression operand = node.getOperand(); - List whenClauses = node.getWhenClauses(); - Optional defaultValue = node.getDefaultValue(); - - // Rewrite each component. - operand = rewriteIfCastOrDereferenceExpression(node.getOperand()); - List newWhenClauses = new ArrayList<>(); - for (WhenClause when : whenClauses) { - Expression whenOperand = rewriteIfCastOrDereferenceExpression(when.getOperand()); - Expression result = rewriteIfCastOrDereferenceExpression(when.getResult()); - newWhenClauses.add(new WhenClause(whenOperand, result)); - } - if (defaultValue.isPresent()) { - defaultValue = Optional.of(rewriteIfCastOrDereferenceExpression(defaultValue.get())); - } - - return node.getLocation().isPresent() - ? new SimpleCaseExpression(node.getLocation().get(), operand, newWhenClauses, defaultValue) - : new SimpleCaseExpression(operand, newWhenClauses, defaultValue); - } - - @Override - protected Node visitExpression(Expression node, Void context) - { return node; } @@ -232,56 +218,59 @@ private Expression convertEnumTypeToMapExpression(Type type) new ArrayConstructor(keys.build()), new ArrayConstructor(values.build()))); } - - private Expression convertEnumTypeToLiteral(DereferenceExpression key, Type type) + @Override + public Expression rewriteFunctionCall(FunctionCall node, Void context, ExpressionTreeRewriter treeRewriter) { - String enumKey = key.getField().getValue().toUpperCase(); - if (type instanceof BigintEnumType) { - Map enumMap = ((EnumType) type).getEnumMap(); - Long enumValue = enumMap.get(enumKey); - if (enumValue == null) { - throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum BigintEnum"); + QualifiedName functionName = node.getName(); + List arguments = node.getArguments(); + if (isValidEnumKeyFunctionCall(node)) { + Expression argument = arguments.get(0); + functionName = QualifiedName.of(FUNCTION_ELEMENT_AT); + Type argumentType; + if (argument instanceof Cast) { + argumentType = functionAndTypeResolver.getType(parseTypeSignature(((Cast) argument).getType())); } - return new LongLiteral(enumValue.toString()); - } - else if (type instanceof VarcharEnumType) { - Map enumMap = ((EnumType) type).getEnumMap(); - String enumValue = enumMap.get(enumKey); - if (enumValue == null) { - throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum VarcharEnum"); + else if (argument instanceof DereferenceExpression) { + argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString())); } - return new StringLiteral(enumValue); + else { + // ENUM_KEY is only supported with Cast or DereferenceExpression for now. + // Return node without rewriting. + return node; + } + if (argumentType instanceof TypeWithName) { + // Rewrite ENUM_KEY(EnumType) -> ELEMENT_AT(MAP(, VARCHAR)) + argumentType = ((TypeWithName) argumentType).getType(); + Expression enumMapExpression = convertEnumTypeToMapExpression(argumentType); + Expression enumValue = treeRewriter.rewrite(argument, null); + if (argumentType instanceof EnumType) { + arguments = ImmutableList.of(enumMapExpression, enumValue); + } + } + } else { + node = treeRewriter.defaultRewrite(node, null); } - return key; + + return node.getLocation().isPresent() + ? new FunctionCall(node.getLocation().get(), functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) + : new FunctionCall(functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); } + } + + private static final class Rewriter + extends DefaultTreeRewriter + { + private final FunctionAndTypeResolver functionAndTypeResolver; - private Expression rewriteIfCastOrDereferenceExpression(Expression argument) + public Rewriter(FunctionAndTypeResolver functionAndTypeResolver) { - if (argument instanceof Cast) { - argument = (Expression) visitCast((Cast) argument, null); - } - return peelIfDereferenceExpression(argument); + this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null"); } - private Expression peelIfDereferenceExpression(Expression argument) + @Override + protected Node visitExpression(Expression node, Void context) { - if (argument instanceof DereferenceExpression) { - try { - DereferenceExpression arg = (DereferenceExpression) argument; - Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(arg.getBase().toString())); - if (argumentType instanceof TypeWithName) { - argumentType = ((TypeWithName) argumentType).getType(); - if (argumentType instanceof EnumType) { - return convertEnumTypeToLiteral(arg, argumentType); - } - } - } - catch (IllegalArgumentException | UnknownTypeException e) { - // Returns the original expression if rewrite fails. - return argument; - } - } - return argument; + return rewriteEnumExpressions(node, this.functionAndTypeResolver); } } } From 0dc7a534dae4acea8b495910c556ea6550d68ea4 Mon Sep 17 00:00:00 2001 From: Heidi Han Date: Thu, 29 May 2025 11:00:41 -0700 Subject: [PATCH 4/4] Formatting --- .../presto/sql/rewrite/NativeExecutionTypeRewrite.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java index d10ea2de36a8d..7e479b7c3fff9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java @@ -102,7 +102,7 @@ public static Expression rewriteEnumExpressions(Expression expression, FunctionA } private static class EnumExpressionRewriter - extends ExpressionRewriter + extends ExpressionRewriter { private final FunctionAndTypeResolver functionAndTypeResolver; @@ -246,8 +246,9 @@ else if (argument instanceof DereferenceExpression) { if (argumentType instanceof EnumType) { arguments = ImmutableList.of(enumMapExpression, enumValue); } - } - } else { + } + } + else { node = treeRewriter.defaultRewrite(node, null); }