This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 263555c9adcca0abe194e9a6c1d85ec591c304e4 Author: fengli <[email protected]> AuthorDate: Mon Feb 27 17:15:47 2023 +0800 [FLINK-31239][hive] Fix native sum function can't get the corrected value when the argument type is string This closes #22031 --- .../table/functions/hive/HiveSumAggFunction.java | 55 +++++++++++++----- .../connectors/hive/HiveDialectAggITCase.java | 66 +++++++++++++++++----- .../resources/explain/testSumAggFunctionPlan.out | 8 +-- .../planner/expressions/ExpressionBuilder.java | 10 ++++ 4 files changed, 106 insertions(+), 33 deletions(-) diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java index 610a1d93239..48470f997df 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java @@ -22,13 +22,20 @@ import org.apache.flink.table.api.DataTypes; import org.apache.flink.table.api.TableException; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; +import org.apache.flink.table.expressions.ValueLiteralExpression; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.inference.CallContext; +import java.math.BigDecimal; + import static org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED; import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef; +import static org.apache.flink.table.expressions.ApiExpressionUtils.valueLiteral; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.and; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.coalesce; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isTrue; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.tryCast; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral; @@ -40,7 +47,10 @@ import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getSc public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction { private final UnresolvedReferenceExpression sum = unresolvedRef("sum"); + private final UnresolvedReferenceExpression isEmpty = unresolvedRef("isEmpty"); + private DataType resultType; + private ValueLiteralExpression zero; @Override public int operandCount() { @@ -49,12 +59,12 @@ public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction { @Override public UnresolvedReferenceExpression[] aggBufferAttributes() { - return new UnresolvedReferenceExpression[] {sum}; + return new UnresolvedReferenceExpression[] {sum, isEmpty}; } @Override public DataType[] getAggBufferTypes() { - return new DataType[] {getResultType()}; + return new DataType[] {getResultType(), DataTypes.BOOLEAN()}; } @Override @@ -64,20 +74,19 @@ public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction { @Override public Expression[] initialValuesExpressions() { - return new Expression[] {/* sum = */ nullOf(getResultType())}; + return new Expression[] {/* sum = */ nullOf(getResultType()), valueLiteral(true)}; } @Override public Expression[] accumulateExpressions() { Expression tryCastOperand = tryCast(operand(0), typeLiteral(getResultType())); + Expression coalesceSum = coalesce(sum, zero); return new Expression[] { /* sum = */ ifThenElse( isNull(tryCastOperand), - sum, - ifThenElse( - isNull(sum), - tryCastOperand, - adjustedPlus(getResultType(), sum, tryCastOperand))) + coalesceSum, + adjustedPlus(getResultType(), coalesceSum, tryCastOperand)), + and(isEmpty, isNull(operand(0))) }; } @@ -88,20 +97,19 @@ public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction { @Override public Expression[] mergeExpressions() { + Expression coalesceSum = coalesce(sum, zero); return new Expression[] { /* sum = */ ifThenElse( isNull(mergeOperand(sum)), - sum, - ifThenElse( - isNull(sum), - mergeOperand(sum), - adjustedPlus(getResultType(), sum, mergeOperand(sum)))) + coalesceSum, + adjustedPlus(getResultType(), coalesceSum, mergeOperand(sum))), + and(isEmpty, mergeOperand(isEmpty)) }; } @Override public Expression getValueExpression() { - return sum; + return ifThenElse(isTrue(isEmpty), nullOf(getResultType()), sum); } @Override @@ -109,6 +117,7 @@ public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction { if (resultType == null) { checkArgumentNum(callContext.getArgumentDataTypes()); resultType = initResultType(callContext.getArgumentDataTypes().get(0)); + zero = defaultValue(resultType); } } @@ -141,4 +150,22 @@ public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction { argsType)); } } + + private ValueLiteralExpression defaultValue(DataType dataType) { + switch (dataType.getLogicalType().getTypeRoot()) { + case BIGINT: + return valueLiteral(0L); + case DOUBLE: + return valueLiteral(0.0); + case DECIMAL: + return valueLiteral( + BigDecimal.valueOf(0, getScale(dataType.getLogicalType())), + dataType.notNull()); + default: + throw new TableException( + String.format( + "Unsupported type %s is passed when initialize the default value.", + dataType)); + } + } } diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java index 11f80f04dcc..896d348cf5c 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java @@ -76,12 +76,12 @@ public class HiveDialectAggITCase { @Test public void testSimpleSumAggFunction() throws Exception { tableEnv.executeSql( - "create table test_sum(x string, y string, z int, d decimal(10,5), e float, f double, ts timestamp)"); + "create table test_sum(x string, y string, g string, z int, d decimal(10,5), e float, f double, ts timestamp)"); tableEnv.executeSql( - "insert into test_sum values (NULL, '2', 1, 1.11, 1.2, 1.3, '2021-08-04 16:26:33.4'), " - + "(NULL, 'b', 2, 2.22, 2.3, 2.4, '2021-08-07 16:26:33.4'), " - + "(NULL, '4', 3, 3.33, 3.5, 3.6, '2021-08-08 16:26:33.4'), " - + "(NULL, NULL, 4, 4.45, 4.7, 4.8, '2021-08-09 16:26:33.4')") + "insert into test_sum values (NULL, '2', 'b', 1, 1.11, 1.2, 1.3, '2021-08-04 16:26:33.4'), " + + "(NULL, 'b', 'b', 2, 2.22, 2.3, 2.4, '2021-08-07 16:26:33.4'), " + + "(NULL, '4', 'b', 3, 3.33, 3.5, 3.6, '2021-08-08 16:26:33.4'), " + + "(NULL, NULL, 'b', 4, 4.45, 4.7, 4.8, '2021-08-09 16:26:33.4')") .await(); // test sum with all elements are null @@ -96,37 +96,43 @@ public class HiveDialectAggITCase { tableEnv.executeSql("select sum(y) from test_sum").collect()); assertThat(result2.toString()).isEqualTo("[+I[6.0]]"); - // test decimal type + // test sum string type with all elements can't convert to double, result type is double List<Row> result3 = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select sum(g) from test_sum").collect()); + assertThat(result3.toString()).isEqualTo("[+I[0.0]]"); + + // test decimal type + List<Row> result4 = CollectionUtil.iteratorToList( tableEnv.executeSql("select sum(d) from test_sum").collect()); - assertThat(result3.toString()).isEqualTo("[+I[11.11000]]"); + assertThat(result4.toString()).isEqualTo("[+I[11.11000]]"); // test sum int, result type is bigint - List<Row> result4 = + List<Row> result5 = CollectionUtil.iteratorToList( tableEnv.executeSql("select sum(z) from test_sum").collect()); - assertThat(result4.toString()).isEqualTo("[+I[10]]"); + assertThat(result5.toString()).isEqualTo("[+I[10]]"); // test float type - List<Row> result5 = + List<Row> result6 = CollectionUtil.iteratorToList( tableEnv.executeSql("select sum(e) from test_sum").collect()); - float actualFloatValue = ((Double) result5.get(0).getField(0)).floatValue(); + float actualFloatValue = ((Double) result6.get(0).getField(0)).floatValue(); assertThat(actualFloatValue).isEqualTo(11.7f); // test double type - List<Row> result6 = + List<Row> result7 = CollectionUtil.iteratorToList( tableEnv.executeSql("select sum(f) from test_sum").collect()); - actualFloatValue = ((Double) result6.get(0).getField(0)).floatValue(); + actualFloatValue = ((Double) result7.get(0).getField(0)).floatValue(); assertThat(actualFloatValue).isEqualTo(12.1f); // test sum string&int type simultaneously - List<Row> result7 = + List<Row> result8 = CollectionUtil.iteratorToList( tableEnv.executeSql("select sum(y), sum(z) from test_sum").collect()); - assertThat(result7.toString()).isEqualTo("[+I[6.0, 10]]"); + assertThat(result8.toString()).isEqualTo("[+I[6.0, 10]]"); // test unsupported timestamp type String expectedMessage = @@ -137,6 +143,36 @@ public class HiveDialectAggITCase { tableEnv.executeSql("drop table test_sum"); } + @Test + public void testSumDecimal() throws Exception { + tableEnv.executeSql( + "create table test_sum_dec(a int, x string, z decimal(10, 5), g decimal(18, 5))"); + tableEnv.executeSql( + "insert into test_sum_dec values (1, 'b', null, null), " + + "(1, 'b', 1.2, null), " + + "(2, 'b', null, null), " + + "(2, 'b', null, null)," + + "(4, '1', null, null)," + + "(4, 'b', null, null)") + .await(); + + List<Row> result = + CollectionUtil.iteratorToList( + tableEnv.executeSql("select a, sum(z), sum(g) from test_sum_dec group by a") + .collect()); + assertThat(result.toString()) + .isEqualTo("[+I[1, 1.20000, null], +I[2, null, null], +I[4, null, null]]"); + + List<Row> result2 = + CollectionUtil.iteratorToList( + tableEnv.executeSql( + "select a, sum(cast(x as decimal(10, 3))) from test_sum_dec group by a") + .collect()); + assertThat(result2.toString()).isEqualTo("[+I[1, 0.000], +I[2, 0.000], +I[4, 1.000]]"); + + tableEnv.executeSql("drop table test_sum_dec"); + } + @Test public void testSumAggWithGroupKey() throws Exception { tableEnv.executeSql( diff --git a/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out b/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out index 702e09fb3f2..95be2ba4e72 100644 --- a/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out +++ b/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out @@ -5,13 +5,13 @@ LogicalProject(x=[$0], _o__c1=[$1]) +- LogicalTableScan(table=[[test-catalog, default, foo]]) == Optimized Physical Plan == -HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0) AS $f1]) +HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0, isEmpty$1) AS $f1]) +- Exchange(distribution=[hash[x]]) - +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS sum$0]) + +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS (sum$0, isEmpty$1)]) +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y]) == Optimized Execution Plan == -HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0) AS $f1]) +HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0, isEmpty$1) AS $f1]) +- Exchange(distribution=[hash[x]]) - +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS sum$0]) + +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS (sum$0, isEmpty$1)]) +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y]) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java index 77cfbaa4baa..bb98614d561 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java @@ -33,6 +33,7 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DE import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CAST; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.COALESCE; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CONCAT; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.DIVIDE; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.EQUALS; @@ -40,6 +41,7 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.GREATE import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.HIVE_AGG_DECIMAL_PLUS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.IF; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.IS_NULL; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.IS_TRUE; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.LESS_THAN; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MINUS; @@ -97,6 +99,14 @@ public class ExpressionBuilder { return call(IS_NULL, input); } + public static UnresolvedCallExpression isTrue(Expression input) { + return call(IS_TRUE, input); + } + + public static UnresolvedCallExpression coalesce(Expression... args) { + return call(COALESCE, args); + } + public static UnresolvedCallExpression ifThenElse( Expression condition, Expression ifTrue, Expression ifFalse) { return call(IF, condition, ifTrue, ifFalse);
