This is an automated email from the ASF dual-hosted git repository.
yuanzhou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 8e55381e18 [GLUTEN-11088] Fix GlutenDataFrameFunctionsSuite in
Spark-4.0 (#11195)
8e55381e18 is described below
commit 8e55381e1806b2045f2a0993f8d861e76fca1135
Author: Mingliang Zhu <[email protected]>
AuthorDate: Fri Nov 28 18:24:57 2025 +0800
[GLUTEN-11088] Fix GlutenDataFrameFunctionsSuite in Spark-4.0 (#11195)
https://github.com/apache/spark/blob/29434ea766b0fc3c3bf6eaadb43a8f931133649e/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L2928-L2937
Vanilla spark throw SparkRuntimeException, gluten throw SparkException.
This patch modified the tests to adapt with Gluten code
---
.../gluten/utils/velox/VeloxTestSettings.scala | 2 +-
.../spark/sql/GlutenDataFrameFunctionsSuite.scala | 229 +++++++++++++++++++++
2 files changed, 230 insertions(+), 1 deletion(-)
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 74ace889e9..5fa43d50f7 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -758,7 +758,7 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("aggregate function - array for non-primitive type")
// Rewrite this test because Velox sorts rows by key for primitive data
types, which disrupts the original row sequence.
.exclude("map_zip_with function - map of primitive types")
- // TODO: fix in Spark-4.0
+ // Vanilla spark throw SparkRuntimeException, gluten throw SparkException.
.exclude("map_concat function")
.exclude("transform keys function - primitive data types")
enableSuite[GlutenDataFrameHintSuite]
diff --git
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
index 2b0b40790a..49f6052b20 100644
---
a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
+++
b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala
@@ -16,7 +16,10 @@
*/
package org.apache.spark.sql
+import org.apache.spark.SparkException
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, MapType, StringType,
StructField, StructType}
class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with
GlutenSQLTestsTrait {
import testImplicits._
@@ -49,4 +52,230 @@ class GlutenDataFrameFunctionsSuite extends
DataFrameFunctionsSuite with GlutenS
false
)
}
+
+ testGluten("map_concat function") {
+ val df1 = Seq(
+ (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 4 -> 400)),
+ (Map[Int, Int](1 -> 100, 2 -> 200), Map[Int, Int](3 -> 300, 1 -> 400)),
+ (null, Map[Int, Int](3 -> 300, 4 -> 400))
+ ).toDF("map1", "map2")
+
+ val expected1a = Seq(
+ Row(Map(1 -> 100, 2 -> 200, 3 -> 300, 4 -> 400)),
+ Row(Map(1 -> 400, 2 -> 200, 3 -> 300)),
+ Row(null)
+ )
+
+ intercept[SparkException](df1.selectExpr("map_concat(map1,
map2)").collect())
+ intercept[SparkException](df1.select(map_concat($"map1",
$"map2")).collect())
+ withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key ->
SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
+ checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
+ checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
+ }
+
+ val expected1b = Seq(
+ Row(Map(1 -> 100, 2 -> 200)),
+ Row(Map(1 -> 100, 2 -> 200)),
+ Row(null)
+ )
+
+ checkAnswer(df1.selectExpr("map_concat(map1)"), expected1b)
+ checkAnswer(df1.select(map_concat($"map1")), expected1b)
+
+ val df2 = Seq(
+ (
+ Map[Array[Int], Int](Array(1) -> 100, Array(2) -> 200),
+ Map[String, Int]("3" -> 300, "4" -> 400)
+ )
+ ).toDF("map1", "map2")
+
+ val expected2 = Seq(Row(Map()))
+
+ checkAnswer(df2.selectExpr("map_concat()"), expected2)
+ checkAnswer(df2.select(map_concat()), expected2)
+
+ val df3 = {
+ val schema = StructType(
+ StructField("map1", MapType(StringType, IntegerType, true), false) ::
+ StructField("map2", MapType(StringType, IntegerType, false), false)
:: Nil
+ )
+ val data = Seq(
+ Row(Map[String, Any]("a" -> 1, "b" -> null), Map[String, Any]("c" ->
3, "d" -> 4)),
+ Row(Map[String, Any]("a" -> 1, "b" -> 2), Map[String, Any]("c" -> 3,
"d" -> 4))
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ }
+
+ val expected3 = Seq(
+ Row(Map[String, Any]("a" -> 1, "b" -> null, "c" -> 3, "d" -> 4)),
+ Row(Map[String, Any]("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4))
+ )
+
+ checkAnswer(df3.selectExpr("map_concat(map1, map2)"), expected3)
+ checkAnswer(df3.select(map_concat($"map1", $"map2")), expected3)
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ df2.selectExpr("map_concat(map1, map2)").collect()
+ },
+ condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
+ sqlState = None,
+ parameters = Map(
+ "sqlExpr" -> "\"map_concat(map1, map2)\"",
+ "dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
+ "functionName" -> "`map_concat`"),
+ context = ExpectedContext(fragment = "map_concat(map1, map2)", start =
0, stop = 21)
+ )
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ df2.select(map_concat($"map1", $"map2")).collect()
+ },
+ condition = "DATATYPE_MISMATCH.DATA_DIFF_TYPES",
+ sqlState = None,
+ parameters = Map(
+ "sqlExpr" -> "\"map_concat(map1, map2)\"",
+ "dataType" -> "(\"MAP<ARRAY<INT>, INT>\" or \"MAP<STRING, INT>\")",
+ "functionName" -> "`map_concat`"),
+ context =
+ ExpectedContext(fragment = "map_concat", callSitePattern =
getCurrentClassCallSitePattern)
+ )
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ df2.selectExpr("map_concat(map1, 12)").collect()
+ },
+ condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
+ sqlState = None,
+ parameters = Map(
+ "sqlExpr" -> "\"map_concat(map1, 12)\"",
+ "dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
+ "functionName" -> "`map_concat`"),
+ context = ExpectedContext(fragment = "map_concat(map1, 12)", start = 0,
stop = 19)
+ )
+
+ checkError(
+ exception = intercept[AnalysisException] {
+ df2.select(map_concat($"map1", lit(12))).collect()
+ },
+ condition = "DATATYPE_MISMATCH.MAP_CONCAT_DIFF_TYPES",
+ sqlState = None,
+ parameters = Map(
+ "sqlExpr" -> "\"map_concat(map1, 12)\"",
+ "dataType" -> "[\"MAP<ARRAY<INT>, INT>\", \"INT\"]",
+ "functionName" -> "`map_concat`"),
+ context =
+ ExpectedContext(fragment = "map_concat", callSitePattern =
getCurrentClassCallSitePattern)
+ )
+ }
+
+ testGluten("transform keys function - primitive data types") {
+ val dfExample1 = Seq(
+ Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
+ ).toDF("i")
+
+ val dfExample2 = Seq(
+ Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
+ ).toDF("j")
+
+ val dfExample3 = Seq(
+ Map[Int, Boolean](25 -> true, 26 -> false)
+ ).toDF("x")
+
+ val dfExample4 = Seq(
+ Map[Array[Int], Boolean](Array(1, 2) -> false)
+ ).toDF("y")
+
+ def testMapOfPrimitiveTypesCombination(): Unit = {
+ checkAnswer(
+ dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
+ Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
+
+ checkAnswer(
+ dfExample1.select(transform_keys(col("i"), (k, v) => k + v)),
+ Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
+
+ checkAnswer(
+ dfExample2.selectExpr(
+ "transform_keys(j, " +
+ "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two',
'three'))[k])"),
+ Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))
+ )
+
+ checkAnswer(
+ dfExample2.select(
+ transform_keys(
+ col("j"),
+ (k, v) =>
+ element_at(
+ map_from_arrays(
+ array(lit(1), lit(2), lit(3)),
+ array(lit("one"), lit("two"), lit("three"))
+ ),
+ k
+ )
+ )
+ ),
+ Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))
+ )
+
+ checkAnswer(
+ dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS
BIGINT) + k)"),
+ Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
+
+ checkAnswer(
+ dfExample2.select(transform_keys(col("j"), (k, v) => (v *
2).cast("bigint") + k)),
+ Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
+
+ checkAnswer(
+ dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
+ Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
+
+ checkAnswer(
+ dfExample2.select(transform_keys(col("j"), (k, v) => k + v)),
+ Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
+
+ intercept[SparkException] {
+ dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR
v)").collect()
+ }
+ intercept[SparkException] {
+ dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 ||
v)).collect()
+ }
+ withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key ->
SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
+ checkAnswer(
+ dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR
v)"),
+ Seq(Row(Map(true -> true, true -> false))))
+
+ checkAnswer(
+ dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 ||
v)),
+ Seq(Row(Map(true -> true, true -> false))))
+ }
+
+ checkAnswer(
+ dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 *
k))"),
+ Seq(Row(Map(50 -> true, 78 -> false))))
+
+ checkAnswer(
+ dfExample3.select(transform_keys(col("x"), (k, v) => when(v, k *
2).otherwise(k * 3))),
+ Seq(Row(Map(50 -> true, 78 -> false))))
+
+ checkAnswer(
+ dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k,
3) AND v)"),
+ Seq(Row(Map(false -> false))))
+
+ checkAnswer(
+ dfExample4.select(transform_keys(col("y"), (k, v) => array_contains(k,
lit(3)) && v)),
+ Seq(Row(Map(false -> false))))
+ }
+
+ // Test with local relation, the Project will be evaluated without codegen
+ testMapOfPrimitiveTypesCombination()
+ dfExample1.cache()
+ dfExample2.cache()
+ dfExample3.cache()
+ dfExample4.cache()
+ // Test with cached relation, the Project will be evaluated with codegen
+ testMapOfPrimitiveTypesCombination()
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]