This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6a42b63 [SPARK-34713][SQL] Fix group by CreateStruct with ExtractValue
6a42b63 is described below
commit 6a42b633bf39981242f6f0d13ae40919f3fa7f8b
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Mar 11 09:21:58 2021 -0800
[SPARK-34713][SQL] Fix group by CreateStruct with ExtractValue
### What changes were proposed in this pull request?
This is a bug caused by https://issues.apache.org/jira/browse/SPARK-31670 .
We remove the `Alias` when resolving column references in grouping expressions,
which breaks `ResolveCreateNamedStruct`
### Why are the changes needed?
bug fix
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
new tests
Closes #31808 from cloud-fan/bug.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 ++
.../catalyst/expressions/complexTypeExtractors.scala | 11 ++++++++++-
.../apache/spark/sql/DataFrameAggregateSuite.scala | 19 +++++++++++++++++++
3 files changed, 31 insertions(+), 1 deletion(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 027a8c6..2c1fade 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3977,6 +3977,8 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan]
{
val children = e.children.grouped(2).flatMap {
case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
Seq(Literal(e.name), e)
+ case Seq(NamePlaceholder, e: ExtractValue) if e.resolved &&
e.name.isDefined =>
+ Seq(Literal(e.name.get), e)
case kv =>
kv
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 139d9a5..4413a3d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -90,7 +90,10 @@ object ExtractValue {
}
}
-trait ExtractValue extends Expression
+trait ExtractValue extends Expression {
+ // The name that is used to extract the value.
+ def name: Option[String]
+}
/**
* Returns the value of fields in the Struct `child`.
@@ -156,6 +159,7 @@ case class GetArrayStructFields(
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
+ override def name: Option[String] = Some(field.name)
protected override def nullSafeEval(input: Any): Any = {
val array = input.asInstanceOf[ArrayData]
@@ -233,6 +237,7 @@ case class GetArrayItem(
override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"
+ override def name: Option[String] = None
override def left: Expression = child
override def right: Expression = ordinal
@@ -448,6 +453,10 @@ case class GetMapValue(
override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"
+ override def name: Option[String] = key match {
+ case NonNullLiteral(s, StringType) => Some(s.toString)
+ case _ => None
+ }
override def left: Expression = child
override def right: Expression = key
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 07e6a40..3e137d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1091,6 +1091,25 @@ class DataFrameAggregateSuite extends QueryTest
val df = spark.sql(query)
checkAnswer(df, Row(0, "0", 0, 0) :: Row(-1, "1", 1, 1) :: Row(-2, "2", 2,
2) :: Nil)
}
+
+ test("SPARK-34713: group by CreateStruct with ExtractValue") {
+ val structDF = Seq(Tuple1(1 -> 1)).toDF("col")
+ checkAnswer(structDF.groupBy(struct($"col._1")).count().select("count"),
Row(1))
+
+ val arrayOfStructDF = Seq(Tuple1(Seq(1 -> 1))).toDF("col")
+
checkAnswer(arrayOfStructDF.groupBy(struct($"col._1")).count().select("count"),
Row(1))
+
+ val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col")
+ checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"),
Row(1))
+
+ val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col")
+ // Spark implicit casts string literal "a" to int to match the key type.
+
checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"),
Row(1))
+
+ val arrayDF = Seq(Tuple1(Seq(1))).toDF("col")
+ val e =
intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count())
+ assert(e.message.contains("requires integral type"))
+ }
}
case class B(c: Option[Double])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]