This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ddc4005a93fc [SPARK-52503][SQL][CONNECT] Fix `drop` when the input 
column is not existent
ddc4005a93fc is described below

commit ddc4005a93fc4293caeea605fb54ec3811b462b6
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jul 2 12:22:28 2025 +0800

    [SPARK-52503][SQL][CONNECT] Fix `drop` when the input column is not existent
    
    ### What changes were proposed in this pull request?
    Fix `drop` when the input column is not existent
    
    ### Why are the changes needed?
    bugfix
    
    ```
    import pyspark.sql.functions as F
    
    df1 = spark.createDataFrame([("a", "b", "c") ], schema='colA string, colB 
string, colC string')
    
    df2 = spark.createDataFrame([("c", "d", "") ], schema='colC string, colD 
string, colE string')
    
    df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn("colB", 
F.when(df1["colB"] == "b", F.concat(df1["colB"].cast("string"), 
F.lit("_newValue"))).otherwise(df1["colB"]))
    
    df3
    DataFrame[colA: string, colB: string, colC: string, colC: string, colD: 
string, colE: string]
    
    df3.drop(df1["colB"])
    DataFrame[colA: string, colC: string, colC: string, colD: string, colE: 
string]
    ```
    
    `df3` doesn't contains `df1["colB"]`, so no columns should be dropped.
    
    ### Does this PR introduce _any_ user-facing change?
    yes, bug fix
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #51196 from zhengruifeng/fix_drop_col.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 python/pyspark/sql/tests/test_dataframe.py         | 33 +++++++++++++++-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  6 ++-
 .../catalyst/analysis/ColumnResolutionHelper.scala | 27 +++++++++++++
 .../analysis/ResolveDataFrameDropColumns.scala     | 19 ++++++---
 .../apache/spark/sql/connect/DataFrameSuite.scala  | 45 ++++++++++++++++++++++
 5 files changed, 122 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 890ae56ffa52..2d578d749517 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -26,7 +26,18 @@ import io
 from contextlib import redirect_stdout
 
 from pyspark.sql import Row, functions, DataFrame
