This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ddf2da74f52 [SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders
ddf2da74f52 is described below

commit ddf2da74f527ee00af99fe3928015149f9477734
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Jan 17 10:52:28 2023 -0800

    [SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders
    
    ### What changes were proposed in this pull request?
    This PR makes `RowEncoder` produce an `AgnosticEncoder`. The expression 
generation for these encoders is moved to `ScalaReflection` (this will be moved 
out in a subsequent PR).
    
    The generated serializer and deserializer expressions will slightly change 
for both schema and type based encoders. These are not semantically different 
from the old expressions. Concretely the following changes have been introduced:
    - There is more type validation in maps/arrays/seqs for type based 
encoders. This should be a positive change, since it disallows users to pass 
wrong data through erasure hackd.
    - Array/Seq serialization is a bit more strict. In the old scenario it was 
possible to pass in sequences/arrays with the wrong type and/or nullability.
    
    ### Why are the changes needed?
    For the Spark Connect Scala Client we also want to be able to use `Row` 
based results.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    This is a refactoring so mostly existing tests. I have added test to the 
catalyst tests that triggered failures downstream (typed arrays in 
`WrappedArray` & `Seq[_]` change in Scala 2.13).
    
    Closes #39627 from hvanhovell/SPARK-41993-v2.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../spark/sql/catalyst/JavaTypeInference.scala     |   4 +-
 .../spark/sql/catalyst/ScalaReflection.scala       | 317 ++++++++++++------
 .../spark/sql/catalyst/SerializerBuildHelper.scala |  25 +-
 .../sql/catalyst/encoders/AgnosticEncoder.scala    | 128 ++++++--
 .../sql/catalyst/encoders/ExpressionEncoder.scala  |   5 +-
 .../spark/sql/catalyst/encoders/RowEncoder.scala   | 354 ++++-----------------
 .../sql/catalyst/expressions/objects/objects.scala |  87 +++--
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  |   9 +-
 .../catalyst/encoders/ExpressionEncoderSuite.scala |   2 +
 .../sql/catalyst/encoders/RowEncoderSuite.scala    |  24 ++
 .../catalyst/expressions/CodeGenerationSuite.scala |   2 +-
 .../expressions/ObjectExpressionsSuite.scala       |   9 +-
 12 files changed, 462 insertions(+), 504 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 827807055ce..81f363dda36 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -423,10 +423,10 @@ object JavaTypeInference {
         case c if c == classOf[java.time.Period] => 
createSerializerForJavaPeriod(inputObject)
 
         case c if c == classOf[java.math.BigInteger] =>
-          createSerializerForJavaBigInteger(inputObject)
+          createSerializerForBigInteger(inputObject)
 
         case c if c == classOf[java.math.BigDecimal] =>
-          createSerializerForJavaBigDecimal(inputObject)
+          createSerializerForBigDecimal(inputObject)
 
         case c if c == classOf[java.lang.Boolean] => 
createSerializerForBoolean(inputObject)
         case c if c == classOf[java.lang.Byte] => 
createSerializerForByte(inputObject)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index e02e42cea1a..42208cd1098 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst
 import javax.lang.model.SourceVersion
 
 import scala.annotation.tailrec
+import scala.language.existentials
 import scala.reflect.ClassTag
 import scala.reflect.internal.Symbols
 import scala.util.{Failure, Success}
@@ -27,12 +28,13 @@ import scala.util.{Failure, Success}
 import org.apache.commons.lang3.reflect.ConstructorUtils
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{expressions => exprs}
 import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
 import org.apache.spark.sql.catalyst.SerializerBuildHelper._
 import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
-import org.apache.spark.sql.catalyst.expressions.{Expression, _}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
 import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -82,12 +84,24 @@ object ScalaReflection extends ScalaReflection {
     }
   }
 
-  // TODO this name is slightly misleading. This returns the input
-  //  data type we expect to see during serialization.
-  private[catalyst] def dataTypeFor(enc: AgnosticEncoder[_]): DataType = {
+  /**
+   * Return the data type we expect to see when deserializing a value with 
encoder `enc`.
+   */
+  private[catalyst] def externalDataTypeFor(enc: AgnosticEncoder[_]): DataType 
= {
+    externalDataTypeFor(enc, lenientSerialization = false)
+  }
+
+  private[catalyst]  def lenientExternalDataTypeFor(enc: AgnosticEncoder[_]): 
DataType =
+    externalDataTypeFor(enc, enc.lenientSerialization)
+
+  private def externalDataTypeFor(
+      enc: AgnosticEncoder[_],
+      lenientSerialization: Boolean): DataType = {
     // DataType can be native.
     if (isNativeEncoder(enc)) {
       enc.dataType
+    } else if (lenientSerialization) {
+      ObjectType(classOf[java.lang.Object])
     } else {
       ObjectType(enc.clsTag.runtimeClass)
     }
@@ -123,7 +137,7 @@ object ScalaReflection extends ScalaReflection {
     case NullEncoder => true
     case CalendarIntervalEncoder => true
     case BinaryEncoder => true
-    case SparkDecimalEncoder => true
+    case _: SparkDecimalEncoder => true
     case _ => false
   }
 
@@ -155,11 +169,19 @@ object ScalaReflection extends ScalaReflection {
     val walkedTypePath = 
WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
     // Assumes we are deserializing the first column of a row.
     val input = GetColumnByOrdinal(0, enc.dataType)
-    val deserializer = deserializerFor(
-      enc,
-      upCastToExpectedType(input, enc.dataType, walkedTypePath),
-      walkedTypePath)
-    expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
+    enc match {
+      case RowEncoder(fields) =>
+        val children = fields.zipWithIndex.map { case (f, i) =>
+          deserializerFor(f.enc, GetStructField(input, i), walkedTypePath)
+        }
+        CreateExternalRow(children, enc.schema)
+      case _ =>
+        val deserializer = deserializerFor(
+          enc,
+          upCastToExpectedType(input, enc.dataType, walkedTypePath),
+          walkedTypePath)
+        expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
+    }
   }
 
   /**
@@ -178,19 +200,7 @@ object ScalaReflection extends ScalaReflection {
       walkedTypePath: WalkedTypePath): Expression = enc match {
     case _ if isNativeEncoder(enc) =>
       path
-    case BoxedBooleanEncoder =>
-      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
-    case BoxedByteEncoder =>
-      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
-    case BoxedShortEncoder =>
-      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
-    case BoxedIntEncoder =>
-      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
-    case BoxedLongEncoder =>
-      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
-    case BoxedFloatEncoder =>
-      createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
-    case BoxedDoubleEncoder =>
+    case _: BoxedLeafEncoder[_, _] =>
       createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
     case JavaEnumEncoder(tag) =>
       val toString = createDeserializerForString(path, returnNullable = false)
@@ -204,9 +214,9 @@ object ScalaReflection extends ScalaReflection {
         returnNullable = false)
     case StringEncoder =>
       createDeserializerForString(path, returnNullable = false)
-    case ScalaDecimalEncoder =>
+    case _: ScalaDecimalEncoder =>
       createDeserializerForScalaBigDecimal(path, returnNullable = false)
-    case JavaDecimalEncoder =>
+    case _: JavaDecimalEncoder =>
       createDeserializerForJavaBigDecimal(path, returnNullable = false)
     case ScalaBigIntEncoder =>
       createDeserializerForScalaBigInt(path)
@@ -216,13 +226,13 @@ object ScalaReflection extends ScalaReflection {
       createDeserializerForDuration(path)
     case YearMonthIntervalEncoder =>
       createDeserializerForPeriod(path)
-    case DateEncoder =>
+    case _: DateEncoder =>
       createDeserializerForSqlDate(path)
-    case LocalDateEncoder =>
+    case _: LocalDateEncoder =>
       createDeserializerForLocalDate(path)
-    case TimestampEncoder =>
+    case _: TimestampEncoder =>
       createDeserializerForSqlTimestamp(path)
-    case InstantEncoder =>
+    case _: InstantEncoder =>
       createDeserializerForInstant(path)
     case LocalDateTimeEncoder =>
       createDeserializerForLocalDateTime(path)
@@ -232,39 +242,29 @@ object ScalaReflection extends ScalaReflection {
     case OptionEncoder(valueEnc) =>
       val newTypePath = 
walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName)
       val deserializer = deserializerFor(valueEnc, path, newTypePath)
-      WrapOption(deserializer, dataTypeFor(valueEnc))
-
-    case ArrayEncoder(elementEnc: AgnosticEncoder[_]) =>
-      val newTypePath = 
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
-      val mapFunction: Expression => Expression = element => {
-        // upcast the array element to the data type the encoder expected.
-        deserializerForWithNullSafetyAndUpcast(
-          element,
-          elementEnc.dataType,
-          nullable = elementEnc.nullable,
-          newTypePath,
-          deserializerFor(elementEnc, _, newTypePath))
-      }
+      WrapOption(deserializer, externalDataTypeFor(valueEnc))
+
+    case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) =>
       Invoke(
-        UnresolvedMapObjects(mapFunction, path),
+        deserializeArray(
+          path,
+          elementEnc,
+          containsNull,
+          None,
+          walkedTypePath),
         toArrayMethodName(elementEnc),
         ObjectType(enc.clsTag.runtimeClass),
         returnNullable = false)
 
-    case IterableEncoder(clsTag, elementEnc) =>
-      val newTypePath = 
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
-      val mapFunction: Expression => Expression = element => {
-        // upcast the array element to the data type the encoder expected.
-        deserializerForWithNullSafetyAndUpcast(
-          element,
-          elementEnc.dataType,
-          nullable = elementEnc.nullable,
-          newTypePath,
-          deserializerFor(elementEnc, _, newTypePath))
-      }
-      UnresolvedMapObjects(mapFunction, path, Some(clsTag.runtimeClass))
+    case IterableEncoder(clsTag, elementEnc, containsNull, _) =>
+      deserializeArray(
+        path,
+        elementEnc,
+        containsNull,
+        Option(clsTag.runtimeClass),
+        walkedTypePath)
 
-    case MapEncoder(tag, keyEncoder, valueEncoder) =>
+    case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
       val newTypePath = walkedTypePath.recordMap(
         keyEncoder.clsTag.runtimeClass.getName,
         valueEncoder.clsTag.runtimeClass.getName)
@@ -298,6 +298,39 @@ object ScalaReflection extends ScalaReflection {
         IsNull(path),
         expressions.Literal.create(null, dt),
         NewInstance(cls, arguments, dt, propagateNull = false))
+
+    case RowEncoder(fields) =>
+      val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+        val newTypePath = walkedTypePath.recordField(
+          f.enc.clsTag.runtimeClass.getName,
+          f.name)
+        exprs.If(
+          Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
+          exprs.Literal.create(null, externalDataTypeFor(f.enc)),
+          deserializerFor(f.enc, GetStructField(path, i), newTypePath))
+      }
+      exprs.If(IsNull(path),
+        exprs.Literal.create(null, externalDataTypeFor(enc)),
+        CreateExternalRow(convertedFields, enc.schema))
+  }
+
+  private def deserializeArray(
+      path: Expression,
+      elementEnc: AgnosticEncoder[_],
+      containsNull: Boolean,
+      cls: Option[Class[_]],
+      walkedTypePath: WalkedTypePath): Expression = {
+    val newTypePath = 
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
+    val mapFunction: Expression => Expression = element => {
+      // upcast the array element to the data type the encoder expects.
+      deserializerForWithNullSafetyAndUpcast(
+        element,
+        elementEnc.dataType,
+        nullable = containsNull,
+        newTypePath,
+        deserializerFor(elementEnc, _, newTypePath))
+    }
+    UnresolvedMapObjects(mapFunction, path, cls)
   }
 
   /**
@@ -306,7 +339,7 @@ object ScalaReflection extends ScalaReflection {
    * input object is located at ordinal 0 of a row, i.e., `BoundReference(0, 
_)`.
    */
   def serializerFor(enc: AgnosticEncoder[_]): Expression = {
-    val input = BoundReference(0, dataTypeFor(enc), nullable = enc.nullable)
+    val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable = 
enc.nullable)
     serializerFor(enc, input)
   }
 
@@ -327,45 +360,52 @@ object ScalaReflection extends ScalaReflection {
     case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
     case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
     case StringEncoder => createSerializerForString(input)
-    case ScalaDecimalEncoder => createSerializerForScalaBigDecimal(input)
-    case JavaDecimalEncoder => createSerializerForJavaBigDecimal(input)
-    case ScalaBigIntEncoder => createSerializerForScalaBigInt(input)
-    case JavaBigIntEncoder => createSerializerForJavaBigInteger(input)
+    case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt)
+    case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input, 
dt)
+    case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input, 
dt)
+    case ScalaBigIntEncoder => createSerializerForBigInteger(input)
+    case JavaBigIntEncoder => createSerializerForBigInteger(input)
     case DayTimeIntervalEncoder => createSerializerForJavaDuration(input)
     case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input)
