This is an automated email from the ASF dual-hosted git repository.
yangzy pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 8a2e6790b [VL] Support array transform function (#5410)
8a2e6790b is described below
commit 8a2e6790badf29d0447effafe160aea60b89c31a
Author: Yang Zhang <[email protected]>
AuthorDate: Tue Apr 23 14:58:12 2024 +0800
[VL] Support array transform function (#5410)
---
.../gluten/backendsapi/velox/SparkPlanExecApiImpl.scala | 16 +++++++++++++++-
.../gluten/execution/ScalarFunctionsValidateSuite.scala | 10 ++++++++++
.../org/apache/gluten/backendsapi/SparkPlanExecApi.scala | 9 +++++++++
.../apache/gluten/expression/ExpressionConverter.scala | 13 +++++++++++++
.../apache/gluten/expression/ExpressionMappings.scala | 1 +
.../org/apache/gluten/expression/ExpressionNames.scala | 1 +
6 files changed, 49 insertions(+), 1 deletion(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
index 1c1750e59..9952147b9 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -37,7 +37,7 @@ import
org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, BloomFilterM
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayExists,
ArrayFilter, ArrayForAll, Ascending, Attribute, Cast, CreateNamedStruct,
ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue,
GetStructField, If, IsNaN, LambdaFunction, Literal, Murmur3Hash,
NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim,
TryEval, Uuid, VeloxBloomFilterMightContain}
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayExists,
ArrayFilter, ArrayForAll, ArrayTransform, Ascending, Attribute, Cast,
CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator,
GetArrayItem, GetMapValue, GetStructField, If, IsNaN, LambdaFunction, Literal,
Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit,
StringTrim, TryEval, Uuid, VeloxBloomFilterMightContain}
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
HLLAdapter, VeloxBloomFilterAggregate}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
@@ -231,6 +231,20 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
}
}
+ /** Transform array transform to Substrait. */
+ override def genArrayTransformTransformer(
+ substraitExprName: String,
+ argument: ExpressionTransformer,
+ function: ExpressionTransformer,
+ expr: ArrayTransform): ExpressionTransformer = {
+ expr.function match {
+ case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
+ throw new GlutenNotSupportException(
+ "transform on array with lambda using index argument is not
supported yet")
+ case _ => GenericExpressionTransformer(substraitExprName, Seq(argument,
function), expr)
+ }
+ }
+
/** Transform posexplode to Substrait. */
override def genPosExplodeTransformer(
substraitExprName: String,
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
index 93f82f752..b1f01537d 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
@@ -747,6 +747,16 @@ class ScalarFunctionsValidateSuite extends
FunctionsValidateTest {
}
}
+ test("test array transform") {
+ withTable("t") {
+ sql("create table t (arr ARRAY<INT>) using parquet")
+ sql("insert into t values(array(1, 2, 3, null))")
+ runQueryAndCompare("select transform(arr, x -> x + 1) from t") {
+ checkGlutenOperatorMatch[ProjectExecTransformer]
+ }
+ }
+ }
+
test("weekofyear") {
withTable("t") {
sql("create table t (dt date) using parquet")
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 12c17ad2d..74e03e329 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -270,6 +270,15 @@ trait SparkPlanExecApi {
throw new GlutenNotSupportException("any_match is not supported")
}
+ /** Transform array transform to Substrait. */
+ def genArrayTransformTransformer(
+ substraitExprName: String,
+ argument: ExpressionTransformer,
+ function: ExpressionTransformer,
+ expr: ArrayTransform): ExpressionTransformer = {
+ throw new GlutenNotSupportException("transform(on array) is not supported")
+ }
+
/** Transform inline to Substrait. */
def genInlineTransformer(
substraitExprName: String,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 60912132f..80c8c6348 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -592,6 +592,19 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
replaceWithExpressionTransformerInternal(f.function, attributeSeq,
expressionsMap),
f
)
+ case arrayTransform: ArrayTransform =>
+
BackendsApiManager.getSparkPlanExecApiInstance.genArrayTransformTransformer(
+ substraitExprName,
+ replaceWithExpressionTransformerInternal(
+ arrayTransform.argument,
+ attributeSeq,
+ expressionsMap),
+ replaceWithExpressionTransformerInternal(
+ arrayTransform.function,
+ attributeSeq,
+ expressionsMap),
+ arrayTransform
+ )
case tryEval @ TryEval(a: Add) =>
BackendsApiManager.getSparkPlanExecApiInstance.genTryAddTransformer(
substraitExprName,
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index c2bbb3345..6be5b0f9b 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -203,6 +203,7 @@ object ExpressionMappings {
Sig[Sha2](SHA2),
Sig[Crc32](CRC32),
// Array functions
+ Sig[ArrayTransform](TRANSFORM),
Sig[Size](SIZE),
Sig[Slice](SLICE),
Sig[Sequence](SEQUENCE),
diff --git
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index e57e75849..44384be72 100644
---
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -249,6 +249,7 @@ object ExpressionNames {
final val FILTER = "filter"
final val FORALL = "forall"
final val EXISTS = "exists"
+ final val TRANSFORM = "transform"
final val SHUFFLE = "shuffle"
// Map functions
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]