-from pyspark.sql.functions import col, lit, count, struct, date_format, 
to_date, array, explode
+from pyspark.sql.functions import (
+    col,
+    lit,
+    count,
+    struct,
+    date_format,
+    to_date,
+    array,
+    explode,
+    when,
+    concat,
+)
 from pyspark.sql.types import (
     StringType,
     IntegerType,
@@ -189,6 +200,26 @@ class DataFrameTestsMixin:
         self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
         self.assertEqual(df.drop(col("name"), col("age"), 
col("random")).columns, ["active"])
 
+    def test_drop_notexistent_col(self):
+        df1 = self.spark.createDataFrame(
+            [("a", "b", "c")],
+            schema="colA string, colB string, colC string",
+        )
+        df2 = self.spark.createDataFrame(
+            [("c", "d", "e")],
+            schema="colC string, colD string, colE string",
+        )
+        df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn(
+            "colB",
+            when(df1["colB"] == "b", concat(df1["colB"].cast("string"), 
lit("x"))).otherwise(
+                df1["colB"]
+            ),
+        )
+        df4 = df3.drop(df1["colB"])
+
+        self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD", 
"colE"])
+        self.assertEqual(df4.count(), 1)
+
     def test_drop_join(self):
         left_df = self.spark.createDataFrame(
             [(1, "a"), (2, "b"), (3, "c")],
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a14efc0c5fb3..9ff0401cd026 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -454,7 +454,6 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       ResolveNaturalAndUsingJoin ::
       ResolveOutputRelation ::
       new ResolveTableConstraints(catalogManager) ::
-      new ResolveDataFrameDropColumns(catalogManager) ::
       new ResolveSetVariable(catalogManager) ::
       ExtractWindowExpressions ::
       GlobalAggregates ::
@@ -1483,6 +1482,8 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       new ResolveReferencesInUpdate(catalogManager)
     private val resolveReferencesInSort =
       new ResolveReferencesInSort(catalogManager)
+    private val resolveDataFrameDropColumns =
+      new ResolveDataFrameDropColumns(catalogManager)
 
     /**
      * Return true if there're conflicting attributes among children's outputs 
of a plan
@@ -1793,6 +1794,9 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       // Pass for Execute Immediate as arguments will be resolved by 
[[SubstituteExecuteImmediate]].
       case e : ExecuteImmediateQuery => e
 
+      case d: DataFrameDropColumns if !d.resolved =>
+        resolveDataFrameDropColumns(d)
+
       case q: LogicalPlan =>
         logTrace(s"Attempting to resolve 
${q.simpleString(conf.maxToStringFields)}")
         q.mapExpressions(resolveExpressionByPlanChildren(_, q, 
includeLastResort = true))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 69591ed8c5f9..d3e52d11b465 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -509,6 +509,33 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
       includeLastResort = includeLastResort)
   }
 
+  // Try to resolve `UnresolvedAttribute` by the children with Plan Ids.
+  // The `UnresolvedAttribute` must have a Plan Id:
+  //  - If Plan Id not found in the plan, raise 
CANNOT_RESOLVE_DATAFRAME_COLUMN.
+  //  - If Plan Id found in the plan, but column not found, return None.
+  //  - Otherwise, return the resolved expression.
+  private[sql] def tryResolveColumnByPlanChildren(
+      u: UnresolvedAttribute,
+      q: LogicalPlan,
+      includeLastResort: Boolean = false): Option[Expression] = {
+    assert(u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty,
+      s"UnresolvedAttribute $u should have a Plan Id tag")
+
+    resolveDataFrameColumn(u, q.children).map { r =>
+      resolveExpression(
+        r,
+        resolveColumnByName = nameParts => {
+          q.resolveChildren(nameParts, conf.resolver)
+        },
+        getAttrCandidates = () => {
+          assert(q.children.length == 1)
+          q.children.head.output
+        },
+        throws = true,
+        includeLastResort = includeLastResort)
+    }
+  }
+
   /**
    * The last resort to resolve columns. Currently it does two things:
    *  - Try to resolve column names as outer references
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
index 0f9b93cc2986..a0f67fa3f445 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, 
LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS
 import org.apache.spark.sql.connector.catalog.CatalogManager
 
@@ -27,17 +27,24 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
  * Note that DataFrameDropColumns allows and ignores non-existing columns.
  */
 class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
-  extends Rule[LogicalPlan] with ColumnResolutionHelper  {
+  extends SQLConfHelper with ColumnResolutionHelper  {
 
-  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsWithPruning(
+  def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
     _.containsPattern(DF_DROP_COLUMNS)) {
     case d: DataFrameDropColumns if d.childrenResolved =>
       // expressions in dropList can be unresolved, e.g.
       //   df.drop(col("non-existing-column"))
-      val dropped = d.dropList.map {
+      val dropped = d.dropList.flatMap {
         case u: UnresolvedAttribute =>
-          resolveExpressionByPlanChildren(u, d)
-        case e => e
+          if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) {
+            // Plan Id comes from Spark Connect,
+            // Here we ignore the `UnresolvedAttribute` if its Plan Id can be 
found
+            // but column not found.
+            tryResolveColumnByPlanChildren(u, d)
+          } else {
+            Some(resolveExpressionByPlanChildren(u, d))
+          }
+        case e => Some(e)
       }
       val remaining = d.child.output.filterNot(attr => 
dropped.exists(_.semanticEquals(attr)))
       if (remaining.size == d.child.output.size) {
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala
new file mode 100644
index 000000000000..2993f44efceb
--- /dev/null
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect
+
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.functions.{concat, lit, when}
+
+class DataFrameSuite extends QueryTest with RemoteSparkSession {
+
+  test("drop") {
+    val sparkSession = spark
+    import sparkSession.implicits._
+
+    val df1 = Seq[(String, String, String)](("a", "b", "c")).toDF("colA", 
"colB", "colC")
+
+    val df2 = Seq[(String, String, String)](("c", "d", "e")).toDF("colC", 
"colD", "colE")
+
+    val df3 = df1
+      .join(df2, df1.col("colC") === df2.col("colC"))
+      .withColumn(
+        "colB",
+        when(df1.col("colB") === "b", concat(df1.col("colB").cast("string"), 
lit("x")))
+          .otherwise(df1.col("colB")))
+
+    val df4 = df3.drop(df1.col("colB"))
+
+    assert(df4.columns === Array("colA", "colB", "colC", "colC", "colD", 
"colE"))
+    assert(df4.count() === 1)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to