Repository: spark
Updated Branches:
  refs/heads/master ddd1b1e8a -> cb5ea201d


[SPARK-25746][SQL] Refactoring ExpressionEncoder to get rid of flat flag

## What changes were proposed in this pull request?

This is inspired during implementing #21732. For now `ScalaReflection` needs to 
consider how `ExpressionEncoder` uses generated serializers and deserializers. 
And `ExpressionEncoder` has a weird `flat` flag. After discussion with 
cloud-fan, it seems to be better to refactor `ExpressionEncoder`. It should 
make SPARK-24762 easier to do.

To summarize the proposed changes:

1. `serializerFor` and `deserializerFor` return expressions for 
serializing/deserializing an input expression for a given type. They are 
private and should not be called directly.
2. `serializerForType` and `deserializerForType` returns an expression for 
serializing/deserializing for an object of type T to/from Spark SQL 
representation. It assumes the input object/Spark SQL representation is located 
at ordinal 0 of a row.

So in other words, `serializerForType` and `deserializerForType` return 
expressions for atomically serializing/deserializing JVM object to/from Spark 
SQL value.

A serializer returned by `serializerForType` will serialize an object at 
`row(0)` to a corresponding Spark SQL representation, e.g. primitive type, 
array, map, struct.

A deserializer returned by `deserializerForType` will deserialize an input 
field at `row(0)` to an object with given type.

3. The construction of `ExpressionEncoder` takes a pair of serializer and 
deserializer for type `T`. It uses them to create serializer and deserializer 
for T <-> row serialization. Now `ExpressionEncoder` dones't need to remember 
if serializer is flat or not. When we need to construct new `ExpressionEncoder` 
based on existing ones, we only need to change input location in the atomic 
serializer and deserializer.

## How was this patch tested?

Existing tests.

Closes #22749 from viirya/SPARK-24762-refactor.

Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cb5ea201
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cb5ea201
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cb5ea201

