Repository: spark
Updated Branches:
  refs/heads/master 06f1fdba6 -> b0c3fd34e


[SPARK-11743] [SQL] Add UserDefinedType support to RowEncoder

JIRA: https://issues.apache.org/jira/browse/SPARK-11743

RowEncoder doesn't support UserDefinedType now. We should add the support for 
it.

Author: Liang-Chi Hsieh <[email protected]>

Closes #9712 from viirya/rowencoder-udt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b0c3fd34
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b0c3fd34
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b0c3fd34

Branch: refs/heads/master
Commit: b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87
Parents: 06f1fdb
Author: Liang-Chi Hsieh <[email protected]>
Authored: Mon Nov 16 09:03:42 2015 -0800
Committer: Davies Liu <[email protected]>
Committed: Mon Nov 16 09:03:42 2015 -0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/Row.scala   | 14 +++-
 .../sql/catalyst/encoders/RowEncoder.scala      | 24 +++++-
 .../sql/catalyst/expressions/objects.scala      | 48 ++++++------
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 82 +++++++++++++++++++-
 4 files changed, 139 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b0c3fd34/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index ed2fdf9..0f0f200 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -152,7 +152,7 @@ trait Row extends Serializable {
    *   BinaryType -> byte array
    *   ArrayType -> scala.collection.Seq (use getList for java.util.List)
    *   MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
-   *   StructType -> org.apache.spark.sql.Row
+   *   StructType -> org.apache.spark.sql.Row (or Product)
    * }}}
    */
   def apply(i: Int): Any = get(i)
@@ -177,7 +177,7 @@ trait Row extends Serializable {
    *   BinaryType -> byte array
    *   ArrayType -> scala.collection.Seq (use getList for java.util.List)
    *   MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
-   *   StructType -> org.apache.spark.sql.Row
+   *   StructType -> org.apache.spark.sql.Row (or Product)
    * }}}
    */
   def get(i: Int): Any
