This is an automated email from the ASF dual-hosted git repository.
zhli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new e870de8ce [VL] Fall back collect_set, min and max when input is
complex type (#5934)
e870de8ce is described below
commit e870de8cea51d138d49e675e1031f090bfc5bf19
Author: Zhen Li <[email protected]>
AuthorDate: Fri May 31 16:55:53 2024 +0800
[VL] Fall back collect_set, min and max when input is complex type (#5934)
[VL] Fall back collect_set, min and max when input is complex type.
---
.../execution/VeloxAggregateFunctionsSuite.scala | 23 ++++++++++++++++++++++
.../substrait/SubstraitToVeloxPlanValidator.cc | 10 ++++++++++
2 files changed, 33 insertions(+)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 4f6f4eb22..ae6306cc0 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -22,7 +22,9 @@ import
org.apache.gluten.extension.columnar.validator.FallbackInjects
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSuite {
@@ -1112,6 +1114,27 @@ abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSu
}
}
}
+
+ test("complex type with null") {
+ val jsonStr =
"""{"txn":{"appId":"txnId","version":0,"lastUpdated":null}}"""
+ val jsonSchema = StructType(
+ Seq(
+ StructField(
+ "txn",
+ StructType(
+ Seq(
+ StructField("appId", StringType, true),
+ StructField("lastUpdated", LongType, true),
+ StructField("version", LongType, true))),
+ true)))
+ val df = spark.read.schema(jsonSchema).json(Seq(jsonStr).toDS)
+ df.select(collect_set(col("txn"))).collect
+
+ df.select(min(col("txn"))).collect
+
+ df.select(max(col("txn"))).collect
+
+ }
}
class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite
{
diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
index abb2bbc56..a3b46d7d0 100644
--- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
+++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
@@ -1045,6 +1045,16 @@ bool
SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait
LOG_VALIDATION_MSG("Validation failed for function " + funcName + "
resolve type in AggregateRel.");
return false;
}
+ static const std::unordered_set<std::string>
notSupportComplexTypeAggFuncs = {"set_agg", "min", "max"};
+ if (notSupportComplexTypeAggFuncs.find(baseFuncName) !=
notSupportComplexTypeAggFuncs.end() &&
+ exec::isRawInput(funcStep)) {
+ auto type = binder.tryResolveType(signature->argumentTypes()[0]);
+ if (type->isArray() || type->isMap() || type->isRow()) {
+ LOG_VALIDATION_MSG("Validation failed for function " +
baseFuncName + " complex type is not supported.");
+ return false;
+ }
+ }
+
resolved = true;
break;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]