This is an automated email from the ASF dual-hosted git repository.

stigahuang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/impala.git

commit b71bdce5ebf8667ae0129ec5fb1d350492d09a46
Author: Daniel Becker <[email protected]>
AuthorDate: Tue Oct 11 14:01:12 2022 +0200

    IMPALA-11462: Constant folding causes cast on int literals to be lost
    
    The following query returns -128 instead of 128 and the return type is
    TINYINT instead of BIGINT:
    
      select shiftleft(cast(1 as bigint), z) c from (select 7 z ) x;
    
    However, if we disable expression rewrites, the result is correct.
    
    The expression rewrite rule 'FoldConstantsRule' folds the the cast to
    bigint into the literal during expression rewrite. This modifies the
    expression, so re-analysis is needed. Re-analysis resets the literal
    expression (a NumericLiteral), which loses its type and becomes TINYINT
    again.
    
    'NumericLiteral's have three kinds of type:
     - natural type: the smallest type that can hold its value
     - explicit type: starts as the natural type and can be widened by
       explicit casts
     - implicit type: the result of casts the analyzer uses to
       adjust input types for functions and arithmetic operations
    See NumericLiteral.java for more.
    
    The problem is that when 'FoldConstantsRule' folds the cast into the
    literal, it doesn't set its explicit type (only its implicit type
    becomes BIGINT). When the expression is reset during re-analysis, the
    type is reverted to the incorrect explicit type.
    
    (In the case where expression rewrites are disabled, no re-analysis
    takes place so the literal and its type are not reset.)
    
    This patch fixes the error by setting the explicit type of
    'NumericLiteral' when folding an explicit cast into it.
    
    Testing:
     - Added test
       
tests/query_test/test_queries.py::TestConstantFoldingNoTypeLoss::test_shiftleft
       that tests 'shiftleft' with all integer types.
    
    Change-Id: Ie7f27b204792ef7c59dec5ead363d44ed0c3bc79
    Reviewed-on: http://gerrit.cloudera.org:8080/19124
    Reviewed-by: Impala Public Jenkins <[email protected]>
    Tested-by: Impala Public Jenkins <[email protected]>
---
 .../org/apache/impala/analysis/LiteralExpr.java    | 51 +++++++++++++++-------
 .../apache/impala/rewrite/FoldConstantsRule.java   |  4 +-
 tests/query_test/test_exprs.py                     | 42 ++++++++++++++++++
 3 files changed, 80 insertions(+), 17 deletions(-)

diff --git a/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java 
b/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java
index 17c5fa0b5..de8d57e12 100644
--- a/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java
+++ b/fe/src/main/java/org/apache/impala/analysis/LiteralExpr.java
@@ -198,10 +198,13 @@ public abstract class LiteralExpr extends Expr implements 
Comparable<LiteralExpr
    * in cases where the corresponding LiteralExpr is not able to represent the 
evaluation
    * result, e.g., NaN or infinity. Returns null if the expr evaluation 
encountered errors
    * or warnings in the BE.
+   * If 'keepOriginalIntType' is true, the type of the result will be the same 
as the type
+   * of 'constExpr'; otherwise for integers a shorter type may be used if it 
is big enough
+   * to hold the value.
    * TODO: Support non-scalar types.
    */
   public static LiteralExpr createBounded(Expr constExpr, TQueryCtx queryCtx,
-    int maxResultSize) throws AnalysisException {
+    int maxResultSize, boolean keepOriginalIntType) throws AnalysisException {
     Preconditions.checkState(constExpr.isConstant());
     Preconditions.checkState(constExpr.getType().isValid());
     if (constExpr instanceof LiteralExpr) return (LiteralExpr) constExpr;
@@ -224,24 +227,10 @@ public abstract class LiteralExpr extends Expr implements 
Comparable<LiteralExpr
         if (val.isSetBool_val()) result = new BoolLiteral(val.bool_val);
         break;
       case TINYINT:
-        if (val.isSetByte_val()) {
-          result = new NumericLiteral(BigDecimal.valueOf(val.byte_val));
-        }
-        break;
       case SMALLINT:
-        if (val.isSetShort_val()) {
-          result = new NumericLiteral(BigDecimal.valueOf(val.short_val));
-        }
-        break;
       case INT:
-        if (val.isSetInt_val()) {
-          result = new NumericLiteral(BigDecimal.valueOf(val.int_val));
-        }
-        break;
       case BIGINT:
-        if (val.isSetLong_val()) {
-          result = new NumericLiteral(BigDecimal.valueOf(val.long_val));
-        }
+        result = createIntegerLiteral(val, constExpr.getType(), 
keepOriginalIntType);
         break;
       case FLOAT:
       case DOUBLE:
@@ -312,6 +301,11 @@ public abstract class LiteralExpr extends Expr implements 
Comparable<LiteralExpr
     return result;
   }
 
