Repository: spark
Updated Branches:
  refs/heads/master fc1efb720 -> a425a37a5


[SPARK-17426][SQL] Refactor `TreeNode.toJSON` to avoid OOM when converting 
unknown fields to JSON

## What changes were proposed in this pull request?

This PR is a follow up of SPARK-17356. Current implementation of 
`TreeNode.toJSON` recursively converts all fields of TreeNode to JSON, even if 
the field is of type `Seq` or type Map. This may trigger out of memory 
exception in cases like:

1. the Seq or Map can be very big. Converting them to JSON may take huge 
memory, which may trigger out of memory error.
2. Some user space input may also be propagated to the Plan. The user space 
input can be of arbitrary type, and may also be self-referencing. Trying to 
print user space input to JSON may trigger out of memory error or stack 
overflow error.

For a code example, please check the Jira description of SPARK-17426.

In this PR, we refactor the `TreeNode.toJSON` so that we only convert a field 
to JSON string if the field is a safe type.

## How was this patch tested?

Unit test.

Author: Sean Zhong <seanzh...@databricks.com>

Closes #14990 from clockfly/json_oom2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a425a37a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a425a37a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a425a37a

Branch: refs/heads/master
Commit: a425a37a5d894e0d7462c8faa81a913495189ece
Parents: fc1efb7
Author: Sean Zhong <seanzh...@databricks.com>
Authored: Fri Sep 16 19:37:30 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Sep 16 19:37:30 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/trees/TreeNode.scala     | 218 +++-----------
 .../sql/catalyst/trees/TreeNodeSuite.scala      | 294 ++++++++++++++++++-
 .../scala/org/apache/spark/sql/QueryTest.scala  | 136 ---------
 3 files changed, 333 insertions(+), 315 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a425a37a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 893af51..83cb375 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -30,10 +30,15 @@ import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, 
CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.ScalaReflection._
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
Partitioning}
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
@@ -597,7 +602,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
       // this child in all children.
       case (name, value: TreeNode[_]) if containsChild(value) =>
         name -> JInt(children.indexOf(value))
-      case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) 
=>
+      case (name, value: Seq[BaseType]) if value.forall(containsChild) =>
         name -> JArray(
           value.map(v => 
JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList
         )
@@ -621,194 +626,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     // SPARK-17356: In usage of mllib, Metadata may store a huge vector of 
data, transforming
     // it to JSON may trigger OutOfMemoryError.
     case m: Metadata => Metadata.empty.jsonValue
+    case clazz: Class[_] => JString(clazz.getName)
     case s: StorageLevel =>
       ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" 
-> s.useOffHeap) ~
         ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication)
     case n: TreeNode[_] => n.jsonValue
     case o: Option[_] => o.map(parseToJson)
-    case t: Seq[_] => JArray(t.map(parseToJson).toList)
-    case m: Map[_, _] =>
-      val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) }
-      JObject(fields)
-    case r: RDD[_] => JNothing
+    // Recursive scan Seq[TreeNode], Seq[Partitioning], Seq[DataType]
+    case t: Seq[_] if t.forall(_.isInstanceOf[TreeNode[_]]) ||
+      t.forall(_.isInstanceOf[Partitioning]) || 
t.forall(_.isInstanceOf[DataType]) =>
+      JArray(t.map(parseToJson).toList)
+    case t: Seq[_] if t.length > 0 && t.head.isInstanceOf[String] =>
+      JString(Utils.truncatedString(t, "[", ", ", "]"))
+    case t: Seq[_] => JNull
+    case m: Map[_, _] => JNull
     // if it's a scala object, we can simply keep the full class path.
     // TODO: currently if the class name ends with "$", we think it's a scala 
object, there is
     // probably a better way to check it.
     case obj if obj.getClass.getName.endsWith("$") => "object" -> 
