This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ddc4005a93fc [SPARK-52503][SQL][CONNECT] Fix `drop` when the input
column is not existent
ddc4005a93fc is described below
commit ddc4005a93fc4293caeea605fb54ec3811b462b6
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jul 2 12:22:28 2025 +0800
[SPARK-52503][SQL][CONNECT] Fix `drop` when the input column is not existent
### What changes were proposed in this pull request?
Fix `drop` when the input column is not existent
### Why are the changes needed?
bugfix
```
import pyspark.sql.functions as F
df1 = spark.createDataFrame([("a", "b", "c") ], schema='colA string, colB
string, colC string')
df2 = spark.createDataFrame([("c", "d", "") ], schema='colC string, colD
string, colE string')
df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn("colB",
F.when(df1["colB"] == "b", F.concat(df1["colB"].cast("string"),
F.lit("_newValue"))).otherwise(df1["colB"]))
df3
DataFrame[colA: string, colB: string, colC: string, colC: string, colD:
string, colE: string]
df3.drop(df1["colB"])
DataFrame[colA: string, colC: string, colC: string, colD: string, colE:
string]
```
`df3` doesn't contains `df1["colB"]`, so no columns should be dropped.
### Does this PR introduce _any_ user-facing change?
yes, bug fix
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #51196 from zhengruifeng/fix_drop_col.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
python/pyspark/sql/tests/test_dataframe.py | 33 +++++++++++++++-
.../spark/sql/catalyst/analysis/Analyzer.scala | 6 ++-
.../catalyst/analysis/ColumnResolutionHelper.scala | 27 +++++++++++++
.../analysis/ResolveDataFrameDropColumns.scala | 19 ++++++---
.../apache/spark/sql/connect/DataFrameSuite.scala | 45 ++++++++++++++++++++++
5 files changed, 122 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 890ae56ffa52..2d578d749517 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -26,7 +26,18 @@ import io
from contextlib import redirect_stdout
from pyspark.sql import Row, functions, DataFrame
-from pyspark.sql.functions import col, lit, count, struct, date_format,
to_date, array, explode
+from pyspark.sql.functions import (
+ col,
+ lit,
+ count,
+ struct,
+ date_format,
+ to_date,
+ array,
+ explode,
+ when,
+ concat,
+)
from pyspark.sql.types import (
StringType,
IntegerType,
@@ -189,6 +200,26 @@ class DataFrameTestsMixin:
self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
self.assertEqual(df.drop(col("name"), col("age"),
col("random")).columns, ["active"])
+ def test_drop_notexistent_col(self):
+ df1 = self.spark.createDataFrame(
+ [("a", "b", "c")],
+ schema="colA string, colB string, colC string",
+ )
+ df2 = self.spark.createDataFrame(
+ [("c", "d", "e")],
+ schema="colC string, colD string, colE string",
+ )
+ df3 = df1.join(df2, df1["colC"] == df2["colC"]).withColumn(
+ "colB",
+ when(df1["colB"] == "b", concat(df1["colB"].cast("string"),
lit("x"))).otherwise(
+ df1["colB"]
+ ),
+ )
+ df4 = df3.drop(df1["colB"])
+
+ self.assertEqual(df4.columns, ["colA", "colB", "colC", "colC", "colD",
"colE"])
+ self.assertEqual(df4.count(), 1)
+
def test_drop_join(self):
left_df = self.spark.createDataFrame(
[(1, "a"), (2, "b"), (3, "c")],
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a14efc0c5fb3..9ff0401cd026 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -454,7 +454,6 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
ResolveNaturalAndUsingJoin ::
ResolveOutputRelation ::
new ResolveTableConstraints(catalogManager) ::
- new ResolveDataFrameDropColumns(catalogManager) ::
new ResolveSetVariable(catalogManager) ::
ExtractWindowExpressions ::
GlobalAggregates ::
@@ -1483,6 +1482,8 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
new ResolveReferencesInUpdate(catalogManager)
private val resolveReferencesInSort =
new ResolveReferencesInSort(catalogManager)
+ private val resolveDataFrameDropColumns =
+ new ResolveDataFrameDropColumns(catalogManager)
/**
* Return true if there're conflicting attributes among children's outputs
of a plan
@@ -1793,6 +1794,9 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
// Pass for Execute Immediate as arguments will be resolved by
[[SubstituteExecuteImmediate]].
case e : ExecuteImmediateQuery => e
+ case d: DataFrameDropColumns if !d.resolved =>
+ resolveDataFrameDropColumns(d)
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve
${q.simpleString(conf.maxToStringFields)}")
q.mapExpressions(resolveExpressionByPlanChildren(_, q,
includeLastResort = true))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 69591ed8c5f9..d3e52d11b465 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -509,6 +509,33 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
includeLastResort = includeLastResort)
}
+ // Try to resolve `UnresolvedAttribute` by the children with Plan Ids.
+ // The `UnresolvedAttribute` must have a Plan Id:
+ // - If Plan Id not found in the plan, raise
CANNOT_RESOLVE_DATAFRAME_COLUMN.
+ // - If Plan Id found in the plan, but column not found, return None.
+ // - Otherwise, return the resolved expression.
+ private[sql] def tryResolveColumnByPlanChildren(
+ u: UnresolvedAttribute,
+ q: LogicalPlan,
+ includeLastResort: Boolean = false): Option[Expression] = {
+ assert(u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty,
+ s"UnresolvedAttribute $u should have a Plan Id tag")
+
+ resolveDataFrameColumn(u, q.children).map { r =>
+ resolveExpression(
+ r,
+ resolveColumnByName = nameParts => {
+ q.resolveChildren(nameParts, conf.resolver)
+ },
+ getAttrCandidates = () => {
+ assert(q.children.length == 1)
+ q.children.head.output
+ },
+ throws = true,
+ includeLastResort = includeLastResort)
+ }
+ }
+
/**
* The last resort to resolve columns. Currently it does two things:
* - Try to resolve column names as outer references
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
index 0f9b93cc2986..a0f67fa3f445 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
@@ -17,8 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns,
LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS
import org.apache.spark.sql.connector.catalog.CatalogManager
@@ -27,17 +27,24 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
* Note that DataFrameDropColumns allows and ignores non-existing columns.
*/
class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
- extends Rule[LogicalPlan] with ColumnResolutionHelper {
+ extends SQLConfHelper with ColumnResolutionHelper {
- override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning(
_.containsPattern(DF_DROP_COLUMNS)) {
case d: DataFrameDropColumns if d.childrenResolved =>
// expressions in dropList can be unresolved, e.g.
// df.drop(col("non-existing-column"))
- val dropped = d.dropList.map {
+ val dropped = d.dropList.flatMap {
case u: UnresolvedAttribute =>
- resolveExpressionByPlanChildren(u, d)
- case e => e
+ if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) {
+ // Plan Id comes from Spark Connect,
+ // Here we ignore the `UnresolvedAttribute` if its Plan Id can be
found
+ // but column not found.
+ tryResolveColumnByPlanChildren(u, d)
+ } else {
+ Some(resolveExpressionByPlanChildren(u, d))
+ }
+ case e => Some(e)
}
val remaining = d.child.output.filterNot(attr =>
dropped.exists(_.semanticEquals(attr)))
if (remaining.size == d.child.output.size) {
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala
new file mode 100644
index 000000000000..2993f44efceb
--- /dev/null
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/DataFrameSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect
+
+import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
+import org.apache.spark.sql.functions.{concat, lit, when}
+
+class DataFrameSuite extends QueryTest with RemoteSparkSession {
+
+ test("drop") {
+ val sparkSession = spark
+ import sparkSession.implicits._
+
+ val df1 = Seq[(String, String, String)](("a", "b", "c")).toDF("colA",
"colB", "colC")
+
+ val df2 = Seq[(String, String, String)](("c", "d", "e")).toDF("colC",
"colD", "colE")
+
+ val df3 = df1
+ .join(df2, df1.col("colC") === df2.col("colC"))
+ .withColumn(
+ "colB",
+ when(df1.col("colB") === "b", concat(df1.col("colB").cast("string"),
lit("x")))
+ .otherwise(df1.col("colB")))
+
+ val df4 = df3.drop(df1.col("colB"))
+
+ assert(df4.columns === Array("colA", "colB", "colC", "colC", "colD",
"colE"))
+ assert(df4.count() === 1)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]