@@ -306,7 +306,15 @@ trait Row extends Serializable {
    *
    * @throws ClassCastException when data type does not match.
    */
-  def getStruct(i: Int): Row = getAs[Row](i)
+  def getStruct(i: Int): Row = {
+    // Product and Row both are recoginized as StructType in a Row
+    val t = get(i)
+    if (t.isInstanceOf[Product]) {
+      Row.fromTuple(t.asInstanceOf[Product])
+    } else {
+      t.asInstanceOf[Row]
+    }
+  }
 
   /**
    * Returns the value at position i.

http://git-wip-us.apache.org/repos/asf/spark/blob/b0c3fd34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e0be896..9bb1602 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -50,6 +50,14 @@ object RowEncoder {
     case BooleanType | ByteType | ShortType | IntegerType | LongType |
          FloatType | DoubleType | BinaryType => inputObject
 
+    case udt: UserDefinedType[_] =>
+      val obj = NewInstance(
+        udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+        Nil,
+        false,
+        dataType = 
ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+      Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
     case TimestampType =>
       StaticInvoke(
         DateTimeUtils,
@@ -109,11 +117,16 @@ object RowEncoder {
 
     case StructType(fields) =>
       val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+        val method = if (f.dataType.isInstanceOf[StructType]) {
+          "getStruct"
+        } else {
+          "get"
+        }
         If(
           Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
           Literal.create(null, f.dataType),
           extractorsFor(
-            Invoke(inputObject, "get", externalDataTypeFor(f.dataType), 
Literal(i) :: Nil),
+            Invoke(inputObject, method, externalDataTypeFor(f.dataType), 
Literal(i) :: Nil),
             f.dataType))
       }
       CreateStruct(convertedFields)
@@ -137,6 +150,7 @@ object RowEncoder {
     case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
     case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
     case _: StructType => ObjectType(classOf[Row])
+    case udt: UserDefinedType[_] => ObjectType(udt.userClass)
   }
 
   private def constructorFor(schema: StructType): Expression = {
@@ -155,6 +169,14 @@ object RowEncoder {
     case BooleanType | ByteType | ShortType | IntegerType | LongType |
          FloatType | DoubleType | BinaryType => input
 
+    case udt: UserDefinedType[_] =>
+      val obj = NewInstance(
+        udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+        Nil,
+        false,
+        dataType = 
ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+      Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
+
     case TimestampType =>
       StaticInvoke(
         DateTimeUtils,

http://git-wip-us.apache.org/repos/asf/spark/blob/b0c3fd34/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 4f58464..5cd19de 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -113,7 +113,7 @@ case class Invoke(
     arguments: Seq[Expression] = Nil) extends Expression {
 
   override def nullable: Boolean = true
-  override def children: Seq[Expression] = targetObject :: Nil
+  override def children: Seq[Expression] = arguments.+:(targetObject)
 
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is 
supported.")
@@ -343,33 +343,35 @@ case class MapObjects(
   private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
   private lazy val completeFunction = function(loopAttribute)
 
+  private def itemAccessorMethod(dataType: DataType): String => String = 
dataType match {
+    case IntegerType => (i: String) => s".getInt($i)"
+    case LongType => (i: String) => s".getLong($i)"
+    case FloatType => (i: String) => s".getFloat($i)"
+    case DoubleType => (i: String) => s".getDouble($i)"
+    case ByteType => (i: String) => s".getByte($i)"
+    case ShortType => (i: String) => s".getShort($i)"
+    case BooleanType => (i: String) => s".getBoolean($i)"
+    case StringType => (i: String) => s".getUTF8String($i)"
+    case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
+    case a: ArrayType => (i: String) => s".getArray($i)"
+    case _: MapType => (i: String) => s".getMap($i)"
+    case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
+  }
+
   private lazy val (lengthFunction, itemAccessor, primitiveElement) = 
inputData.dataType match {
     case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
       (".size()", (i: String) => s".apply($i)", false)
     case ObjectType(cls) if cls.isArray =>
       (".length", (i: String) => s"[$i]", false)
-    case ArrayType(s: StructType, _) =>
-      (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
-    case ArrayType(a: ArrayType, _) =>
-      (".numElements()", (i: String) => s".getArray($i)", true)
-    case ArrayType(IntegerType, _) =>
-      (".numElements()", (i: String) => s".getInt($i)", true)
-    case ArrayType(LongType, _) =>
-      (".numElements()", (i: String) => s".getLong($i)", true)
-    case ArrayType(FloatType, _) =>
-      (".numElements()", (i: String) => s".getFloat($i)", true)
-    case ArrayType(DoubleType, _) =>
-      (".numElements()", (i: String) => s".getDouble($i)", true)
-    case ArrayType(ByteType, _) =>
-      (".numElements()", (i: String) => s".getByte($i)", true)
-    case ArrayType(ShortType, _) =>
-      (".numElements()", (i: String) => s".getShort($i)", true)
-    case ArrayType(BooleanType, _) =>
-      (".numElements()", (i: String) => s".getBoolean($i)", true)
-    case ArrayType(StringType, _) =>
-      (".numElements()", (i: String) => s".getUTF8String($i)", false)
-    case ArrayType(_: MapType, _) =>
-      (".numElements()", (i: String) => s".getMap($i)", false)
+    case ArrayType(t, _) =>
+      val (sqlType, primitiveElement) = t match {
+        case m: MapType => (m, false)
+        case s: StructType => (s, false)
+        case s: StringType => (s, false)
+        case udt: UserDefinedType[_] => (udt.sqlType, false)
+        case o => (o, true)
+      }
+      (".numElements()", itemAccessorMethod(sqlType), primitiveElement)
   }
 
   override def nullable: Boolean = true

http://git-wip-us.apache.org/repos/asf/spark/blob/b0c3fd34/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index e8301e8..c868dde 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
+@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
+class ExamplePoint(val x: Double, val y: Double) extends Serializable {
+  override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
+  override def equals(that: Any): Boolean = {
+    if (that.isInstanceOf[ExamplePoint]) {
+      val e = that.asInstanceOf[ExamplePoint]
+      (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && 
e.x.isInfinity)) &&
+        (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && 
e.y.isInfinity))
+    } else {
+      false
+    }
+  }
+}
+
+/**
+ * User-defined type for [[ExamplePoint]].
+ */
+class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
+
+  override def sqlType: DataType = ArrayType(DoubleType, false)
+
+  override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
+
+  override def serialize(obj: Any): GenericArrayData = {
+    obj match {
+      case p: ExamplePoint =>
+        val output = new Array[Any](2)
+        output(0) = p.x
+        output(1) = p.y
+        new GenericArrayData(output)
+    }
+  }
+
+  override def deserialize(datum: Any): ExamplePoint = {
+    datum match {
+      case values: ArrayData =>
+        new ExamplePoint(values.getDouble(0), values.getDouble(1))
+    }
+  }
+
+  override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
+
+  private[spark] override def asNullable: ExamplePointUDT = this
+}
+
 class RowEncoderSuite extends SparkFunSuite {
 
   private val structOfString = new StructType().add("str", StringType)
+  private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, 
false)
   private val arrayOfString = ArrayType(StringType)
   private val mapOfString = MapType(StringType, StringType)
+  private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
 
   encodeDecodeTest(
     new StructType()
@@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite {
       .add("string", StringType)
       .add("binary", BinaryType)
       .add("date", DateType)
-      .add("timestamp", TimestampType))
+      .add("timestamp", TimestampType)
+      .add("udt", new ExamplePointUDT, false))
 
   encodeDecodeTest(
     new StructType()
@@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite {
       .add("structOfArray", new StructType().add("array", arrayOfString))
       .add("structOfMap", new StructType().add("map", mapOfString))
       .add("structOfArrayAndMap",
-        new StructType().add("array", arrayOfString).add("map", mapOfString)))
+        new StructType().add("array", arrayOfString).add("map", mapOfString))
+      .add("structOfUDT", structOfUDT))
+
+  test(s"encode/decode: arrayOfUDT") {
+    val schema = new StructType()
+      .add("arrayOfUDT", arrayOfUDT)
+
+    val encoder = RowEncoder(schema)
+
+    val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 
0.4)))
+    val row = encoder.toRow(input)
+    val convertedBack = encoder.fromRow(row)
+    assert(input.getSeq[ExamplePoint](0) == 
convertedBack.getSeq[ExamplePoint](0))
+  }
+
+  test(s"encode/decode: Product") {
+    val schema = new StructType()
+      .add("structAsProduct",
+        new StructType()
+          .add("int", IntegerType)
+          .add("string", StringType)
+          .add("double", DoubleType))
+
+    val encoder = RowEncoder(schema)
+
+    val input: Row = Row((100, "test", 0.123))
+    val row = encoder.toRow(input)
+    val convertedBack = encoder.fromRow(row)
+    assert(input.getStruct(0) == convertedBack.getStruct(0))
+  }
 
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to