obj.getClass.getName
-    // returns null if the product type doesn't have a primary constructor, 
e.g. HiveFunctionWrapper
-    case p: Product => try {
-      val fieldNames = getConstructorParameterNames(p.getClass)
-      val fieldValues = p.productIterator.toSeq
-      assert(fieldNames.length == fieldValues.length)
-      ("product-class" -> JString(p.getClass.getName)) :: 
fieldNames.zip(fieldValues).map {
-        case (name, value) => name -> parseToJson(value)
-      }.toList
-    } catch {
-      case _: RuntimeException => null
-    }
-    case _ => JNull
-  }
-}
-
-object TreeNode {
-  def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: 
SparkContext): BaseType = {
-    val jsonAST = parse(json)
-    assert(jsonAST.isInstanceOf[JArray])
-    reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType]
-  }
-
-  private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] 
= {
-    assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject]))
-    val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*)
-
-    def parseNextNode(): TreeNode[_] = {
-      val nextNode = jsonNodes.pop()
-
-      val cls = Utils.classForName((nextNode \ 
"class").asInstanceOf[JString].s)
-      if (cls == classOf[Literal]) {
-        Literal.fromJSON(nextNode)
-      } else if (cls.getName.endsWith("$")) {
-        cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]]
-      } else {
-        val numChildren = (nextNode \ 
"num-children").asInstanceOf[JInt].num.toInt
-
-        val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => 
parseNextNode())
-        val fields = getConstructorParameters(cls)
-
-        val parameters: Array[AnyRef] = fields.map {
-          case (fieldName, fieldType) =>
-            parseFromJson(nextNode \ fieldName, fieldType, children, sc)
-        }.toArray
-
-        val maybeCtor = cls.getConstructors.find { p =>
-          val expectedTypes = p.getParameterTypes
-          expectedTypes.length == fields.length && 
expectedTypes.zip(fields.map(_._2)).forall {
-            case (cls, tpe) => cls == getClassFromType(tpe)
-          }
-        }
-        if (maybeCtor.isEmpty) {
-          sys.error(s"No valid constructor for ${cls.getName}")
-        } else {
-          try {
-            maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]]
-          } catch {
-            case e: java.lang.IllegalArgumentException =>
-              throw new RuntimeException(
-                s"""
-                  |Failed to construct tree node: ${cls.getName}
-                  |ctor: ${maybeCtor.get}
-                  |types: ${parameters.map(_.getClass).mkString(", ")}
-                  |args: ${parameters.mkString(", ")}
-                """.stripMargin, e)
-          }
-        }
-      }
-    }
-
-    parseNextNode()
-  }
-
-  import universe._
-
-  private def parseFromJson(
-      value: JValue,
-      expectedType: Type,
-      children: Seq[TreeNode[_]],
-      sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized {
-    if (value == JNull) return null
-
-    expectedType match {
-      case t if t <:< definitions.BooleanTpe =>
-        value.asInstanceOf[JBool].value: java.lang.Boolean
-      case t if t <:< definitions.ByteTpe =>
-        value.asInstanceOf[JInt].num.toByte: java.lang.Byte
-      case t if t <:< definitions.ShortTpe =>
-        value.asInstanceOf[JInt].num.toShort: java.lang.Short
-      case t if t <:< definitions.IntTpe =>
-        value.asInstanceOf[JInt].num.toInt: java.lang.Integer
-      case t if t <:< definitions.LongTpe =>
-        value.asInstanceOf[JInt].num.toLong: java.lang.Long
-      case t if t <:< definitions.FloatTpe =>
-        value.asInstanceOf[JDouble].num.toFloat: java.lang.Float
-      case t if t <:< definitions.DoubleTpe =>
-        value.asInstanceOf[JDouble].num: java.lang.Double
-
-      case t if t <:< localTypeOf[java.lang.Boolean] =>
-        value.asInstanceOf[JBool].value: java.lang.Boolean
-      case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num
-      case t if t <:< localTypeOf[java.lang.String] => 
value.asInstanceOf[JString].s
-      case t if t <:< localTypeOf[UUID] => 
UUID.fromString(value.asInstanceOf[JString].s)
-      case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value)
-      case t if t <:< localTypeOf[Metadata] => 
Metadata.fromJObject(value.asInstanceOf[JObject])
-      case t if t <:< localTypeOf[StorageLevel] =>
-        val JBool(useDisk) = value \ "useDisk"
-        val JBool(useMemory) = value \ "useMemory"
-        val JBool(useOffHeap) = value \ "useOffHeap"
-        val JBool(deserialized) = value \ "deserialized"
-        val JInt(replication) = value \ "replication"
-        StorageLevel(useDisk, useMemory, useOffHeap, deserialized, 
replication.toInt)
-      case t if t <:< localTypeOf[TreeNode[_]] => value match {
-        case JInt(i) => children(i.toInt)
-        case arr: JArray => reconstruct(arr, sc)
-        case _ => throw new RuntimeException(s"$value is not a valid json 
value for tree node.")
+    case p: Product if shouldConvertToJson(p) =>
+      try {
+        val fieldNames = getConstructorParameterNames(p.getClass)
+        val fieldValues = p.productIterator.toSeq
+        assert(fieldNames.length == fieldValues.length)
+        ("product-class" -> JString(p.getClass.getName)) :: 
fieldNames.zip(fieldValues).map {
+          case (name, value) => name -> parseToJson(value)
+        }.toList
+      } catch {
+        case _: RuntimeException => null
       }
-      case t if t <:< localTypeOf[Option[_]] =>
-        if (value == JNothing) {
-          None
-        } else {
-          val TypeRef(_, _, Seq(optType)) = t
-          Option(parseFromJson(value, optType, children, sc))
-        }
-      case t if t <:< localTypeOf[Seq[_]] =>
-        val TypeRef(_, _, Seq(elementType)) = t
-        val JArray(elements) = value
-        elements.map(parseFromJson(_, elementType, children, sc)).toSeq
-      case t if t <:< localTypeOf[Map[_, _]] =>
-        val TypeRef(_, _, Seq(keyType, valueType)) = t
-        val JObject(fields) = value
-        fields.map {
-          case (name, value) => name -> parseFromJson(value, valueType, 
children, sc)
-        }.toMap
-      case t if t <:< localTypeOf[RDD[_]] =>
-        new EmptyRDD[Any](sc)
-      case _ if isScalaObject(value) =>
-        val JString(clsName) = value \ "object"
-        val cls = Utils.classForName(clsName)
-        cls.getField("MODULE$").get(cls)
-      case t if t <:< localTypeOf[Product] =>
-        val fields = getConstructorParameters(t)
-        val clsName = getClassNameFromType(t)
-        parseToProduct(clsName, fields, value, children, sc)
-      // There maybe some cases that the parameter type signature is not 
Product but the value is,
-      // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle 
it here.
-      case _ if isScalaProduct(value) =>
-        val JString(clsName) = value \ "product-class"
-        val fields = getConstructorParameters(Utils.classForName(clsName))
-        parseToProduct(clsName, fields, value, children, sc)
-      case _ => sys.error(s"Do not support type $expectedType with json 
$value.")
-    }
-  }
-
-  private def parseToProduct(
-      clsName: String,
-      fields: Seq[(String, Type)],
-      value: JValue,
-      children: Seq[TreeNode[_]],
-      sc: SparkContext): AnyRef = {
-    val parameters: Array[AnyRef] = fields.map {
-      case (fieldName, fieldType) => parseFromJson(value \ fieldName, 
fieldType, children, sc)
-    }.toArray
-    val ctor = 
Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size)
-    ctor.newInstance(parameters: _*).asInstanceOf[AnyRef]
-  }
-
-  private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") 
match {
-    case JString(str) if str.endsWith("$") => true
-    case _ => false
+    case _ => JNull
   }
 
-  private def isScalaProduct(jValue: JValue): Boolean = (jValue \ 
"product-class") match {
-    case _: JString => true
+  private def shouldConvertToJson(product: Product): Boolean = product match {
+    case exprId: ExprId => true
+    case field: StructField => true
+    case id: TableIdentifier => true
+    case join: JoinType => true
+    case id: FunctionIdentifier => true
+    case spec: BucketSpec => true
+    case catalog: CatalogTable => true
+    case boundary: FrameBoundary => true
+    case frame: WindowFrame => true
+    case partition: Partitioning => true
+    case resource: FunctionResource => true
+    case broadcast: BroadcastMode => true
+    case table: CatalogTableType => true
+    case storage: CatalogStorageFormat => true
     case _ => false
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a425a37a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 6246380..cb0426c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -17,13 +17,29 @@
 
 package org.apache.spark.sql.catalyst.trees
 
+import java.math.BigInteger
+import java.util.UUID
+
 import scala.collection.mutable.ArrayBuffer
 
+import org.json4s.jackson.JsonMethods
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonAST._
+import org.json4s.JsonDSL._
+
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, 
CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, 
JarResource}
+import org.apache.spark.sql.catalyst.dsl.expressions.DslString
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types.{IntegerType, NullType, StringType}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
+import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, 
RoundRobinPartitioning, SinglePartition}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, 
IntegerType, Metadata, NullType, StringType, StructField, StructType}
+import org.apache.spark.storage.StorageLevel
 
 case class Dummy(optKey: Option[Expression]) extends Expression with 
CodegenFallback {
   override def children: Seq[Expression] = optKey.toSeq
@@ -45,6 +61,20 @@ case class ExpressionInMap(map: Map[String, Expression]) 
extends Expression with
   override lazy val resolved = true
 }
 
+case class JsonTestTreeNode(arg: Any) extends LeafNode {
+  override def output: Seq[Attribute] = Seq.empty[Attribute]
+}
+
+case class NameValue(name: String, value: Any)
+
+case object DummyObject
+
+case class SelfReferenceUDF(
+    var config: Map[String, Any] = Map.empty[String, Any]) extends 
Function1[String, Boolean] {
+  config += "self" -> this
+  def apply(key: String): Boolean = config.contains(key)
+}
+
 class TreeNodeSuite extends SparkFunSuite {
   test("top node changed") {
     val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -261,4 +291,264 @@ class TreeNodeSuite extends SparkFunSuite {
       assert(actual === expected)
     }
   }
+
+  test("toJSON") {
+    def assertJSON(input: Any, json: JValue): Unit = {
+      val expected =
+        s"""
+           |[{
+           |  "class": "${classOf[JsonTestTreeNode].getName}",
+           |  "num-children": 0,
+           |  "arg": ${compact(render(json))}
+           |}]
+         """.stripMargin
+      compareJSON(JsonTestTreeNode(input).toJSON, expected)
+    }
+
+    // Converts simple types to JSON
+    assertJSON(true, true)
+    assertJSON(33.toByte, 33)
+    assertJSON(44, 44)
+    assertJSON(55L, 55L)
+    assertJSON(3.0, 3.0)
+    assertJSON(4.0D, 4.0D)
+    assertJSON(BigInt(BigInteger.valueOf(88L)), 88L)
+    assertJSON(null, JNull)
+    assertJSON("text", "text")
+    assertJSON(Some("text"), "text")
+    compareJSON(JsonTestTreeNode(None).toJSON,
+      s"""[
+         |  {
+         |    "class": "${classOf[JsonTestTreeNode].getName}",
+         |    "num-children": 0
+         |  }
+         |]
+       """.stripMargin)
+
+    val uuid = UUID.randomUUID()
+    assertJSON(uuid, uuid.toString)
+
+    // Converts Spark Sql DataType to JSON
+    assertJSON(IntegerType, "integer")
+    assertJSON(Metadata.empty, JObject(Nil))
+    assertJSON(
+      StorageLevel.NONE,
+      JObject(
+        "useDisk" -> false,
+        "useMemory" -> false,
+        "useOffHeap" -> false,
+        "deserialized" -> false,
+        "replication" -> 1)
+    )
+
+    // Converts TreeNode argument to JSON
+    assertJSON(
+      Literal(333),
+      List(
+        JObject(
+          "class" -> classOf[Literal].getName,
+          "num-children" -> 0,
+          "value" -> "333",
+          "dataType" -> "integer")))
+
+    // Converts Seq[String] to JSON
+    assertJSON(Seq("1", "2", "3"), "[1, 2, 3]")
+
+    // Converts Seq[DataType] to JSON
+    assertJSON(Seq(IntegerType, DoubleType, FloatType), List("integer", 
"double", "float"))
+
+    // Converts Seq[Partitioning] to JSON
+    assertJSON(
+      Seq(SinglePartition, RoundRobinPartitioning(numPartitions = 3)),
+      List(
+        JObject("object" -> JString(SinglePartition.getClass.getName)),
+        JObject(
+          "product-class" -> classOf[RoundRobinPartitioning].getName,
+          "numPartitions" -> 3)))
+
+    // Converts case object to JSON
+    assertJSON(DummyObject, JObject("object" -> 
JString(DummyObject.getClass.getName)))
+
+    // Converts ExprId to JSON
+    assertJSON(
+      ExprId(0, uuid),
+      JObject(
+        "product-class" -> classOf[ExprId].getName,
+        "id" -> 0,
+        "jvmId" -> uuid.toString))
+
+    // Converts StructField to JSON
+    assertJSON(
+      StructField("field", IntegerType),
+      JObject(
+        "product-class" -> classOf[StructField].getName,
+        "name" -> "field",
+        "dataType" -> "integer",
+        "nullable" -> true,
+        "metadata" -> JObject(Nil)))
+
+    // Converts TableIdentifier to JSON
+    assertJSON(
+      TableIdentifier("table"),
+      JObject(
+        "product-class" -> classOf[TableIdentifier].getName,
+        "table" -> "table"))
+
+    // Converts JoinType to JSON
+    assertJSON(
+      NaturalJoin(LeftOuter),
+      JObject(
+        "product-class" -> classOf[NaturalJoin].getName,
+        "tpe" -> JObject("object" -> JString(LeftOuter.getClass.getName))))
+
+    // Converts FunctionIdentifier to JSON
+    assertJSON(
+      FunctionIdentifier("function", None),
+      JObject(
+        "product-class" -> JString(classOf[FunctionIdentifier].getName),
+          "funcName" -> "function"))
+
+    // Converts BucketSpec to JSON
+    assertJSON(
+      BucketSpec(1, Seq("bucket"), Seq("sort")),
+      JObject(
+        "product-class" -> classOf[BucketSpec].getName,
+        "numBuckets" -> 1,
+        "bucketColumnNames" -> "[bucket]",
+        "sortColumnNames" -> "[sort]"))
+
+    // Converts FrameBoundary to JSON
+    assertJSON(
+      ValueFollowing(3),
+      JObject(
+        "product-class" -> classOf[ValueFollowing].getName,
+        "value" -> 3))
+
+    // Converts WindowFrame to JSON
+    assertJSON(
+      SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow),
+      JObject(
+        "product-class" -> classOf[SpecifiedWindowFrame].getName,
+        "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)),
+        "frameStart" -> JObject("object" -> 
JString(UnboundedFollowing.getClass.getName)),
+        "frameEnd" -> JObject("object" -> 
JString(CurrentRow.getClass.getName))))
+
+    // Converts Partitioning to JSON
+    assertJSON(
+      RoundRobinPartitioning(numPartitions = 3),
+      JObject(
+        "product-class" -> classOf[RoundRobinPartitioning].getName,
+        "numPartitions" -> 3))
+
+    // Converts FunctionResource to JSON
+    assertJSON(
+      FunctionResource(JarResource, "file:///"),
+      JObject(
+        "product-class" -> JString(classOf[FunctionResource].getName),
+        "resourceType" -> JObject("object" -> 
JString(JarResource.getClass.getName)),
+        "uri" -> "file:///"))
+
+    // Converts BroadcastMode to JSON
+    assertJSON(
+      IdentityBroadcastMode,
+      JObject("object" -> JString(IdentityBroadcastMode.getClass.getName)))
+
+    // Converts CatalogTable to JSON
+    assertJSON(
+      CatalogTable(
+        TableIdentifier("table"),
+        CatalogTableType.MANAGED,
+        CatalogStorageFormat.empty,
+        StructType(StructField("a", IntegerType, true) :: Nil),
+        createTime = 0L),
+
+      JObject(
+        "product-class" -> classOf[CatalogTable].getName,
+        "identifier" -> JObject(
+          "product-class" -> classOf[TableIdentifier].getName,
+          "table" -> "table"
+        ),
+        "tableType" -> JObject(
+          "product-class" -> classOf[CatalogTableType].getName,
+          "name" -> "MANAGED"
+        ),
+        "storage" -> JObject(
+          "product-class" -> classOf[CatalogStorageFormat].getName,
+          "compressed" -> false,
+          "properties" -> JNull
+        ),
+        "schema" -> JObject(
+          "type" -> "struct",
+          "fields" -> List(
+            JObject(
+              "name" -> "a",
+              "type" -> "integer",
+              "nullable" -> true,
+              "metadata" -> JObject(Nil)))),
+        "partitionColumnNames" -> List.empty[String],
+        "owner" -> "",
+        "createTime" -> 0,
+        "lastAccessTime" -> -1,
+        "properties" -> JNull,
+        "unsupportedFeatures" -> List.empty[String]))
+
+    // For unknown case class, returns JNull.
+    val bigValue = new Array[Int](10000)
+    assertJSON(NameValue("name", bigValue), JNull)
+
+    // Converts Seq[TreeNode] to JSON recursively
+    assertJSON(
+      Seq(Literal(1), Literal(2)),
+      List(
+        List(
+          JObject(
+            "class" -> JString(classOf[Literal].getName),
+            "num-children" -> 0,
+            "value" -> "1",
+            "dataType" -> "integer")),
+        List(
+          JObject(
+            "class" -> JString(classOf[Literal].getName),
+            "num-children" -> 0,
+            "value" -> "2",
+            "dataType" -> "integer"))))
+
+    // Other Seq is converted to JNull, to reduce the risk of out of memory
+    assertJSON(Seq(1, 2, 3), JNull)
+
+    // All Map type is converted to JNull, to reduce the risk of out of memory
+    assertJSON(Map("key" -> "value"), JNull)
+
+    // Unknown type is converted to JNull, to reduce the risk of out of memory
+    assertJSON(new Object {}, JNull)
+
+    // Convert all TreeNode children to JSON
+    assertJSON(
+      Union(Seq(JsonTestTreeNode("0"), JsonTestTreeNode("1"))),
+      List(
+        JObject(
+          "class" -> classOf[Union].getName,
+          "num-children" -> 2,
+          "children" -> List(0, 1)),
+        JObject(
+          "class" -> classOf[JsonTestTreeNode].getName,
+          "num-children" -> 0,
+          "arg" -> "0"),
+        JObject(
+          "class" -> classOf[JsonTestTreeNode].getName,
+          "num-children" -> 0,
+          "arg" -> "1")))
+  }
+
+  test("toJSON should not throws java.lang.StackOverflowError") {
+    val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr))
+    // Should not throw java.lang.StackOverflowError
+    udf.toJSON
+  }
+
+  private def compareJSON(leftJson: String, rightJson: String): Unit = {
+    val left = JsonMethods.parse(leftJson)
+    val right = JsonMethods.parse(rightJson)
+    assert(left == right)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a425a37a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d361f61..34fa626 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -120,7 +120,6 @@ abstract class QueryTest extends PlanTest {
           throw ae
         }
     }
