This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e2e405442c91 [SPARK-55395][SQL] Disable RDD cache in 
`DataFrame.zipWithIndex`
e2e405442c91 is described below

commit e2e405442c91c393e8297a84dedb3aed32ea4d00
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Feb 9 07:13:11 2026 +0900

    [SPARK-55395][SQL] Disable RDD cache in `DataFrame.zipWithIndex`
    
    ### What changes were proposed in this pull request?
    Disable RDD cache in DataFrame.zipWithIndex
    
    ### Why are the changes needed?
    When `AttachDistributedSequence` was first introduced for Pandas API on 
Spark in 
https://github.com/apache/spark/commit/93cec49212fe82816fcadf69f429cebaec60e058,
 the underlying RDD was always `localCheckpoint`ed to cache  to avoid 
re-computation.
    Then we hit serious executor memory issue, and in 
https://github.com/apache/spark/commit/42790905668effc2c0c081bae7d081faa1e18424 
we made the storage level configurable and release the cached data after each 
stage by AQE.
    
    Since we are reusing `AttachDistributedSequence` to implement 
`DataFrame.zipWithIndex`, to be more conservative, we'd start with a no-cache 
version, it will be easy to enable the caching if necessary in the future.
    
    Moreover, there is some chance to optimize the no-cache version 
https://github.com/apache/spark/pull/54169
    
    This PR disable the RDD cache in `DistributedSequenceID` by default; and in 
the PS callsites, explicitly set `cache=True`
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #54178 from zhengruifeng/zip_with_index_cache.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/internal.py                     |  4 ++-
 .../catalyst/analysis/DeduplicateRelations.scala   |  2 +-
 .../analysis/ExtractDistributedSequenceID.scala    |  6 +++-
 .../expressions/DistributedSequenceID.scala        |  9 ++++--
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  2 +-
 .../plans/logical/pythonLogicalOperators.scala     |  8 ++++--
 .../org/apache/spark/sql/classic/Dataset.scala     |  2 +-
 .../spark/sql/execution/SparkStrategies.scala      |  4 +--
 .../python/AttachDistributedSequenceExec.scala     | 33 +++++++++++++++-------
 9 files changed, 48 insertions(+), 22 deletions(-)

diff --git a/python/pyspark/sql/internal.py b/python/pyspark/sql/internal.py
index 3007b28b0044..dd9ebbcdc182 100644
--- a/python/pyspark/sql/internal.py
+++ b/python/pyspark/sql/internal.py
@@ -104,7 +104,9 @@ class InternalFunction:
 
     @staticmethod
     def distributed_sequence_id() -> Column:
-        return 
InternalFunction._invoke_internal_function_over_columns("distributed_sequence_id")
+        return InternalFunction._invoke_internal_function_over_columns(
+            "distributed_sequence_id", F.lit(True)
+        )
 
     @staticmethod
     def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index b8da376bead6..2a2440117e40 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -441,7 +441,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
         newVersion.copyTagsFrom(oldVersion)
         Seq((oldVersion, newVersion))
 
-      case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
+      case oldVersion @ AttachDistributedSequence(sequenceAttr, _, _)
         if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
         val newVersion = oldVersion.copy(sequenceAttr = 
sequenceAttr.newInstance())
         newVersion.copyTagsFrom(oldVersion)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