-    case DateEncoder => createSerializerForSqlDate(input)
-    case LocalDateEncoder => createSerializerForJavaLocalDate(input)
-    case TimestampEncoder => createSerializerForSqlTimestamp(input)
-    case InstantEncoder => createSerializerForJavaInstant(input)
+    case DateEncoder(true) | LocalDateEncoder(true) => 
createSerializerForAnyDate(input)
+    case DateEncoder(false) => createSerializerForSqlDate(input)
+    case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input)
+    case TimestampEncoder(true) | InstantEncoder(true) => 
createSerializerForAnyTimestamp(input)
+    case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
+    case InstantEncoder(false) => createSerializerForJavaInstant(input)
     case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
     case UDTEncoder(udt, udtClass) => 
createSerializerForUserDefinedType(input, udt, udtClass)
     case OptionEncoder(valueEnc) =>
-      serializerFor(valueEnc, UnwrapOption(dataTypeFor(valueEnc), input))
+      serializerFor(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc), 
input))
 
-    case ArrayEncoder(elementEncoder) =>
-      serializerForArray(isArray = true, elementEncoder, input)
+    case ArrayEncoder(elementEncoder, containsNull) =>
+      if (elementEncoder.isPrimitive) {
+        createSerializerForPrimitiveArray(input, elementEncoder.dataType)
+      } else {
+        serializerForArray(elementEncoder, containsNull, input, 
lenientSerialization = false)
+      }
 
-    case IterableEncoder(ctag, elementEncoder) =>
+    case IterableEncoder(ctag, elementEncoder, containsNull, 
lenientSerialization) =>
       val getter = if 
(classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) {
         // There's no corresponding Catalyst type for `Set`, we serialize a 
`Set` to Catalyst array.
         // Note that the property of `Set` is only kept when manipulating the 
data as domain object.
-        Invoke(input, "toSeq", ObjectType(classOf[Seq[_]]))
+        Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]]))
       } else {
         input
       }
-      serializerForArray(isArray = false, elementEncoder, getter)
+      serializerForArray(elementEncoder, containsNull, getter, 
lenientSerialization)
 
-    case MapEncoder(_, keyEncoder, valueEncoder) =>
+    case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) =>
       createSerializerForMap(
         input,
         MapElementInformation(
-          dataTypeFor(keyEncoder),
-          nullable = !keyEncoder.isPrimitive,
-          serializerFor(keyEncoder, _)),
+          ObjectType(classOf[AnyRef]),
+          nullable = keyEncoder.nullable,
+          validateAndSerializeElement(keyEncoder, keyEncoder.nullable)),
         MapElementInformation(
-          dataTypeFor(valueEncoder),
-          nullable = !valueEncoder.isPrimitive,
-          serializerFor(valueEncoder, _))
+          ObjectType(classOf[AnyRef]),
+          nullable = valueContainsNull,
+          validateAndSerializeElement(valueEncoder, valueContainsNull))
       )
 
     case ProductEncoder(_, fields) =>
