This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e2e405442c91 [SPARK-55395][SQL] Disable RDD cache in
`DataFrame.zipWithIndex`
e2e405442c91 is described below
commit e2e405442c91c393e8297a84dedb3aed32ea4d00
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Feb 9 07:13:11 2026 +0900
[SPARK-55395][SQL] Disable RDD cache in `DataFrame.zipWithIndex`
### What changes were proposed in this pull request?
Disable RDD cache in DataFrame.zipWithIndex
### Why are the changes needed?
When `AttachDistributedSequence` was first introduced for Pandas API on
Spark in
https://github.com/apache/spark/commit/93cec49212fe82816fcadf69f429cebaec60e058,
the underlying RDD was always `localCheckpoint`ed to cache to avoid
re-computation.
Then we hit serious executor memory issue, and in
https://github.com/apache/spark/commit/42790905668effc2c0c081bae7d081faa1e18424
we made the storage level configurable and release the cached data after each
stage by AQE.
Since we are reusing `AttachDistributedSequence` to implement
`DataFrame.zipWithIndex`, to be more conservative, we'd start with a no-cache
version, it will be easy to enable the caching if necessary in the future.
Moreover, there is some chance to optimize the no-cache version
https://github.com/apache/spark/pull/54169
This PR disable the RDD cache in `DistributedSequenceID` by default; and in
the PS callsites, explicitly set `cache=True`
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #54178 from zhengruifeng/zip_with_index_cache.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/internal.py | 4 ++-
.../catalyst/analysis/DeduplicateRelations.scala | 2 +-
.../analysis/ExtractDistributedSequenceID.scala | 6 +++-
.../expressions/DistributedSequenceID.scala | 9 ++++--
.../spark/sql/catalyst/optimizer/Optimizer.scala | 2 +-
.../plans/logical/pythonLogicalOperators.scala | 8 ++++--
.../org/apache/spark/sql/classic/Dataset.scala | 2 +-
.../spark/sql/execution/SparkStrategies.scala | 4 +--
.../python/AttachDistributedSequenceExec.scala | 33 +++++++++++++++-------
9 files changed, 48 insertions(+), 22 deletions(-)
diff --git a/python/pyspark/sql/internal.py b/python/pyspark/sql/internal.py
index 3007b28b0044..dd9ebbcdc182 100644
--- a/python/pyspark/sql/internal.py
+++ b/python/pyspark/sql/internal.py
@@ -104,7 +104,9 @@ class InternalFunction:
@staticmethod
def distributed_sequence_id() -> Column:
- return
InternalFunction._invoke_internal_function_over_columns("distributed_sequence_id")
+ return InternalFunction._invoke_internal_function_over_columns(
+ "distributed_sequence_id", F.lit(True)
+ )
@staticmethod
def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index b8da376bead6..2a2440117e40 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -441,7 +441,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))
- case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
+ case oldVersion @ AttachDistributedSequence(sequenceAttr, _, _)
if
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(sequenceAttr =
sequenceAttr.newInstance())
newVersion.copyTagsFrom(oldVersion)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
index bf6ab8e50616..fe26122f3ac1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
@@ -34,8 +34,12 @@ object ExtractDistributedSequenceID extends
Rule[LogicalPlan] {
plan.resolveOperatorsUpWithPruning(_.containsPattern(DISTRIBUTED_SEQUENCE_ID)) {
case plan: LogicalPlan if plan.resolved &&
plan.expressions.exists(_.exists(_.isInstanceOf[DistributedSequenceID])) =>
+ val cache = plan.expressions.exists(_.exists(e =>
+ e.isInstanceOf[DistributedSequenceID] &&
+
e.asInstanceOf[DistributedSequenceID].cache.eval().asInstanceOf[Boolean]))
val attr = AttributeReference("distributed_sequence_id", LongType,
nullable = false)()
- val newPlan =
plan.withNewChildren(plan.children.map(AttachDistributedSequence(attr, _)))
+ val newPlan = plan.withNewChildren(
+ plan.children.map(AttachDistributedSequence(attr, _, cache)))
.transformExpressions { case _: DistributedSequenceID => attr }
Project(plan.output, newPlan)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
index 5a0bff990e68..cd71ee858052 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
@@ -26,10 +26,15 @@ import org.apache.spark.sql.types.{DataType, LongType}
*
* @note this expression is dedicated for Pandas API on Spark to use.
*/
-case class DistributedSequenceID() extends LeafExpression with Unevaluable
with NonSQLExpression {
+case class DistributedSequenceID(cache: Expression)
+ extends LeafExpression with Unevaluable with NonSQLExpression {
+
+ // This argument indicate whether the underlying RDD should be cached
+ // according to PS config "pandas_on_Spark.compute.default_index_cache".
+ def this() = this(Literal(false))
override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]):
Expression = {
- DistributedSequenceID()
+ DistributedSequenceID(cache)
}
override def nullable: Boolean = false
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fe15819bd44a..125db2752b20 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1068,7 +1068,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
a.copy(child = Expand(newProjects, newOutput, grandChild))
// Prune and drop AttachDistributedSequence if the produced attribute is
not referred.
- case p @ Project(_, a @ AttachDistributedSequence(_, grandChild))
+ case p @ Project(_, a @ AttachDistributedSequence(_, grandChild, _))
if !p.references.contains(a.sequenceAttr) =>
p.copy(child = prunedChild(grandChild, p.references))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index bcfcae2ee16c..db22a0781c0e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -367,12 +367,14 @@ case class ArrowEvalPythonUDTF(
/**
* A logical plan that adds a new long column with the name `name` that
- * increases one by one. This is for 'distributed-sequence' default index
- * in pandas API on Spark.
+ * increases one by one.
+ * This is used in both 'distributed-sequence' index in pandas API on Spark
+ * and 'DataFrame.zipWithIndex'.
*/
case class AttachDistributedSequence(
sequenceAttr: Attribute,
- child: LogicalPlan) extends UnaryNode {
+ child: LogicalPlan,
+ cache: Boolean = false) extends UnaryNode {
override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index 088df782a541..17d4640f22fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -2062,7 +2062,7 @@ class Dataset[T] private[sql](
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
- select(Column(DistributedSequenceID()).alias(name), col("*"))
+ select(Column(DistributedSequenceID(Literal(true))).alias(name), col("*"))
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5efad83bcba7..5c393b1db227 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -969,8 +969,8 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
execution.python.MapInPandasExec(func, output, planLater(child),
isBarrier, profile) :: Nil
case logical.MapInArrow(func, output, child, isBarrier, profile) =>
execution.python.MapInArrowExec(func, output, planLater(child),
isBarrier, profile) :: Nil
- case logical.AttachDistributedSequence(attr, child) =>
- execution.python.AttachDistributedSequenceExec(attr, planLater(child))
:: Nil
+ case logical.AttachDistributedSequence(attr, child, cache) =>
+ execution.python.AttachDistributedSequenceExec(attr, planLater(child),
cache) :: Nil
case logical.PythonWorkerLogs(jsonAttr) =>
execution.python.PythonWorkerLogsExec(jsonAttr) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
index e27bde38a6f5..507b632f5565 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
@@ -29,12 +29,16 @@ import org.apache.spark.storage.{StorageLevel,
StorageLevelMapper}
/**
* A physical plan that adds a new long column with `sequenceAttr` that
- * increases one by one. This is for 'distributed-sequence' default index
- * in pandas API on Spark.
+ * increases one by one.
+ * This is for 'distributed-sequence' default index in pandas API on Spark,
+ * and 'DataFrame.zipWithIndex'
+ * When cache is true, the underlying RDD will be cached according to
+ * PS config "pandas_on_Spark.compute.default_index_cache".
*/
case class AttachDistributedSequenceExec(
sequenceAttr: Attribute,
- child: SparkPlan)
+ child: SparkPlan,
+ cache: Boolean)
extends UnaryExecNode {
override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
@@ -45,8 +49,9 @@ case class AttachDistributedSequenceExec(
@transient private var cached: RDD[InternalRow] = _
- override protected def doExecute(): RDD[InternalRow] = {
- val childRDD = child.execute()
+ // cache the underlying RDD according to
+ // PS config "pandas_on_Spark.compute.default_index_cache"
+ private def cacheRDD(rdd: RDD[InternalRow]): RDD[InternalRow] = {
// before `compute.default_index_cache` is explicitly set via
// `ps.set_option`, `SQLConf.get` can not get its value (as well as its
default value);
// after `ps.set_option`, `SQLConf.get` can get its value:
@@ -74,22 +79,30 @@ case class AttachDistributedSequenceExec(
StorageLevelMapper.MEMORY_AND_DISK_SER.name()
).stripPrefix("\"").stripSuffix("\"")
- val cachedRDD = storageLevel match {
+ storageLevel match {
// zipWithIndex launches a Spark job only if #partition > 1
- case _ if childRDD.getNumPartitions <= 1 => childRDD
+ case _ if rdd.getNumPartitions <= 1 => rdd
- case "NONE" => childRDD
+ case "NONE" => rdd
case "LOCAL_CHECKPOINT" =>
// localcheckpointing is unreliable so should not eagerly release it
in 'cleanupResources'
- childRDD.map(_.copy()).localCheckpoint()
+ rdd.map(_.copy()).localCheckpoint()
.setName(s"Temporary RDD locally checkpointed in
AttachDistributedSequenceExec($id)")
case _ =>
- cached =
childRDD.map(_.copy()).persist(StorageLevel.fromString(storageLevel))
+ cached =
rdd.map(_.copy()).persist(StorageLevel.fromString(storageLevel))
.setName(s"Temporary RDD cached in
AttachDistributedSequenceExec($id)")
cached
}
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val childRDD: RDD[InternalRow] = child.execute()
+
+ // if cache is true, the underlying rdd is cached according to
+ // PS config "pandas_on_Spark.compute.default_index_cache"
+ val cachedRDD = if (cache) this.cacheRDD(childRDD) else childRDD
cachedRDD.zipWithIndex().mapPartitions { iter =>
val unsafeProj = UnsafeProjection.create(output, output)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]