Branch: refs/heads/master
Commit: cb5ea201df5fae8aacb653ffb4147b9288bca1e9
Parents: ddd1b1e
Author: Liang-Chi Hsieh <[email protected]>
Authored: Thu Oct 25 19:27:45 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Thu Oct 25 19:27:45 2018 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Encoders.scala   |   8 +-
 .../spark/sql/catalyst/JavaTypeInference.scala  |  78 ++++---
 .../spark/sql/catalyst/ScalaReflection.scala    | 182 ++++++++---------
 .../catalyst/encoders/ExpressionEncoder.scala   | 201 +++++++++++--------
 .../sql/catalyst/encoders/RowEncoder.scala      |  16 +-
 .../sql/catalyst/ScalaReflectionSuite.scala     |  70 ++++---
 .../encoders/ExpressionEncoderSuite.scala       |   6 +-
 .../sql/catalyst/encoders/RowEncoderSuite.scala |   2 +-
 .../scala/org/apache/spark/sql/Dataset.scala    |  10 +-
 .../spark/sql/KeyValueGroupedDataset.scala      |   2 +-
 .../aggregate/TypedAggregateExpression.scala    |  12 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |   2 +-
 12 files changed, 304 insertions(+), 285 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index b47ec0b..8a30c81 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -203,12 +203,10 @@ object Encoders {
     validatePublicClass[T]()
 
     ExpressionEncoder[T](
-      schema = new StructType().add("value", BinaryType),
-      flat = true,
-      serializer = Seq(
+      objSerializer =
         EncodeUsingSerializer(
-          BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), 
kryo = useKryo)),
-      deserializer =
+          BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), 
kryo = useKryo),
+      objDeserializer =
         DecodeUsingSerializer[T](
           Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
           classTag[T],

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 60dd4a5..f32e080 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -187,26 +187,23 @@ object JavaTypeInference {
   }
 
   /**
-   * Returns an expression that can be used to deserialize an internal row to 
an object of java bean
-   * `T` with a compatible schema.  Fields of the row will be extracted using 
UnresolvedAttributes
-   * of the same name as the constructor arguments.  Nested classes will have 
their fields accessed
-   * using UnresolvedExtractValue.
+   * Returns an expression that can be used to deserialize a Spark SQL 
representation to an object
+   * of java bean `T` with a compatible schema.  The Spark SQL representation 
is located at ordinal
+   * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have 
their fields accessed
+   * using `UnresolvedExtractValue`.
    */
   def deserializerFor(beanClass: Class[_]): Expression = {
-    deserializerFor(TypeToken.of(beanClass), None)
+    val typeToken = TypeToken.of(beanClass)
+    deserializerFor(typeToken, GetColumnByOrdinal(0, 
inferDataType(typeToken)._1))
   }
 
-  private def deserializerFor(typeToken: TypeToken[_], path: 
Option[Expression]): Expression = {
+  private def deserializerFor(typeToken: TypeToken[_], path: Expression): 
Expression = {
     /** Returns the current path with a sub-field extracted. */
-    def addToPath(part: String): Expression = path
-      .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
-      .getOrElse(UnresolvedAttribute(part))
-
-    /** Returns the current path or `GetColumnByOrdinal`. */
-    def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, 
inferDataType(typeToken)._1))
+    def addToPath(part: String): Expression = UnresolvedExtractValue(path,
+      expressions.Literal(part))
 
     typeToken.getRawType match {
-      case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
+      case c if !inferExternalType(c).isInstanceOf[ObjectType] => path
 
       case c if c == classOf[java.lang.Short] ||
                 c == classOf[java.lang.Integer] ||
@@ -219,7 +216,7 @@ object JavaTypeInference {
           c,
           ObjectType(c),
           "valueOf",
-          getPath :: Nil,
+          path :: Nil,
           returnNullable = false)
 
       case c if c == classOf[java.sql.Date] =>
@@ -227,7 +224,7 @@ object JavaTypeInference {
           DateTimeUtils.getClass,
           ObjectType(c),
           "toJavaDate",
-          getPath :: Nil,
+          path :: Nil,
           returnNullable = false)
 
       case c if c == classOf[java.sql.Timestamp] =>
@@ -235,14 +232,14 @@ object JavaTypeInference {
           DateTimeUtils.getClass,
           ObjectType(c),
           "toJavaTimestamp",
-          getPath :: Nil,
+          path :: Nil,
           returnNullable = false)
 
       case c if c == classOf[java.lang.String] =>
-        Invoke(getPath, "toString", ObjectType(classOf[String]))
+        Invoke(path, "toString", ObjectType(classOf[String]))
 
       case c if c == classOf[java.math.BigDecimal] =>
-        Invoke(getPath, "toJavaBigDecimal", 
ObjectType(classOf[java.math.BigDecimal]))
+        Invoke(path, "toJavaBigDecimal", 
ObjectType(classOf[java.math.BigDecimal]))
 
       case c if c.isArray =>
         val elementType = c.getComponentType
@@ -258,12 +255,12 @@ object JavaTypeInference {
         }
 
         primitiveMethod.map { method =>
-          Invoke(getPath, method, ObjectType(c))
+          Invoke(path, method, ObjectType(c))
         }.getOrElse {
           Invoke(
             MapObjects(
-              p => deserializerFor(typeToken.getComponentType, Some(p)),
-              getPath,
+              p => deserializerFor(typeToken.getComponentType, p),
+              path,
               inferDataType(elementType)._1),
             "array",
             ObjectType(c))
@@ -272,8 +269,8 @@ object JavaTypeInference {
       case c if listType.isAssignableFrom(typeToken) =>
         val et = elementType(typeToken)
         UnresolvedMapObjects(
-          p => deserializerFor(et, Some(p)),
-          getPath,
+          p => deserializerFor(et, p),
+          path,
           customCollectionCls = Some(c))
 
       case _ if mapType.isAssignableFrom(typeToken) =>
@@ -282,16 +279,16 @@ object JavaTypeInference {
         val keyData =
           Invoke(
             UnresolvedMapObjects(
-              p => deserializerFor(keyType, Some(p)),
-              GetKeyArrayFromMap(getPath)),
+              p => deserializerFor(keyType, p),
+              GetKeyArrayFromMap(path)),
             "array",
             ObjectType(classOf[Array[Any]]))
 
         val valueData =
           Invoke(
             UnresolvedMapObjects(
-              p => deserializerFor(valueType, Some(p)),
-              GetValueArrayFromMap(getPath)),
+              p => deserializerFor(valueType, p),
+              GetValueArrayFromMap(path)),
             "array",
             ObjectType(classOf[Array[Any]]))
 
@@ -307,7 +304,7 @@ object JavaTypeInference {
           other,
           ObjectType(other),
           "valueOf",
-          Invoke(getPath, "toString", ObjectType(classOf[String]), 
returnNullable = false) :: Nil,
+          Invoke(path, "toString", ObjectType(classOf[String]), returnNullable 
= false) :: Nil,
           returnNullable = false)
 
       case other =>
@@ -316,7 +313,7 @@ object JavaTypeInference {
           val fieldName = p.getName
           val fieldType = typeToken.method(p.getReadMethod).getReturnType
           val (_, nullable) = inferDataType(fieldType)
-          val constructor = deserializerFor(fieldType, 
Some(addToPath(fieldName)))
+          val constructor = deserializerFor(fieldType, addToPath(fieldName))
           val setter = if (nullable) {
             constructor
           } else {
@@ -328,28 +325,23 @@ object JavaTypeInference {
         val newInstance = NewInstance(other, Nil, ObjectType(other), 
propagateNull = false)
         val result = InitializeJavaBean(newInstance, setters)
 
-        if (path.nonEmpty) {
-          expressions.If(
-            IsNull(getPath),
-            expressions.Literal.create(null, ObjectType(other)),
-            result
-          )
-        } else {
+        expressions.If(
+          IsNull(path),
+          expressions.Literal.create(null, ObjectType(other)),
           result
-        }
+        )
     }
   }
 
   /**
-   * Returns an expression for serializing an object of the given type to an 
internal row.
+   * Returns an expression for serializing an object of the given type to a 
Spark SQL
+   * representation. The input object is located at ordinal 0 of a row, i.e.,
+   * `BoundReference(0, _)`.
    */
-  def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
+  def serializerFor(beanClass: Class[_]): Expression = {
     val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
     val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
-    serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
-      case expressions.If(_, _, s: CreateNamedStruct) => s
-      case other => CreateNamedStruct(expressions.Literal("value") :: other :: 
Nil)
-    }
+    serializerFor(nullSafeInput, TypeToken.of(beanClass))
   }
 
   private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): 
Expression = {

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index c27180e..40074b3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -24,7 +24,7 @@ import scala.util.Properties
 import org.apache.commons.lang3.reflect.ConstructorUtils
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, 
UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, 
UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, 
GenericArrayData, MapData}
@@ -129,21 +129,44 @@ object ScalaReflection extends ScalaReflection {
   }
 
   /**
-   * Returns an expression that can be used to deserialize an input row to an 
object of type `T`
-   * with a compatible schema.  Fields of the row will be extracted using 
UnresolvedAttributes
-   * of the same name as the constructor arguments.  Nested classes will have 
their fields accessed
-   * using UnresolvedExtractValue.
+   * When we build the `deserializer` for an encoder, we set up a lot of 
"unresolved" stuff
+   * and lost the required data type, which may lead to runtime error if the 
real type doesn't
+   * match the encoder's schema.
+   * For example, we build an encoder for `case class Data(a: Int, b: String)` 
and the real type
+   * is [a: int, b: long], then we will hit runtime error and say that we 
can't construct class
+   * `Data` with int and long, because we lost the information that `b` should 
be a string.
    *
-   * When used on a primitive type, the constructor will instead default to 
extracting the value
-   * from ordinal 0 (since there are no names to map to).  The actual location 
can be moved by
-   * calling resolve/bind with a new schema.
+   * This method help us "remember" the required data type by adding a 
`UpCast`. Note that we
+   * only need to do this for leaf nodes.
    */
-  def deserializerFor[T : TypeTag]: Expression = {
-    val tpe = localTypeOf[T]
+  private def upCastToExpectedType(expr: Expression, expected: DataType,
+      walkedTypePath: Seq[String]): Expression = expected match {
+    case _: StructType => expr
+    case _: ArrayType => expr
+    // TODO: ideally we should also skip MapType, but nested StructType inside 
MapType is rare and
+    // it's not trivial to support by-name resolution for StructType inside 
MapType.
+    case _ => UpCast(expr, expected, walkedTypePath)
+  }
+
+  /**
+   * Returns an expression that can be used to deserialize a Spark SQL 
representation to an object
+   * of type `T` with a compatible schema. The Spark SQL representation is 
located at ordinal 0 of
+   * a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their 
fields accessed using
+   * `UnresolvedExtractValue`.
+   *
+   * The returned expression is used by `ExpressionEncoder`. The encoder will 
resolve and bind this
+   * deserializer expression when using it.
+   */
+  def deserializerForType(tpe: `Type`): Expression = {
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
-    val expr = deserializerFor(tpe, None, walkedTypePath)
-    val Schema(_, nullable) = schemaFor(tpe)
+    val Schema(dataType, nullable) = schemaFor(tpe)
+
+    // Assumes we are deserializing the first column of a row.
+    val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType,
+      walkedTypePath)
+
+    val expr = deserializerFor(tpe, input, walkedTypePath)
     if (nullable) {
       expr
     } else {
@@ -151,16 +174,22 @@ object ScalaReflection extends ScalaReflection {
     }
   }
 
+  /**
+   * Returns an expression that can be used to deserialize an input expression 
to an object of type
+   * `T` with a compatible schema.
+   *
+   * @param tpe The `Type` of deserialized object.
+   * @param path The expression which can be used to extract serialized value.
+   * @param walkedTypePath The paths from top to bottom to access current 
field when deserializing.
+   */
   private def deserializerFor(
       tpe: `Type`,
-      path: Option[Expression],
+      path: Expression,
       walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
 
     /** Returns the current path with a sub-field extracted. */
     def addToPath(part: String, dataType: DataType, walkedTypePath: 
Seq[String]): Expression = {
-      val newPath = path
-        .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
-        .getOrElse(UnresolvedAttribute.quoted(part))
+      val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
       upCastToExpectedType(newPath, dataType, walkedTypePath)
     }
 
@@ -169,46 +198,12 @@ object ScalaReflection extends ScalaReflection {
         ordinal: Int,
         dataType: DataType,
         walkedTypePath: Seq[String]): Expression = {
-      val newPath = path
-        .map(p => GetStructField(p, ordinal))
-        .getOrElse(GetColumnByOrdinal(ordinal, dataType))
+      val newPath = GetStructField(path, ordinal)
       upCastToExpectedType(newPath, dataType, walkedTypePath)
     }
 
-    /** Returns the current path or `GetColumnByOrdinal`. */
-    def getPath: Expression = {
-      val dataType = schemaFor(tpe).dataType
-      if (path.isDefined) {
-        path.get
-      } else {
-        upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, 
walkedTypePath)
-      }
-    }
-
-    /**
-     * When we build the `deserializer` for an encoder, we set up a lot of 
"unresolved" stuff
-     * and lost the required data type, which may lead to runtime error if the 
real type doesn't
-     * match the encoder's schema.
-     * For example, we build an encoder for `case class Data(a: Int, b: 
String)` and the real type
-     * is [a: int, b: long], then we will hit runtime error and say that we 
can't construct class
-     * `Data` with int and long, because we lost the information that `b` 
should be a string.
-     *
-     * This method help us "remember" the required data type by adding a 
`UpCast`. Note that we
-     * only need to do this for leaf nodes.
-     */
-    def upCastToExpectedType(
-        expr: Expression,
-        expected: DataType,
-        walkedTypePath: Seq[String]): Expression = expected match {
-      case _: StructType => expr
-      case _: ArrayType => expr
-      // TODO: ideally we should also skip MapType, but nested StructType 
inside MapType is rare and
-      // it's not trivial to support by-name resolution for StructType inside 
MapType.
-      case _ => UpCast(expr, expected, walkedTypePath)
-    }
-
     tpe.dealias match {
-      case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
+      case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path
 
       case t if t <:< localTypeOf[Option[_]] =>
         val TypeRef(_, _, Seq(optType)) = t
@@ -219,44 +214,44 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< localTypeOf[java.lang.Integer] =>
         val boxedType = classOf[java.lang.Integer]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Long] =>
         val boxedType = classOf[java.lang.Long]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Double] =>
         val boxedType = classOf[java.lang.Double]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Float] =>
         val boxedType = classOf[java.lang.Float]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Short] =>
         val boxedType = classOf[java.lang.Short]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Byte] =>
         val boxedType = classOf[java.lang.Byte]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.Boolean] =>
         val boxedType = classOf[java.lang.Boolean]
         val objectType = ObjectType(boxedType)
