This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 22731393069a [SPARK-50601][SQL] Support withColumns /
withColumnsRenamed in subqueries
22731393069a is described below
commit 22731393069a3f180a9e719e57a694347c0ce87b
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jan 13 18:22:16 2025 -0800
[SPARK-50601][SQL] Support withColumns / withColumnsRenamed in subqueries
### What changes were proposed in this pull request?
Supports `withColumns` / `withColumnsRenamed` in subqueries.
### Why are the changes needed?
When the query is used as a subquery by adding `col.outer()`, `withColumns`
or `withColumnsRenamed` doesn't work because they need analyzed plans.
### Does this PR introduce _any_ user-facing change?
Yes, those APIs are available in subqueries.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49386 from ueshin/issues/SPARK-50601/with_columns.
Lead-authored-by: Takuya Ueshin <[email protected]>
Co-authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
.../apache/spark/sql/DataFrameSubquerySuite.scala | 57 +++++++--
.../sql/tests/connect/test_parity_subquery.py | 4 -
python/pyspark/sql/tests/test_subquery.py | 39 +++++-
.../spark/sql/catalyst/analysis/unresolved.scala | 132 ++++++++++++++++++---
.../sql/connect/planner/SparkConnectPlanner.scala | 33 +++---
.../connect/planner/SparkConnectPlannerSuite.scala | 33 +++---
.../main/scala/org/apache/spark/sql/Dataset.scala | 56 +++------
.../apache/spark/sql/DataFrameSubquerySuite.scala | 48 +++++++-
8 files changed, 295 insertions(+), 107 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index 4b36d36983a5..1d2165b668f6 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession}
@@ -665,15 +665,52 @@ class DataFrameSubquerySuite extends QueryTest with
RemoteSparkSession {
withView("t1") {
val t1 = table1()
- // TODO(SPARK-50601): Fix the SparkConnectPlanner to support this case
- checkError(
- intercept[SparkException] {
- t1.withColumn("scalar", spark.range(1).select($"c1".outer() +
$"c2".outer()).scalar())
- .collect()
- },
- "INTERNAL_ERROR",
- parameters = Map("message" -> "Found the unresolved operator: .*"),
- matchPVals = true)
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .select($"c1".outer() + $"c2".outer())
+ .scalar()),
+ t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .withColumn("c1", $"c1".outer())
+ .select($"c1" + $"c2".outer())
+ .scalar()),
+ t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .select($"c1".outer().as("c1"))
+ .withColumn("c2", $"c2".outer())
+ .select($"c1" + $"c2")
+ .scalar()),
+ t1.select($"*", ($"c1" + $"c2").as("scalar")))
+ }
+ }
+
+ test("subquery in withColumnsRenamed") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .select($"c1".outer().as("c1"), $"c2".outer().as("c2"))
+ .withColumnsRenamed(Map("c1" -> "x", "c2" -> "y"))
+ .select($"x" + $"y")
+ .scalar()),
+ t1.select($"*", ($"c1".as("x") + $"c2".as("y")).as("scalar")))
}
}
diff --git a/python/pyspark/sql/tests/connect/test_parity_subquery.py
b/python/pyspark/sql/tests/connect/test_parity_subquery.py
index dae60a354d20..f3225fcb7f2d 100644
--- a/python/pyspark/sql/tests/connect/test_parity_subquery.py
+++ b/python/pyspark/sql/tests/connect/test_parity_subquery.py
@@ -45,10 +45,6 @@ class SubqueryParityTests(SubqueryTestsMixin,
ReusedConnectTestCase):
def test_subquery_in_unpivot(self):
self.check_subquery_in_unpivot(None, None)
- @unittest.skip("SPARK-50601: Fix the SparkConnectPlanner to support this
case")
- def test_subquery_in_with_columns(self):
- super().test_subquery_in_with_columns()
-
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_subquery import * # noqa: F401
diff --git a/python/pyspark/sql/tests/test_subquery.py
b/python/pyspark/sql/tests/test_subquery.py
index 99a22d7c2966..7c63ddb69458 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -939,7 +939,44 @@ class SubqueryTestsMixin:
.select(sf.col("c1").outer() + sf.col("c2").outer())
.scalar(),
),
- t1.withColumn("scalar", sf.col("c1") + sf.col("c2")),
+ t1.select("*", (sf.col("c1") + sf.col("c2")).alias("scalar")),
+ )
+ assertDataFrameEqual(
+ t1.withColumn(
+ "scalar",
+ self.spark.range(1)
+ .withColumn("c1", sf.col("c1").outer())
+ .select(sf.col("c1") + sf.col("c2").outer())
+ .scalar(),
+ ),
+ t1.select("*", (sf.col("c1") + sf.col("c2")).alias("scalar")),
+ )
+ assertDataFrameEqual(
+ t1.withColumn(
+ "scalar",
+ self.spark.range(1)
+ .select(sf.col("c1").outer().alias("c1"))
+ .withColumn("c2", sf.col("c2").outer())
+ .select(sf.col("c1") + sf.col("c2"))
+ .scalar(),
+ ),
+ t1.select("*", (sf.col("c1") + sf.col("c2")).alias("scalar")),
+ )
+
+ def test_subquery_in_with_columns_renamed(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ assertDataFrameEqual(
+ t1.withColumn(
+ "scalar",
+ self.spark.range(1)
+ .select(sf.col("c1").outer().alias("c1"),
sf.col("c2").outer().alias("c2"))
+ .withColumnsRenamed({"c1": "x", "c2": "y"})
+ .select(sf.col("x") + sf.col("y"))
+ .scalar(),
+ ),
+ t1.select("*", (sf.col("c1").alias("x") +
sf.col("c2").alias("y")).alias("scalar")),
)
def test_subquery_in_drop(self):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b47af90c651a..fabe551d054c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -429,7 +429,7 @@ object UnresolvedFunction {
* Represents all of the input attributes to a given relational operator, for
example in
* "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
*/
-abstract class Star extends LeafExpression with NamedExpression {
+trait Star extends NamedExpression {
override def name: String = throw new UnresolvedException("name")
override def exprId: ExprId = throw new UnresolvedException("exprId")
@@ -451,15 +451,20 @@ abstract class Star extends LeafExpression with
NamedExpression {
* This is also used to expand structs. For example:
* "SELECT record.* from (SELECT struct(a,b,c) as record ...)
*
- * @param target an optional name that should be the target of the expansion.
If omitted all
- * targets' columns are produced. This can either be a table
name or struct name. This
- * is a list of identifiers that is the path of the expansion.
- *
- * This class provides the shared behavior between the classes for SELECT *
([[UnresolvedStar]])
- * and SELECT * EXCEPT ([[UnresolvedStarExceptOrReplace]]). [[UnresolvedStar]]
is just a case class
- * of this, while [[UnresolvedStarExceptOrReplace]] adds some additional logic
to the expand method.
+ * This trait provides the shared behavior among the classes for SELECT *
([[UnresolvedStar]])
+ * and SELECT * EXCEPT ([[UnresolvedStarExceptOrReplace]]), etc.
[[UnresolvedStar]] is just a case
+ * class of this, while [[UnresolvedStarExceptOrReplace]] or other classes add
some additional logic
+ * to the expand method.
*/
-abstract class UnresolvedStarBase(target: Option[Seq[String]]) extends Star
with Unevaluable {
+trait UnresolvedStarBase extends Star with Unevaluable {
+
+ /**
+ * An optional name that should be the target of the expansion. If omitted
all
+ * targets' columns are produced. This can either be a table name or struct
name. This
+ * is a list of identifiers that is the path of the expansion.
+ */
+ def target: Option[Seq[String]]
+
/**
* Returns true if the nameParts is a subset of the last elements of
qualifier of the attribute.
*
@@ -583,7 +588,7 @@ case class UnresolvedStarExceptOrReplace(
target: Option[Seq[String]],
excepts: Seq[Seq[String]],
replacements: Option[Seq[NamedExpression]])
- extends UnresolvedStarBase(target) {
+ extends LeafExpression with UnresolvedStarBase {
/**
* We expand the * EXCEPT by the following three steps:
@@ -712,6 +717,103 @@ case class UnresolvedStarExceptOrReplace(
}
}
+/**
+ * Represents some of the input attributes to a given relational operator, for
example in
+ * `df.withColumn`.
+ *
+ * @param colNames a list of column names that should be replaced or produced.
+ *
+ * @param exprs the corresponding expressions for `colNames`.
+ *
+ * @param explicitMetadata an optional list of explicit metadata to associate
with the columns.
+ */
+case class UnresolvedStarWithColumns(
+ colNames: Seq[String],
+ exprs: Seq[Expression],
+ explicitMetadata: Option[Seq[Metadata]] = None)
+ extends UnresolvedStarBase {
+
+ override def target: Option[Seq[String]] = None
+ override def children: Seq[Expression] = exprs
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): UnresolvedStarWithColumns =
+ copy(exprs = newChildren)
+
+ override def expand(input: LogicalPlan, resolver: Resolver):
Seq[NamedExpression] = {
+ assert(colNames.size == exprs.size,
+ s"The size of column names: ${colNames.size} isn't equal to " +
+ s"the size of expressions: ${exprs.size}")
+ explicitMetadata.foreach { m =>
+ assert(colNames.size == m.size,
+ s"The size of column names: ${colNames.size} isn't equal to " +
+ s"the size of metadata elements: ${m.size}")
+ }
+
+ SchemaUtils.checkColumnNameDuplication(colNames, resolver)
+
+ val expandedCols = super.expand(input, resolver)
+
+ val columnSeq = explicitMetadata match {
+ case Some(ms) => colNames.zip(exprs).zip(ms.map(Some(_)))
+ case _ => colNames.zip(exprs).map((_, None))
+ }
+
+ val replacedAndExistingColumns = expandedCols.map { field =>
+ columnSeq.find { case ((colName, _), _) =>
+ resolver(field.name, colName)
+ } match {
+ case Some(((colName, expr), m)) => Alias(expr,
colName)(explicitMetadata = m)
+ case _ => field
+ }
+ }
+
+ val newColumns = columnSeq.filter { case ((colName, _), _) =>
+ !expandedCols.exists(f => resolver(f.name, colName))
+ }.map {
+ case ((colName, expr), m) => Alias(expr, colName)(explicitMetadata = m)
+ }
+
+ replacedAndExistingColumns ++ newColumns
+ }
+}
+
+/**
+ * Represents some of the input attributes to a given relational operator, for
example in
+ * `df.withColumnRenamed`.
+ *
+ * @param existingNames a list of column names that should be replaced.
+ * If the column does not exist, it is ignored.
+ *
+ * @param newNames a list of new column names that should be used to replace
the existing columns.
+ */
+case class UnresolvedStarWithColumnsRenames(
+ existingNames: Seq[String],
+ newNames: Seq[String])
+ extends LeafExpression with UnresolvedStarBase {
+
+ override def target: Option[Seq[String]] = None
+
+ override def expand(input: LogicalPlan, resolver: Resolver):
Seq[NamedExpression] = {
+ assert(existingNames.size == newNames.size,
+ s"The size of existing column names: ${existingNames.size} isn't equal
to " +
+ s"the size of new column names: ${newNames.size}")
+
+ val expandedCols = super.expand(input, resolver)
+
+ existingNames.zip(newNames).foldLeft(expandedCols) {
+ case (attrs, (existingName, newName)) =>
+ attrs.map(attr =>
+ if (resolver(attr.name, existingName)) {
+ Alias(attr, newName)()
+ } else {
+ attr
+ }
+ )
+ }
+ }
+}
+
/**
* Represents all of the input attributes to a given relational operator, for
example in
* "SELECT * FROM ...".
@@ -723,7 +825,8 @@ case class UnresolvedStarExceptOrReplace(
* targets' columns are produced. This can either be a table name
or struct name. This
* is a list of identifiers that is the path of the expansion.
*/
-case class UnresolvedStar(target: Option[Seq[String]]) extends
UnresolvedStarBase(target)
+case class UnresolvedStar(target: Option[Seq[String]])
+ extends LeafExpression with UnresolvedStarBase
/**
* Represents all of the input attributes to a given relational operator, for
example in
@@ -733,7 +836,7 @@ case class UnresolvedStar(target: Option[Seq[String]])
extends UnresolvedStarBas
* tables' columns are produced.
*/
case class UnresolvedRegex(regexPattern: String, table: Option[String],
caseSensitive: Boolean)
- extends Star with Unevaluable {
+ extends LeafExpression with Star with Unevaluable {
override def expand(input: LogicalPlan, resolver: Resolver):
Seq[NamedExpression] = {
val pattern = if (caseSensitive) regexPattern else s"(?i)$regexPattern"
table match {
@@ -791,7 +894,8 @@ case class MultiAlias(child: Expression, names: Seq[String])
*
* @param expressions Expressions to expand.
*/
-case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with
Unevaluable {
+case class ResolvedStar(expressions: Seq[NamedExpression])
+ extends LeafExpression with Star with Unevaluable {
override def newInstance(): NamedExpression = throw new
UnresolvedException("newInstance")
override def expand(input: LogicalPlan, resolver: Resolver):
Seq[NamedExpression] = expressions
override def toString: String = expressions.mkString("ResolvedStar(", ", ",
")")
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6ab69aea12e5..acbbeb49b267 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID,
SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LazyExpression, LocalTempView, MultiAlias,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases,
UnresolvedTableValuedFunction, UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LazyExpression, LocalTempView, MultiAlias,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns,
UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases,
UnresolvedTableValuedFunc [...]
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -1065,25 +1065,21 @@ class SparkConnectPlanner(
}
private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed):
LogicalPlan = {
- if (rel.getRenamesCount > 0) {
- val (colNames, newColNames) = rel.getRenamesList.asScala.toSeq.map {
rename =>
+ val (colNames, newColNames) = if (rel.getRenamesCount > 0) {
+ rel.getRenamesList.asScala.toSeq.map { rename =>
(rename.getColName, rename.getNewColName)
}.unzip
- Dataset
- .ofRows(session, transformRelation(rel.getInput))
- .withColumnsRenamed(colNames, newColNames)
- .logicalPlan
} else {
// for backward compatibility
- Dataset
- .ofRows(session, transformRelation(rel.getInput))
- .withColumnsRenamed(rel.getRenameColumnsMapMap)
- .logicalPlan
+ rel.getRenameColumnsMapMap.asScala.toSeq.unzip
}
+ Project(
+ Seq(UnresolvedStarWithColumnsRenames(existingNames = colNames, newNames
= newColNames)),
+ transformRelation(rel.getInput))
}
private def transformWithColumns(rel: proto.WithColumns): LogicalPlan = {
- val (colNames, cols, metadata) =
+ val (colNames, exprs, metadata) =
rel.getAliasesList.asScala.toSeq.map { alias =>
if (alias.getNameCount != 1) {
throw InvalidPlanInput(s"""WithColumns require column name only
contains one name part,
@@ -1096,13 +1092,16 @@ class SparkConnectPlanner(
Metadata.empty
}
- (alias.getName(0), Column(transformExpression(alias.getExpr)),
metadata)
+ (alias.getName(0), transformExpression(alias.getExpr), metadata)
}.unzip3
- Dataset
- .ofRows(session, transformRelation(rel.getInput))
- .withColumns(colNames, cols, metadata)
- .logicalPlan
+ Project(
+ Seq(
+ UnresolvedStarWithColumns(
+ colNames = colNames,
+ exprs = exprs,
+ explicitMetadata = Some(metadata))),
+ transformRelation(rel.getInput))
}
private def transformWithWatermark(rel: proto.WithWatermark): LogicalPlan = {
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index aaeb5d9fe509..054a32179935 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -504,26 +504,27 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
}
test("Test duplicated names in WithColumns") {
- intercept[AnalysisException] {
- transform(
- proto.Relation
- .newBuilder()
- .setWithColumns(
- proto.WithColumns
- .newBuilder()
- .setInput(readRel)
- .addAliases(proto.Expression.Alias
+ val logical = transform(
+ proto.Relation
+ .newBuilder()
+ .setWithColumns(
+ proto.WithColumns
+ .newBuilder()
+ .setInput(readRel)
+ .addAliases(
+ proto.Expression.Alias
.newBuilder()
.addName("test")
.setExpr(proto.Expression.newBuilder
.setLiteral(proto.Expression.Literal.newBuilder.setInteger(32))))
- .addAliases(proto.Expression.Alias
- .newBuilder()
- .addName("test")
- .setExpr(proto.Expression.newBuilder
-
.setLiteral(proto.Expression.Literal.newBuilder.setInteger(32)))))
- .build())
- }
+ .addAliases(proto.Expression.Alias
+ .newBuilder()
+ .addName("test")
+ .setExpr(proto.Expression.newBuilder
+
.setLiteral(proto.Expression.Literal.newBuilder.setInteger(32)))))
+ .build())
+
+ intercept[AnalysisException](Dataset.ofRows(spark, logical))
}
test("Test multi nameparts for column names in WithColumns") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e4e782a50e3d..e41521cba533 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1275,29 +1275,14 @@ class Dataset[T] private[sql](
require(colNames.size == cols.size,
s"The size of column names: ${colNames.size} isn't equal to " +
s"the size of columns: ${cols.size}")
- SchemaUtils.checkColumnNameDuplication(
- colNames,
- sparkSession.sessionState.conf.caseSensitiveAnalysis)
-
- val resolver = sparkSession.sessionState.analyzer.resolver
- val output = queryExecution.analyzed.output
-
- val columnSeq = colNames.zip(cols)
-
- val replacedAndExistingColumns = output.map { field =>
- columnSeq.find { case (colName, _) =>
- resolver(field.name, colName)
- } match {
- case Some((colName: String, col: Column)) => col.as(colName)
- case _ => Column(field)
- }
+ withPlan {
+ Project(
+ Seq(
+ UnresolvedStarWithColumns(
+ colNames = colNames,
+ exprs = cols.map(_.expr))),
+ logicalPlan)
}
-
- val newColumns = columnSeq.filter { case (colName, col) =>
- !output.exists(f => resolver(f.name, colName))
- }.map { case (colName, col) => col.as(colName) }
-
- select(replacedAndExistingColumns ++ newColumns : _*)
}
/** @inheritdoc */
@@ -1324,26 +1309,13 @@ class Dataset[T] private[sql](
require(colNames.size == newColNames.size,
s"The size of existing column names: ${colNames.size} isn't equal to " +
s"the size of new column names: ${newColNames.size}")
-
- val resolver = sparkSession.sessionState.analyzer.resolver
- val output: Seq[NamedExpression] = queryExecution.analyzed.output
- var shouldRename = false
-
- val projectList = colNames.zip(newColNames).foldLeft(output) {
- case (attrs, (existingName, newName)) =>
- attrs.map(attr =>
- if (resolver(attr.name, existingName)) {
- shouldRename = true
- Alias(attr, newName)()
- } else {
- attr
- }
- )
- }
- if (shouldRename) {
- withPlan(Project(projectList, logicalPlan))
- } else {
- toDF()
+ withPlan {
+ Project(
+ Seq(
+ UnresolvedStarWithColumnsRenames(
+ existingNames = colNames,
+ newNames = newColNames)),
+ logicalPlan)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index fdfb909d9ba7..621d468454d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -777,9 +777,51 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
val t1 = table1()
checkAnswer(
- t1.withColumn("scalar", spark.range(1).select($"c1".outer() +
$"c2".outer()).scalar()),
- t1.withColumn("scalar", $"c1" + $"c2")
- )
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .select($"c1".outer() + $"c2".outer())
+ .scalar()),
+ t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .withColumn("c1", $"c1".outer())
+ .select($"c1" + $"c2".outer())
+ .scalar()),
+ t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .select($"c1".outer().as("c1"))
+ .withColumn("c2", $"c2".outer())
+ .select($"c1" + $"c2")
+ .scalar()),
+ t1.select($"*", ($"c1" + $"c2").as("scalar")))
+ }
+ }
+
+ test("subquery in withColumnsRenamed") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.withColumn(
+ "scalar",
+ spark
+ .range(1)
+ .select($"c1".outer().as("c1"), $"c2".outer().as("c2"))
+ .withColumnsRenamed(Map("c1" -> "x", "c2" -> "y"))
+ .select($"x" + $"y")
+ .scalar()),
+ t1.select($"*", ($"c1".as("x") + $"c2".as("y")).as("scalar")))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]