-    checkJsonFormat(analyzedDS)
     assertEmptyMissingInput(analyzedDS)
 
     try ds.collect() catch {
@@ -168,8 +167,6 @@ abstract class QueryTest extends PlanTest {
         }
     }
 
-    checkJsonFormat(analyzedDF)
-
     assertEmptyMissingInput(analyzedDF)
 
     QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
@@ -228,139 +225,6 @@ abstract class QueryTest extends PlanTest {
         planWithCaching)
   }
 
-  private def checkJsonFormat(ds: Dataset[_]): Unit = {
-    // Get the analyzed plan and rewrite the PredicateSubqueries in order to 
make sure that
-    // RDD and Data resolution does not break.
-    val logicalPlan = ds.queryExecution.analyzed
-
-    // bypass some cases that we can't handle currently.
-    logicalPlan.transform {
-      case _: ObjectConsumer => return
-      case _: ObjectProducer => return
-      case _: AppendColumns => return
-      case _: TypedFilter => return
-      case _: LogicalRelation => return
-      case p if p.getClass.getSimpleName == "MetastoreRelation" => return
-      case _: MemoryPlan => return
-      case p: InMemoryRelation =>
-        p.child.transform {
-          case _: ObjectConsumerExec => return
-          case _: ObjectProducerExec => return
-        }
-        p
-    }.transformAllExpressions {
-      case _: ImperativeAggregate => return
-      case _: TypedAggregateExpression => return
-      case Literal(_, _: ObjectType) => return
-      case _: UserDefinedGenerator => return
-    }
-
-    // bypass hive tests before we fix all corner cases in hive module.
-    if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
-
-    val jsonString = try {
-      logicalPlan.toJSON
-    } catch {
-      case NonFatal(e) =>
-        fail(
-          s"""
-             |Failed to parse logical plan to JSON:
-             |${logicalPlan.treeString}
-           """.stripMargin, e)
-    }
-
-    // scala function is not serializable to JSON, use null to replace them so 
that we can compare
-    // the plans later.
-    val normalized1 = logicalPlan.transformAllExpressions {
-      case udf: ScalaUDF => udf.copy(function = null)
-      case gen: UserDefinedGenerator => gen.copy(function = null)
-      // After SPARK-17356: the JSON representation no longer has the 
Metadata. We need to remove
-      // the Metadata from the normalized plan so that we can compare this 
plan with the
-      // JSON-deserialzed plan.
-      case a @ Alias(child, name) if a.explicitMetadata.isDefined =>
-        Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), 
a.isGenerated)
-      case a: AttributeReference if a.metadata != Metadata.empty =>
-        AttributeReference(a.name, a.dataType, a.nullable, 
Metadata.empty)(a.exprId, a.qualifier,
-          a.isGenerated)
-    }
-
-    // RDDs/data are not serializable to JSON, so we need to collect 
LogicalPlans that contains
-    // these non-serializable stuff, and use these original ones to replace 
the null-placeholders
-    // in the logical plans parsed from JSON.
-    val logicalRDDs = new ArrayDeque[LogicalRDD]()
-    val localRelations = new ArrayDeque[LocalRelation]()
-    val inMemoryRelations = new ArrayDeque[InMemoryRelation]()
-    def collectData: (LogicalPlan => Unit) = {
-      case l: LogicalRDD =>
-        logicalRDDs.offer(l)
-      case l: LocalRelation =>
-        localRelations.offer(l)
-      case i: InMemoryRelation =>
-        inMemoryRelations.offer(i)
-      case p =>
-        p.expressions.foreach {
-          _.foreach {
-            case s: SubqueryExpression =>
-              s.plan.foreach(collectData)
-            case _ =>
-          }
-        }
-    }
-    logicalPlan.foreach(collectData)
-
-
-    val jsonBackPlan = try {
-      TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext)
-    } catch {
-      case NonFatal(e) =>
-        fail(
-          s"""
-             |Failed to rebuild the logical plan from JSON:
-             |${logicalPlan.treeString}
-             |
-             |${logicalPlan.prettyJson}
-           """.stripMargin, e)
-    }
-
-    def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = {
-      case l: LogicalRDD =>
-        val origin = logicalRDDs.pop()
-        LogicalRDD(l.output, origin.rdd)(spark)
-      case l: LocalRelation =>
-        val origin = localRelations.pop()
-        l.copy(data = origin.data)
-      case l: InMemoryRelation =>
-        val origin = inMemoryRelations.pop()
-        InMemoryRelation(
-          l.output,
-          l.useCompression,
-          l.batchSize,
-          l.storageLevel,
-          origin.child,
-          l.tableName)(
-          origin.cachedColumnBuffers,
-          origin.batchStats)
-      case p =>
-        p.transformExpressions {
-          case s: SubqueryExpression =>
-            s.withNewPlan(s.plan.transformDown(renormalize))
-        }
-    }
-    val normalized2 = jsonBackPlan.transformDown(renormalize)
-
-    assert(logicalRDDs.isEmpty)
-    assert(localRelations.isEmpty)
-    assert(inMemoryRelations.isEmpty)
-
-    if (normalized1 != normalized2) {
-      fail(
-        s"""
-           |== FAIL: the logical plan parsed from json does not match the 
original one ===
-           |${sideBySide(logicalPlan.treeString, 
normalized2.treeString).mkString("\n")}
-          """.stripMargin)
-    }
-  }
-
   /**
    * Asserts that a given [[Dataset]] does not have missing inputs in all the 
analyzed plans.
    */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to