Skip to content

Extend native type rewrite for cast, function call, and simple case expression #25142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.ArrayConstructor;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.Node;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.Parameter;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SimpleCaseExpression;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.sql.tree.StringLiteral;
import com.facebook.presto.sql.tree.WhenClause;
import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -121,28 +125,74 @@ protected Node visitCast(Cast node, Void context)
catch (IllegalArgumentException | UnknownTypeException e) {
throw new SemanticException(TYPE_MISMATCH, node, "Unknown type: " + node.getType());
}
return node;

Expression expression = node.getExpression();
expression = peelIfDereferenceExpression(expression);
return new Cast(node.getLocation(), expression, node.getType(), node.isSafe(), node.isTypeOnly());
}

@Override
protected Node visitFunctionCall(FunctionCall node, Void context)
{
if (isValidEnumKeyFunctionCall(node)) {
Cast argument = (Cast) node.getArguments().get(0);
Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(argument.getType()));
if (argumentType instanceof TypeWithName) {
// Peel user defined type name.
argumentType = ((TypeWithName) argumentType).getType();
QualifiedName functionName = node.getName();
List<Expression> arguments = node.getArguments();
if (node.getArguments().size() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we specifically check for 1 argument? why not visit all the arguments and rewrite as needed?

Expression argument = arguments.get(0);
if (isValidEnumKeyFunctionCall(node)) {
functionName = QualifiedName.of(FUNCTION_ELEMENT_AT);
Type argumentType;
if (argument instanceof Cast) {
argumentType = functionAndTypeResolver.getType(parseTypeSignature(((Cast) argument).getType()));
}
else if (argument instanceof DereferenceExpression) {
argumentType = functionAndTypeResolver.getType(parseTypeSignature(((DereferenceExpression) argument).getBase().toString()));
}
else {
// ENUM_KEY is only supported with Cast or DereferenceExpression for now.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if there is a reason why, but based on all the cases I've encountered, custom types were either a cast expression or dereference expression

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's most common, but not sure it's guaranteed. The rewriter should be able to support all query shapes. If we leave an enum_key function in, queries will fail during execution.

// Return node without rewriting.
return node;
}
argument = rewriteIfCastOrDereferenceExpression(argument);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of this custom rewriteIfCastOrdereferenceExpression, why don't we visit all the argument s

if (argumentType instanceof TypeWithName) {
// Peel user defined type name.
argumentType = ((TypeWithName) argumentType).getType();
if (argumentType instanceof EnumType) {
arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument);
}
}
}
if (argumentType instanceof EnumType) {
// Convert enum_key to element_at.
List<Expression> arguments = ImmutableList.of(convertEnumTypeToMapExpression(argumentType), argument.getExpression());
return node.getLocation().isPresent()
? new FunctionCall(node.getLocation().get(), QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments)
: new FunctionCall(QualifiedName.of(FUNCTION_ELEMENT_AT), node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments);
else {
arguments.set(0, rewriteIfCastOrDereferenceExpression(argument));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arguments will be immutable. Need to make a new list.

}
}
return super.visitFunctionCall(node, context);
return node.getLocation().isPresent()
? new FunctionCall(node.getLocation().get(), functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments)
: new FunctionCall(functionName, node.getWindow(), node.getFilter(), node.getOrderBy(), node.isDistinct(), node.isIgnoreNulls(), arguments);
}

@Override
protected Node visitSimpleCaseExpression(SimpleCaseExpression node, Void context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think if you use an ExpressionTreeRewriter for rewriting expressions generally, you won't need to add custom logic for handling this or other expressions that might have arguments that have enum functions. And it will work better because it can be nested as many layers deep as someone cares to write.

{
// SimpleCaseExpression has 3 parts: operand, whenClauses, and defaultValue.
Expression operand = node.getOperand();
List<WhenClause> whenClauses = node.getWhenClauses();
Optional<Expression> defaultValue = node.getDefaultValue();

// Rewrite each component.
operand = rewriteIfCastOrDereferenceExpression(node.getOperand());
List<WhenClause> newWhenClauses = new ArrayList<>();
for (WhenClause when : whenClauses) {
Expression whenOperand = rewriteIfCastOrDereferenceExpression(when.getOperand());
Expression result = rewriteIfCastOrDereferenceExpression(when.getResult());
newWhenClauses.add(new WhenClause(whenOperand, result));
}
if (defaultValue.isPresent()) {
defaultValue = Optional.of(rewriteIfCastOrDereferenceExpression(defaultValue.get()));
}

return node.getLocation().isPresent()
? new SimpleCaseExpression(node.getLocation().get(), operand, newWhenClauses, defaultValue)
: new SimpleCaseExpression(operand, newWhenClauses, defaultValue);
}

@Override
Expand All @@ -154,8 +204,7 @@ protected Node visitExpression(Expression node, Void context)
private boolean isValidEnumKeyFunctionCall(FunctionCall node)
{
return node.getName().equals(QualifiedName.of(FUNCTION_ENUM_KEY))
&& node.getArguments().size() == 1
&& node.getArguments().get(0) instanceof Cast;
&& node.getArguments().size() == 1;
}

private Expression convertEnumTypeToMapExpression(Type type)
Expand Down Expand Up @@ -183,5 +232,56 @@ private Expression convertEnumTypeToMapExpression(Type type)
new ArrayConstructor(keys.build()),
new ArrayConstructor(values.build())));
}

