This is an automated email from the ASF dual-hosted git repository. stigahuang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/impala.git
commit b71bdce5ebf8667ae0129ec5fb1d350492d09a46 Author: Daniel Becker <[email protected]> AuthorDate: Tue Oct 11 14:01:12 2022 +0200 IMPALA-11462: Constant folding causes cast on int literals to be lost The following query returns -128 instead of 128 and the return type is TINYINT instead of BIGINT: select shiftleft(cast(1 as bigint), z) c from (select 7 z ) x; However, if we disable expression rewrites, the result is correct. The expression rewrite rule 'FoldConstantsRule' folds the the cast to bigint into the literal during expression rewrite. This modifies the expression, so re-analysis is needed. Re-analysis resets the literal expression (a NumericLiteral), which loses its type and becomes TINYINT again. 'NumericLiteral's have three kinds of type: - natural type: the smallest type that can hold its value - explicit type: starts as the natural type and can be widened by explicit casts - implicit type: the result of casts the analyzer uses to adjust input types for functions and arithmetic operations See NumericLiteral.java for more. The problem is that when 'FoldConstantsRule' folds the cast into the literal, it doesn't set its explicit type (only its implicit type becomes BIGINT). When the expression is reset during re-analysis, the type is reverted to the incorrect explicit type. (In the case where expression rewrites are disabled, no re-analysis takes place so the literal and its type are not reset.) This patch fixes the error by setting the explicit type of 'NumericLiteral' when folding an explicit cast into it. Testing: - Added test tests/query_test/test_queries.py::TestConstantFoldingNoTypeLoss::test_shiftleft that tests 'shiftleft' with all integer types. Change-Id: Ie7f27b204792ef7c59dec5ead363d44ed0c3bc79 Reviewed-on: http://gerrit.cloudera.org:8080/19124 Reviewed-by: Impala Public Jenkins <[email protected]> Tested-by: Impala Public Jenkins <[email protected]> --- .../org/apache/impala/analysis/LiteralExpr.java | 51 +++++++++++++++------- .../apache/impala/rewrite/FoldConstantsRule.java | 4 +- tests/query_test/test_exprs.py | 42 ++++++++++++++++++ 3 files changed, 80 insertions(+), 17 deletions(-) diff --git a/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java b/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java index 17c5fa0b5..de8d57e12 100644 --- a/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java +++ b/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java @@ -198,10 +198,13 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr * in cases where the corresponding LiteralExpr is not able to represent the evaluation * result, e.g., NaN or infinity. Returns null if the expr evaluation encountered errors * or warnings in the BE. + * If 'keepOriginalIntType' is true, the type of the result will be the same as the type + * of 'constExpr'; otherwise for integers a shorter type may be used if it is big enough + * to hold the value. * TODO: Support non-scalar types. */ public static LiteralExpr createBounded(Expr constExpr, TQueryCtx queryCtx, - int maxResultSize) throws AnalysisException { + int maxResultSize, boolean keepOriginalIntType) throws AnalysisException { Preconditions.checkState(constExpr.isConstant()); Preconditions.checkState(constExpr.getType().isValid()); if (constExpr instanceof LiteralExpr) return (LiteralExpr) constExpr; @@ -224,24 +227,10 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr if (val.isSetBool_val()) result = new BoolLiteral(val.bool_val); break; case TINYINT: - if (val.isSetByte_val()) { - result = new NumericLiteral(BigDecimal.valueOf(val.byte_val)); - } - break; case SMALLINT: - if (val.isSetShort_val()) { - result = new NumericLiteral(BigDecimal.valueOf(val.short_val)); - } - break; case INT: - if (val.isSetInt_val()) { - result = new NumericLiteral(BigDecimal.valueOf(val.int_val)); - } - break; case BIGINT: - if (val.isSetLong_val()) { - result = new NumericLiteral(BigDecimal.valueOf(val.long_val)); - } + result = createIntegerLiteral(val, constExpr.getType(), keepOriginalIntType); break; case FLOAT: case DOUBLE: @@ -312,6 +301,11 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr return result; } + public static LiteralExpr createBounded(Expr constExpr, TQueryCtx queryCtx, + int maxResultSize) throws AnalysisException { + return createBounded(constExpr, queryCtx, maxResultSize, false); + } + // Order NullLiterals based on the SQL ORDER BY default behavior: NULLS LAST. @Override public int compareTo(LiteralExpr other) { @@ -321,4 +315,29 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr if (getClass() != other.getClass()) return -1; return 0; } + + static private NumericLiteral createIntegerLiteral(TColumnValue val, Type type, + boolean keepOriginalIntType) throws SqlCastException { + BigDecimal value = null; + switch (type.getPrimitiveType()) { + case TINYINT: + if (val.isSetByte_val()) value = BigDecimal.valueOf(val.byte_val); + break; + case SMALLINT: + if (val.isSetShort_val()) value = BigDecimal.valueOf(val.short_val); + break; + case INT: + if (val.isSetInt_val()) value = BigDecimal.valueOf(val.int_val); + break; + case BIGINT: + if (val.isSetLong_val()) value = BigDecimal.valueOf(val.long_val); + break; + default: + Preconditions.checkState(false, + String.format("Integer type expected, got '%s'.", type.toSql())); + } + if (value == null) return null; + return keepOriginalIntType ? + new NumericLiteral(value, type) : new NumericLiteral(value); + } } diff --git a/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java b/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java index a88d2131f..f11ed40ff 100644 --- a/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java +++ b/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java @@ -63,8 +63,10 @@ public class FoldConstantsRule implements ExprRewriteRule { expr.analyze(analyzer); if (!expr.isConstant()) return expr; } + // Force the type to be preserved if it is an explicit cast (see IMPALA-11462). + boolean isExplicitCast = expr instanceof CastExpr && !expr.isImplicitCast(); Expr result = LiteralExpr.createBounded(expr, analyzer.getQueryCtx(), - LiteralExpr.MAX_STRING_LITERAL_SIZE); + LiteralExpr.MAX_STRING_LITERAL_SIZE, isExplicitCast); // Preserve original type so parent Exprs do not need to be re-analyzed. if (result != null) return result.castTo(expr.getType()); diff --git a/tests/query_test/test_exprs.py b/tests/query_test/test_exprs.py index cf98fc7c8..03e6a3b15 100644 --- a/tests/query_test/test_exprs.py +++ b/tests/query_test/test_exprs.py @@ -246,3 +246,45 @@ class TestUtcTimestampFunctions(ImpalaTestSuite): vector.get_value('exec_option')['enable_expr_rewrites'] = \ vector.get_value('enable_expr_rewrites') self.run_test_case('QueryTest/utc-timestamp-functions', vector) + + +class TestConstantFoldingNoTypeLoss(ImpalaTestSuite): + """"Regression tests for IMPALA-11462.""" + + @classmethod + def get_workload(self): + return "functional-query" + + @classmethod + def add_test_dimensions(cls): + super(TestConstantFoldingNoTypeLoss, cls).add_test_dimensions() + # Test with and without expr rewrites to verify that constant folding does not change + # the behaviour. + cls.ImpalaTestMatrix.add_dimension( + ImpalaTestDimension('enable_expr_rewrites', *[0,1])) + # We don't actually use a table so one file format is enough. + cls.ImpalaTestMatrix.add_constraint(lambda v: + v.get_value('table_format').file_format in ['parquet']) + + def test_shiftleft(self, vector): + """ Tests that the return values of the 'shiftleft' functions are correct for the + input types (the return type should be the same as the first argument).""" + types_and_widths = [ + ("TINYINT", 8), + ("SMALLINT", 16), + ("INT", 32), + ("BIGINT", 64) + ] + query_template = ("select shiftleft(cast(1 as {typename}), z) c " + "from (select {shift_val} z ) x") + for (typename, width) in types_and_widths: + shift_val = width - 2 # Valid and positive for signed types. + expected_value = 1 << shift_val + result = self.execute_query_expect_success(self.client, + query_template.format(typename=typename, shift_val=shift_val)) + assert result.data == [str(expected_value)] + + def test_addition(self, vector): + query = "select typeof(cast(1 as bigint) + cast(rand() as tinyint))" + result = self.execute_query_expect_success(self.client, query) + assert result.data == ["BIGINT"]
