This is an automated email from the ASF dual-hosted git repository. stigahuang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/impala.git
commit d5f43ff19a6c72c2b293da6b38ad1f19f179c251 Author: Steve Carlin <[email protected]> AuthorDate: Wed Nov 6 08:01:51 2024 -0800 IMPALA-13523: Decimal precision and scale needs to be in return type The inferred return type needs to contain a decimal precision and scale. The return type is calculated by taking the most compatible type of all the arguments One query in the e2e tests that will be fixed because of this is (previously it was throwing an analysis exception): select appx_median(c1), appx_median(c2), appx_median(c3) from decimal_tiny; The CoerceOperandShuttle also ensures that if the return type is a decimal in certain cases, all the arguments for the function will be cast to that specific return type. The 2 cases here are: 1) when the function is a case statement, in which all cases need to be the same precision and scale 2) when the function contains varargs, in which case all the comparisons need to be of the same precision and scale. Change-Id: Ie10521b587a74930a01c08b711364f897bb2dc33 Reviewed-on: http://gerrit.cloudera.org:8080/22086 Tested-by: Impala Public Jenkins <[email protected]> Reviewed-by: Aman Sinha <[email protected]> --- .../calcite/coercenodes/CoerceOperandShuttle.java | 87 ++++++++++++---------- .../calcite/operators/CommonOperatorFunctions.java | 7 +- .../queries/QueryTest/decimal-exprs.test | 18 +++++ .../functional-query/queries/QueryTest/exprs.test | 11 +++ 4 files changed, 82 insertions(+), 41 deletions(-) diff --git a/java/calcite-planner/src/main/java/org/apache/impala/calcite/coercenodes/CoerceOperandShuttle.java b/java/calcite-planner/src/main/java/org/apache/impala/calcite/coercenodes/CoerceOperandShuttle.java index 8c845e0f5..654afad1b 100644 --- a/java/calcite-planner/src/main/java/org/apache/impala/calcite/coercenodes/CoerceOperandShuttle.java +++ b/java/calcite-planner/src/main/java/org/apache/impala/calcite/coercenodes/CoerceOperandShuttle.java @@ -130,17 +130,17 @@ public class CoerceOperandShuttle extends RexShuttle { // returns a decimal type. The Decimal type from the function resolver would // have to calculate the precision and scale based on operand types. If // necessary, this code should be added later. - Preconditions.checkState(retType.getSqlTypeName() != SqlTypeName.DECIMAL || - castedOperandsCall.getType().getSqlTypeName() == SqlTypeName.DECIMAL); + Preconditions.checkState(!SqlTypeUtil.isDecimal(retType) || + SqlTypeUtil.isDecimal(castedOperandsCall.getType())); // So if the original return type is Decimal and the function resolves to // decimal, the precision and scale are saved from the original function. - if (retType.getSqlTypeName().equals(SqlTypeName.DECIMAL)) { + if (SqlTypeUtil.isDecimal(retType)) { retType = castedOperandsCall.getType(); } - List<RexNode> newOperands = - getCastedArgTypes(fn, castedOperandsCall.getOperands(), factory, rexBuilder); + List<RexNode> newOperands = getCastedArgTypes(fn, castedOperandsCall.getOperands(), + retType, factory, rexBuilder); // keep the original call if nothing changed, else build a new RexCall. return retType.equals(castedOperandsCall.getType()) @@ -165,7 +165,7 @@ public class CoerceOperandShuttle extends RexShuttle { RelDataType retType = getReturnType(castedOver, fn.getReturnType()); List<RexNode> newOperands = - getCastedArgTypes(fn, castedOver.getOperands(), factory, rexBuilder); + getCastedArgTypes(fn, castedOver.getOperands(), retType, factory, rexBuilder); return retType.equals(castedOver.getType()) && newOperands.equals(castedOver.getOperands()) @@ -215,12 +215,12 @@ public class CoerceOperandShuttle extends RexShuttle { // returns a decimal type. The Decimal type from the function resolver would // have to calculate the precision and scale based on operand types. If // necessary, this code should be added later. - Preconditions.checkState(retType.getSqlTypeName() != SqlTypeName.DECIMAL || - rexNode.getType().getSqlTypeName() == SqlTypeName.DECIMAL); + Preconditions.checkState(!SqlTypeUtil.isDecimal(retType) || + SqlTypeUtil.isDecimal(rexNode.getType())); // So if the original return type is Decimal and the function resolves to // decimal, the precision and scale are saved from the original function. - if (retType.getSqlTypeName().equals(SqlTypeName.DECIMAL)) { + if (SqlTypeUtil.isDecimal(retType)) { retType = rexNode.getType(); } @@ -283,8 +283,7 @@ public class CoerceOperandShuttle extends RexShuttle { * Return a list of the operands, casting whenever needed. */ private static List<RexNode> getCastedArgTypes(Function fn, List<RexNode> operands, - RelDataTypeFactory factory, RexBuilder rexBuilder) { - + RelDataType retType, RelDataTypeFactory factory, RexBuilder rexBuilder) { List<RelDataType> argTypes = Util.transform(operands, RexNode::getType); List<RexNode> newOperands = new ArrayList<>(); // The "Case" operator is special because the operands alternate between @@ -292,6 +291,7 @@ public class CoerceOperandShuttle extends RexShuttle { // boolean, so they don't need casting. boolean isCaseFunction = isCaseFunction(fn); boolean castedOperand = false; + Preconditions.checkState(argTypes.size() == 0 || fn.getNumArgs() > 0); for (int i = 0; i < argTypes.size(); ++i) { if (isCaseFunction && FunctionResolver.shouldSkipOperandForCase(argTypes.size(), i)) { @@ -299,10 +299,15 @@ public class CoerceOperandShuttle extends RexShuttle { newOperands.add(operands.get(i)); continue; } - // if there are varargs, the last arg in the signature will match all - // remaining args. - int sigIndex = getArgIndex(fn, i, isCaseFunction); - RexNode operand = castOperand(operands.get(i), fn.getArgs()[sigIndex], + + // in the case of varargs, take the last argument in the signature. + int indexToUse = Math.min(i, fn.getNumArgs() - 1); + Type toImpalaType = fn.getArgs()[indexToUse]; + RelDataType toType = useReturnTypeForCastingArg(fn, argTypes.get(indexToUse)) + ? retType + : getCastedToType(argTypes.get(i), toImpalaType, factory); + + RexNode operand = castOperand(operands.get(i), toType, factory, rexBuilder); Preconditions.checkNotNull(operand); newOperands.add(operand); @@ -314,28 +319,36 @@ public class CoerceOperandShuttle extends RexShuttle { return castedOperand ? newOperands : operands; } - /** - * Return the argIndex. If it's a case statement, the index is always 0 - * If there are varargs, the last index returns if the "i" value passed - * in overflows the size of the operands. - */ - private static int getArgIndex(Function fn, int i, boolean isCaseFn) { - if (isCaseFn) { - return 0; + private static boolean useReturnTypeForCastingArg(Function fn, RelDataType argType) { + // case functions use the precalculated return type from the function resolver. + if (isCaseFunction(fn)) { + return true; } - return Math.min(i, fn.getNumArgs() - 1); + // For functions that have decimal varargs and return a decimal + // (e.g. greatest, least), the type has been calculated at validation time. + return SqlTypeUtil.isDecimal(argType) && + fn.getReturnType().isDecimal() && fn.hasVarArgs(); } private static boolean isCaseFunction(Function fn) { return fn.functionName().equals("case"); } - private static RexNode castOperand(RexNode node, Type toImpalaType, - RelDataTypeFactory factory, RexBuilder rexBuilder) { + private static RelDataType getCastedToType(RelDataType fromType, + Type toImpalaType, RelDataTypeFactory factory) { - RelDataType toType = ImpalaTypeConverter.getRelDataType(toImpalaType); - return castOperand(node, toType, factory, rexBuilder); + if (!toImpalaType.isDecimal() || SqlTypeUtil.isNull(fromType)) { + return ImpalaTypeConverter.getRelDataType(toImpalaType); + } + + // Integer based type needs special conversion to Decimal types based on the + // size of the type of Integer (e.g. TINYINT, SMALLINT, etc...), but don't change + // the type if the from type is also DECIMAL. + ScalarType impalaType = (ScalarType) ImpalaTypeConverter.createImpalaType(fromType); + ScalarType decimalType = impalaType.getMinResolutionDecimal(); + return factory.createSqlType(SqlTypeName.DECIMAL, + decimalType.decimalPrecision(), decimalType.decimalScale()); } /** @@ -356,11 +369,17 @@ public class CoerceOperandShuttle extends RexShuttle { } // No need to cast if types are the same - if (fromType.getSqlTypeName().equals(toType.getSqlTypeName())) { + if (fromType.getSqlTypeName().equals(toType.getSqlTypeName()) && + fromType.getPrecision() == toType.getPrecision() && + fromType.getScale() == toType.getScale()) { return node; } - if (fromType.getSqlTypeName().equals(SqlTypeName.NULL)) { + if (SqlTypeUtil.isNull(fromType)) { + if (SqlTypeUtil.isDecimal(toType)) { + Type impalaType = ImpalaTypeConverter.createImpalaType(Type.DECIMAL, 1, 0); + toType = ImpalaTypeConverter.createRelDataType(impalaType); + } return rexBuilder.makeCast(toType, node); } @@ -368,14 +387,6 @@ public class CoerceOperandShuttle extends RexShuttle { return null; } - // Integer based type needs special conversion to Decimal types based on the - // size of the type of Integer (e.g. TINYINT, SMALLINT, etc...) - if (toType.getSqlTypeName().equals(SqlTypeName.DECIMAL)) { - ScalarType impalaType = (ScalarType) ImpalaTypeConverter.createImpalaType(fromType); - ScalarType decimalType = impalaType.getMinResolutionDecimal(); - toType = factory.createSqlType(SqlTypeName.DECIMAL, - decimalType.decimalPrecision(), decimalType.decimalScale()); - } return rexBuilder.makeCast(toType, node); } diff --git a/java/calcite-planner/src/main/java/org/apache/impala/calcite/operators/CommonOperatorFunctions.java b/java/calcite-planner/src/main/java/org/apache/impala/calcite/operators/CommonOperatorFunctions.java index 2c30b0007..cf9d4adc3 100644 --- a/java/calcite-planner/src/main/java/org/apache/impala/calcite/operators/CommonOperatorFunctions.java +++ b/java/calcite-planner/src/main/java/org/apache/impala/calcite/operators/CommonOperatorFunctions.java @@ -49,7 +49,6 @@ import java.util.List; * operators. */ public class CommonOperatorFunctions { - // Allow any count because this is used for all functions. Validation for specific // number of parameters will be done when Impala function resolving is done. public static SqlOperandCountRange ANY_COUNT_RANGE = SqlOperandCountRanges.any(); @@ -70,8 +69,10 @@ public class CommonOperatorFunctions { + name + "; operand types: " + operandTypes); } - RelDataType returnType = - ImpalaTypeConverter.getRelDataType(fn.getReturnType()); + RelDataType returnType = fn.getReturnType().equals(Type.DECIMAL) + ? ImpalaTypeConverter.getCompatibleType(operandTypes, factory) + : ImpalaTypeConverter.getRelDataType(fn.getReturnType()); + return isNullable(operandTypes) ? returnType : factory.createTypeWithNullability(returnType, true); diff --git a/testdata/workloads/functional-query/queries/QueryTest/decimal-exprs.test b/testdata/workloads/functional-query/queries/QueryTest/decimal-exprs.test index c8eb10215..4a081c8d1 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/decimal-exprs.test +++ b/testdata/workloads/functional-query/queries/QueryTest/decimal-exprs.test @@ -553,3 +553,21 @@ select typeof(mod(9.6, 3)); ---- RESULTS 'DECIMAL(4,1)' ==== +---- QUERY +# make sure coalesce function works ok +set decimal_v2=true; +select coalesce(cast(18.11 as decimal(4,2)), cast(18.1 as decimal(3,1))) +---- RESULTS +18.11 +---- TYPES +DECIMAL +==== +---- QUERY +# make sure coalesce function works ok +set decimal_v2=true; +select coalesce(cast(18.1 as decimal(3,1)), cast(18.11 as decimal(4,2))) +---- RESULTS +18.10 +---- TYPES +DECIMAL +==== diff --git a/testdata/workloads/functional-query/queries/QueryTest/exprs.test b/testdata/workloads/functional-query/queries/QueryTest/exprs.test index d0cbd576e..b79067ede 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/exprs.test +++ b/testdata/workloads/functional-query/queries/QueryTest/exprs.test @@ -3379,3 +3379,14 @@ select cast(round(1/3, 20) as string); ---- TYPES string ==== +---- QUERY +select +greatest(cast(18.3 as decimal(3,1)), cast(19.44 as decimal(4,2))), +least(cast(18.3 as decimal(3,1)), cast(19.44 as decimal(4,2))), +greatest(cast(19.44 as decimal(4,2)), cast(18.3 as decimal(3,1))), +least(cast(19.44 as decimal(4,2)), cast(18.3 as decimal(3,1))); +---- RESULTS +19.44,18.30,19.44,18.30 +---- TYPES +DECIMAL,DECIMAL,DECIMAL,DECIMAL +====
