This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new f42c029fac5c [SPARK-41049][SQL][FOLLOW-UP] Mark map related
expressions as stateful expressions
f42c029fac5c is described below
commit f42c029fac5c8015d80ad957fae325243a2ed30d
Author: Rui Wang <[email protected]>
AuthorDate: Mon May 27 22:40:13 2024 -0700
[SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as stateful
expressions
MapConcat contains a state so it is stateful:
```
private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType,
dataType.valueType)
```
Similarly `MapFromEntries, CreateMap, MapFromArrays, StringToMap, and
TransformKeys` need the same change.
Stateful expression should be marked as stateful.
No
N/A
No
Closes #46721 from amaliujia/statefulexpr.
Authored-by: Rui Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit af1ac1edc2a96c9aba949e3100ddae37b6f0e5b2)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/expressions/collectionOperations.scala | 3 +++
.../spark/sql/catalyst/expressions/complexTypeCreator.scala | 6 ++++++
.../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 2 ++
.../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++++-
4 files changed, 20 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 3ddbe38fdedf..45896382af67 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -712,6 +712,7 @@ case class MapConcat(children: Seq[Expression])
}
}
+ override def stateful: Boolean = true
override def nullable: Boolean = children.exists(_.nullable)
private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType,
dataType.valueType)
@@ -827,6 +828,8 @@ case class MapFromEntries(child: Expression)
override def nullable: Boolean = child.nullable || nullEntries
+ override def stateful: Boolean = true
+
@transient override lazy val dataType: MapType = dataTypeDetails.get._1
override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index c95a0987330d..1b6f86984be7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -242,6 +242,8 @@ case class CreateMap(children: Seq[Expression],
useStringTypeWhenEmpty: Boolean)
private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType,
dataType.valueType)
+ override def stateful: Boolean = true
+
override def eval(input: InternalRow): Any = {
var i = 0
while (i < keys.length) {
@@ -317,6 +319,8 @@ case class MapFromArrays(left: Expression, right:
Expression)
valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
}
+ override def stateful: Boolean = true
+
private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType,
dataType.valueType)
override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
@@ -563,6 +567,8 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
this(child, Literal(","), Literal(":"))
}
+ override def stateful: Boolean = true
+
override def first: Expression = text
override def second: Expression = pairDelim
override def third: Expression = keyValueDelim
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index fec1df108bcc..5b10b401af98 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -918,6 +918,8 @@ case class TransformKeys(
override def dataType: MapType = MapType(function.dataType, valueType,
valueContainsNull)
+ override def stateful: Boolean = true
+
override def checkInputDataTypes(): TypeCheckResult = {
TypeUtils.checkForMapKeyType(function.dataType)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index c586da6105fd..260ecaa5ece1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener,
SparkListenerJobEnd}
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap,
AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF,
ScalarSubquery, Uuid}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap,
AttributeReference, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal,
PythonUDF, ScalarSubquery, Uuid}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode,
LocalRelation, LogicalPlan, OneRowRelation, Statistics}
@@ -3636,6 +3636,14 @@ class DataFrameSuite extends QueryTest
assert(row.getInt(0).toString == row.getString(2))
assert(row.getInt(0).toString == row.getString(3))
}
+
+ val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value"))))
+ val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback
+ df.select(v3, v3, v4, v4).collect().foreach { row =>
+ assert(row.getMap(0).toString() == row.getMap(1).toString())
+ assert(row.getString(2) == s"{key -> ${row.getMap(0).get("key").get}}")
+ assert(row.getString(3) == s"{key -> ${row.getMap(0).get("key").get}}")
+ }
}
test("SPARK-41219: IntegralDivide use decimal(1, 0) to represent 0") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]