Repository: spark
Updated Branches:
  refs/heads/branch-1.3 c0101d392 -> c699e2b76


[SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes

Due to a recent change that made `StructType` a `Seq` we started inadvertently 
turning `StructType`s into generic `Traversable` when attempting nested tree 
transformations.  In this PR we explicitly avoid descending into `DataType`s to 
avoid this bug.

Author: Michael Armbrust <[email protected]>

Closes #5157 from marmbrus/udfFix and squashes the following commits:

26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold 
StructTypes

(cherry picked from commit 3fa3d121dfec60f9768d3859e8450ee482b2d4e8)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c699e2b7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c699e2b7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c699e2b7

Branch: refs/heads/branch-1.3
Commit: c699e2b766a7cb9e03762bf278d7b19f631cb4e8
Parents: c0101d3
Author: Michael Armbrust <[email protected]>
Authored: Tue Mar 24 12:28:01 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Tue Mar 24 12:28:15 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/QueryPlan.scala    |  2 ++
 .../spark/sql/catalyst/trees/TreeNode.scala     | 20 +++++++++++++++++---
 .../scala/org/apache/spark/sql/UDFSuite.scala   |  6 ++++++
 3 files changed, 25 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c699e2b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 48191f3..bd9291e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
       case e: Expression => transformExpressionDown(e)
       case Some(e: Expression) => Some(transformExpressionDown(e))
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map {
         case e: Expression => transformExpressionDown(e)
         case other => other
@@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] 
extends TreeNode[PlanTy
       case e: Expression => transformExpressionUp(e)
       case Some(e: Expression) => Some(transformExpressionUp(e))
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map {
         case e: Expression => transformExpressionUp(e)
         case other => other

http://git-wip-us.apache.org/repos/asf/spark/blob/c699e2b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index cfd0203..8fa4fc3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.trees
 
 import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.types.DataType
 
 /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given 
number */
 private class MutableInt(var i: Int)
@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
           Some(arg)
         }
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
         case arg: TreeNode[_] if children contains arg =>
           val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
           Some(arg)
         }
       case m: Map[_,_] => m
+      case d: DataType => d // Avoid unpacking Structs
       case args: Traversable[_] => args.map {
         case arg: TreeNode[_] if children contains arg =>
           val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
    * @param newArgs the new product arguments.
    */
   def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, 
"makeCopy") {
+    val defaultCtor =
+      getClass.getConstructors
+        .find(_.getParameterTypes.size != 0)
+        .headOption
+        .getOrElse(sys.error(s"No valid constructor for $nodeName"))
+
     try {
       CurrentOrigin.withOrigin(origin) {
         // Skip no-arg constructors that are just there for kryo.
-        val defaultCtor = 
getClass.getConstructors.find(_.getParameterTypes.size != 0).head
         if (otherCopyArgs.isEmpty) {
           defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
         } else {
@@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
     } catch {
       case e: java.lang.IllegalArgumentException =>
         throw new TreeNodeException(
-          this, s"Failed to copy node.  Is otherCopyArgs specified correctly 
for $nodeName? "
-            + s"Exception message: ${e.getMessage}.")
+          this,
+          s"""
+             |Failed to copy node.
+             |Is otherCopyArgs specified correctly for $nodeName.
+             |Exception message: ${e.getMessage}
+             |ctor: $defaultCtor?
+             |args: ${newArgs.mkString(", ")}
+           """.stripMargin)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c699e2b7/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index be105c6..d615542 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
         .select($"ret.f1").head().getString(0)
     assert(result === "test")
   }
+
+  test("udf that is transformed") {
+    udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+    // 1 + 1 is constant folded causing a transformation.
+    assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 
2))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to