index bf6ab8e50616..fe26122f3ac1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala
@@ -34,8 +34,12 @@ object ExtractDistributedSequenceID extends 
Rule[LogicalPlan] {
     
plan.resolveOperatorsUpWithPruning(_.containsPattern(DISTRIBUTED_SEQUENCE_ID)) {
       case plan: LogicalPlan if plan.resolved &&
           
plan.expressions.exists(_.exists(_.isInstanceOf[DistributedSequenceID])) =>
+        val cache = plan.expressions.exists(_.exists(e =>
+          e.isInstanceOf[DistributedSequenceID] &&
+            
e.asInstanceOf[DistributedSequenceID].cache.eval().asInstanceOf[Boolean]))
         val attr = AttributeReference("distributed_sequence_id", LongType, 
nullable = false)()
-        val newPlan = 
plan.withNewChildren(plan.children.map(AttachDistributedSequence(attr, _)))
+        val newPlan = plan.withNewChildren(
+            plan.children.map(AttachDistributedSequence(attr, _, cache)))
           .transformExpressions { case _: DistributedSequenceID => attr }
         Project(plan.output, newPlan)
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
index 5a0bff990e68..cd71ee858052 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala
@@ -26,10 +26,15 @@ import org.apache.spark.sql.types.{DataType, LongType}
  *
  * @note this expression is dedicated for Pandas API on Spark to use.
  */
-case class DistributedSequenceID() extends LeafExpression with Unevaluable 
with NonSQLExpression {
+case class DistributedSequenceID(cache: Expression)
+  extends LeafExpression with Unevaluable with NonSQLExpression {
+
+  // This argument indicate whether the underlying RDD should be cached
+  // according to PS config "pandas_on_Spark.compute.default_index_cache".
+  def this() = this(Literal(false))
 
   override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
-    DistributedSequenceID()
+    DistributedSequenceID(cache)
   }
 
   override def nullable: Boolean = false
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fe15819bd44a..125db2752b20 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1068,7 +1068,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
       a.copy(child = Expand(newProjects, newOutput, grandChild))
 
     // Prune and drop AttachDistributedSequence if the produced attribute is 
not referred.
-    case p @ Project(_, a @ AttachDistributedSequence(_, grandChild))
+    case p @ Project(_, a @ AttachDistributedSequence(_, grandChild, _))
         if !p.references.contains(a.sequenceAttr) =>
       p.copy(child = prunedChild(grandChild, p.references))
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index bcfcae2ee16c..db22a0781c0e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -367,12 +367,14 @@ case class ArrowEvalPythonUDTF(
 
 /**
  * A logical plan that adds a new long column with the name `name` that
- * increases one by one. This is for 'distributed-sequence' default index
- * in pandas API on Spark.
+ * increases one by one.
+ * This is used in both 'distributed-sequence' index in pandas API on Spark
+ * and 'DataFrame.zipWithIndex'.
  */
 case class AttachDistributedSequence(
     sequenceAttr: Attribute,
-    child: LogicalPlan) extends UnaryNode {
+    child: LogicalPlan,
+    cache: Boolean = false) extends UnaryNode {
 
   override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
index 088df782a541..17d4640f22fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
@@ -2062,7 +2062,7 @@ class Dataset[T] private[sql](
    * This is for 'distributed-sequence' default index in pandas API on Spark.
    */
   private[sql] def withSequenceColumn(name: String) = {
-    select(Column(DistributedSequenceID()).alias(name), col("*"))
+    select(Column(DistributedSequenceID(Literal(true))).alias(name), col("*"))
   }
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5efad83bcba7..5c393b1db227 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -969,8 +969,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.python.MapInPandasExec(func, output, planLater(child), 
isBarrier, profile) :: Nil
       case logical.MapInArrow(func, output, child, isBarrier, profile) =>
         execution.python.MapInArrowExec(func, output, planLater(child), 
isBarrier, profile) :: Nil
-      case logical.AttachDistributedSequence(attr, child) =>
-        execution.python.AttachDistributedSequenceExec(attr, planLater(child)) 
:: Nil
+      case logical.AttachDistributedSequence(attr, child, cache) =>
+        execution.python.AttachDistributedSequenceExec(attr, planLater(child), 
cache) :: Nil
       case logical.PythonWorkerLogs(jsonAttr) =>
         execution.python.PythonWorkerLogsExec(jsonAttr) :: Nil
       case logical.MapElements(f, _, _, objAttr, child) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
index e27bde38a6f5..507b632f5565 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala
@@ -29,12 +29,16 @@ import org.apache.spark.storage.{StorageLevel, 
StorageLevelMapper}
 
 /**
  * A physical plan that adds a new long column with `sequenceAttr` that
- * increases one by one. This is for 'distributed-sequence' default index
- * in pandas API on Spark.
+ * increases one by one.
+ * This is for 'distributed-sequence' default index in pandas API on Spark,
+ * and 'DataFrame.zipWithIndex'
+ * When cache is true, the underlying RDD will be cached according to
+ * PS config "pandas_on_Spark.compute.default_index_cache".
  */
 case class AttachDistributedSequenceExec(
     sequenceAttr: Attribute,
-    child: SparkPlan)
+    child: SparkPlan,
+    cache: Boolean)
   extends UnaryExecNode {
 
   override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
@@ -45,8 +49,9 @@ case class AttachDistributedSequenceExec(
 
   @transient private var cached: RDD[InternalRow] = _
 
-  override protected def doExecute(): RDD[InternalRow] = {
-    val childRDD = child.execute()
+  // cache the underlying RDD according to
+  // PS config "pandas_on_Spark.compute.default_index_cache"
+  private def cacheRDD(rdd: RDD[InternalRow]): RDD[InternalRow] = {
     // before `compute.default_index_cache` is explicitly set via
     // `ps.set_option`, `SQLConf.get` can not get its value (as well as its 
default value);
     // after `ps.set_option`, `SQLConf.get` can get its value:
@@ -74,22 +79,30 @@ case class AttachDistributedSequenceExec(
       StorageLevelMapper.MEMORY_AND_DISK_SER.name()
     ).stripPrefix("\"").stripSuffix("\"")
 
-    val cachedRDD = storageLevel match {
+    storageLevel match {
       // zipWithIndex launches a Spark job only if #partition > 1
-      case _ if childRDD.getNumPartitions <= 1 => childRDD
+      case _ if rdd.getNumPartitions <= 1 => rdd
 
-      case "NONE" => childRDD
+      case "NONE" => rdd
 
       case "LOCAL_CHECKPOINT" =>
         // localcheckpointing is unreliable so should not eagerly release it 
in 'cleanupResources'
-        childRDD.map(_.copy()).localCheckpoint()
+        rdd.map(_.copy()).localCheckpoint()
           .setName(s"Temporary RDD locally checkpointed in 
AttachDistributedSequenceExec($id)")
 
       case _ =>
-        cached = 
childRDD.map(_.copy()).persist(StorageLevel.fromString(storageLevel))
+        cached = 
rdd.map(_.copy()).persist(StorageLevel.fromString(storageLevel))
           .setName(s"Temporary RDD cached in 
AttachDistributedSequenceExec($id)")
         cached
     }
+  }
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val childRDD: RDD[InternalRow] = child.execute()
+
+    // if cache is true, the underlying rdd is cached according to
+    // PS config "pandas_on_Spark.compute.default_index_cache"
+    val cachedRDD = if (cache) this.cacheRDD(childRDD) else childRDD
 
     cachedRDD.zipWithIndex().mapPartitions { iter =>
       val unsafeProj = UnsafeProjection.create(output, output)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to