-        StaticInvoke(boxedType, objectType, "valueOf", getPath :: Nil, 
returnNullable = false)
+        StaticInvoke(boxedType, objectType, "valueOf", path :: Nil, 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.sql.Date] =>
         StaticInvoke(
           DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Date]),
           "toJavaDate",
-          getPath :: Nil,
+          path :: Nil,
           returnNullable = false)
 
       case t if t <:< localTypeOf[java.sql.Timestamp] =>
@@ -264,25 +259,25 @@ object ScalaReflection extends ScalaReflection {
           DateTimeUtils.getClass,
           ObjectType(classOf[java.sql.Timestamp]),
           "toJavaTimestamp",
-          getPath :: Nil,
+          path :: Nil,
           returnNullable = false)
 
       case t if t <:< localTypeOf[java.lang.String] =>
-        Invoke(getPath, "toString", ObjectType(classOf[String]), 
returnNullable = false)
+        Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = 
false)
 
       case t if t <:< localTypeOf[java.math.BigDecimal] =>
-        Invoke(getPath, "toJavaBigDecimal", 
ObjectType(classOf[java.math.BigDecimal]),
+        Invoke(path, "toJavaBigDecimal", 
ObjectType(classOf[java.math.BigDecimal]),
           returnNullable = false)
 
       case t if t <:< localTypeOf[BigDecimal] =>
-        Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]), 
returnNullable = false)
+        Invoke(path, "toBigDecimal", ObjectType(classOf[BigDecimal]), 
returnNullable = false)
 
       case t if t <:< localTypeOf[java.math.BigInteger] =>
