Repository: spark Updated Branches: refs/heads/branch-1.3 c0101d392 -> c699e2b76
[SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations. In this PR we explicitly avoid descending into `DataType`s to avoid this bug. Author: Michael Armbrust <[email protected]> Closes #5157 from marmbrus/udfFix and squashes the following commits: 26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes (cherry picked from commit 3fa3d121dfec60f9768d3859e8450ee482b2d4e8) Signed-off-by: Michael Armbrust <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c699e2b7 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c699e2b7 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c699e2b7 Branch: refs/heads/branch-1.3 Commit: c699e2b766a7cb9e03762bf278d7b19f631cb4e8 Parents: c0101d3 Author: Michael Armbrust <[email protected]> Authored: Tue Mar 24 12:28:01 2015 -0700 Committer: Michael Armbrust <[email protected]> Committed: Tue Mar 24 12:28:15 2015 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 ++ .../spark/sql/catalyst/trees/TreeNode.scala | 20 +++++++++++++++++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 6 ++++++ 3 files changed, 25 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c699e2b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 48191f3..bd9291e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) case other => other @@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) case other => other http://git-wip-us.apache.org/repos/asf/spark/blob/c699e2b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cfd0203..8fa4fc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.types.DataType /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) @@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { Some(arg) } case m: Map[_,_] => m + case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformUp(rule) @@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { + val defaultCtor = + getClass.getConstructors + .find(_.getParameterTypes.size != 0) + .headOption + .getOrElse(sys.error(s"No valid constructor for $nodeName")) + try { CurrentOrigin.withOrigin(origin) { // Skip no-arg constructors that are just there for kryo. - val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] } else { @@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " - + s"Exception message: ${e.getMessage}.") + this, + s""" + |Failed to copy node. + |Is otherCopyArgs specified correctly for $nodeName. + |Exception message: ${e.getMessage} + |ctor: $defaultCtor? + |args: ${newArgs.mkString(", ")} + """.stripMargin) } } http://git-wip-us.apache.org/repos/asf/spark/blob/c699e2b7/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index be105c6..d615542 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -50,4 +50,10 @@ class UDFSuite extends QueryTest { .select($"ret.f1").head().getString(0) assert(result === "test") } + + test("udf that is transformed") { + udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + // 1 + 1 is constant folded causing a transformation. + assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
