This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6a42b63 [SPARK-34713][SQL] Fix group by CreateStruct with ExtractValue 6a42b63 is described below commit 6a42b633bf39981242f6f0d13ae40919f3fa7f8b Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu Mar 11 09:21:58 2021 -0800 [SPARK-34713][SQL] Fix group by CreateStruct with ExtractValue ### What changes were proposed in this pull request? This is a bug caused by https://issues.apache.org/jira/browse/SPARK-31670 . We remove the `Alias` when resolving column references in grouping expressions, which breaks `ResolveCreateNamedStruct` ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests Closes #31808 from cloud-fan/bug. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 ++ .../catalyst/expressions/complexTypeExtractors.scala | 11 ++++++++++- .../apache/spark/sql/DataFrameAggregateSuite.scala | 19 +++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 027a8c6..2c1fade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3977,6 +3977,8 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] { val children = e.children.grouped(2).flatMap { case Seq(NamePlaceholder, e: NamedExpression) if e.resolved => Seq(Literal(e.name), e) + case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && e.name.isDefined => + Seq(Literal(e.name.get), e) case kv => kv } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 139d9a5..4413a3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -90,7 +90,10 @@ object ExtractValue { } } -trait ExtractValue extends Expression +trait ExtractValue extends Expression { + // The name that is used to extract the value. + def name: Option[String] +} /** * Returns the value of fields in the Struct `child`. @@ -156,6 +159,7 @@ case class GetArrayStructFields( override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" + override def name: Option[String] = Some(field.name) protected override def nullSafeEval(input: Any): Any = { val array = input.asInstanceOf[ArrayData] @@ -233,6 +237,7 @@ case class GetArrayItem( override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" + override def name: Option[String] = None override def left: Expression = child override def right: Expression = ordinal @@ -448,6 +453,10 @@ case class GetMapValue( override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" + override def name: Option[String] = key match { + case NonNullLiteral(s, StringType) => Some(s.toString) + case _ => None + } override def left: Expression = child override def right: Expression = key diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 07e6a40..3e137d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1091,6 +1091,25 @@ class DataFrameAggregateSuite extends QueryTest val df = spark.sql(query) checkAnswer(df, Row(0, "0", 0, 0) :: Row(-1, "1", 1, 1) :: Row(-2, "2", 2, 2) :: Nil) } + + test("SPARK-34713: group by CreateStruct with ExtractValue") { + val structDF = Seq(Tuple1(1 -> 1)).toDF("col") + checkAnswer(structDF.groupBy(struct($"col._1")).count().select("count"), Row(1)) + + val arrayOfStructDF = Seq(Tuple1(Seq(1 -> 1))).toDF("col") + checkAnswer(arrayOfStructDF.groupBy(struct($"col._1")).count().select("count"), Row(1)) + + val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col") + checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) + + val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col") + // Spark implicit casts string literal "a" to int to match the key type. + checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), Row(1)) + + val arrayDF = Seq(Tuple1(Seq(1))).toDF("col") + val e = intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count()) + assert(e.message.contains("requires integral type")) + } } case class B(c: Option[Double]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org