-        Invoke(getPath, "toJavaBigInteger", 
ObjectType(classOf[java.math.BigInteger]),
+        Invoke(path, "toJavaBigInteger", 
ObjectType(classOf[java.math.BigInteger]),
           returnNullable = false)
 
       case t if t <:< localTypeOf[scala.math.BigInt] =>
-        Invoke(getPath, "toScalaBigInt", 
ObjectType(classOf[scala.math.BigInt]),
+        Invoke(path, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]),
           returnNullable = false)
 
       case t if t <:< localTypeOf[Array[_]] =>
@@ -294,7 +289,7 @@ object ScalaReflection extends ScalaReflection {
         val mapFunction: Expression => Expression = element => {
           // upcast the array element to the data type the encoder expected.
           val casted = upCastToExpectedType(element, dataType, newTypePath)
-          val converter = deserializerFor(elementType, Some(casted), 
newTypePath)
+          val converter = deserializerFor(elementType, casted, newTypePath)
           if (elementNullable) {
             converter
           } else {
@@ -302,7 +297,7 @@ object ScalaReflection extends ScalaReflection {
           }
         }
 
-        val arrayData = UnresolvedMapObjects(mapFunction, getPath)
+        val arrayData = UnresolvedMapObjects(mapFunction, path)
         val arrayCls = arrayClassFor(elementType)
 
         if (elementNullable) {
@@ -334,7 +329,7 @@ object ScalaReflection extends ScalaReflection {
         val mapFunction: Expression => Expression = element => {
           // upcast the array element to the data type the encoder expected.
           val casted = upCastToExpectedType(element, dataType, newTypePath)
-          val converter = deserializerFor(elementType, Some(casted), 
newTypePath)
+          val converter = deserializerFor(elementType, casted, newTypePath)
           if (elementNullable) {
             converter
           } else {
@@ -349,16 +344,16 @@ object ScalaReflection extends ScalaReflection {
             classOf[scala.collection.Set[_]]
           case _ => mirror.runtimeClass(t.typeSymbol.asClass)
         }
-        UnresolvedMapObjects(mapFunction, getPath, Some(cls))
+        UnresolvedMapObjects(mapFunction, path, Some(cls))
 
       case t if t <:< localTypeOf[Map[_, _]] =>
         // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
         CatalystToExternalMap(
-          p => deserializerFor(keyType, Some(p), walkedTypePath),
-          p => deserializerFor(valueType, Some(p), walkedTypePath),
-          getPath,
+          p => deserializerFor(keyType, p, walkedTypePath),
+          p => deserializerFor(valueType, p, walkedTypePath),
+          path,
           mirror.runtimeClass(t.typeSymbol.asClass)
         )
 
@@ -368,7 +363,7 @@ object ScalaReflection extends ScalaReflection {
           udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
           Nil,
           dataType = 
ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
 
       case t if UDTRegistration.exists(getClassNameFromType(t)) =>
         val udt = 
UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
@@ -377,7 +372,7 @@ object ScalaReflection extends ScalaReflection {
           udt.getClass,
           Nil,
           dataType = ObjectType(udt.getClass))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)
 
       case t if definedByConstructorParams(t) =>
         val params = getConstructorParameters(t)
@@ -392,12 +387,12 @@ object ScalaReflection extends ScalaReflection {
           val constructor = if (cls.getName startsWith "scala.Tuple") {
             deserializerFor(
               fieldType,
-              Some(addToPathOrdinal(i, dataType, newTypePath)),
+              addToPathOrdinal(i, dataType, newTypePath),
               newTypePath)
           } else {
             deserializerFor(
               fieldType,
-              Some(addToPath(fieldName, dataType, newTypePath)),
+              addToPath(fieldName, dataType, newTypePath),
               newTypePath)
           }
 
@@ -410,20 +405,17 @@ object ScalaReflection extends ScalaReflection {
 
         val newInstance = NewInstance(cls, arguments, ObjectType(cls), 
propagateNull = false)
 
-        if (path.nonEmpty) {
-          expressions.If(
-            IsNull(getPath),
-            expressions.Literal.create(null, ObjectType(cls)),
-            newInstance
-          )
-        } else {
+        expressions.If(
+          IsNull(path),
+          expressions.Literal.create(null, ObjectType(cls)),
           newInstance
-        }
+        )
     }
   }
 
   /**
-   * Returns an expression for serializing an object of type T to an internal 
row.
+   * Returns an expression for serializing an object of type T to Spark SQL 
representation. The
+   * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, 
_)`.
    *
    * If the given type is not supported, i.e. there is no encoder can be built 
for this type,
    * an [[UnsupportedOperationException]] will be thrown with detailed error 
message to explain
@@ -434,17 +426,21 @@ object ScalaReflection extends ScalaReflection {
    *  * the element type of [[Array]] or [[Seq]]: `array element class: 
"abc.xyz.MyClass"`
    *  * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: 
"myField")`
    */
-  def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = 
{
-    val tpe = localTypeOf[T]
+  def serializerForType(tpe: `Type`): Expression = 
ScalaReflection.cleanUpReflectionObjects {
     val clsName = getClassNameFromType(tpe)
     val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
-    serializerFor(inputObject, tpe, walkedTypePath) match {
-      case expressions.If(_, _, s: CreateNamedStruct) if 
definedByConstructorParams(tpe) => s
-      case other => CreateNamedStruct(expressions.Literal("value") :: other :: 
Nil)
-    }
+
+    // The input object to `ExpressionEncoder` is located at first column of 
an row.
+    val inputObject = BoundReference(0, dataTypeFor(tpe),
+      nullable = !tpe.typeSymbol.asClass.isPrimitive)
+
+    serializerFor(inputObject, tpe, walkedTypePath)
   }
 
-  /** Helper for extracting internal fields from a case class. */
+  /**
+   * Returns an expression for serializing the value of an input expression 
into Spark SQL
+   * internal representation.
+   */
   private def serializerFor(
       inputObject: Expression,
       tpe: `Type`,

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index cbea3c0..29f6136 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, 
JavaTypeInference, ScalaRefle
 import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, 
SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, 
GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
Invoke, NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
InitializeJavaBean, Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, 
DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, 
StructType}
+import org.apache.spark.sql.types.{ObjectType, StringType, StructField, 
StructType}
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
 /**
@@ -43,8 +44,8 @@ import org.apache.spark.util.Utils
  *    to the name `value`.
  */
 object ExpressionEncoder {
+
   def apply[T : TypeTag](): ExpressionEncoder[T] = {
-    // We convert the not-serializable TypeTag into StructType and ClassTag.
     val mirror = ScalaReflection.mirror
     val tpe = typeTag[T].in(mirror).tpe
 
@@ -58,25 +59,11 @@ object ExpressionEncoder {
     }
 
     val cls = mirror.runtimeClass(tpe)
-    val flat = !ScalaReflection.definedByConstructorParams(tpe)
-
-    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], 
nullable = !cls.isPrimitive)
-    val nullSafeInput = if (flat) {
-      inputObject
-    } else {
-      // For input object of Product type, we can't encode it to row if it's 
null, as Spark SQL
-      // doesn't allow top-level row to be null, only its columns can be null.
-      AssertNotNull(inputObject, Seq("top level Product input object"))
-    }
-    val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
-    val deserializer = ScalaReflection.deserializerFor[T]
-
-    val schema = serializer.dataType
+    val serializer = ScalaReflection.serializerForType(tpe)
+    val deserializer = ScalaReflection.deserializerForType(tpe)
 
     new ExpressionEncoder[T](
-      schema,
-      flat,
-      serializer.flatten,
+      serializer,
       deserializer,
       ClassTag[T](cls))
   }
@@ -86,14 +73,12 @@ object ExpressionEncoder {
     val schema = JavaTypeInference.inferDataType(beanClass)._1
     assert(schema.isInstanceOf[StructType])
 
-    val serializer = JavaTypeInference.serializerFor(beanClass)
-    val deserializer = JavaTypeInference.deserializerFor(beanClass)
+    val objSerializer = JavaTypeInference.serializerFor(beanClass)
+    val objDeserializer = JavaTypeInference.deserializerFor(beanClass)
 
     new ExpressionEncoder[T](
-      schema.asInstanceOf[StructType],
-      flat = false,
-      serializer.flatten,
-      deserializer,
+      objSerializer,
+      objDeserializer,
       ClassTag[T](beanClass))
   }
 
@@ -103,75 +88,59 @@ object ExpressionEncoder {
    * name/positional binding is preserved.
    */
   def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+    // TODO: check if encoders length is more than 22 and throw exception for 
it.
+
     encoders.foreach(_.assertUnresolved())
 
     val schema = StructType(encoders.zipWithIndex.map {
       case (e, i) =>
-        val (dataType, nullable) = if (e.flat) {
-          e.schema.head.dataType -> e.schema.head.nullable
-        } else {
-          e.schema -> true
-        }
-        StructField(s"_${i + 1}", dataType, nullable)
+        StructField(s"_${i + 1}", e.objSerializer.dataType, 
e.objSerializer.nullable)
     })
 
     val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
 
-    val serializer = encoders.zipWithIndex.map { case (enc, index) =>
-      val originalInputObject = enc.serializer.head.collect { case b: 
BoundReference => b }.head
+    val serializers = encoders.zipWithIndex.map { case (enc, index) =>
+      val boundRefs = enc.objSerializer.collect { case b: BoundReference => b 
}.distinct
+      assert(boundRefs.size == 1, "object serializer should have only one 
bound reference but " +
+        s"there are ${boundRefs.size}")
+
+      val originalInputObject = boundRefs.head
       val newInputObject = Invoke(
         BoundReference(0, ObjectType(cls), nullable = true),
         s"_${index + 1}",
-        originalInputObject.dataType)
-
-      val newSerializer = enc.serializer.map(_.transformUp {
-        case b: BoundReference if b == originalInputObject => newInputObject
-      })
+        originalInputObject.dataType,
+        returnNullable = originalInputObject.nullable)
 
-      val serializerExpr = if (enc.flat) {
-        newSerializer.head
-      } else {
-        // For non-flat encoder, the input object is not top level anymore 
after being combined to
-        // a tuple encoder, thus it can be null and we should wrap the 
`CreateStruct` with `If` and
-        // null check to handle null case correctly.
-        // e.g. for Encoder[(Int, String)], the serializer expressions will 
create 2 columns, and is
-        // not able to handle the case when the input tuple is null. This is 
not a problem as there
-        // is a check to make sure the input object won't be null. However, if 
this encoder is used
-        // to create a bigger tuple encoder, the original input object becomes 
a filed of the new
-        // input tuple and can be null. So instead of creating a struct 
directly here, we should add
-        // a null/None check and return a null struct if the null/None check 
fails.
-        val struct = CreateStruct(newSerializer)
-        val nullCheck = Or(
-          IsNull(newInputObject),
-          Invoke(Literal.fromObject(None), "equals", BooleanType, 
newInputObject :: Nil))
-        If(nullCheck, Literal.create(null, struct.dataType), struct)
+      val newSerializer = enc.objSerializer.transformUp {
+        case b: BoundReference => newInputObject
       }
-      Alias(serializerExpr, s"_${index + 1}")()
+
+      Alias(newSerializer, s"_${index + 1}")()
     }
 
     val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) 
=>
-      if (enc.flat) {
-        enc.deserializer.transform {
-          case g: GetColumnByOrdinal => g.copy(ordinal = index)
-        }
+      val getColumnsByOrdinals = enc.objDeserializer.collect { case c: 
GetColumnByOrdinal => c }
+        .distinct
+      assert(getColumnsByOrdinals.size == 1, "object deserializer should have 
only one " +
+        s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}")
+
+      val input = GetStructField(GetColumnByOrdinal(0, schema), index)
+      val newDeserializer = enc.objDeserializer.transformUp {
+        case GetColumnByOrdinal(0, _) => input
+      }
+      if (schema(index).nullable) {
+        If(IsNull(input), Literal.create(null, newDeserializer.dataType), 
newDeserializer)
       } else {
-        val input = GetColumnByOrdinal(index, enc.schema)
-        val deserialized = enc.deserializer.transformUp {
-          case UnresolvedAttribute(nameParts) =>
-            assert(nameParts.length == 1)
-            UnresolvedExtractValue(input, Literal(nameParts.head))
-          case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
-        }
-        If(IsNull(input), Literal.create(null, deserialized.dataType), 
deserialized)
+        newDeserializer
       }
     }
 
+    val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = 
true)),
+      Literal.create(null, schema), CreateStruct(serializers))
     val deserializer =
       NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = 
false)
 
     new ExpressionEncoder[Any](
-      schema,
-      flat = false,
       serializer,
       deserializer,
       ClassTag(cls))
@@ -212,21 +181,91 @@ object ExpressionEncoder {
  * A generic encoder for JVM objects that uses Catalyst Expressions for a 
`serializer`
  * and a `deserializer`.
  *
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param serializer A set of expressions, one for each top-level field that 
can be used to
- *                   extract the values from a raw object into an 
[[InternalRow]].
- * @param deserializer An expression that will construct an object given an 
[[InternalRow]].
+ * @param objSerializer An expression that can be used to encode a raw object 
to corresponding
+ *                   Spark SQL representation that can be a primitive column, 
array, map or a
+ *                   struct. This represents how Spark SQL generally 
serializes an object of
+ *                   type `T`.
+ * @param objDeserializer An expression that will construct an object given a 
Spark SQL
+ *                        representation. This represents how Spark SQL 
generally deserializes
+ *                        a serialized value in Spark SQL representation back 
to an object of
+ *                        type `T`.
  * @param clsTag A classtag for `T`.
  */
 case class ExpressionEncoder[T](
-    schema: StructType,
-    flat: Boolean,
-    serializer: Seq[Expression],
-    deserializer: Expression,
+    objSerializer: Expression,
+    objDeserializer: Expression,
     clsTag: ClassTag[T])
   extends Encoder[T] {
 
-  if (flat) require(serializer.size == 1)
+  /**
+   * A sequence of expressions, one for each top-level field that can be used 
to
+   * extract the values from a raw object into an [[InternalRow]]:
+   * 1. If `serializer` encodes a raw object to a struct, strip the outer 
If-IsNull and get
+   *    the `CreateNamedStruct`.
+   * 2. For other cases, wrap the single serializer with `CreateNamedStruct`.
+   */
+  val serializer: Seq[NamedExpression] = {
+    val clsName = Utils.getSimpleName(clsTag.runtimeClass)
+
+    if (isSerializedAsStruct) {
+      val nullSafeSerializer = objSerializer.transformUp {
+        case r: BoundReference =>
+          // For input object of Product type, we can't encode it to row if 
it's null, as Spark SQL
+          // doesn't allow top-level row to be null, only its columns can be 
null.
+          AssertNotNull(r, Seq("top level Product or row object"))
+      }
+      nullSafeSerializer match {
+        case If(_: IsNull, _, s: CreateNamedStruct) => s
+        case s: CreateNamedStruct => s
+        case _ =>
+          throw new RuntimeException(s"class $clsName has unexpected 
serializer: $objSerializer")
+      }
+    } else {
+      // For other input objects like primitive, array, map, etc., we 
construct a struct to wrap
+      // the serializer which is a column of an row.
+      CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
+    }
+  }.flatten
+
+  /**
+   * Returns an expression that can be used to deserialize an input row to an 
object of type `T`
+   * with a compatible schema. Fields of the row will be extracted using 
`UnresolvedAttribute`.
+   * of the same name as the constructor arguments.
+   *
+   * For complex objects that are encoded to structs, Fields of the struct 
will be extracted using
+   * `GetColumnByOrdinal` with corresponding ordinal.
+   */
+  val deserializer: Expression = {
+    if (isSerializedAsStruct) {
+      // We serialized this kind of objects to root-level row. The input of 
general deserializer
+      // is a `GetColumnByOrdinal(0)` expression to extract first column of a 
row. We need to
+      // transform attributes accessors.
+      objDeserializer.transform {
+        case UnresolvedExtractValue(GetColumnByOrdinal(0, _),
+            Literal(part: UTF8String, StringType)) =>
+          UnresolvedAttribute.quoted(part.toString)
+        case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) =>
+          GetColumnByOrdinal(ordinal, dt)
+        case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n
+        case If(IsNull(GetColumnByOrdinal(0, _)), _, i: InitializeJavaBean) => 
i
+      }
+    } else {
+      // For other input objects like primitive, array, map, etc., we 
deserialize the first column
+      // of a row to the object.
+      objDeserializer
+    }
+  }
+
+  // The schema after converting `T` to a Spark SQL row. This schema is 
dependent on the given
+  // serialier.
+  val schema: StructType = StructType(serializer.map { s =>
+    StructField(s.name, s.dataType, s.nullable)
+  })
+
+  /**
+   * Returns true if the type `T` is serialized as a struct.
+   */
+  def isSerializedAsStruct: Boolean = 
objSerializer.dataType.isInstanceOf[StructType]
 
   // serializer expressions are used to encode an object to a row, while the 
object is usually an
   // intermediate value produced inside an operator, not from the output of 
the child operator. This
@@ -258,7 +297,7 @@ case class ExpressionEncoder[T](
     analyzer.checkAnalysis(analyzedPlan)
     val resolved = 
SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer
     val bound = BindReferences.bindReference(resolved, attrs)
-    copy(deserializer = bound)
+    copy(objDeserializer = bound)
   }
 
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index ae89f98..d905f8f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -58,12 +58,10 @@ object RowEncoder {
   def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
     val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level 
row object")), schema)
-    val deserializer = deserializerFor(schema)
+    val serializer = serializerFor(inputObject, schema)
+    val deserializer = deserializerFor(GetColumnByOrdinal(0, 
serializer.dataType), schema)
     new ExpressionEncoder[Row](
-      schema,
-      flat = false,
-      serializer.asInstanceOf[CreateNamedStruct].flatten,
+      serializer,
       deserializer,
       ClassTag(cls))
   }
@@ -237,13 +235,9 @@ object RowEncoder {
     case udt: UserDefinedType[_] => ObjectType(udt.userClass)
   }
 
-  private def deserializerFor(schema: StructType): Expression = {
+  private def deserializerFor(input: Expression, schema: StructType): 
Expression = {
     val fields = schema.zipWithIndex.map { case (f, i) =>
-      val dt = f.dataType match {
-        case p: PythonUserDefinedType => p.sqlType
-        case other => other
-      }
-      deserializerFor(GetColumnByOrdinal(i, dt))
+      deserializerFor(GetStructField(input, i))
     }
     CreateExternalRow(fields, schema)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index f9ee948..d98589d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst
 
 import java.sql.{Date, Timestamp}
 
+import scala.reflect.runtime.universe.TypeTag
+
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, 
Literal, SpecificInternalRow, UpCast}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
+import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, 
Expression, If, SpecificInternalRow, UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
NewInstance}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 case class PrimitiveData(
     intField: Int,
@@ -112,6 +113,14 @@ object TestingUDT {
 class ScalaReflectionSuite extends SparkFunSuite {
   import org.apache.spark.sql.catalyst.ScalaReflection._
 
+  // A helper method used to test `ScalaReflection.serializerForType`.
+  private def serializerFor[T: TypeTag]: Expression =
+    serializerForType(ScalaReflection.localTypeOf[T])
+
+  // A helper method used to test `ScalaReflection.deserializerForType`.
+  private def deserializerFor[T: TypeTag]: Expression =
+    deserializerForType(ScalaReflection.localTypeOf[T])
+
   test("SQLUserDefinedType annotation on Scala structure") {
     val schema = schemaFor[TestingUDT.NestedStruct]
     assert(schema === Schema(
@@ -263,13 +272,9 @@ class ScalaReflectionSuite extends SparkFunSuite {
 
   test("SPARK-15062: Get correct serializer for List[_]") {
     val list = List(1, 2, 3)
-    val serializer = serializerFor[List[Int]](BoundReference(
-      0, ObjectType(list.getClass), nullable = false))
-    assert(serializer.children.size == 2)
-    assert(serializer.children.head.isInstanceOf[Literal])
-    assert(serializer.children.head.asInstanceOf[Literal].value === 
UTF8String.fromString("value"))
-    assert(serializer.children.last.isInstanceOf[NewInstance])
-    assert(serializer.children.last.asInstanceOf[NewInstance]
+    val serializer = serializerFor[List[Int]]
+    assert(serializer.isInstanceOf[NewInstance])
+    assert(serializer.asInstanceOf[NewInstance]
       
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
   }
 
@@ -280,59 +285,58 @@ class ScalaReflectionSuite extends SparkFunSuite {
 
   test("serialize and deserialize arbitrary sequence types") {
     import scala.collection.immutable.Queue
-    val queueSerializer = serializerFor[Queue[Int]](BoundReference(
-      0, ObjectType(classOf[Queue[Int]]), nullable = false))
-    assert(queueSerializer.dataType.head.dataType ==
+    val queueSerializer = serializerFor[Queue[Int]]
+    assert(queueSerializer.dataType ==
       ArrayType(IntegerType, containsNull = false))
     val queueDeserializer = deserializerFor[Queue[Int]]
     assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
 
     import scala.collection.mutable.ArrayBuffer
-    val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
-      0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
-    assert(arrayBufferSerializer.dataType.head.dataType ==
+    val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]]
+    assert(arrayBufferSerializer.dataType ==
       ArrayType(IntegerType, containsNull = false))
     val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
     assert(arrayBufferDeserializer.dataType == 
ObjectType(classOf[ArrayBuffer[_]]))
   }
 
   test("serialize and deserialize arbitrary map types") {
-    val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
-      0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
-    assert(mapSerializer.dataType.head.dataType ==
+    val mapSerializer = serializerFor[Map[Int, Int]]
+    assert(mapSerializer.dataType ==
       MapType(IntegerType, IntegerType, valueContainsNull = false))
     val mapDeserializer = deserializerFor[Map[Int, Int]]
     assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
 
     import scala.collection.immutable.HashMap
-    val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
-      0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
-    assert(hashMapSerializer.dataType.head.dataType ==
+    val hashMapSerializer = serializerFor[HashMap[Int, Int]]
+    assert(hashMapSerializer.dataType ==
       MapType(IntegerType, IntegerType, valueContainsNull = false))
     val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
     assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
 
     import scala.collection.mutable.{LinkedHashMap => LHMap}
-    val linkedHashMapSerializer = serializerFor[LHMap[Long, 
String]](BoundReference(
-      0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
-    assert(linkedHashMapSerializer.dataType.head.dataType ==
+    val linkedHashMapSerializer = serializerFor[LHMap[Long, String]]
+    assert(linkedHashMapSerializer.dataType ==
       MapType(LongType, StringType, valueContainsNull = true))
     val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
     assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, 
_]]))
   }
 
   test("SPARK-22442: Generate correct field names for special characters") {
-    val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
-      0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
+    val serializer = serializerFor[SpecialCharAsFieldData]
+      .collect {
+        case If(_, _, s: CreateNamedStruct) => s
+      }.head
     val deserializer = deserializerFor[SpecialCharAsFieldData]
     assert(serializer.dataType(0).name == "field.1")
     assert(serializer.dataType(1).name == "field 2")
 
-    val argumentsFields = 
deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
-      case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
+    val newInstance = deserializer.collect { case n: NewInstance => n }.head
+
+    val argumentsFields = newInstance.arguments.flatMap { _.collect {
+      case UpCast(u: UnresolvedExtractValue, _, _) => u.extraction.toString
     }}
-    assert(argumentsFields(0) == Seq("field.1"))
-    assert(argumentsFields(1) == Seq("field 2"))
+    assert(argumentsFields(0) == "field.1")
+    assert(argumentsFields(1) == "field 2")
   }
 
   test("SPARK-22472: add null check for top-level primitive values") {
@@ -351,8 +355,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
 
   test("SPARK-23835: add null check to non-nullable types in Tuples") {
     def numberOfCheckedArguments(deserializer: Expression): Int = {
-      assert(deserializer.isInstanceOf[NewInstance])
-      
deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
+      val newInstance = deserializer.collect { case n: NewInstance => n}.head
+      newInstance.arguments.count(_.isInstanceOf[AssertNotNull])
     }
     assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
     assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) 
== 1)

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index f0d61de..e9b100b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -28,9 +28,9 @@ import org.apache.spark.sql.{Encoder, Encoders}
 import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
 import org.apache.spark.sql.catalyst.analysis.AnalysisTest
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.util.ArrayData
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -348,7 +348,7 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
 
   test("nullable of encoder serializer") {
     def checkNullable[T: Encoder](nullable: Boolean): Unit = {
-      assert(encoderFor[T].serializer.forall(_.nullable === nullable))
+      assert(encoderFor[T].objSerializer.nullable === nullable)
     }
 
     // test for flat encoders

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 2357321..ab819be 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -239,7 +239,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
     val encoder = RowEncoder(schema)
     val e = intercept[RuntimeException](encoder.toRow(null))
     assert(e.getMessage.contains("Null value appeared in non-nullable field"))
-    assert(e.getMessage.contains("top level row object"))
+    assert(e.getMessage.contains("top level Product or row object"))
   }
 
   test("RowEncoder should validate external type") {

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 0fb3301..c91b0d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1087,7 +1087,7 @@ class Dataset[T] private[sql](
     // Note that we do this before joining them, to enable the join operator 
to return null for one
     // side, in cases like outer-join.
     val left = {
-      val combined = if (this.exprEnc.flat) {
+      val combined = if (!this.exprEnc.isSerializedAsStruct) {
         assert(joined.left.output.length == 1)
         Alias(joined.left.output.head, "_1")()
       } else {
@@ -1097,7 +1097,7 @@ class Dataset[T] private[sql](
     }
 
     val right = {
-      val combined = if (other.exprEnc.flat) {
+      val combined = if (!other.exprEnc.isSerializedAsStruct) {
         assert(joined.right.output.length == 1)
         Alias(joined.right.output.head, "_2")()
       } else {
@@ -1110,14 +1110,14 @@ class Dataset[T] private[sql](
     // combine the outputs of each join side.
     val conditionExpr = joined.condition.get transformUp {
       case a: Attribute if joined.left.outputSet.contains(a) =>
-        if (this.exprEnc.flat) {
+        if (!this.exprEnc.isSerializedAsStruct) {
           left.output.head
         } else {
           val index = joined.left.output.indexWhere(_.exprId == a.exprId)
           GetStructField(left.output.head, index)
         }
       case a: Attribute if joined.right.outputSet.contains(a) =>
-        if (other.exprEnc.flat) {
+        if (!other.exprEnc.isSerializedAsStruct) {
           right.output.head
         } else {
           val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -1390,7 +1390,7 @@ class Dataset[T] private[sql](
     implicit val encoder = c1.encoder
     val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named 
:: Nil, logicalPlan)
 
-    if (encoder.flat) {
+    if (!encoder.isSerializedAsStruct) {
       new Dataset[U1](sparkSession, project, encoder)
     } else {
       // Flattens inner fields of U1

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 6bab21d..555bcdf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -457,7 +457,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
     val encoders = columns.map(_.encoder)
     val namedColumns =
       columns.map(_.withInputType(vExprEnc, dataAttributes).named)
-    val keyColumn = if (kExprEnc.flat) {
+    val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
       assert(groupingAttributes.length == 1)
       groupingAttributes.head
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 6d44890..39200ec 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -38,18 +38,14 @@ object TypedAggregateExpression {
     val bufferSerializer = bufferEncoder.namedExpressions
 
     val outputEncoder = encoderFor[OUT]
-    val outputType = if (outputEncoder.flat) {
-      outputEncoder.schema.head.dataType
-    } else {
-      outputEncoder.schema
-    }
+    val outputType = outputEncoder.objSerializer.dataType
 
     // Checks if the buffer object is simple, i.e. the buffer encoder is flat 
and the serializer
     // expression is an alias of `BoundReference`, which means the buffer 
object doesn't need
     // serialization.
     val isSimpleBuffer = {
       bufferSerializer.head match {
-        case Alias(_: BoundReference, _) if bufferEncoder.flat => true
+        case Alias(_: BoundReference, _) if 
!bufferEncoder.isSerializedAsStruct => true
         case _ => false
       }
     }
@@ -71,7 +67,7 @@ object TypedAggregateExpression {
         outputEncoder.serializer,
         outputEncoder.deserializer.dataType,
         outputType,
-        !outputEncoder.flat || outputEncoder.schema.head.nullable)
+        outputEncoder.objSerializer.nullable)
     } else {
       ComplexTypedAggregateExpression(
         aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
@@ -82,7 +78,7 @@ object TypedAggregateExpression {
         bufferEncoder.resolveAndBind().deserializer,
         outputEncoder.serializer,
         outputType,
-        !outputEncoder.flat || outputEncoder.schema.head.nullable)
+        outputEncoder.objSerializer.nullable)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/cb5ea201/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 4e593ff..27b3b3d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1065,7 +1065,7 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
   test("Dataset should throw RuntimeException if top-level product input 
object is null") {
     val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
     assert(e.getMessage.contains("Null value appeared in non-nullable field"))
-    assert(e.getMessage.contains("top level Product input object"))
+    assert(e.getMessage.contains("top level Product or row object"))
   }
 
   test("dropDuplicates") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to