This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a9c1189 [SPARK-34649][SQL][DOCS]
org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name
having a dot
a9c1189 is described below
commit a9c11896a5db3cd6844d5e444ad59e65d9441e7c
Author: Amandeep Sharma <[email protected]>
AuthorDate: Tue Mar 9 11:47:01 2021 +0000
[SPARK-34649][SQL][DOCS]
org.apache.spark.sql.DataFrameNaFunctions.replace() fails for column name
having a dot
### What changes were proposed in this pull request?
Use resolved attributes instead of data-frame fields for replacing values.
### Why are the changes needed?
dataframe.na.replace() does not work for column having a dot in the name
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
Added unit tests for the same
Closes #31769 from amandeep-sharma/master.
Authored-by: Amandeep Sharma <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
docs/sql-migration-guide.md | 2 +
.../apache/spark/sql/DataFrameNaFunctions.scala | 42 ++++++++++++---------
.../spark/sql/DataFrameNaFunctionsSuite.scala | 43 ++++++++++++++++++++--
3 files changed, 67 insertions(+), 20 deletions(-)
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index 0e96c6d..5551d56 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -66,6 +66,8 @@ license: |
- In Spark 3.2, the output schema of `SHOW TBLPROPERTIES` becomes `key:
string, value: string` whether you specify the table property key or not. In
Spark 3.1 and earlier, the output schema of `SHOW TBLPROPERTIES` is `value:
string` when you specify the table property key. To restore the old schema with
the builtin catalog, you can set `spark.sql.legacy.keepCommandOutputSchema` to
`true`.
- In Spark 3.2, we support typed literals in the partition spec of INSERT
and ADD/DROP/RENAME PARTITION. For example, `ADD PARTITION(dt =
date'2020-01-01')` adds a partition with date value `2020-01-01`. In Spark 3.1
and earlier, the partition value will be parsed as string value `date
'2020-01-01'`, which is an illegal date value, and we add a partition with null
value at the end.
+
+ - In Spark 3.2, `DataFrameNaFunctions.replace()` no longer uses exact string
match for the input column names, to match the SQL syntax and support qualified
column names. Input column name having a dot in the name (not nested) needs to
be escaped with backtick \`. Now, it throws `AnalysisException` if the column
is not found in the data frame schema. It also throws
`IllegalArgumentException` if the input column name is a nested column. In
Spark 3.1 and earlier, it used to ignore invali [...]
## Upgrading from Spark SQL 3.0 to 3.1
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 308bb96..91905f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -327,9 +327,9 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*/
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
if (col == "*") {
- replace0(df.columns, replacement)
+ replace0(df.logicalPlan.output, replacement)
} else {
- replace0(Seq(col), replacement)
+ replace(Seq(col), replacement)
}
}
@@ -352,10 +352,21 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame =
replace0(cols, replacement)
+ def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
+ val attrs = cols.map { colName =>
+ // Check column name exists
+ val attr = df.resolve(colName) match {
+ case a: Attribute => a
+ case _ => throw new UnsupportedOperationException(
+ s"Nested field ${colName} is not supported.")
+ }
+ attr
+ }
+ replace0(attrs, replacement)
+ }
- private def replace0[T](cols: Seq[String], replacement: Map[T, T]):
DataFrame = {
- if (replacement.isEmpty || cols.isEmpty) {
+ private def replace0[T](attrs: Seq[Attribute], replacement: Map[T, T]):
DataFrame = {
+ if (replacement.isEmpty || attrs.isEmpty) {
return df
}
@@ -379,15 +390,13 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
case _: String => StringType
}
- val columnEquals = df.sparkSession.sessionState.analyzer.resolver
- val projections = df.schema.fields.map { f =>
- val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
- if (f.dataType.isInstanceOf[NumericType] && targetColumnType ==
DoubleType && shouldReplace) {
- replaceCol(f, replacementMap)
- } else if (f.dataType == targetColumnType && shouldReplace) {
- replaceCol(f, replacementMap)
+ val output = df.queryExecution.analyzed.output
+ val projections = output.map { attr =>
+ if (attrs.contains(attr) && (attr.dataType == targetColumnType ||
+ (attr.dataType.isInstanceOf[NumericType] && targetColumnType ==
DoubleType))) {
+ replaceCol(attr, replacementMap)
} else {
- df.col(f.name)
+ Column(attr)
}
}
df.select(projections : _*)
@@ -453,13 +462,12 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* TODO: This can be optimized to use broadcast join when replacementMap is
large.
*/
- private def replaceCol[K, V](col: StructField, replacementMap: Map[K, V]):
Column = {
- val keyExpr = df.col(col.name).expr
- def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
+ private def replaceCol[K, V](attr: Attribute, replacementMap: Map[K, V]):
Column = {
+ def buildExpr(v: Any) = Cast(Literal(v), attr.dataType)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(Literal(source), buildExpr(target))
}.toSeq
- new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
+ new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
}
private def convertToDouble(v: Any): Double = v match {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 23c2349..20ae995 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -461,7 +461,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with
SharedSparkSession {
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}
- test("SPARK-34417 - test fillMap() for column with a dot in the name") {
+ test("SPARK-34417: test fillMap() for column with a dot in the name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot", "Col")
@@ -469,7 +469,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with
SharedSparkSession {
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
- test("SPARK-34417 - test fillMap() for qualified-column with a dot in the
name") {
+ test("SPARK-34417: test fillMap() for qualified-column with a dot in the
name") {
val na = "n/a"
checkAnswer(
Seq(("abc", 23L), ("def", 44L), (null, 0L)).toDF("ColWith.Dot",
"Col").as("testDF")
@@ -477,7 +477,7 @@ class DataFrameNaFunctionsSuite extends QueryTest with
SharedSparkSession {
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
- test("SPARK-34417 - test fillMap() for column without a dot in the name" +
+ test("SPARK-34417: test fillMap() for column without a dot in the name" +
" and dataframe with another column having a dot in the name") {
val na = "n/a"
checkAnswer(
@@ -485,4 +485,41 @@ class DataFrameNaFunctionsSuite extends QueryTest with
SharedSparkSession {
.na.fill(Map("Col" -> na)),
Row("abc", 23) :: Row("def", 44L) :: Row(na, 0L) :: Nil)
}
+
+ test("SPARK-34649: replace value of a column with dot in the name") {
+ checkAnswer(
+ Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
+ .na.replace("`Col.1`", Map( "n/a" -> "unknown")),
+ Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
+ }
+
+ test("SPARK-34649: replace value of a qualified-column with dot in the
name") {
+ checkAnswer(
+ Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1",
"Col.2").as("testDf")
+ .na.replace("testDf.`Col.1`", Map( "n/a" -> "unknown")),
+ Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
+ }
+
+ test("SPARK-34649: replace value of a dataframe having dot in the all column
names") {
+ checkAnswer(
+ Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
+ .na.replace("*", Map( "n/a" -> "unknown")),
+ Row("abc", 23) :: Row("def", 44L) :: Row("unknown", 0L) :: Nil)
+ }
+
+ test("SPARK-34649: replace value of a column not present in the dataframe") {
+ val df = Seq(("abc", 23), ("def", 44), ("n/a", 0)).toDF("Col.1", "Col.2")
+ val exception = intercept[AnalysisException] {
+ df.na.replace("aa", Map( "n/a" -> "unknown"))
+ }
+ assert(exception.getMessage.equals("Cannot resolve column name \"aa\"
among (Col.1, Col.2)"))
+ }
+
+ test("SPARK-34649: replace value of a nested column") {
+ val df = createDFWithNestedColumns
+ val exception = intercept[UnsupportedOperationException] {
+ df.na.replace("c1.c1-1", Map("b1" ->"a1"))
+ }
+ assert(exception.getMessage.equals("Nested field c1.c1-1 is not
supported."))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]