@@ -377,25 +417,94 @@ object ScalaReflection extends ScalaReflection {
         val getter = Invoke(
           KnownNotNull(input),
           field.name,
-          dataTypeFor(field.enc),
-          returnNullable = field.enc.nullable)
+          externalDataTypeFor(field.enc),
+          returnNullable = field.nullable)
         field.name -> serializerFor(field.enc, getter)
       }
       createSerializerForObject(input, serializedFields)
+
+    case RowEncoder(fields) =>
+      val serializedFields = fields.zipWithIndex.map { case (field, index) =>
+        val fieldValue = serializerFor(
+          field.enc,
+          ValidateExternalType(
+            GetExternalRowField(input, index, field.name),
+            field.enc.dataType,
+            lenientExternalDataTypeFor(field.enc)))
+
+        val convertedField = if (field.nullable) {
+          exprs.If(
+            Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: 
Nil),
+            // Because we strip UDTs, `field.dataType` can be different from 
`fieldValue.dataType`.
+            // We should use `fieldValue.dataType` here.
+            exprs.Literal.create(null, fieldValue.dataType),
+            fieldValue
+          )
+        } else {
+          AssertNotNull(fieldValue)
+        }
+        field.name -> convertedField
+      }
+      createSerializerForObject(input, serializedFields)
   }
 
   private def serializerForArray(
-      isArray: Boolean,
       elementEnc: AgnosticEncoder[_],
-      input: Expression): Expression = {
-    dataTypeFor(elementEnc) match {
-      case dt: ObjectType =>
-        createSerializerForMapObjects(input, dt, serializerFor(elementEnc, _))
-      case dt if isArray && elementEnc.isPrimitive =>
-        createSerializerForPrimitiveArray(input, dt)
-      case dt =>
-        createSerializerForGenericArray(input, dt, elementEnc.nullable)
+      elementNullable: Boolean,
+      input: Expression,
+      lenientSerialization: Boolean): Expression = {
+    // Default serializer for Seq and generic Arrays. This does not work for 
primitive arrays.
+    val genericSerializer = createSerializerForMapObjects(
+      input,
+      ObjectType(classOf[AnyRef]),
+      validateAndSerializeElement(elementEnc, elementNullable))
+
+    // Check if it is possible the user can pass a primitive array. This is 
the only case when it
+    // is safe to directly convert to an array (for generic arrays and Seqs 
the type and the
+    // nullability can be violated). If the user has passed a primitive array 
we create a special
+    // code path to deal with these.
+    val primitiveEncoderOption = elementEnc match {
+      case _ if !lenientSerialization => None
+      case enc: PrimitiveLeafEncoder[_] => Option(enc)
+      case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive)
+      case _ => None
     }
+    primitiveEncoderOption match {
+      case Some(primitiveEncoder) =>
+        val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass
+        val check = Invoke(
+          targetObject = exprs.Literal.fromObject(primitiveArrayClass),
+          functionName = "isInstance",
+          BooleanType,
+          arguments = input :: Nil,
+          propagateNull = false,
+          returnNullable = false)
+        exprs.If(
+          check,
+          // TODO replace this with `createSerializerForPrimitiveArray` as
+          //  soon as Cast support ObjectType casts.
+          StaticInvoke(
+            classOf[ArrayData],
+            ArrayType(elementEnc.dataType, containsNull = false),
+            "toArrayData",
+            input :: Nil,
+            propagateNull = false,
+            returnNullable = false),
+          genericSerializer)
+      case None =>
+        genericSerializer
+    }
+  }
+
+  private def validateAndSerializeElement(
+      enc: AgnosticEncoder[_],
+      nullable: Boolean): Expression => Expression = { input =>
+    expressionWithNullSafety(
+      serializerFor(
+        enc,
+        ValidateExternalType(input, enc.dataType, 
lenientExternalDataTypeFor(enc))),
+      nullable,
+      WalkedTypePath())
   }
 
   /**
@@ -598,8 +707,8 @@ object ScalaReflection extends ScalaReflection {
     case StringType => classOf[UTF8String]
     case CalendarIntervalType => classOf[CalendarInterval]
     case _: StructType => classOf[InternalRow]
-    case _: ArrayType => classOf[ArrayType]
-    case _: MapType => classOf[MapType]
+    case _: ArrayType => classOf[ArrayData]
+    case _: MapType => classOf[MapData]
     case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
     case ObjectType(cls) => cls
     case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, 
classOf[java.lang.Object])
@@ -657,7 +766,11 @@ object ScalaReflection extends ScalaReflection {
         case NoSymbol => fallbackClass
         case _ => mirror.runtimeClass(t.typeSymbol.asClass)
       }
-      IterableEncoder(ClassTag(targetClass), encoder)
+      IterableEncoder(
+        ClassTag(targetClass),
+        encoder,
+        encoder.nullable,
+        lenientSerialization = false)
     }
 
     baseType(tpe) match {
@@ -698,18 +811,18 @@ object ScalaReflection extends ScalaReflection {
 
       // Leaf encoders
       case t if isSubtype(t, localTypeOf[String]) => StringEncoder
-      case t if isSubtype(t, localTypeOf[Decimal]) => SparkDecimalEncoder
-      case t if isSubtype(t, localTypeOf[BigDecimal]) => ScalaDecimalEncoder
-      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => 
JavaDecimalEncoder
+      case t if isSubtype(t, localTypeOf[Decimal]) => 
DEFAULT_SPARK_DECIMAL_ENCODER
+      case t if isSubtype(t, localTypeOf[BigDecimal]) => 
DEFAULT_SCALA_DECIMAL_ENCODER
+      case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => 
DEFAULT_JAVA_DECIMAL_ENCODER
       case t if isSubtype(t, localTypeOf[BigInt]) => ScalaBigIntEncoder
       case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => 
JavaBigIntEncoder
       case t if isSubtype(t, localTypeOf[CalendarInterval]) => 
CalendarIntervalEncoder
       case t if isSubtype(t, localTypeOf[java.time.Duration]) => 
DayTimeIntervalEncoder
       case t if isSubtype(t, localTypeOf[java.time.Period]) => 
YearMonthIntervalEncoder
-      case t if isSubtype(t, localTypeOf[java.sql.Date]) => DateEncoder
-      case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => 
LocalDateEncoder
-      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => 
TimestampEncoder
-      case t if isSubtype(t, localTypeOf[java.time.Instant]) => InstantEncoder
+      case t if isSubtype(t, localTypeOf[java.sql.Date]) => STRICT_DATE_ENCODER
+      case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => 
STRICT_LOCAL_DATE_ENCODER
+      case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => 
STRICT_TIMESTAMP_ENCODER
+      case t if isSubtype(t, localTypeOf[java.time.Instant]) => 
STRICT_INSTANT_ENCODER
       case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => 
LocalDateTimeEncoder
 
       // UDT encoders
@@ -739,7 +852,7 @@ object ScalaReflection extends ScalaReflection {
           elementType,
           seenTypeSet,
           path.recordArray(getClassNameFromType(elementType)))
-        ArrayEncoder(encoder)
+        ArrayEncoder(encoder, encoder.nullable)
 
       case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) =>
         createIterableEncoder(t, classOf[scala.collection.Seq[_]])
@@ -757,7 +870,7 @@ object ScalaReflection extends ScalaReflection {
           valueType,
           seenTypeSet,
           path.recordValueForMap(getClassNameFromType(valueType)))
-        MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder)
+        MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder, 
valueEncoder.nullable)
 
       case t if definedByConstructorParams(t) =>
         if (seenTypeSet.contains(t)) {
@@ -775,7 +888,7 @@ object ScalaReflection extends ScalaReflection {
               fieldType,
               seenTypeSet + t,
               path.recordField(getClassNameFromType(fieldType), fieldName))
-            EncoderField(fieldName, encoder)
+            EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
         }
         ProductEncoder(ClassTag(getClassFromType(t)), params)
       case _ =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 25f6ce520d9..33b0edb0c44 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -158,20 +158,29 @@ object SerializerBuildHelper {
       returnNullable = false)
   }
 
-  def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = 
{
+  def createSerializerForBigDecimal(inputObject: Expression): Expression = {
+    createSerializerForBigDecimal(inputObject, DecimalType.SYSTEM_DEFAULT)
+  }
+
+  def createSerializerForBigDecimal(inputObject: Expression, dt: DecimalType): 
Expression = {
     CheckOverflow(StaticInvoke(
       Decimal.getClass,
-      DecimalType.SYSTEM_DEFAULT,
+      dt,
       "apply",
       inputObject :: Nil,
-      returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
+      returnNullable = false), dt, nullOnOverflow)
   }
 
-  def createSerializerForScalaBigDecimal(inputObject: Expression): Expression 
= {
-    createSerializerForJavaBigDecimal(inputObject)
+  def createSerializerForAnyDecimal(inputObject: Expression, dt: DecimalType): 
Expression = {
+    CheckOverflow(StaticInvoke(
+      Decimal.getClass,
+      dt,
+      "fromDecimal",
+      inputObject :: Nil,
+      returnNullable = false), dt, nullOnOverflow)
   }
 
-  def createSerializerForJavaBigInteger(inputObject: Expression): Expression = 
{
+  def createSerializerForBigInteger(inputObject: Expression): Expression = {
     CheckOverflow(StaticInvoke(
       Decimal.getClass,
       DecimalType.BigIntDecimal,
@@ -180,10 +189,6 @@ object SerializerBuildHelper {
       returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
   }
 
-  def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
-    createSerializerForJavaBigInteger(inputObject)
-  }
-
   def createSerializerForPrimitiveArray(
       inputObject: Expression,
       dataType: DataType): Expression = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 6081ac8dc28..cdc64f2ddb5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -16,28 +16,33 @@
  */
 package org.apache.spark.sql.catalyst.encoders
 
