This is an automated email from the ASF dual-hosted git repository.

zhli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new e870de8ce [VL] Fall back collect_set, min and max when input is 
complex type (#5934)
e870de8ce is described below

commit e870de8cea51d138d49e675e1031f090bfc5bf19
Author: Zhen Li <[email protected]>
AuthorDate: Fri May 31 16:55:53 2024 +0800

    [VL] Fall back collect_set, min and max when input is complex type (#5934)
    
    [VL] Fall back collect_set, min and max when input is complex type.
---
 .../execution/VeloxAggregateFunctionsSuite.scala   | 23 ++++++++++++++++++++++
 .../substrait/SubstraitToVeloxPlanValidator.cc     | 10 ++++++++++
 2 files changed, 33 insertions(+)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 4f6f4eb22..ae6306cc0 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -22,7 +22,9 @@ import 
org.apache.gluten.extension.columnar.validator.FallbackInjects
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
 
 abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSuite {
 
@@ -1112,6 +1114,27 @@ abstract class VeloxAggregateFunctionsSuite extends 
VeloxWholeStageTransformerSu
         }
     }
   }
+
+  test("complex type with null") {
+    val jsonStr = 
"""{"txn":{"appId":"txnId","version":0,"lastUpdated":null}}"""
+    val jsonSchema = StructType(
+      Seq(
+        StructField(
+          "txn",
+          StructType(
+            Seq(
+              StructField("appId", StringType, true),
+              StructField("lastUpdated", LongType, true),
+              StructField("version", LongType, true))),
+          true)))
+    val df = spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS)
+    df.select(collect_set(col("txn"))).collect
+
+    df.select(min(col("txn"))).collect
+
+    df.select(max(col("txn"))).collect
+
+  }
 }
 
 class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite 
{
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc 
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index abb2bbc56..a3b46d7d0 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -1045,6 +1045,16 @@ bool 
SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait
           LOG_VALIDATION_MSG("Validation failed for function " + funcName + " 
resolve type in AggregateRel.");
           return false;
         }
+        static const std::unordered_set<std::string> 
notSupportComplexTypeAggFuncs = {"set_agg", "min", "max"};
+        if (notSupportComplexTypeAggFuncs.find(baseFuncName) != 
notSupportComplexTypeAggFuncs.end() &&
+            exec::isRawInput(funcStep)) {
+          auto type = binder.tryResolveType(signature->argumentTypes()[0]);
+          if (type->isArray() || type->isMap() || type->isRow()) {
+            LOG_VALIDATION_MSG("Validation failed for function " + 
baseFuncName + " complex type is not supported.");
+            return false;
+          }
+        }
+
         resolved = true;
         break;
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to