Repository: spark
Updated Branches:
  refs/heads/master b73eb0efe -> 41a7de600


[SPARK-25084][SQL] "distribute by" on multiple columns (wrap in brackets) may 
lead to codegen issue

## What changes were proposed in this pull request?

"distribute by" on multiple columns (wrap in brackets) may lead to codegen 
issue.

Simple way to reproduce:
```scala
  val df = spark.range(1000)
  val columns = (0 until 400).map{ i => s"id as id$i" }
  val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
  df.selectExpr(columns : _*).createTempView("test")
  spark.sql(s"select * from test distribute by ($distributeExprs)").count()
```

## How was this patch tested?

Add UT.

Closes #22066 from yucai/SPARK-25084.

Authored-by: yucai <y...@ebay.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/41a7de60
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/41a7de60
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/41a7de60

Branch: refs/heads/master
Commit: 41a7de6002d071ba81321bbe02b46db4b3f8cda2
Parents: b73eb0e
Author: yucai <y...@ebay.com>
Authored: Sat Aug 11 21:38:31 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Sat Aug 11 21:38:31 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/hash.scala   | 23 +++++++++++++++-----
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 12 ++++++++++
 2 files changed, 29 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/41a7de60/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index cec00b6..a754e87 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression {
       input: String,
       result: String,
       fields: Array[StructField]): String = {
+    val tmpInput = ctx.freshName("input")
     val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
-      nullSafeElementHash(input, index.toString, field.nullable, 
field.dataType, result, ctx)
+      nullSafeElementHash(tmpInput, index.toString, field.nullable, 
field.dataType, result, ctx)
     }
     val hashResultType = CodeGenerator.javaType(dataType)
-    ctx.splitExpressions(
+    val code = ctx.splitExpressions(
       expressions = fieldsHash,
       funcName = "computeHashForStruct",
-      arguments = Seq("InternalRow" -> input, hashResultType -> result),
+      arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result),
       returnType = hashResultType,
       makeSplitFunction = body =>
         s"""
@@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression {
            |return $result;
          """.stripMargin,
       foldFunctions = _.map(funcCall => s"$result = 
$funcCall;").mkString("\n"))
+    s"""
+       |final InternalRow $tmpInput = $input;
+       |$code
+     """.stripMargin
   }
 
   @tailrec
@@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
       input: String,
       result: String,
       fields: Array[StructField]): String = {
+    val tmpInput = ctx.freshName("input")
     val childResult = ctx.freshName("childResult")
     val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
       val computeFieldHash = nullSafeElementHash(
-        input, index.toString, field.nullable, field.dataType, childResult, 
ctx)
+        tmpInput, index.toString, field.nullable, field.dataType, childResult, 
ctx)
       s"""
          |$childResult = 0;
          |$computeFieldHash
@@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
        """.stripMargin
     }
 
-    s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
+    val code = ctx.splitExpressions(
       expressions = fieldsHash,
       funcName = "computeHashForStruct",
-      arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> 
result),
+      arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> 
result),
       returnType = CodeGenerator.JAVA_INT,
       makeSplitFunction = body =>
         s"""
@@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
            |return $result;
            """.stripMargin,
       foldFunctions = _.map(funcCall => s"$result = 
$funcCall;").mkString("\n"))
+    s"""
+       |final InternalRow $tmpInput = $input;
+       |${CodeGenerator.JAVA_INT} $childResult = 0;
+       |$code
+     """.stripMargin
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/41a7de60/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c1a5f50..84efd2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2840,4 +2840,16 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
       }
     }
   }
+
+  test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen 
issue") {
+    withView("spark_25084") {
+      val count = 1000
+      val df = spark.range(count)
+      val columns = (0 until 400).map{ i => s"id as id$i" }
+      val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
+      df.selectExpr(columns : _*).createTempView("spark_25084")
+      assert(
+        spark.sql(s"select * from spark_25084 distribute by 
($distributeExprs)").count === count)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to