This is an automated email from the ASF dual-hosted git repository.

jark pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new c086a91d977 [FLINK-27175][hive] Fix fail to call Hive UDAF when the 
UDAF accepts one parameter with array type
c086a91d977 is described below

commit c086a91d977a5cb51d1b7c962bcb51f7d2a867fc
Author: yuxia Luo <[email protected]>
AuthorDate: Fri Aug 26 17:49:19 2022 +0800

    [FLINK-27175][hive] Fix fail to call Hive UDAF when the UDAF accepts one 
parameter with array type
    
    This closes #19423
---
 .../table/functions/hive/HiveGenericUDAF.java      | 11 +++++++
 .../table/functions/hive/HiveGenericUDTF.java      |  2 +-
 .../table/functions/hive/HiveScalarFunction.java   |  2 +-
 .../functions/hive/util/HiveFunctionUtil.java      | 37 ----------------------
 .../table/functions/hive/HiveGenericUDAFTest.java  | 36 +++++++++++++++++++++
 .../src/test/resources/query-test/udaf.q           |  3 ++
 6 files changed, 52 insertions(+), 39 deletions(-)

diff --git 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java
 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java
index e396038077e..25e1ca058c3 100644
--- 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java
+++ 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java
@@ -30,6 +30,7 @@ import org.apache.flink.table.functions.FunctionContext;
 import org.apache.flink.table.functions.hive.conversion.HiveInspectors;
 import org.apache.flink.table.functions.hive.conversion.HiveObjectConversion;
 import org.apache.flink.table.functions.hive.conversion.IdentityConversion;
+import org.apache.flink.table.functions.hive.util.HiveFunctionUtil;
 import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.inference.CallContext;
 import org.apache.flink.table.types.inference.TypeInference;
@@ -69,6 +70,7 @@ public class HiveGenericUDAF
     private transient GenericUDAFEvaluator partialEvaluator;
     private transient GenericUDAFEvaluator finalEvaluator;
     private transient ObjectInspector finalResultObjectInspector;
+    private transient boolean isArgsSingleArray;
     private transient HiveObjectConversion[] conversions;
     private transient boolean allIdentityConverter;
     private transient boolean initialized;
@@ -109,6 +111,7 @@ public class HiveGenericUDAF
                         GenericUDAFEvaluator.Mode.FINAL,
                         new ObjectInspector[] {partialResultObjectInspector});
 
