Repository: spark Updated Branches: refs/heads/master 8019f66df -> 5e2b44474
[SPARK-11802][SQL] Kryo-based encoder for opaque types in Datasets I also found a bug with self-joins returning incorrect results in the Dataset API. Two test cases attached and filed SPARK-11803. Author: Reynold Xin <[email protected]> Closes #9789 from rxin/SPARK-11802. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5e2b4447 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5e2b4447 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5e2b4447 Branch: refs/heads/master Commit: 5e2b44474c2b838bebeffe5ba5cd72961b0cd31e Parents: 8019f66 Author: Reynold Xin <[email protected]> Authored: Wed Nov 18 00:09:29 2015 -0800 Committer: Reynold Xin <[email protected]> Committed: Wed Nov 18 00:09:29 2015 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Encoder.scala | 31 ++++++++- .../catalyst/encoders/ExpressionEncoder.scala | 4 +- .../sql/catalyst/encoders/ProductEncoder.scala | 2 +- .../sql/catalyst/expressions/objects.scala | 69 ++++++++++++++++++- .../catalyst/encoders/FlatEncoderSuite.scala | 18 +++++ .../scala/org/apache/spark/sql/Dataset.scala | 6 ++ .../org/apache/spark/sql/GroupedDataset.scala | 1 - .../org/apache/spark/sql/DatasetSuite.scala | 70 +++++++++++++++----- 8 files changed, 178 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index c8b017e..79c2255 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql -import scala.reflect.ClassTag +import scala.reflect.{ClassTag, classTag} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{DeserializeWithKryo, BoundReference, SerializeWithKryo} +import org.apache.spark.sql.types._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,7 +38,33 @@ trait Encoder[T] extends Serializable { def clsTag: ClassTag[T] } +/** + * Methods for creating encoders. + */ object Encoders { + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T: ClassTag]: Encoder[T] = { + val ser = SerializeWithKryo(BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true)) + val deser = DeserializeWithKryo[T](BoundReference(0, BinaryType, nullable = true), classTag[T]) + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + toRowExpressions = Seq(ser), + fromRowExpression = deser, + clsTag = classTag[T] + ) + } + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 9a1a8f5..b977f27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -161,7 +161,9 @@ case class ExpressionEncoder[T]( @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) - private val inputRow = new GenericMutableRow(1) + + @transient + private lazy val inputRow = new GenericMutableRow(1) @transient private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 414adb2..55c4ee1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -230,7 +230,7 @@ object ProductEncoder { Invoke(inputObject, "booleanValue", BooleanType) case other => - throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + throw new UnsupportedOperationException(s"Encoder for type $other is not supported") } } } http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 5cd19de..489c612 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions +import scala.language.existentials +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoSerializerInstance, KryoSerializer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.util.GenericArrayData - -import scala.language.existentials - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ @@ -514,3 +516,64 @@ case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataTy """ } } + +/** Serializes an input object using Kryo serializer. */ +case class SerializeWithKryo(child: Expression) extends UnaryExpression { + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val input = child.gen(ctx) + val kryo = ctx.freshName("kryoSerializer") + val kryoClass = classOf[KryoSerializer].getName + val kryoInstanceClass = classOf[KryoSerializerInstance].getName + val sparkConfClass = classOf[SparkConf].getName + ctx.addMutableState( + kryoInstanceClass, + kryo, + s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $kryo.serialize(${input.value}, null).array(); + } + """ + } + + override def dataType: DataType = BinaryType +} + +/** + * Deserializes an input object using Kryo serializer. Note that the ClassTag is not an implicit + * parameter because TreeNode cannot copy implicit parameters. + */ +case class DeserializeWithKryo[T](child: Expression, tag: ClassTag[T]) extends UnaryExpression { + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val input = child.gen(ctx) + val kryo = ctx.freshName("kryoSerializer") + val kryoClass = classOf[KryoSerializer].getName + val kryoInstanceClass = classOf[KryoSerializerInstance].getName + val sparkConfClass = classOf[SparkConf].getName + ctx.addMutableState( + kryoInstanceClass, + kryo, + s"$kryo = ($kryoInstanceClass) new $kryoClass(new $sparkConfClass()).newInstance();") + + s""" + ${input.code} + final boolean ${ev.isNull} = ${input.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.javaType(dataType)}) + $kryo.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); + } + """ + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) +} http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala index 55821c4..2729db8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.encoders import java.sql.{Date, Timestamp} +import org.apache.spark.sql.Encoders class FlatEncoderSuite extends ExpressionEncoderSuite { encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") @@ -71,4 +72,21 @@ class FlatEncoderSuite extends ExpressionEncoderSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), FlatEncoder[Map[Int, Map[String, Int]]], "map of map") + + // Kryo encoders + encodeDecodeTest( + "hello", + Encoders.kryo[String].asInstanceOf[ExpressionEncoder[String]], + "kryo string") + encodeDecodeTest( + new NotJavaSerializable(15), + Encoders.kryo[NotJavaSerializable].asInstanceOf[ExpressionEncoder[NotJavaSerializable]], + "kryo object serialization") +} + + +class NotJavaSerializable(val value: Int) { + override def equals(other: Any): Boolean = { + this.value == other.asInstanceOf[NotJavaSerializable].value + } } http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 718ed81..817c20f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -147,6 +147,12 @@ class Dataset[T] private[sql]( } } + /** + * Returns the number of elements in the [[Dataset]]. + * @since 1.6.0 + */ + def count(): Long = toDF().count() + /* *********************** * * Functional Operations * * *********************** */ http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 467cd42..c66162e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql - import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental http://git-wip-us.apache.org/repos/asf/spark/blob/5e2b4447/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index ea29428..a522894 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,21 +24,6 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -case class ClassData(a: String, b: Int) - -/** - * A class used to test serialization using encoders. This class throws exceptions when using - * Java serialization -- so the only way it can be "serialized" is through our encoders. - */ -case class NonSerializableCaseClass(value: String) extends Externalizable { - override def readExternal(in: ObjectInput): Unit = { - throw new UnsupportedOperationException - } - - override def writeExternal(out: ObjectOutput): Unit = { - throw new UnsupportedOperationException - } -} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -362,8 +347,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(joined, ("2", 2)) } + ignore("self join") { + val ds = Seq("1", "2").toDS().as("a") + val joined = ds.joinWith(ds, lit(true)) + checkAnswer(joined, ("1", "1"), ("1", "2"), ("2", "1"), ("2", "2")) + } + test("toString") { val ds = Seq((1, 2)).toDS() assert(ds.toString == "[_1: int, _2: int]") } + + test("kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + + assert(ds.groupBy(p => p).count().collect().toSeq == + Seq((KryoData(1), 1L), (KryoData(2), 1L))) + } + + ignore("kryo encoder self join") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = sqlContext.createDataset(Seq(KryoData(1), KryoData(2))) + assert(ds.joinWith(ds, lit(true)).collect().toSet == + Set( + (KryoData(1), KryoData(1)), + (KryoData(1), KryoData(2)), + (KryoData(2), KryoData(1)), + (KryoData(2), KryoData(2)))) + } +} + + +case class ClassData(a: String, b: Int) + +/** + * A class used to test serialization using encoders. This class throws exceptions when using + * Java serialization -- so the only way it can be "serialized" is through our encoders. + */ +case class NonSerializableCaseClass(value: String) extends Externalizable { + override def readExternal(in: ObjectInput): Unit = { + throw new UnsupportedOperationException + } + + override def writeExternal(out: ObjectOutput): Unit = { + throw new UnsupportedOperationException + } +} + +/** Used to test Kryo encoder. */ +class KryoData(val a: Int) { + override def equals(other: Any): Boolean = { + a == other.asInstanceOf[KryoData].a + } + override def hashCode: Int = a + override def toString: String = s"KryoData($a)" +} + +object KryoData { + def apply(a: Int): KryoData = new KryoData(a) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