private Expression convertEnumTypeToLiteral(DereferenceExpression key, Type type)
{
String enumKey = key.getField().getValue().toUpperCase();
if (type instanceof BigintEnumType) {
Map<String, Long> enumMap = ((EnumType) type).getEnumMap();
Long enumValue = enumMap.get(enumKey);
if (enumValue == null) {
throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum BigintEnum");
}
return new LongLiteral(enumValue.toString());
}
else if (type instanceof VarcharEnumType) {
Map<String, String> enumMap = ((EnumType) type).getEnumMap();
String enumValue = enumMap.get(enumKey);
if (enumValue == null) {
throw new SemanticException(TYPE_MISMATCH, "No value " + enumKey + " in enum VarcharEnum");
}
return new StringLiteral(enumValue);
}
return key;
}

private Expression rewriteIfCastOrDereferenceExpression(Expression argument)
{
if (argument instanceof Cast) {
argument = (Expression) visitCast((Cast) argument, null);
}
return peelIfDereferenceExpression(argument);
}

private Expression peelIfDereferenceExpression(Expression argument)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any downsides to unwrapping the dereference expression? Can we just have this happen always? i.e. have a visitDereferenceExpression() that handles it. (or have visitExpression call to an ExpressionTreeRewriter that handles all the expression rewrites including Dereference, Cast, SimpleCaseExpression, etc. that way we'll automatically recurse into e.g. cast contained in an argument to some other custom function no matter how many layers deep.)

{
if (argument instanceof DereferenceExpression) {
try {
DereferenceExpression arg = (DereferenceExpression) argument;
Type argumentType = functionAndTypeResolver.getType(parseTypeSignature(arg.getBase().toString()));
if (argumentType instanceof TypeWithName) {
argumentType = ((TypeWithName) argumentType).getType();
if (argumentType instanceof EnumType) {
return convertEnumTypeToLiteral(arg, argumentType);
}
}
}
catch (IllegalArgumentException | UnknownTypeException e) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in what cases does the rewrite fail. Could we write the code so it doesn't happen instead of relying on exception handling, which could mask bugs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that it can fail while getting the type (functionAndTypeResolver.getType) for other dereference expressions, for example for a named field in a row or getting some values from a CTE, as it is not recognized as a valid type. I'm not sure what else we can do if this throws other than return the original node. Would logging help?

// Returns the original expression if rewrite fails.
return argument;
}
}
return argument;
}
}
}
Loading