+  public static LiteralExpr createBounded(Expr constExpr, TQueryCtx queryCtx,
+    int maxResultSize) throws AnalysisException {
+    return createBounded(constExpr, queryCtx, maxResultSize, false);
+  }
+
   // Order NullLiterals based on the SQL ORDER BY default behavior: NULLS LAST.
   @Override
   public int compareTo(LiteralExpr other) {
@@ -321,4 +315,29 @@ public abstract class LiteralExpr extends Expr implements 
Comparable<LiteralExpr
     if (getClass() != other.getClass()) return -1;
     return 0;
   }
+
+  static private NumericLiteral createIntegerLiteral(TColumnValue val, Type 
type,
+      boolean keepOriginalIntType) throws SqlCastException {
+    BigDecimal value = null;
+    switch (type.getPrimitiveType()) {
+      case TINYINT:
+        if (val.isSetByte_val()) value = BigDecimal.valueOf(val.byte_val);
+        break;
+      case SMALLINT:
+        if (val.isSetShort_val()) value = BigDecimal.valueOf(val.short_val);
+        break;
+      case INT:
+        if (val.isSetInt_val()) value = BigDecimal.valueOf(val.int_val);
+        break;
+      case BIGINT:
+        if (val.isSetLong_val()) value = BigDecimal.valueOf(val.long_val);
+        break;
+      default:
+        Preconditions.checkState(false,
+            String.format("Integer type expected, got '%s'.", type.toSql()));
+    }
+    if (value == null) return null;
+    return keepOriginalIntType ?
+        new NumericLiteral(value, type) : new NumericLiteral(value);
+  }
 }
diff --git a/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java 
b/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java
index a88d2131f..f11ed40ff 100644
--- a/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java
+++ b/fe/src/main/java/org/apache/impala/rewrite/FoldConstantsRule.java
@@ -63,8 +63,10 @@ public class FoldConstantsRule implements ExprRewriteRule {
       expr.analyze(analyzer);
       if (!expr.isConstant()) return expr;
     }
+    // Force the type to be preserved if it is an explicit cast (see 
IMPALA-11462).
+    boolean isExplicitCast = expr instanceof CastExpr && 
!expr.isImplicitCast();
     Expr result = LiteralExpr.createBounded(expr, analyzer.getQueryCtx(),
-      LiteralExpr.MAX_STRING_LITERAL_SIZE);
+      LiteralExpr.MAX_STRING_LITERAL_SIZE, isExplicitCast);
 
     // Preserve original type so parent Exprs do not need to be re-analyzed.
     if (result != null) return result.castTo(expr.getType());
diff --git a/tests/query_test/test_exprs.py b/tests/query_test/test_exprs.py
index cf98fc7c8..03e6a3b15 100644
--- a/tests/query_test/test_exprs.py
+++ b/tests/query_test/test_exprs.py
@@ -246,3 +246,45 @@ class TestUtcTimestampFunctions(ImpalaTestSuite):
     vector.get_value('exec_option')['enable_expr_rewrites'] = \
         vector.get_value('enable_expr_rewrites')
     self.run_test_case('QueryTest/utc-timestamp-functions', vector)
+
+
+class TestConstantFoldingNoTypeLoss(ImpalaTestSuite):
+  """"Regression tests for IMPALA-11462."""
+
+  @classmethod
+  def get_workload(self):
+    return "functional-query"
+
+  @classmethod
+  def add_test_dimensions(cls):
+    super(TestConstantFoldingNoTypeLoss, cls).add_test_dimensions()
+    # Test with and without expr rewrites to verify that constant folding does 
not change
+    # the behaviour.
+    cls.ImpalaTestMatrix.add_dimension(
+        ImpalaTestDimension('enable_expr_rewrites', *[0,1]))
+    # We don't actually use a table so one file format is enough.
+    cls.ImpalaTestMatrix.add_constraint(lambda v:
+        v.get_value('table_format').file_format in ['parquet'])
+
+  def test_shiftleft(self, vector):
+    """ Tests that the return values of the 'shiftleft' functions are correct 
for the
+    input types (the return type should be the same as the first argument)."""
+    types_and_widths = [
+      ("TINYINT", 8),
+      ("SMALLINT", 16),
+      ("INT", 32),
+      ("BIGINT", 64)
+    ]
+    query_template = ("select shiftleft(cast(1 as {typename}), z) c "
+        "from (select {shift_val} z ) x")
+    for (typename, width) in types_and_widths:
+      shift_val = width - 2  # Valid and positive for signed types.
+      expected_value = 1 << shift_val
+      result = self.execute_query_expect_success(self.client,
+          query_template.format(typename=typename, shift_val=shift_val))
+      assert result.data == [str(expected_value)]
+
+  def test_addition(self, vector):
+    query = "select typeof(cast(1 as bigint) + cast(rand() as tinyint))"
+    result = self.execute_query_expect_success(self.client, query)
+    assert result.data == ["BIGINT"]

Reply via email to