+import java.{sql => jsql}
 import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
 import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
 
 import scala.reflect.{classTag, ClassTag}
 
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.{Encoder, Row}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 
 /**
  * A non implementation specific encoder. This encoder containers all the 
information needed
  * to generate an implementation specific encoder (e.g. InternalRow <=> Custom 
Object).
+ *
+ * The input of the serialization does not need to match the external type of 
the encoder. This is
+ * called lenient serialization. An example of this is lenient date 
serialization, in this case both
+ * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization 
is never lenient; it
+ * will always produce instance of the external type.
  */
 trait AgnosticEncoder[T] extends Encoder[T] {
   def isPrimitive: Boolean
   def nullable: Boolean = !isPrimitive
   def dataType: DataType
   override def schema: StructType = StructType(StructField("value", dataType, 
nullable) :: Nil)
+  def lenientSerialization: Boolean = false
 }
 
-// TODO check RowEncoder
-// TODO check BeanEncoder
 object AgnosticEncoders {
   case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
     extends AgnosticEncoder[Option[E]] {
@@ -46,35 +51,48 @@ object AgnosticEncoders {
     override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]])
   }
 
-  case class ArrayEncoder[E](element: AgnosticEncoder[E])
+  case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull: 
Boolean)
     extends AgnosticEncoder[Array[E]] {
     override def isPrimitive: Boolean = false
-    override def dataType: DataType = ArrayType(element.dataType, 
element.nullable)
+    override def dataType: DataType = ArrayType(element.dataType, containsNull)
     override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap
   }
 
-  case class IterableEncoder[C <: Iterable[E], E](
+  /**
+   * Encoder for collections.
+   *
+   * This encoder can be lenient for [[Row]] encoders. In that case we allow 
[[Seq]], primitive
+   * array (if any), and generic arrays as input.
+   */
+  case class IterableEncoder[C, E](
       override val clsTag: ClassTag[C],
-      element: AgnosticEncoder[E])
+      element: AgnosticEncoder[E],
+      containsNull: Boolean,
+      override val lenientSerialization: Boolean)
     extends AgnosticEncoder[C] {
     override def isPrimitive: Boolean = false
-    override val dataType: DataType = ArrayType(element.dataType, 
element.nullable)
+    override val dataType: DataType = ArrayType(element.dataType, containsNull)
   }
 
   case class MapEncoder[C, K, V](
       override val clsTag: ClassTag[C],
       keyEncoder: AgnosticEncoder[K],
-      valueEncoder: AgnosticEncoder[V])
+      valueEncoder: AgnosticEncoder[V],
+      valueContainsNull: Boolean)
     extends AgnosticEncoder[C] {
     override def isPrimitive: Boolean = false
     override val dataType: DataType = MapType(
       keyEncoder.dataType,
       valueEncoder.dataType,
-      valueEncoder.nullable)
+      valueContainsNull)
   }
 
-  case class EncoderField(name: String, enc: AgnosticEncoder[_]) {
-    def structField: StructField = StructField(name, enc.dataType, 
enc.nullable)
+  case class EncoderField(
+      name: String,
+      enc: AgnosticEncoder[_],
+      nullable: Boolean,
+      metadata: Metadata) {
+    def structField: StructField = StructField(name, enc.dataType, nullable, 
metadata)
   }
 
   // This supports both Product and DefinedByConstructorParams
@@ -87,6 +105,13 @@ object AgnosticEncoders {
     override def dataType: DataType = schema
   }
 
+  case class RowEncoder(fields: Seq[EncoderField]) extends 
AgnosticEncoder[Row] {
+    override def isPrimitive: Boolean = false
+    override val schema: StructType = StructType(fields.map(_.structField))
+    override def dataType: DataType = schema
+    override def clsTag: ClassTag[Row] = classTag[Row]
+  }
+
   // This will only work for encoding from/to Sparks' InternalRow format.
   // It is here for compatibility.
   case class UDTEncoder[E >: Null](
@@ -116,39 +141,74 @@ object AgnosticEncoders {
   }
 
   // Primitive encoders
-  case object PrimitiveBooleanEncoder extends LeafEncoder[Boolean](BooleanType)
-  case object PrimitiveByteEncoder extends LeafEncoder[Byte](ByteType)
-  case object PrimitiveShortEncoder extends LeafEncoder[Short](ShortType)
-  case object PrimitiveIntEncoder extends LeafEncoder[Int](IntegerType)
-  case object PrimitiveLongEncoder extends LeafEncoder[Long](LongType)
-  case object PrimitiveFloatEncoder extends LeafEncoder[Float](FloatType)
-  case object PrimitiveDoubleEncoder extends LeafEncoder[Double](DoubleType)
+  abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType)
+    extends LeafEncoder[E](dataType)
+  case object PrimitiveBooleanEncoder extends 
PrimitiveLeafEncoder[Boolean](BooleanType)
+  case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType)
+  case object PrimitiveShortEncoder extends 
PrimitiveLeafEncoder[Short](ShortType)
+  case object PrimitiveIntEncoder extends 
PrimitiveLeafEncoder[Int](IntegerType)
+  case object PrimitiveLongEncoder extends PrimitiveLeafEncoder[Long](LongType)
+  case object PrimitiveFloatEncoder extends 
PrimitiveLeafEncoder[Float](FloatType)
+  case object PrimitiveDoubleEncoder extends 
PrimitiveLeafEncoder[Double](DoubleType)
 
   // Primitive wrapper encoders.
