Repository: spark
Updated Branches:
  refs/heads/master 3abe2b734 -> db8cc6f28


[SPARK-1845] [SQL] Use AllScalaRegistrar for SparkSqlSerializer to register 
serializers of ...

...Scala collections.

When I execute `orderBy` or `limit` for `SchemaRDD` including `ArrayType` or 
`MapType`, `SparkSqlSerializer` throws the following exception:

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing 
no-arg constructor): scala.collection.immutable.$colon$colon
```

or

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing 
no-arg constructor): scala.collection.immutable.Vector
```

or

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing 
no-arg constructor): scala.collection.immutable.HashMap$HashTrieMap
```

and so on.

This is because registrations of serializers for each concrete collections are 
missing in `SparkSqlSerializer`.
I believe it should use `AllScalaRegistrar`.
`AllScalaRegistrar` covers a lot of serializers for concrete classes of `Seq`, 
`Map` for `ArrayType`, `MapType`.

Author: Takuya UESHIN <[email protected]>

Closes #790 from ueshin/issues/SPARK-1845 and squashes the following commits:

d1ed992 [Takuya UESHIN] Use AllScalaRegistrar for SparkSqlSerializer to 
register serializers of Scala collections.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/db8cc6f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/db8cc6f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/db8cc6f2

Branch: refs/heads/master
Commit: db8cc6f28abe4326cea6f53feb604920e4867a27
Parents: 3abe2b7
Author: Takuya UESHIN <[email protected]>
Authored: Thu May 15 11:20:21 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu May 15 11:20:21 2014 -0700

----------------------------------------------------------------------
 .../sql/execution/SparkSqlSerializer.scala      | 28 ++----------------
 .../org/apache/spark/sql/DslQuerySuite.scala    | 24 ++++++++++++++++
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 30 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/TestData.scala   | 10 +++++++
 4 files changed, 66 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/db8cc6f2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 94c2a24..34b355e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
 import com.clearspring.analytics.stream.cardinality.HyperLogLog
 import com.esotericsoftware.kryo.io.{Input, Output}
 import com.esotericsoftware.kryo.{Serializer, Kryo}
+import com.twitter.chill.AllScalaRegistrar
 
 import org.apache.spark.{SparkEnv, SparkConf}
 import org.apache.spark.serializer.KryoSerializer
@@ -35,22 +36,14 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) 
extends KryoSerializer(co
     val kryo = new Kryo()
     kryo.setRegistrationRequired(false)
     kryo.register(classOf[MutablePair[_, _]])
-    kryo.register(classOf[Array[Any]])
-    // This is kinda hacky...
-    kryo.register(classOf[scala.collection.immutable.Map$Map1], new 
MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map$Map2], new 
MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map$Map3], new 
MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map$Map4], new 
MapSerializer)
-    kryo.register(classOf[scala.collection.immutable.Map[_,_]], new 
MapSerializer)
-    kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
     
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
     
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
     
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
                   new HyperLogLogSerializer)
-    kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
     kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
     kryo.setReferences(false)
     kryo.setClassLoader(Utils.getSparkClassLoader)
+    new AllScalaRegistrar().apply(kryo)
     kryo
   }
 }
@@ -97,20 +90,3 @@ private[sql] class HyperLogLogSerializer extends 
Serializer[HyperLogLog] {
     HyperLogLog.Builder.build(bytes)
   }
 }
-
-/**
- * Maps do not have a no arg constructor and so cannot be serialized by 
default. So, we serialize
- * them as `Array[(k,v)]`.
- */
-private[sql] class MapSerializer extends Serializer[Map[_,_]] {
-  def write(kryo: Kryo, output: Output, map: Map[_,_]) {
-    kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
-  }
-
-  def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
-    kryo.readObject(input, classOf[Array[Any]])
-      .sliding(2,2)
-      .map { case Array(k,v) => (k,v) }
-      .toMap
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/db8cc6f2/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 92a707e..f43e98d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -69,12 +69,36 @@ class DslQuerySuite extends QueryTest {
     checkAnswer(
       testData2.orderBy('a.desc, 'b.asc),
       Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+
+    checkAnswer(
+      arrayData.orderBy(GetItem('data, 0).asc),
+      arrayData.collect().sortBy(_.data(0)).toSeq)
+
+    checkAnswer(
+      arrayData.orderBy(GetItem('data, 0).desc),
+      arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+
+    checkAnswer(
+      mapData.orderBy(GetItem('data, 1).asc),
+      mapData.collect().sortBy(_.data(1)).toSeq)
+
+    checkAnswer(
+      mapData.orderBy(GetItem('data, 1).desc),
+      mapData.collect().sortBy(_.data(1)).reverse.toSeq)
   }
 
   test("limit") {
     checkAnswer(
       testData.limit(10),
       testData.take(10).toSeq)
+
+    checkAnswer(
+      arrayData.limit(1),
+      arrayData.take(1).toSeq)
+
+    checkAnswer(
+      mapData.limit(1),
+      mapData.take(1).toSeq)
   }
 
   test("average") {

http://git-wip-us.apache.org/repos/asf/spark/blob/db8cc6f2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 524549e..189dccd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -85,6 +85,36 @@ class SQLQuerySuite extends QueryTest {
     checkAnswer(
       sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
       Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
+
+    checkAnswer(
+      sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
+      arrayData.collect().sortBy(_.data(0)).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
+      arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
+      mapData.collect().sortBy(_.data(1)).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
+      mapData.collect().sortBy(_.data(1)).reverse.toSeq)
+  }
+
+  test("limit") {
+    checkAnswer(
+      sql("SELECT * FROM testData LIMIT 10"),
+      testData.take(10).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM arrayData LIMIT 1"),
+      arrayData.collect().take(1).toSeq)
+
+    checkAnswer(
+      sql("SELECT * FROM mapData LIMIT 1"),
+      mapData.collect().take(1).toSeq)
   }
 
   test("average") {

http://git-wip-us.apache.org/repos/asf/spark/blob/db8cc6f2/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index aa71e27..1aca387 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -74,6 +74,16 @@ object TestData {
       ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
   arrayData.registerAsTable("arrayData")
 
+  case class MapData(data: Map[Int, String])
+  val mapData =
+    TestSQLContext.sparkContext.parallelize(
+      MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+      MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+      MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+      MapData(Map(1 -> "a4", 2 -> "b4")) ::
+      MapData(Map(1 -> "a5")) :: Nil)
+  mapData.registerAsTable("mapData")
+
   case class StringData(s: String)
   val repeatedData =
     TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))

Reply via email to