This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4321604fe0b [SPARK-41049][SQL] Revisit stateful expression handling
4321604fe0b is described below
commit 4321604fe0b8f54b7cc5d1372f9b89a91fbad8b2
Author: Wenchen Fan <[email protected]>
AuthorDate: Sat Dec 31 10:05:41 2022 +0800
[SPARK-41049][SQL] Revisit stateful expression handling
### What changes were proposed in this pull request?
Spark has a `Stateful` trait for stateful expressions. The basic idea is to
have fresh copies of stateful expressions before evaluating them. This is to
avoid issues caused by the flexible DataFrame APIs:
1. A single expression instance may appear more than once in the expression
tree. We have to replace it with fresh copies to avoid sharing states.
2. An expression tree can be evaluated by multiple callers at the same
time. We have to use fresh copies before expression evaluation to avoid sharing
states.
However, the handling of stateful expression has several problems. This PR
fixes all of them:
1. We should use fresh copies with codegen as well. If the root expression
extends `CodegenFallback`, then the expression tree will be evaluated using the
interpreted mode, even with the codegen code path.
2. The fresh copies will be dropped if the stateful expression is deeply
nested (3 layers).
3. `InterpretedSafeProjection` never implemented initialize() for
initializing Nondeterministic expressions.
4. `ConvertToLocalRelation` called a `InterpretedMutableProjection`
constructor which did not implement the existing Stateful-copying logic. I
fixed this by moving that logic out of a factory method and into class's main
constructor, guaranteeing that it will always run.
5. Stateful expression is not always nondeterministic, e.g. `ScalaUDF`. I
removed the `Stateful` trait and added a `def stateful: Boolean` function in
`Expression`.
### Why are the changes needed?
Fix stateful expression handling
### Does this PR introduce _any_ user-facing change?
Yes, now we never share states for stateful expressions, which may produce
wrong result.
### How was this patch tested?
new tests
Closes #39248 from cloud-fan/expr.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/Expression.scala | 79 ++++++++++++++--------
.../expressions/ExpressionsEvaluator.scala | 45 ++++++++++++
.../expressions/InterpretedMutableProjection.scala | 16 +----
.../expressions/InterpretedSafeProjection.scala | 21 +++---
.../expressions/InterpretedUnsafeProjection.scala | 14 +---
.../expressions/MonotonicallyIncreasingID.scala | 10 ++-
.../sql/catalyst/expressions/Projection.scala | 12 ++--
.../spark/sql/catalyst/expressions/ScalaUDF.scala | 5 ++
.../expressions/codegen/CodeGenerator.scala | 7 +-
.../expressions/collectionOperations.scala | 8 +--
.../spark/sql/catalyst/expressions/misc.scala | 6 +-
.../spark/sql/catalyst/expressions/package.scala | 10 +--
.../sql/catalyst/expressions/predicates.scala | 17 +----
.../catalyst/expressions/randomExpressions.scala | 8 +--
.../apache/spark/sql/catalyst/trees/TreeNode.scala | 3 +-
.../expressions/CollectionExpressionsSuite.scala | 6 --
.../expressions/MiscExpressionsSuite.scala | 6 --
.../spark/sql/catalyst/trees/TreeNodeSuite.scala | 44 ++++++++++++
.../org/apache/spark/sql/DataFrameSuite.scala | 11 +++
19 files changed, 203 insertions(+), 125 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 0ddf1a7df19..de0e90285f5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -26,7 +26,7 @@ import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike,
QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin,
LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE,
TreePattern}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
@@ -47,8 +47,6 @@ import org.apache.spark.sql.types._
* There are a few important traits or abstract classes:
*
* - [[Nondeterministic]]: an expression that is not deterministic.
- * - [[Stateful]]: an expression that contains mutable state. For example,
MonotonicallyIncreasingID
- * and Rand. A stateful expression is always non-deterministic.
* - [[Unevaluable]]: an expression that is not supposed to be evaluated.
* - [[CodegenFallback]]: an expression that does not have code gen
implemented and falls back to
* interpreted mode.
@@ -127,6 +125,54 @@ abstract class Expression extends TreeNode[Expression] {
def references: AttributeSet = _references
+ /**
+ * Returns true if the expression contains mutable state.
+ *
+ * A stateful expression should never be evaluated multiple times for a
single row. This should
+ * only be a problem for interpreted execution. This can be prevented by
creating fresh copies
+ * of the stateful expression before execution. A common example to trigger
this issue:
+ * {{{
+ * val rand = functions.rand()
+ * df.select(rand, rand) // These 2 rand should not share a state.
+ * }}}
+ */
+ def stateful: Boolean = false
+
+ /**
+ * Returns a copy of this expression where all stateful expressions are
replaced with fresh
+ * uninitialized copies. If the expression contains no stateful expressions
then the original
+ * expression is returned.
+ */
+ def freshCopyIfContainsStatefulExpression(): Expression = {
+ val childrenIndexedSeq: IndexedSeq[Expression] = children match {
+ case types: IndexedSeq[Expression] => types
+ case other => other.toIndexedSeq
+ }
+ val newChildren =
childrenIndexedSeq.map(_.freshCopyIfContainsStatefulExpression())
+ // A more efficient version of `children.zip(newChildren).exists(_ ne _)`
+ val anyChildChanged = {
+ val size = newChildren.length
+ var i = 0
+ var res: Boolean = false
+ while (!res && i < size) {
+ res |= (childrenIndexedSeq(i) ne newChildren(i))
+ i += 1
+ }
+ res
+ }
+ // If the children contain stateful expressions and get copied, or this
expression is stateful,
+ // copy this expression with the new children.
+ if (anyChildChanged || stateful) {
+ CurrentOrigin.withOrigin(origin) {
+ val res = withNewChildrenInternal(newChildren)
+ res.copyTagsFrom(this)
+ res
+ }
+ } else {
+ this
+ }
+ }
+
/** Returns the result of evaluating this expression on a given input Row */
def eval(input: InternalRow = null): Any
@@ -472,33 +518,6 @@ trait ConditionalExpression extends Expression {
def branchGroups: Seq[Seq[Expression]]
}
-/**
- * An expression that contains mutable state. A stateful expression is always
non-deterministic
- * because the results it produces during evaluation are not only dependent on
the given input
- * but also on its internal state.
- *
- * The state of the expressions is generally not exposed in the parameter list
and this makes
- * comparing stateful expressions problematic because similar stateful
expressions (with the same
- * parameter list) but with different internal state will be considered equal.
This is especially
- * problematic during tree transformations. In order to counter this the
`fastEquals` method for
- * stateful expressions only returns `true` for the same reference.
- *
- * A stateful expression should never be evaluated multiple times for a single
row. This should
- * only be a problem for interpreted execution. This can be prevented by
creating fresh copies
- * of the stateful expression before execution, these can be made using the
`freshCopy` function.
- */
-trait Stateful extends Nondeterministic {
- /**
- * Return a fresh uninitialized copy of the stateful expression.
- */
- def freshCopy(): Stateful
-
- /**
- * Only the same reference is considered equal.
- */
- override def fastEquals(other: TreeNode[_]): Boolean = this eq other
-}
-
/**
* A leaf expression, i.e. one without any child expressions.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
new file mode 100644
index 00000000000..dcbc6926cd3
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionsEvaluator.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.internal.SQLConf
+
+// A helper class to evaluate expressions.
+trait ExpressionsEvaluator {
+ protected lazy val runtime =
+ new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
+
+ protected def prepareExpressions(
+ exprs: Seq[Expression],
+ subExprEliminationEnabled: Boolean): Seq[Expression] = {
+ // We need to make sure that we do not reuse stateful expressions.
+ val cleanedExpressions =
exprs.map(_.freshCopyIfContainsStatefulExpression())
+ if (subExprEliminationEnabled) {
+ runtime.proxyExpressions(cleanedExpressions)
+ } else {
+ cleanedExpressions
+ }
+ }
+
+ /**
+ * Initializes internal states given the current partition index.
+ * This is used by nondeterministic expressions to set initial states.
+ * The default implementation does nothing.
+ */
+ def initialize(partitionIndex: Int): Unit = {}
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 5d95ac71be8..682604b9bf7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -36,18 +36,12 @@ class InterpretedMutableProjection(expressions:
Seq[Expression]) extends Mutable
this(bindReferences(expressions, inputSchema))
private[this] val subExprEliminationEnabled =
SQLConf.get.subexpressionEliminationEnabled
- private[this] lazy val runtime =
- new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
- private[this] val exprs = if (subExprEliminationEnabled) {
- runtime.proxyExpressions(expressions)
- } else {
- expressions
- }
+ private[this] val exprs = prepareExpressions(expressions,
subExprEliminationEnabled)
private[this] val buffer = new Array[Any](expressions.size)
override def initialize(partitionIndex: Int): Unit = {
- expressions.foreach(_.foreach {
+ exprs.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
@@ -117,10 +111,6 @@ object InterpretedMutableProjection {
* Returns a [[MutableProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): MutableProjection = {
- // We need to make sure that we do not reuse stateful expressions.
- val cleanedExpressions = exprs.map(_.transform {
- case s: Stateful => s.freshCopy()
- })
- new InterpretedMutableProjection(cleanedExpressions)
+ new InterpretedMutableProjection(exprs)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala
index 0e71892db66..84263d97f5d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedSafeProjection.scala
@@ -32,13 +32,7 @@ import org.apache.spark.sql.types._
class InterpretedSafeProjection(expressions: Seq[Expression]) extends
Projection {
private[this] val subExprEliminationEnabled =
SQLConf.get.subexpressionEliminationEnabled
- private[this] lazy val runtime =
- new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
- private[this] val exprs = if (subExprEliminationEnabled) {
- runtime.proxyExpressions(expressions)
- } else {
- expressions
- }
+ private[this] val exprs = prepareExpressions(expressions,
subExprEliminationEnabled)
private[this] val mutableRow = new
SpecificInternalRow(expressions.map(_.dataType))
@@ -106,6 +100,13 @@ class InterpretedSafeProjection(expressions:
Seq[Expression]) extends Projection
case _ => identity
}
+ override def initialize(partitionIndex: Int): Unit = {
+ expressions.foreach(_.foreach {
+ case n: Nondeterministic => n.initialize(partitionIndex)
+ case _ =>
+ })
+ }
+
override def apply(row: InternalRow): InternalRow = {
if (subExprEliminationEnabled) {
runtime.setInput(row)
@@ -130,10 +131,6 @@ object InterpretedSafeProjection {
* Returns an [[SafeProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): Projection = {
- // We need to make sure that we do not reuse stateful expressions.
- val cleanedExpressions = exprs.map(_.transform {
- case s: Stateful => s.freshCopy()
- })
- new InterpretedSafeProjection(cleanedExpressions)
+ new InterpretedSafeProjection(exprs)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 9a9a41b1f18..d87c0c006cf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -35,13 +35,7 @@ class InterpretedUnsafeProjection(expressions:
Array[Expression]) extends Unsafe
import InterpretedUnsafeProjection._
private[this] val subExprEliminationEnabled =
SQLConf.get.subexpressionEliminationEnabled
- private[this] lazy val runtime =
- new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
- private[this] val exprs = if (subExprEliminationEnabled) {
- runtime.proxyExpressions(expressions)
- } else {
- expressions.toSeq
- }
+ private[this] val exprs = prepareExpressions(expressions,
subExprEliminationEnabled)
/** Number of (top level) fields in the resulting row. */
private[this] val numFields = expressions.length
@@ -106,11 +100,7 @@ object InterpretedUnsafeProjection {
* Returns an [[UnsafeProjection]] for given sequence of bound Expressions.
*/
def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
- // We need to make sure that we do not reuse stateful expressions.
- val cleanedExpressions = exprs.map(_.transform {
- case s: Stateful => s.freshCopy()
- })
- new InterpretedUnsafeProjection(cleanedExpressions.toArray)
+ new InterpretedUnsafeProjection(exprs.toArray)
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index ecf254f65f5..8dc1ba4846a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -48,7 +48,7 @@ import org.apache.spark.sql.types.{DataType, LongType}
""",
since = "1.4.0",
group = "misc_funcs")
-case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
+case class MonotonicallyIncreasingID() extends LeafExpression with
Nondeterministic {
/**
* Record ID within each partition. By being transient, count's value is
reset to 0 every time
@@ -58,11 +58,17 @@ case class MonotonicallyIncreasingID() extends
LeafExpression with Stateful {
@transient private[this] var partitionMask: Long = _
+ override def stateful: Boolean = true
+
override protected def initializeInternal(partitionIndex: Int): Unit = {
count = 0L
partitionMask = partitionIndex.toLong << 33
}
+ override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression = {
+ MonotonicallyIncreasingID()
+ }
+
override def nullable: Boolean = false
override def dataType: DataType = LongType
@@ -88,6 +94,4 @@ case class MonotonicallyIncreasingID() extends LeafExpression
with Stateful {
override def nodeName: String = "monotonically_increasing_id"
override def sql: String = s"$prettyName()"
-
- override def freshCopy(): MonotonicallyIncreasingID =
MonotonicallyIncreasingID()
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index b4a85e3e50b..20969fa584a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -33,16 +33,20 @@ class InterpretedProjection(expressions: Seq[Expression])
extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(bindReferences(expressions, inputSchema))
+ // null check is required for when Kryo invokes the no-arg constructor.
+ protected val exprArray = if (expressions != null) {
+ prepareExpressions(expressions, subExprEliminationEnabled = false).toArray
+ } else {
+ null
+ }
+
override def initialize(partitionIndex: Int): Unit = {
- expressions.foreach(_.foreach {
+ exprArray.foreach(_.foreach {
case n: Nondeterministic => n.initialize(partitionIndex)
case _ =>
})
}
- // null check is required for when Kryo invokes the no-arg constructor.
- protected val exprArray = if (expressions != null) expressions.toArray else
null
-
def apply(input: InternalRow): InternalRow = {
val outputArray = new Array[Any](exprArray.length)
var i = 0
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index f8ff5f583f6..137a8976a40 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -58,6 +58,11 @@ case class ScalaUDF(
override lazy val deterministic: Boolean = udfDeterministic &&
children.forall(_.deterministic)
+ // `ScalaUDF` uses `ExpressionEncoder` to convert the function result to
Catalyst internal format.
+ // `ExpressionEncoder` is stateful as it reuses the `UnsafeRow` instance,
thus `ScalaUDF` is
+ // stateful as well.
+ override def stateful: Boolean = true
+
final override val nodePatterns: Seq[TreePattern] = Seq(SCALA_UDF)
override def toString: String = s"$name(${children.mkString(", ")})"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 61f888f17b1..12103ceef6e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1271,8 +1271,11 @@ class CodegenContext extends Logging {
def generateExpressions(
expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[ExprCode] = {
- if (doSubexpressionElimination) subexpressionElimination(expressions)
- expressions.map(e => e.genCode(this))
+ // We need to make sure that we do not reuse stateful expressions. This is
needed for codegen
+ // as well because some expressions may implement `CodegenFallback`.
+ val cleanedExpressions =
expressions.map(_.freshCopyIfContainsStatefulExpression())
+ if (doSubexpressionElimination)
subexpressionElimination(cleanedExpressions)
+ cleanedExpressions.map(e => e.genCode(this))
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 229987fc0c8..22584a64f7d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1125,11 +1125,13 @@ case class SortArray(base: Expression, ascendingOrder:
Expression)
""",
group = "array_funcs",
since = "2.4.0")
-case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
- extends UnaryExpression with ExpectsInputTypes with Stateful with
ExpressionWithRandomSeed {
+case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends
UnaryExpression
+ with ExpectsInputTypes with Nondeterministic with ExpressionWithRandomSeed {
def this(child: Expression) = this(child, None)
+ override def stateful: Boolean = true
+
override def seedExpression: Expression =
randomSeed.map(Literal.apply).getOrElse(UnresolvedSeed)
override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
@@ -1195,8 +1197,6 @@ case class Shuffle(child: Expression, randomSeed:
Option[Long] = None)
""".stripMargin
}
- override def freshCopy(): Shuffle = Shuffle(child, randomSeed)
-
override def withNewChildInternal(newChild: Expression): Shuffle =
copy(child = newChild)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index eb21bd555db..bf9dd700dfa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -201,7 +201,7 @@ case class CurrentCatalog() extends LeafExpression with
Unevaluable {
since = "2.3.0",
group = "misc_funcs")
// scalastyle:on line.size.limit
-case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with
Stateful
+case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with
Nondeterministic
with ExpressionWithRandomSeed {
def this() = this(None)
@@ -216,6 +216,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends
LeafExpression with Sta
override def dataType: DataType = StringType
+ override def stateful: Boolean = true
+
@transient private[this] var randomGenerator: RandomUUIDGenerator = _
override protected def initializeInternal(partitionIndex: Int): Unit =
@@ -235,8 +237,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends
LeafExpression with Sta
ev.copy(code = code"final UTF8String ${ev.value} =
$randomGen.getNextUUIDUTF8String();",
isNull = FalseLiteral)
}
-
- override def freshCopy(): Uuid = Uuid(randomSeed)
}
// scalastyle:off line.size.limit
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index ededac3d917..44813ac7b61 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -68,15 +68,7 @@ package object expressions {
* column of the new row. If the schema of the input row is specified, then
the given expression
* will be bound to that schema.
*/
- abstract class Projection extends (InternalRow => InternalRow) {
-
- /**
- * Initializes internal states given the current partition index.
- * This is used by nondeterministic expressions to set initial states.
- * The default implementation does nothing.
- */
- def initialize(partitionIndex: Int): Unit = {}
- }
+ abstract class Projection extends (InternalRow => InternalRow) with
ExpressionsEvaluator
/**
* An identity projection. This returns the input row.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f51c9fd5ef3..4e4ac6ee492 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -36,26 +36,13 @@ import org.apache.spark.sql.types._
/**
* A base class for generated/interpreted predicate
*/
-abstract class BasePredicate {
+abstract class BasePredicate extends ExpressionsEvaluator {
def eval(r: InternalRow): Boolean
-
- /**
- * Initializes internal states given the current partition index.
- * This is used by nondeterministic expressions to set initial states.
- * The default implementation does nothing.
- */
- def initialize(partitionIndex: Int): Unit = {}
}
case class InterpretedPredicate(expression: Expression) extends BasePredicate {
private[this] val subExprEliminationEnabled =
SQLConf.get.subexpressionEliminationEnabled
- private[this] lazy val runtime =
- new
SubExprEvaluationRuntime(SQLConf.get.subexpressionEliminationCacheMaxEntries)
- private[this] val expr = if (subExprEliminationEnabled) {
- runtime.proxyExpressions(Seq(expression)).head
- } else {
- expression
- }
+ private[this] val expr = prepareExpressions(Seq(expression),
subExprEliminationEnabled).head
override def eval(r: InternalRow): Boolean = {
if (subExprEliminationEnabled) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index e2eb7fb1643..db78415a0cc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.random.XORShiftRandom
*
* Since this expression is stateful, it cannot be a case object.
*/
-abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
+abstract class RDG extends UnaryExpression with ExpectsInputTypes with
Nondeterministic
with ExpressionWithRandomSeed {
/**
* Record ID within each partition. By being transient, the Random Number
Generator is
@@ -40,6 +40,8 @@ abstract class RDG extends UnaryExpression with
ExpectsInputTypes with Stateful
*/
@transient protected var rng: XORShiftRandom = _
+ override def stateful: Boolean = true
+
override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}
@@ -108,8 +110,6 @@ case class Rand(child: Expression, hideSeed: Boolean =
false) extends RDG {
isNull = FalseLiteral)
}
- override def freshCopy(): Rand = Rand(child, hideSeed)
-
override def flatArguments: Iterator[Any] = Iterator(child)
override def sql: String = {
s"rand(${if (hideSeed) "" else child.sql})"
@@ -161,8 +161,6 @@ case class Randn(child: Expression, hideSeed: Boolean =
false) extends RDG {
isNull = FalseLiteral)
}
- override def freshCopy(): Randn = Randn(child, hideSeed)
-
override def flatArguments: Iterator[Any] = Iterator(child)
override def sql: String = {
s"randn(${if (hideSeed) "" else child.sql})"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index a7573fc1bd9..9510aa4d9e7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -1113,7 +1113,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
extends Product with Tre
trait LeafLike[T <: TreeNode[T]] { self: TreeNode[T] =>
override final def children: Seq[T] = Nil
override final def mapChildren(f: T => T): T = this.asInstanceOf[T]
- override final def withNewChildrenInternal(newChildren: IndexedSeq[T]): T =
this.asInstanceOf[T]
+ // Stateful expressions should override this method to return a new instance.
+ override def withNewChildrenInternal(newChildren: IndexedSeq[T]): T =
this.asInstanceOf[T]
}
trait UnaryLike[T <: TreeNode[T]] { self: TreeNode[T] =>
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index f6c529ec4ce..32b3840760f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -2103,12 +2103,6 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
evaluateWithMutableProjection(Shuffle(ai0, seed2)))
assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !==
evaluateWithUnsafeProjection(Shuffle(ai0, seed2)))
-
- val shuffle = Shuffle(ai0, seed1)
- assert(shuffle.fastEquals(shuffle))
- assert(!shuffle.fastEquals(Shuffle(ai0, seed1)))
- assert(!shuffle.fastEquals(shuffle.freshCopy()))
- assert(!shuffle.fastEquals(Shuffle(ai0, seed2)))
}
test("Array Except") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
index 15a0695943b..d449de3defb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala
@@ -70,12 +70,6 @@ class MiscExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
evaluateWithMutableProjection(Uuid(seed2)))
assert(evaluateWithUnsafeProjection(Uuid(seed1)) !==
evaluateWithUnsafeProjection(Uuid(seed2)))
-
- val uuid = Uuid(seed1)
- assert(uuid.fastEquals(uuid))
- assert(!uuid.fastEquals(Uuid(seed1)))
- assert(!uuid.fastEquals(uuid.freshCopy()))
- assert(!uuid.fastEquals(Uuid(seed2)))
}
test("PrintToStderr") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index bef88a7c0a3..286d3dddae6 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -820,6 +820,50 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
assert(leaf.child.eq(leafCloned.asInstanceOf[FakeLeafPlan].child))
}
+ test("Expression.freshCopyIfContainsStatefulExpression()") {
+ val tag = TreeNodeTag[String]("test")
+
+ def makeExprWithPositionAndTag(block: => Expression): Expression = {
+ CurrentOrigin.setPosition(1, 1)
+ val expr = block
+ CurrentOrigin.reset()
+ expr.setTagValue(tag, "tagValue")
+ expr
+ }
+
+ // Test generic assertions which should always hold for any value returned
+ // from freshCopyIfContainsStatefulExpression()
+ def genericAssertions(before: Expression, after: Expression): Unit = {
+ assert(before == after)
+ assert(before.origin == after.origin)
+ assert(before.getTagValue(tag) == after.getTagValue(tag))
+ }
+
+ // Doesn't transform for non-stateful expressions:
+ val onePlusOneBefore = makeExprWithPositionAndTag(Add(Literal(1),
Literal(1)))
+ val onePlusOneAfter =
onePlusOneBefore.freshCopyIfContainsStatefulExpression()
+ genericAssertions(onePlusOneBefore, onePlusOneAfter)
+ assert(onePlusOneBefore eq onePlusOneAfter)
+
+ // Transforms stateful expressions with no nesting:
+ val statefulExprBefore = makeExprWithPositionAndTag(Rand(Literal(1)))
+ val statefulExprAfter =
statefulExprBefore.freshCopyIfContainsStatefulExpression()
+ genericAssertions(statefulExprBefore, statefulExprAfter)
+ assert(statefulExprBefore ne statefulExprAfter)
+
+ // Transforms expressions nested three levels deep:
+ val withNestedStatefulBefore = makeExprWithPositionAndTag(
+ Add(Literal(1), Add(Literal(1), Rand(Literal(1))))
+ )
+ val withNestedStatefulAfter =
withNestedStatefulBefore.freshCopyIfContainsStatefulExpression()
+ genericAssertions(withNestedStatefulBefore, withNestedStatefulAfter)
+ assert(withNestedStatefulBefore ne withNestedStatefulAfter)
+ def getStateful(e: Expression): Expression = {
+ e.collect { case e if e.stateful => e }.head
+ }
+ assert(getStateful(withNestedStatefulBefore) ne
getStateful(withNestedStatefulAfter))
+ }
+
object MalformedClassObject extends Serializable {
case class MalformedNameExpression(child: Expression) extends
TaggingExpression {
override protected def withNewChildInternal(newChild: Expression):
Expression =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4269aaea0df..a7bb0a2d1bd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3567,6 +3567,17 @@ class DataFrameSuite extends QueryTest
}.isEmpty)
}
}
+
+ test("SPARK-41049: stateful expression should be copied correctly") {
+ val df = spark.sparkContext.parallelize(1 to 5).toDF("x")
+ val v1 = (rand() * 10000).cast(IntegerType)
+ val v2 = to_csv(struct(v1.as("a"))) // to_csv is CodegenFallback
+ df.select(v1, v1, v2, v2).collect.foreach { row =>
+ assert(row.getInt(0) == row.getInt(1))
+ assert(row.getInt(0).toString == row.getString(2))
+ assert(row.getInt(0).toString == row.getString(3))
+ }
+ }
}
case class GroupByKey(a: Int, b: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]