This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a42b63  [SPARK-34713][SQL] Fix group by CreateStruct with ExtractValue
6a42b63 is described below

commit 6a42b633bf39981242f6f0d13ae40919f3fa7f8b
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu Mar 11 09:21:58 2021 -0800

    [SPARK-34713][SQL] Fix group by CreateStruct with ExtractValue
    
    ### What changes were proposed in this pull request?
    
    This is a bug caused by https://issues.apache.org/jira/browse/SPARK-31670 . 
We remove the `Alias` when resolving column references in grouping expressions, 
which breaks `ResolveCreateNamedStruct`
    
    ### Why are the changes needed?
    
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new tests
    
    Closes #31808 from cloud-fan/bug.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/catalyst/analysis/Analyzer.scala |  2 ++
 .../catalyst/expressions/complexTypeExtractors.scala  | 11 ++++++++++-
 .../apache/spark/sql/DataFrameAggregateSuite.scala    | 19 +++++++++++++++++++
 3 files changed, 31 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 027a8c6..2c1fade 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3977,6 +3977,8 @@ object ResolveCreateNamedStruct extends Rule[LogicalPlan] 
{
       val children = e.children.grouped(2).flatMap {
         case Seq(NamePlaceholder, e: NamedExpression) if e.resolved =>
           Seq(Literal(e.name), e)
+        case Seq(NamePlaceholder, e: ExtractValue) if e.resolved && 
e.name.isDefined =>
+          Seq(Literal(e.name.get), e)
         case kv =>
           kv
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 139d9a5..4413a3d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -90,7 +90,10 @@ object ExtractValue {
   }
 }
 
-trait ExtractValue extends Expression
+trait ExtractValue extends Expression {
+  // The name that is used to extract the value.
+  def name: Option[String]
+}
 
 /**
  * Returns the value of fields in the Struct `child`.
@@ -156,6 +159,7 @@ case class GetArrayStructFields(
   override def dataType: DataType = ArrayType(field.dataType, containsNull)
   override def toString: String = s"$child.${field.name}"
   override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
+  override def name: Option[String] = Some(field.name)
 
   protected override def nullSafeEval(input: Any): Any = {
     val array = input.asInstanceOf[ArrayData]
@@ -233,6 +237,7 @@ case class GetArrayItem(
 
   override def toString: String = s"$child[$ordinal]"
   override def sql: String = s"${child.sql}[${ordinal.sql}]"
+  override def name: Option[String] = None
 
   override def left: Expression = child
   override def right: Expression = ordinal
@@ -448,6 +453,10 @@ case class GetMapValue(
 
   override def toString: String = s"$child[$key]"
   override def sql: String = s"${child.sql}[${key.sql}]"
+  override def name: Option[String] = key match {
+    case NonNullLiteral(s, StringType) => Some(s.toString)
+    case _ => None
+  }
 
   override def left: Expression = child
   override def right: Expression = key
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 07e6a40..3e137d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1091,6 +1091,25 @@ class DataFrameAggregateSuite extends QueryTest
     val df = spark.sql(query)
     checkAnswer(df, Row(0, "0", 0, 0) :: Row(-1, "1", 1, 1) :: Row(-2, "2", 2, 
2) :: Nil)
   }
+
+  test("SPARK-34713: group by CreateStruct with ExtractValue") {
+    val structDF = Seq(Tuple1(1 -> 1)).toDF("col")
+    checkAnswer(structDF.groupBy(struct($"col._1")).count().select("count"), 
Row(1))
+
+    val arrayOfStructDF = Seq(Tuple1(Seq(1 -> 1))).toDF("col")
+    
checkAnswer(arrayOfStructDF.groupBy(struct($"col._1")).count().select("count"), 
Row(1))
+
+    val mapDF = Seq(Tuple1(Map("a" -> "a"))).toDF("col")
+    checkAnswer(mapDF.groupBy(struct($"col.a")).count().select("count"), 
Row(1))
+
+    val nonStringMapDF = Seq(Tuple1(Map(1 -> 1))).toDF("col")
+    // Spark implicit casts string literal "a" to int to match the key type.
+    
checkAnswer(nonStringMapDF.groupBy(struct($"col.a")).count().select("count"), 
Row(1))
+
+    val arrayDF = Seq(Tuple1(Seq(1))).toDF("col")
+    val e = 
intercept[AnalysisException](arrayDF.groupBy(struct($"col.a")).count())
+    assert(e.message.contains("requires integral type"))
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to