+        isArgsSingleArray = HiveFunctionUtil.isSingleBoxedArray(arguments);
         conversions = new HiveObjectConversion[inputInspectors.length];
         for (int i = 0; i < inputInspectors.length; i++) {
             conversions[i] =
@@ -182,6 +185,14 @@ public class HiveGenericUDAF
 
     public void accumulate(GenericUDAFEvaluator.AggregationBuffer acc, 
Object... inputs)
             throws HiveException {
+        // When the parameter of the function is (Integer, Array[Double]), 
Flink calls
+        // udf.accumulate(AggregationBuffer, Integer, Array[Double]), which is 
not a problem.
+        // But when the parameter is a single array, Flink calls 
udf.accumulate(AggregationBuffer,
+        // Array[Double]), at this point java's var-args will cast 
Array[Double] to Array[Object]
+        // and let it be Object... args, So we need wrap it.
+        if (isArgsSingleArray) {
+            inputs = new Object[] {inputs};
+        }
         if (!allIdentityConverter) {
             for (int i = 0; i < inputs.length; i++) {
                 inputs[i] = conversions[i].toHiveObject(inputs[i]);
diff --git 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java
 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java
index 0eba5ac7317..bf91ba25479 100644
--- 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java
+++ 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java
@@ -115,7 +115,7 @@ public class HiveGenericUDTF extends TableFunction<Row> 
implements HiveFunction<
 
         // When the parameter is (Integer, Array[Double]), Flink calls 
udf.eval(Integer,
         // Array[Double]), which is not a problem.
-        // But when the parameter is an single array, Flink calls 
udf.eval(Array[Double]),
+        // But when the parameter is a single array, Flink calls 
udf.eval(Array[Double]),
         // at this point java's var-args will cast Array[Double] to 
Array[Object] and let it be
         // Object... args, So we need wrap it.
         if (isArgsSingleArray) {
diff --git 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java
 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java
index 459575303c2..3f0b7ad9fc5 100644
--- 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java
+++ 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java
@@ -79,7 +79,7 @@ public abstract class HiveScalarFunction<UDFType> extends 
ScalarFunction
 
         // When the parameter is (Integer, Array[Double]), Flink calls 
udf.eval(Integer,
         // Array[Double]), which is not a problem.
-        // But when the parameter is an single array, Flink calls 
udf.eval(Array[Double]),
+        // But when the parameter is a single array, Flink calls 
udf.eval(Array[Double]),
         // at this point java's var-args will cast Array[Double] to 
Array[Object] and let it be
         // Object... args, So we need wrap it.
         if (isArgsSingleArray) {
diff --git 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/util/HiveFunctionUtil.java
 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/util/HiveFunctionUtil.java
index 681334a78cd..f0200850647 100644
--- 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/util/HiveFunctionUtil.java
+++ 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/util/HiveFunctionUtil.java
@@ -19,54 +19,17 @@
 package org.apache.flink.table.functions.hive.util;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.table.functions.hive.FlinkHiveUDFException;
 import org.apache.flink.table.functions.hive.HiveFunctionArguments;
 import org.apache.flink.table.types.DataType;
-import org.apache.flink.table.types.logical.ArrayType;
-import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
 
 /** Util for Hive functions. */
 @Internal
 public class HiveFunctionUtil {
     public static boolean isSingleBoxedArray(HiveFunctionArguments arguments) {
-        for (int i = 0; i < arguments.size(); i++) {
-            if (HiveFunctionUtil.isPrimitiveArray(arguments.getDataType(i))) {
-                throw new FlinkHiveUDFException(
-                        "Flink doesn't support primitive array for Hive 
function yet.");
-            }
-        }
         return arguments.size() == 1 && 
HiveFunctionUtil.isArrayType(arguments.getDataType(0));
     }
 
-    private static boolean isPrimitiveArray(DataType dataType) {
-        if (isArrayType(dataType)) {
-            ArrayType arrayType = (ArrayType) dataType.getLogicalType();
-
-            LogicalType elementType = arrayType.getElementType();
-            return !(elementType.isNullable() || !isPrimitive(elementType));
-        } else {
-            return false;
-        }
-    }
-
-    // This is copied from PlannerTypeUtils in flink-table-runtime that we 
shouldn't depend on
-    // TODO: remove this and use the original code when it's moved to 
accessible, dependable module
-    private static boolean isPrimitive(LogicalType type) {
-        switch (type.getTypeRoot()) {
-            case BOOLEAN:
-            case TINYINT:
-            case SMALLINT:
-            case INTEGER:
-            case BIGINT:
-            case FLOAT:
-            case DOUBLE:
-                return true;
-            default:
-                return false;
-        }
-    }
-
     private static boolean isArrayType(DataType dataType) {
         return 
dataType.getLogicalType().getTypeRoot().equals(LogicalTypeRoot.ARRAY);
     }
diff --git 
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java
 
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java
index 49ca87f3484..6234c50e14a 100644
--- 
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java
+++ 
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java
@@ -24,6 +24,8 @@ import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.inference.utils.CallContextMock;
 import org.apache.flink.types.Row;
 
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFContextNGrams;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCount;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
@@ -40,6 +42,7 @@ import java.util.Optional;
 import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
 
 /** Test for {@link HiveGenericUDAF}. */
 public class HiveGenericUDAFTest {
@@ -139,6 +142,39 @@ public class HiveGenericUDAFTest {
         assertThat(Arrays.toString((Row[]) 
udaf.getValue(acc))).isEqualTo("[+I[[think], 1.0]]");
     }
 
+    @Test
+    public void testUDAFWithSingleArrayAsParameter() throws Exception {
+        Object[] constantArgs = new Object[] {null};
+
+        DataType[] argTypes = new DataType[] 
{DataTypes.ARRAY(DataTypes.INT().notNull())};
+
+        // test CollectList
+        HiveGenericUDAF udf =
+                init(GenericUDAFCollectList.class, constantArgs, argTypes, 
false, false);
+        GenericUDAFEvaluator.AggregationBuffer acc = udf.createAccumulator();
+
+        udf.accumulate(acc, new Integer[] {1, 2});
+        udf.accumulate(acc, new Integer[] {2, 3});
+
+        udf.merge(acc, Collections.emptyList());
+
+        Integer[][] expectedResult = new Integer[][] {new Integer[] {1, 2}, 
new Integer[] {2, 3}};
+        assertArrayEquals(expectedResult, (Integer[][]) udf.getValue(acc));
+
+        // test CollectSet
+        udf = init(GenericUDAFCollectSet.class, constantArgs, argTypes, false, 
false);
+        acc = udf.createAccumulator();
+
+        udf.accumulate(acc, new Integer[] {1, 2});
+        udf.accumulate(acc, new Integer[] {2, 3});
+        udf.accumulate(acc, new Integer[] {1, 2});
+
+        udf.merge(acc, Collections.emptySet());
+
+        expectedResult = new Integer[][] {new Integer[] {1, 2}, new Integer[] 
{2, 3}};
+        assertArrayEquals(expectedResult, (Integer[][]) udf.getValue(acc));
+    }
+
     private static HiveGenericUDAF init(
             Class<?> hiveUdfClass,
             Object[] constantArgs,
diff --git 
a/flink-connectors/flink-connector-hive/src/test/resources/query-test/udaf.q 
b/flink-connectors/flink-connector-hive/src/test/resources/query-test/udaf.q
index 3f4568626c9..436950fc7b2 100644
--- a/flink-connectors/flink-connector-hive/src/test/resources/query-test/udaf.q
+++ b/flink-connectors/flink-connector-hive/src/test/resources/query-test/udaf.q
@@ -7,3 +7,6 @@ select sum(x) from foo;
 select percentile_approx(x, 0.5) from foo;
 
 [+I[2.5]]
+
+select i, collect_list(array(s)) from bar group by i;
+[+I[1, [[a], [aa]]], +I[2, [[b]]]]

Reply via email to