diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java index ebd9e3d31a543..30ec90f8ff245 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/rewrite/NativeExecutionTypeRewrite.java @@ -30,6 +30,7 @@ import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.tree.ArrayConstructor; import com.facebook.presto.sql.tree.Cast; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.LongLiteral; @@ -37,10 +38,13 @@ import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.Parameter; import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.Statement; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.WhenClause; import com.google.common.collect.ImmutableList; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -121,28 +125,74 @@ protected Node visitCast(Cast node, Void context) catch (IllegalArgumentException | UnknownTypeException e) { throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType()); } - return node; + + Expression expression = node.getExpression(); + expression = peelIfDereferenceExpression(expression); + return new Cast(node.getLocation(), expression, node.getType(), node.isSafe(), node.isTypeOnly()); } @Override protected Node visitFunctionCall(FunctionCall node, Void context) { - if (isValidEnumKeyFunctionCall(node)) { - Cast argument = (Cast) node.getArguments().get(0); - Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(argument.getType())); - if (argumentType instanceof TypeWithName) { - // Peel user defined type name. - argumentType = ((TypeWithName) argumentType).getType(); + QualifiedName functionName = node.getName(); + List arguments = node.getArguments(); + if (node.getArguments().size() == 1) { + Expression argument = arguments.get(0); + if (isValidEnumKeyFunctionCall(node)) { + functionName = QualifiedName.of(FUNCTION_ELEMENT_AT); + Type argumentType; + if (argument instanceof Cast) { + argumentType = functionAndTypeResolver.getType(parseTypeSignature(((Cast) argument).getType())); + } + else if (argument instanceof DereferenceExpression) { + argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString())); + } + else { + // ENUM_KEY is only supported with Cast or DereferenceExpression for now. + // Return node without rewriting. + return node; + } + argument = rewriteIfCastOrDereferenceExpression(argument); + if (argumentType instanceof TypeWithName) { + // Peel user defined type name. + argumentType = ((TypeWithName) argumentType).getType(); + if (argumentType instanceof EnumType) { + arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument); + } + } } - if (argumentType instanceof EnumType) { - // Convert enum_key to element_at. - List arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument.getExpression()); - return node.getLocation().isPresent() - ? new FunctionCall(node.getLocation().get(), QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) - : new FunctionCall(QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); + else { + arguments.set(0, rewriteIfCastOrDereferenceExpression(argument)); } } - return super.visitFunctionCall(node, context); + return node.getLocation().isPresent() + ? new FunctionCall(node.getLocation().get(), functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments) + : new FunctionCall(functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments); + } + + @Override + protected Node visitSimpleCaseExpression(SimpleCaseExpression node, Void context) + { + // SimpleCaseExpression has 3 parts: operand, whenClauses, and defaultValue. + Expression operand = node.getOperand(); + List whenClauses = node.getWhenClauses(); + Optional defaultValue = node.getDefaultValue(); + + // Rewrite each component. + operand = rewriteIfCastOrDereferenceExpression(node.getOperand()); + List newWhenClauses = new ArrayList<>(); + for (WhenClause when : whenClauses) { + Expression whenOperand = rewriteIfCastOrDereferenceExpression(when.getOperand()); + Expression result = rewriteIfCastOrDereferenceExpression(when.getResult()); + newWhenClauses.add(new WhenClause(whenOperand, result)); + } + if (defaultValue.isPresent()) { + defaultValue = Optional.of(rewriteIfCastOrDereferenceExpression(defaultValue.get())); + } + + return node.getLocation().isPresent() + ? new SimpleCaseExpression(node.getLocation().get(), operand, newWhenClauses, defaultValue) + : new SimpleCaseExpression(operand, newWhenClauses, defaultValue); } @Override @@ -154,8 +204,7 @@ protected Node visitExpression(Expression node, Void context) private boolean isValidEnumKeyFunctionCall(FunctionCall node) { return node.getName().equals(QualifiedName.of(FUNCTION_ENUM_KEY)) - && node.getArguments().size() == 1 - && node.getArguments().get(0) instanceof Cast; + && node.getArguments().size() == 1; } private Expression convertEnumTypeToMapExpression(Type type) @@ -183,5 +232,56 @@ private Expression convertEnumTypeToMapExpression(Type type) new ArrayConstructor(keys.build()), new ArrayConstructor(values.build()))); } + + private Expression convertEnumTypeToLiteral(DereferenceExpression key, Type type) + { + String enumKey = key.getField().getValue().toUpperCase(); + if (type instanceof BigintEnumType) { + Map enumMap = ((EnumType) type).getEnumMap(); + Long enumValue = enumMap.get(enumKey); + if (enumValue == null) { + throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum BigintEnum"); + } + return new LongLiteral(enumValue.toString()); + } + else if (type instanceof VarcharEnumType) { + Map enumMap = ((EnumType) type).getEnumMap(); + String enumValue = enumMap.get(enumKey); + if (enumValue == null) { + throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum VarcharEnum"); + } + return new StringLiteral(enumValue); + } + return key; + } + + private Expression rewriteIfCastOrDereferenceExpression(Expression argument) + { + if (argument instanceof Cast) { + argument = (Expression) visitCast((Cast) argument, null); + } + return peelIfDereferenceExpression(argument); + } + + private Expression peelIfDereferenceExpression(Expression argument) + { + if (argument instanceof DereferenceExpression) { + try { + DereferenceExpression arg = (DereferenceExpression) argument; + Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(arg.getBase().toString())); + if (argumentType instanceof TypeWithName) { + argumentType = ((TypeWithName) argumentType).getType(); + if (argumentType instanceof EnumType) { + return convertEnumTypeToLiteral(arg, argumentType); + } + } + } + catch (IllegalArgumentException | UnknownTypeException e) { + // Returns the original expression if rewrite fails. + return argument; + } + } + return argument; + } } }