This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 37d09c109 [GLUTEN-4039][VL] Support array insert function for spark
3.4+ (#7123)
37d09c109 is described below
commit 37d09c1092a3916a28f23da9b27cc799a8885a5e
Author: Tengfei Huang <[email protected]>
AuthorDate: Fri Sep 6 11:55:30 2024 +0800
[GLUTEN-4039][VL] Support array insert function for spark 3.4+ (#7123)
---
.../execution/ScalarFunctionsValidateSuite.scala | 26 ++++++++++++++++++++++
.../gluten/expression/ExpressionConverter.scala | 8 +++++++
.../apache/gluten/expression/ExpressionNames.scala | 1 +
.../org/apache/gluten/sql/shims/SparkShims.scala | 4 ++++
.../gluten/sql/shims/spark34/Spark34Shims.scala | 8 ++++++-
.../gluten/sql/shims/spark35/Spark35Shims.scala | 8 ++++++-
6 files changed, 53 insertions(+), 2 deletions(-)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
index b8de30b1b..81da24f8e 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
@@ -1365,4 +1365,30 @@ abstract class ScalarFunctionsValidateSuite extends
FunctionsValidateSuite {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
+
+ testWithSpecifiedSparkVersion("array insert", Some("3.4")) {
+ withTempPath {
+ path =>
+ Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2),
Seq.empty, null)
+ .toDF("value")
+ .write
+ .parquet(path.getCanonicalPath)
+
+
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl")
+
+ Seq("true", "false").foreach {
+ legacyNegativeIndex =>
+ withSQLConf("spark.sql.legacy.negativeIndexInArrayInsert" ->
legacyNegativeIndex) {
+ runQueryAndCompare("""
+ |select
+ | array_insert(value, 1, 0),
array_insert(value, 10, 0),
+ | array_insert(value, -1, 0),
array_insert(value, -10, 0)
+ |from array_tbl
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[ProjectExecTransformer]
+ }
+ }
+ }
+ }
+ }
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 6f6e2cf12..606cbd96e 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -633,6 +633,14 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
replaceWithExpressionTransformer0(a.function, attributeSeq,
expressionsMap),
a
)
+ case arrayInsert if
arrayInsert.getClass.getSimpleName.equals("ArrayInsert") =>
+ // Since spark 3.4.0
+ val children =
SparkShimLoader.getSparkShims.extractExpressionArrayInsert(arrayInsert)
+ GenericExpressionTransformer(
+ substraitExprName,
+ children.map(replaceWithExpressionTransformer0(_, attributeSeq,
expressionsMap)),
+ arrayInsert
+ )
case s: Shuffle =>
GenericExpressionTransformer(
substraitExprName,
diff --git
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 96a615615..f198bb7e1 100644
---
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -272,6 +272,7 @@ object ExpressionNames {
final val SHUFFLE = "shuffle"
final val ZIP_WITH = "zip_with"
final val FLATTEN = "flatten"
+ final val ARRAY_INSERT = "array_insert"
// Map functions
final val CREATE_MAP = "map"
diff --git
a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index fa6ed18e9..7671f236c 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -266,4 +266,8 @@ trait SparkShims {
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}
+
+ def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] =
{
+ throw new UnsupportedOperationException("ArrayInsert not supported.")
+ }
}
diff --git
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index b277139e8..5e42f66ba 100644
---
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -81,7 +81,8 @@ class Spark34Shims extends SparkShims {
Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD),
Sig[RoundFloor](ExpressionNames.FLOOR),
Sig[RoundCeil](ExpressionNames.CEIL),
- Sig[Mask](ExpressionNames.MASK)
+ Sig[Mask](ExpressionNames.MASK),
+ Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT)
)
}
@@ -492,4 +493,9 @@ class Spark34Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}
+
+ override def extractExpressionArrayInsert(arrayInsert: Expression):
Seq[Expression] = {
+ val expr = arrayInsert.asInstanceOf[ArrayInsert]
+ Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr,
Literal(expr.legacyNegativeIndex))
+ }
}
diff --git
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
index 6474c74fe..ddb023b5a 100644
---
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
+++
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
@@ -81,7 +81,8 @@ class Spark35Shims extends SparkShims {
Sig[Mask](ExpressionNames.MASK),
Sig[TimestampAdd](ExpressionNames.TIMESTAMP_ADD),
Sig[RoundFloor](ExpressionNames.FLOOR),
- Sig[RoundCeil](ExpressionNames.CEIL)
+ Sig[RoundCeil](ExpressionNames.CEIL),
+ Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT)
)
}
@@ -517,4 +518,9 @@ class Spark35Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}
+
+ override def extractExpressionArrayInsert(arrayInsert: Expression):
Seq[Expression] = {
+ val expr = arrayInsert.asInstanceOf[ArrayInsert]
+ Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr,
Literal(expr.legacyNegativeIndex))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]