-  case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
-  case object BoxedBooleanEncoder extends 
LeafEncoder[java.lang.Boolean](BooleanType)
-  case object BoxedByteEncoder extends LeafEncoder[java.lang.Byte](ByteType)
-  case object BoxedShortEncoder extends LeafEncoder[java.lang.Short](ShortType)
-  case object BoxedIntEncoder extends 
LeafEncoder[java.lang.Integer](IntegerType)
-  case object BoxedLongEncoder extends LeafEncoder[java.lang.Long](LongType)
-  case object BoxedFloatEncoder extends LeafEncoder[java.lang.Float](FloatType)
-  case object BoxedDoubleEncoder extends 
LeafEncoder[java.lang.Double](DoubleType)
+  abstract class BoxedLeafEncoder[E : ClassTag, P](
+      dataType: DataType,
+      val primitive: PrimitiveLeafEncoder[P])
+    extends LeafEncoder[E](dataType)
+  case object BoxedBooleanEncoder
+    extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType, 
PrimitiveBooleanEncoder)
+  case object BoxedByteEncoder
+    extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType, 
PrimitiveByteEncoder)
+  case object BoxedShortEncoder
+    extends BoxedLeafEncoder[java.lang.Short, Short](ShortType, 
PrimitiveShortEncoder)
+  case object BoxedIntEncoder
+    extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType, 
PrimitiveIntEncoder)
+  case object BoxedLongEncoder
+    extends BoxedLeafEncoder[java.lang.Long, Long](LongType, 
PrimitiveLongEncoder)
+  case object BoxedFloatEncoder
+    extends BoxedLeafEncoder[java.lang.Float, Float](FloatType, 
PrimitiveFloatEncoder)
+  case object BoxedDoubleEncoder
+    extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType, 
PrimitiveDoubleEncoder)
 
   // Nullable leaf encoders
+  case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
   case object StringEncoder extends LeafEncoder[String](StringType)
   case object BinaryEncoder extends LeafEncoder[Array[Byte]](BinaryType)
-  case object SparkDecimalEncoder extends 
LeafEncoder[Decimal](DecimalType.SYSTEM_DEFAULT)
-  case object ScalaDecimalEncoder extends 
LeafEncoder[BigDecimal](DecimalType.SYSTEM_DEFAULT)
-  case object JavaDecimalEncoder extends 
LeafEncoder[JBigDecimal](DecimalType.SYSTEM_DEFAULT)
   case object ScalaBigIntEncoder extends 
LeafEncoder[BigInt](DecimalType.BigIntDecimal)
   case object JavaBigIntEncoder extends 
LeafEncoder[JBigInt](DecimalType.BigIntDecimal)
   case object CalendarIntervalEncoder extends 
LeafEncoder[CalendarInterval](CalendarIntervalType)
   case object DayTimeIntervalEncoder extends 
LeafEncoder[Duration](DayTimeIntervalType())
   case object YearMonthIntervalEncoder extends 
LeafEncoder[Period](YearMonthIntervalType())
-  case object DateEncoder extends LeafEncoder[java.sql.Date](DateType)
-  case object LocalDateEncoder extends LeafEncoder[LocalDate](DateType)
-  case object TimestampEncoder extends 
LeafEncoder[java.sql.Timestamp](TimestampType)
-  case object InstantEncoder extends LeafEncoder[Instant](TimestampType)
+  case class DateEncoder(override val lenientSerialization: Boolean)
+    extends LeafEncoder[jsql.Date](DateType)
+  case class LocalDateEncoder(override val lenientSerialization: Boolean)
+    extends LeafEncoder[LocalDate](DateType)
+  case class TimestampEncoder(override val lenientSerialization: Boolean)
+    extends LeafEncoder[jsql.Timestamp](TimestampType)
+  case class InstantEncoder(override val lenientSerialization: Boolean)
+    extends LeafEncoder[Instant](TimestampType)
   case object LocalDateTimeEncoder extends 
LeafEncoder[LocalDateTime](TimestampNTZType)
+
+  case class SparkDecimalEncoder(dt: DecimalType) extends 
LeafEncoder[Decimal](dt)
+  case class ScalaDecimalEncoder(dt: DecimalType) extends 
LeafEncoder[BigDecimal](dt)
+  case class JavaDecimalEncoder(dt: DecimalType, override val 
lenientSerialization: Boolean)
+    extends LeafEncoder[JBigDecimal](dt)
+
+  val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = 
false)
+  val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder = 
LocalDateEncoder(lenientSerialization = false)
+  val STRICT_TIMESTAMP_ENCODER: TimestampEncoder = 
TimestampEncoder(lenientSerialization = false)
+  val STRICT_INSTANT_ENCODER: InstantEncoder = 
InstantEncoder(lenientSerialization = false)
+  val LENIENT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization = 
true)
+  val LENIENT_LOCAL_DATE_ENCODER: LocalDateEncoder = 
LocalDateEncoder(lenientSerialization = true)
+  val LENIENT_TIMESTAMP_ENCODER: TimestampEncoder = 
TimestampEncoder(lenientSerialization = true)
+  val LENIENT_INSTANT_ENCODER: InstantEncoder = 
InstantEncoder(lenientSerialization = true)
+
+  val DEFAULT_SPARK_DECIMAL_ENCODER: SparkDecimalEncoder =
+    SparkDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
+  val DEFAULT_SCALA_DECIMAL_ENCODER: ScalaDecimalEncoder =
+    ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
+  val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
+    JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = 
false)
 }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 82a6863b5ff..9ca2fc72ad9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -47,7 +47,10 @@ import org.apache.spark.util.Utils
 object ExpressionEncoder {
 
   def apply[T : TypeTag](): ExpressionEncoder[T] = {
-    val enc = ScalaReflection.encoderFor[T]
+    apply(ScalaReflection.encoderFor[T])
+  }
+
+  def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = {
     new ExpressionEncoder[T](
       ScalaReflection.serializerFor(enc),
       ScalaReflection.deserializerFor(enc),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 8eb3475acaa..78243894544 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -17,19 +17,11 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import scala.annotation.tailrec
-import scala.collection.Map
-import scala.reflect.ClassTag
+import scala.collection.mutable
+import scala.reflect.classTag
 
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.{ScalaReflection, WalkedTypePath}
-import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
-import org.apache.spark.sql.catalyst.SerializerBuildHelper._
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, 
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, 
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, 
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, 
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, 
MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, 
TimestampEncoder, UDTEncoder, YearMont [...]
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -68,224 +60,46 @@ import org.apache.spark.sql.types._
  */
 object RowEncoder {
   def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
-    val cls = classOf[Row]
-    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
-    val serializer = serializerFor(inputObject, schema, lenient)
-    val deserializer = deserializerFor(GetColumnByOrdinal(0, 
serializer.dataType), schema)
-    new ExpressionEncoder[Row](
-      serializer,
-      deserializer,
-      ClassTag(cls))
+    ExpressionEncoder(encoderFor(schema, lenient))
   }
+
   def apply(schema: StructType): ExpressionEncoder[Row] = {
     apply(schema, lenient = false)
   }
 
-  private def serializerFor(
-      inputObject: Expression,
-      inputType: DataType,
-      lenient: Boolean): Expression = inputType match {
-    case dt if ScalaReflection.isNativeType(dt) => inputObject
-
-    case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType, 
lenient)
-
-    case udt: UserDefinedType[_] =>
-      val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
-      val udtClass: Class[_] = if (annotation != null) {
-        annotation.udt()
-      } else {
-        UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
-          throw 
QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
-        }
-      }
-      val obj = NewInstance(
-        udtClass,
-        Nil,
-        dataType = ObjectType(udtClass), false)
-      Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
-
-    case TimestampType =>
-      if (lenient) {
-        createSerializerForAnyTimestamp(inputObject)
-      } else if (SQLConf.get.datetimeJava8ApiEnabled) {
-        createSerializerForJavaInstant(inputObject)
-      } else {
-        createSerializerForSqlTimestamp(inputObject)
-      }
-
-    case TimestampNTZType => createSerializerForLocalDateTime(inputObject)
-
-    case DateType =>
-      if (lenient) {
-        createSerializerForAnyDate(inputObject)
-      } else if (SQLConf.get.datetimeJava8ApiEnabled) {
-        createSerializerForJavaLocalDate(inputObject)
-      } else {
-        createSerializerForSqlDate(inputObject)
-      }
-
-    case _: DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
-
-    case _: YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)
-
-    case d: DecimalType =>
-      CheckOverflow(StaticInvoke(
-        Decimal.getClass,
-        d,
-        "fromDecimal",
-        inputObject :: Nil,
-        returnNullable = false), d, !SQLConf.get.ansiEnabled)
-
-    case StringType => createSerializerForString(inputObject)
-
-    case t @ ArrayType(et, containsNull) =>
-      et match {
-        case BooleanType | ByteType | ShortType | IntegerType | LongType | 
FloatType | DoubleType =>
-          StaticInvoke(
-            classOf[ArrayData],
-            t,
-            "toArrayData",
-            inputObject :: Nil,
-            returnNullable = false)
-
-        case _ =>
-          createSerializerForMapObjects(
-            inputObject,
-            ObjectType(classOf[Object]),
-            element => {
-              val value = serializerFor(ValidateExternalType(element, et, 
lenient), et, lenient)
-              expressionWithNullSafety(value, containsNull, WalkedTypePath())
-            })
-      }
-
-    case t @ MapType(kt, vt, valueNullable) =>
-      val keys =
-        Invoke(
-          Invoke(inputObject, "keysIterator", 
ObjectType(classOf[scala.collection.Iterator[_]]),
-            returnNullable = false),
-          "toSeq",
-          ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
-      val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient)
-
-      val values =
-        Invoke(
-          Invoke(inputObject, "valuesIterator", 
ObjectType(classOf[scala.collection.Iterator[_]]),
-            returnNullable = false),
-          "toSeq",
-          ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
-      val convertedValues = serializerFor(values, ArrayType(vt, 
valueNullable), lenient)
-
-      val nonNullOutput = NewInstance(
-        classOf[ArrayBasedMapData],
-        convertedKeys :: convertedValues :: Nil,
-        dataType = t,
-        propagateNull = false)
-
-      if (inputObject.nullable) {
-        expressionForNullableExpr(inputObject, nonNullOutput)
-      } else {
-        nonNullOutput
-      }
-
-    case StructType(fields) =>
-      val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case 
(field, index) =>
-        val fieldValue = serializerFor(
-          ValidateExternalType(
-            GetExternalRowField(inputObject, index, field.name),
-            field.dataType,
-            lenient),
-          field.dataType,
-          lenient)
-        val convertedField = if (field.nullable) {
-          If(
-            Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: 
Nil),
-            // Because we strip UDTs, `field.dataType` can be different from 
`fieldValue.dataType`.
-            // We should use `fieldValue.dataType` here.
-            Literal.create(null, fieldValue.dataType),
-            fieldValue
-          )
-        } else {
-          fieldValue
-        }
-        Literal(field.name) :: convertedField :: Nil
-      })
-
-      if (inputObject.nullable) {
-        expressionForNullableExpr(inputObject, nonNullOutput)
-      } else {
-        nonNullOutput
-      }
-    // For other data types, return the internal catalyst value as it is.
-    case _ => inputObject
-  }
-
-  /**
-   * Returns the `DataType` that can be used when generating code that 
converts input data
-   * into the Spark SQL internal format.  Unlike `externalDataTypeFor`, the 
`DataType` returned
-   * by this function can be more permissive since multiple external types may 
map to a single
-   * internal type.  For example, for an input with DecimalType in external 
row, its external types
-   * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
-   * `org.apache.spark.sql.types.Decimal`.
-   */
-  def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt 
match {
-    // In order to support both Decimal and java/scala BigDecimal in external 
row, we make this
-    // as java.lang.Object.
-    case _: DecimalType => ObjectType(classOf[java.lang.Object])
-    // In order to support both Array and Seq in external row, we make this as 
java.lang.Object.
-    case _: ArrayType => ObjectType(classOf[java.lang.Object])
-    case _: DateType | _: TimestampType if lenient => 
ObjectType(classOf[java.lang.Object])
-    case _ => externalDataTypeFor(dt)
-  }
-
-  @tailrec
-  def externalDataTypeFor(dt: DataType): DataType = dt match {
-    case _ if ScalaReflection.isNativeType(dt) => dt
-    case TimestampType =>
-      if (SQLConf.get.datetimeJava8ApiEnabled) {
-        ObjectType(classOf[java.time.Instant])
-      } else {
-        ObjectType(classOf[java.sql.Timestamp])
-      }
-    case TimestampNTZType =>
-      ObjectType(classOf[java.time.LocalDateTime])
-    case DateType =>
-      if (SQLConf.get.datetimeJava8ApiEnabled) {
-        ObjectType(classOf[java.time.LocalDate])
-      } else {
-        ObjectType(classOf[java.sql.Date])
-      }
-    case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
-    case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period])
-    case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
-    case udt: UserDefinedType[_] => ObjectType(udt.userClass)
-    case _ => dt.physicalDataType match {
-      case _: PhysicalArrayType => ObjectType(classOf[scala.collection.Seq[_]])
-      case _: PhysicalDecimalType => ObjectType(classOf[java.math.BigDecimal])
-      case _: PhysicalMapType => ObjectType(classOf[scala.collection.Map[_, 
_]])
-      case PhysicalStringType => ObjectType(classOf[java.lang.String])
-      case _: PhysicalStructType => ObjectType(classOf[Row])
-      // For other data types, return the data type as it is.
-      case _ => dt
-    }
-  }
-
-  private def deserializerFor(input: Expression, schema: StructType): 
Expression = {
-    val fields = schema.zipWithIndex.map { case (f, i) =>
-      deserializerFor(GetStructField(input, i))
-    }
-    CreateExternalRow(fields, schema)
+  def encoderFor(schema: StructType): AgnosticEncoder[Row] = {
+    encoderFor(schema, lenient = false)
   }
 
-  private def deserializerFor(input: Expression): Expression = {
-    deserializerFor(input, input.dataType)
+  def encoderFor(schema: StructType, lenient: Boolean): AgnosticEncoder[Row] = 
{
+    encoderForDataType(schema, lenient).asInstanceOf[AgnosticEncoder[Row]]
   }
 
-  @tailrec
-  private def deserializerFor(input: Expression, dataType: DataType): 
Expression = dataType match {
-    case dt if ScalaReflection.isNativeType(dt) => input
-
-    case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)
-
+  private[catalyst] def encoderForDataType(
+      dataType: DataType,
+      lenient: Boolean): AgnosticEncoder[_] = dataType match {
+    case NullType => NullEncoder
+    case BooleanType => BoxedBooleanEncoder
+    case ByteType => BoxedByteEncoder
+    case ShortType => BoxedShortEncoder
+    case IntegerType => BoxedIntEncoder
+    case LongType => BoxedLongEncoder
+    case FloatType => BoxedFloatEncoder
+    case DoubleType => BoxedDoubleEncoder
+    case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
+    case BinaryType => BinaryEncoder
+    case StringType => StringEncoder
+    case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => 
InstantEncoder(lenient)
+    case TimestampType => TimestampEncoder(lenient)
+    case TimestampNTZType => LocalDateTimeEncoder
+    case DateType if SQLConf.get.datetimeJava8ApiEnabled => 
LocalDateEncoder(lenient)
+    case DateType => DateEncoder(lenient)
+    case CalendarIntervalType => CalendarIntervalEncoder
+    case _: DayTimeIntervalType => DayTimeIntervalEncoder
+    case _: YearMonthIntervalType => YearMonthIntervalEncoder
+    case p: PythonUserDefinedType =>
+      // TODO check if this works.
+      encoderForDataType(p.sqlType, lenient)
     case udt: UserDefinedType[_] =>
       val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
       val udtClass: Class[_] = if (annotation != null) {
@@ -295,84 +109,26 @@ object RowEncoder {
           throw 
QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
         }
       }
-      val obj = NewInstance(
-        udtClass,
-        Nil,
-        dataType = ObjectType(udtClass))
-      Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
-
-    case TimestampType =>
-      if (SQLConf.get.datetimeJava8ApiEnabled) {
-        createDeserializerForInstant(input)
-      } else {
-        createDeserializerForSqlTimestamp(input)
-      }
-
-    case TimestampNTZType =>
-      createDeserializerForLocalDateTime(input)
-
-    case DateType =>
-      if (SQLConf.get.datetimeJava8ApiEnabled) {
-        createDeserializerForLocalDate(input)
-      } else {
-        createDeserializerForSqlDate(input)
-      }
-
-    case _: DayTimeIntervalType => createDeserializerForDuration(input)
-
-    case _: YearMonthIntervalType => createDeserializerForPeriod(input)
-
-    case _: DecimalType => createDeserializerForJavaBigDecimal(input, 
returnNullable = false)
-
-    case StringType => createDeserializerForString(input, returnNullable = 
false)
-
-    case ArrayType(et, nullable) =>
-      val arrayData =
-        Invoke(
-          MapObjects(deserializerFor(_), input, et),
-          "array",
-          ObjectType(classOf[Array[_]]), returnNullable = false)
-      // TODO should use `scala.collection.immutable.ArrayDeq.unsafeMake` 
method to create
-      //  `immutable.Seq` in Scala 2.13 when Scala version compatibility is no 
longer required.
-      StaticInvoke(
-        scala.collection.mutable.WrappedArray.getClass,
-        ObjectType(classOf[scala.collection.Seq[_]]),
-        "make",
-        arrayData :: Nil,
-        returnNullable = false)
-
-    case MapType(kt, vt, valueNullable) =>
-      val keyArrayType = ArrayType(kt, false)
-      val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))
-
-      val valueArrayType = ArrayType(vt, valueNullable)
-      val valueData = deserializerFor(Invoke(input, "valueArray", 
valueArrayType))
-
-      StaticInvoke(
-        ArrayBasedMapData.getClass,
-        ObjectType(classOf[Map[_, _]]),
-        "toScalaMap",
-        keyData :: valueData :: Nil,
-        returnNullable = false)
-
-    case schema @ StructType(fields) =>
-      val convertedFields = fields.zipWithIndex.map { case (f, i) =>
-        If(
-          Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
-          Literal.create(null, externalDataTypeFor(f.dataType)),
-          deserializerFor(GetStructField(input, i)))
-      }
-      If(IsNull(input),
-        Literal.create(null, externalDataTypeFor(input.dataType)),
-        CreateExternalRow(convertedFields, schema))
-
-    // For other data types, return the internal catalyst value as it is.
-    case _ => input
-  }
-
-  private def expressionForNullableExpr(
-      expr: Expression,
-      newExprWhenNotNull: Expression): Expression = {
-    If(IsNull(expr), Literal.create(null, newExprWhenNotNull.dataType), 
newExprWhenNotNull)
+      UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])
+    case ArrayType(elementType, containsNull) =>
+      IterableEncoder(
+        classTag[mutable.WrappedArray[_]],
+        encoderForDataType(elementType, lenient),
+        containsNull,
+        lenientSerialization = true)
+    case MapType(keyType, valueType, valueContainsNull) =>
+      MapEncoder(
+        classTag[scala.collection.Map[_, _]],
+        encoderForDataType(keyType, lenient),
+        encoderForDataType(valueType, lenient),
+        valueContainsNull)
+    case StructType(fields) =>
+      AgnosticRowEncoder(fields.map { field =>
+        EncoderField(
+          field.name,
+          encoderForDataType(field.dataType, lenient),
+          field.nullable,
+          field.metadata)
+      })
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index a644b90a96f..56facda2af6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.catalyst.expressions.objects
 import java.lang.reflect.{Method, Modifier}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.collection.mutable.{Builder, WrappedArray}
 import scala.reflect.ClassTag
-import scala.util.{Properties, Try}
+import scala.util.Try
 
 import org.apache.commons.lang3.reflect.MethodUtils
 
@@ -30,7 +31,6 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -859,7 +859,7 @@ case class MapObjects private(
     case _ => inputData.dataType
   }
 
-  private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
+  private def executeFuncOnCollection(inputCollection: Iterable[_]): 
Iterator[_] = {
     val row = new GenericInternalRow(1)
     inputCollection.iterator.map { element =>
       row.update(0, element)
@@ -867,7 +867,7 @@ case class MapObjects private(
     }
   }
 
-  private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
+  private lazy val convertToSeq: Any => scala.collection.Seq[_] = 
inputDataType match {
     case ObjectType(cls) if 
classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
       _.asInstanceOf[scala.collection.Seq[_]].toSeq
     case ObjectType(cls) if cls.isArray =>
@@ -879,17 +879,33 @@ case class MapObjects private(
         if (inputCollection.getClass.isArray) {
           inputCollection.asInstanceOf[Array[_]].toSeq
         } else {
-          inputCollection.asInstanceOf[Seq[_]]
+          inputCollection.asInstanceOf[scala.collection.Seq[_]]
         }
       }
     case ArrayType(et, _) =>
       _.asInstanceOf[ArrayData].toSeq[Any](et)
   }
 
-  private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
+  private def elementClassTag(): ClassTag[Any] = {
+    val clazz = lambdaFunction.dataType match {
+      case ObjectType(cls) => cls
+      case dt if lambdaFunction.nullable => ScalaReflection.javaBoxedType(dt)
+      case dt => ScalaReflection.dataTypeJavaClass(dt)
+    }
+    ClassTag(clazz).asInstanceOf[ClassTag[Any]]
+  }
+
+  private lazy val mapElements: scala.collection.Seq[_] => Any = 
customCollectionCls match {
     case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
-      // Scala WrappedArray
-      inputCollection => 
WrappedArray.make(executeFuncOnCollection(inputCollection).toArray)
+      // The implicit tag is a workaround to deal with a small change in the
+      // (scala) signature of ArrayBuilder.make between Scala 2.12 and 2.13.
+      implicit val tag: ClassTag[Any] = elementClassTag()
+      input => {
+        val builder = mutable.ArrayBuilder.make[Any]
+        builder.sizeHint(input.size)
+        executeFuncOnCollection(input).foreach(builder += _)
+        mutable.WrappedArray.make(builder.result())
+      }
     case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
       // Scala sequence
       executeFuncOnCollection(_).toSeq
@@ -1047,44 +1063,20 @@ case class MapObjects private(
     val (initCollection, addElement, getResult): (String, String => String, 
String) =
       customCollectionCls match {
         case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
-          def doCodeGenForScala212 = {
-            // WrappedArray in Scala 2.12
-            val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
-            val builder = ctx.freshName("collectionBuilder")
-            (
-              s"""
-                 ${classOf[Builder[_, _]].getName} $builder = $getBuilder;
-                 $builder.sizeHint($dataLength);
-               """,
-              (genValue: String) => s"$builder.$$plus$$eq($genValue);",
-              s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
-                s"MODULE$$.make(((${classOf[IndexedSeq[_]].getName})$builder" +
-                
s".result()).toArray(scala.reflect.ClassTag$$.MODULE$$.Object()));"
-            )
-          }
-
-          def doCodeGenForScala213 = {
-            // In Scala 2.13, WrappedArray is mutable.ArraySeq and newBuilder 
method need
-            // a ClassTag type construction parameter
-            val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder(" +
-              s"scala.reflect.ClassTag$$.MODULE$$.Object())"
-            val builder = ctx.freshName("collectionBuilder")
-            (
-              s"""
+          val tag = ctx.addReferenceObj("tag", elementClassTag())
+          val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
+          val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
+          val builder = ctx.freshName("collectionBuilder")
+          (
+            s"""
                  ${classOf[Builder[_, _]].getName} $builder = $getBuilder;
                  $builder.sizeHint($dataLength);
                """,
-              (genValue: String) => s"$builder.$$plus$$eq($genValue);",
-              s"(${cls.getName})$builder.result();"
-            )
-          }
+            (genValue: String) => s"$builder.$$plus$$eq($genValue);",
+            s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
+              s"MODULE$$.make($builder.result());"
+          )
 
-          val scalaVersion = Properties.versionNumberString
-          if (scalaVersion.startsWith("2.12")) {
-            doCodeGenForScala212
-          } else {
-            doCodeGenForScala213
-          }
         case Some(cls) if 
classOf[scala.collection.Seq[_]].isAssignableFrom(cls) ||
           classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
           // Scala sequence or set
@@ -1908,14 +1900,14 @@ case class GetExternalRowField(
  * Validates the actual data type of input expression at runtime.  If it 
doesn't match the
  * expectation, throw an exception.
  */
-case class ValidateExternalType(child: Expression, expected: DataType, 
lenient: Boolean)
+case class ValidateExternalType(child: Expression, expected: DataType, 
externalDataType: DataType)
   extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
 
   override def inputTypes: Seq[AbstractDataType] = 
Seq(ObjectType(classOf[Object]))
 
   override def nullable: Boolean = child.nullable
 
-  override val dataType: DataType = 
RowEncoder.externalDataTypeForInput(expected, lenient)
+  override val dataType: DataType = externalDataType
 
   private lazy val errMsg = s" is not a valid external type for schema of 
${expected.simpleString}"
 
@@ -1927,7 +1919,9 @@ case class ValidateExternalType(child: Expression, 
expected: DataType, lenient:
       }
     case _: ArrayType =>
       (value: Any) => {
-        value.getClass.isArray || value.isInstanceOf[Seq[_]]
+        value.getClass.isArray ||
+          value.isInstanceOf[scala.collection.Seq[_]] ||
+          value.isInstanceOf[Set[_]]
       }
     case _: DateType =>
       (value: Any) => {
@@ -1968,7 +1962,8 @@ case class ValidateExternalType(child: Expression, 
expected: DataType, lenient:
           classOf[scala.math.BigDecimal],
           classOf[Decimal]))
       case _: ArrayType =>
-        s"$obj.getClass().isArray() || $obj instanceof 
${classOf[scala.collection.Seq[_]].getName}"
+        val check = genCheckTypes(Seq(classOf[scala.collection.Seq[_]], 
classOf[Set[_]]))
+        s"$obj.getClass().isArray() || $check"
       case _: DateType =>
         genCheckTypes(Seq(classOf[java.sql.Date], 
classOf[java.time.LocalDate]))
       case _: TimestampType =>
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 7e7ce29972b..f8ebdfe7676 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.FooEnum.FooEnum
 import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
 import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, 
Expression, If, SpecificInternalRow, UpCast}
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
MapObjects, NewInstance}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 
@@ -388,11 +388,10 @@ class ScalaReflectionSuite extends SparkFunSuite {
   }
 
   test("SPARK-15062: Get correct serializer for List[_]") {
-    val list = List(1, 2, 3)
     val serializer = serializerFor[List[Int]]
-    assert(serializer.isInstanceOf[NewInstance])
-    assert(serializer.asInstanceOf[NewInstance]
-      
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
+    assert(serializer.isInstanceOf[MapObjects])
+    val mapObjects = serializer.asInstanceOf[MapObjects]
+    assert(mapObjects.customCollectionCls.isEmpty)
   }
 
   test("SPARK 16792: Get correct deserializer for List[_]") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3a0db1ca121..c6546105231 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -480,6 +480,8 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
   encodeDecodeTest(ScroogeLikeExample(1),
     "SPARK-40385 class with only a companion object constructor")
 
+  encodeDecodeTest(Array(Set(1, 2), Set(2, 3)), "array of sets")
+
   productTest(("UDT", new ExamplePoint(0.1, 0.2)))
 
   test("AnyVal class with Any fields") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index c6bddfa5eee..b133b38a559 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
+import scala.collection.mutable
 import scala.util.Random
 
 import org.apache.spark.sql.{RandomDataGenerator, Row}
@@ -310,6 +311,19 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
     assert(e4.getMessage.contains("java.lang.String is not a valid external 
type"))
   }
 
+  private def roundTripArray[T](dt: DataType, nullable: Boolean, data: 
Array[T]): Unit = {
+    val schema = new StructType().add("a", ArrayType(dt, nullable))
+    test(s"RowEncoder should return WrappedArray with properly typed array for 
$schema") {
+      val encoder = RowEncoder(schema).resolveAndBind()
+      val result = fromRow(encoder, toRow(encoder, 
Row(data))).getAs[mutable.WrappedArray[_]](0)
+      assert(result.array.getClass === data.getClass)
+      assert(result === data)
+    }
+  }
+
+  roundTripArray(IntegerType, nullable = false, Array(1, 2, 3).map(Int.box))
+  roundTripArray(StringType, nullable = true, Array("hello", "world", "!", 
null))
+
   test("SPARK-25791: Datatype of serializers should be accessible") {
     val udtSQLType = new StructType().add("a", IntegerType)
     val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT", 
"serializedPyClass")
@@ -458,4 +472,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
       }
     }
   }
+
+  test("Encoding an ArraySeq/WrappedArray in scala-2.13") {
+    val schema = new StructType()
+      .add("headers", ArrayType(new StructType()
+        .add("key", StringType)
+        .add("value", BinaryType)))
+    val encoder = RowEncoder(schema, lenient = true).resolveAndBind()
+    val data = Row(mutable.WrappedArray.make(Array(Row("key", 
"value".getBytes))))
+    val row = encoder.createSerializer()(data)
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 737fcb1bada..265b0eeb8bd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       ValidateExternalType(
         GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"),
         IntegerType,
-        lenient = false) :: Nil)
+        IntegerType) :: Nil)
   }
 
   test("SPARK-17160: field names are properly escaped by AssertTrue") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 2286b734477..05ab7a65a32 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -496,10 +496,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       (java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
       (Array(3, 2, 1), ArrayType(IntegerType))
     ).foreach { case (input, dt) =>
+      val enc = RowEncoder.encoderForDataType(dt, lenient = false)
       val validateType = ValidateExternalType(
         GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
         dt,
-        lenient = false)
+        ScalaReflection.lenientExternalDataTypeFor(enc))
       checkObjectExprEvaluation(validateType, input, 
InternalRow.fromSeq(Seq(Row(input))))
     }
 
@@ -507,7 +508,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       ValidateExternalType(
         GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
         DoubleType,
-        lenient = false),
+        DoubleType),
       InternalRow.fromSeq(Seq(Row(1))),
       "java.lang.Integer is not a valid external type for schema of double")
   }
@@ -559,10 +560,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
 
     ExternalMapToCatalyst(
       inputObject,
-      ScalaReflection.dataTypeFor(keyEnc),
+      ScalaReflection.externalDataTypeFor(keyEnc),
       kvSerializerFor(keyEnc),
       keyNullable = keyEnc.nullable,
-      ScalaReflection.dataTypeFor(valueEnc),
+      ScalaReflection.externalDataTypeFor(valueEnc),
       kvSerializerFor(valueEnc),
       valueNullable = valueEnc.nullable
     )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to