This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5c7518cce43 Revert "[SPARK-41993][SQL] Move RowEncoder to
AgnosticEncoders"
5c7518cce43 is described below
commit 5c7518cce434098f8912db202b70c0f628c32852
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Mon Jan 16 14:45:37 2023 -0800
Revert "[SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders"
This reverts commit 2d4be52b71ba73a7c4586c3cd5faa6ede473cd4e.
---
.../spark/sql/catalyst/JavaTypeInference.scala | 4 +-
.../spark/sql/catalyst/ScalaReflection.scala | 315 ++++++------------
.../spark/sql/catalyst/SerializerBuildHelper.scala | 25 +-
.../sql/catalyst/encoders/AgnosticEncoder.scala | 128 ++------
.../sql/catalyst/encoders/ExpressionEncoder.scala | 5 +-
.../spark/sql/catalyst/encoders/RowEncoder.scala | 354 +++++++++++++++++----
.../sql/catalyst/expressions/objects/objects.scala | 76 +++--
.../spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +-
.../catalyst/encoders/ExpressionEncoderSuite.scala | 2 -
.../sql/catalyst/encoders/RowEncoderSuite.scala | 14 -
.../catalyst/expressions/CodeGenerationSuite.scala | 2 +-
.../expressions/ObjectExpressionsSuite.scala | 9 +-
12 files changed, 499 insertions(+), 444 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 81f363dda36..827807055ce 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -423,10 +423,10 @@ object JavaTypeInference {
case c if c == classOf[java.time.Period] =>
createSerializerForJavaPeriod(inputObject)
case c if c == classOf[java.math.BigInteger] =>
- createSerializerForBigInteger(inputObject)
+ createSerializerForJavaBigInteger(inputObject)
case c if c == classOf[java.math.BigDecimal] =>
- createSerializerForBigDecimal(inputObject)
+ createSerializerForJavaBigDecimal(inputObject)
case c if c == classOf[java.lang.Boolean] =>
createSerializerForBoolean(inputObject)
case c if c == classOf[java.lang.Byte] =>
createSerializerForByte(inputObject)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index bb14a47f51b..e02e42cea1a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst
import javax.lang.model.SourceVersion
import scala.annotation.tailrec
-import scala.language.existentials
import scala.reflect.ClassTag
import scala.reflect.internal.Symbols
import scala.util.{Failure, Success}
@@ -28,13 +27,12 @@ import scala.util.{Failure, Success}
import org.apache.commons.lang3.reflect.ConstructorUtils
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Expression, _}
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -84,24 +82,12 @@ object ScalaReflection extends ScalaReflection {
}
}
- /**
- * Return the data type we expect to see when deserializing a value with
encoder `enc`.
- */
- private[catalyst] def externalDataTypeFor(enc: AgnosticEncoder[_]): DataType
= {
- externalDataTypeFor(enc, lenientSerialization = false)
- }
-
- private[catalyst] def lenientExternalDataTypeFor(enc: AgnosticEncoder[_]):
DataType =
- externalDataTypeFor(enc, enc.lenientSerialization)
-
- private def externalDataTypeFor(
- enc: AgnosticEncoder[_],
- lenientSerialization: Boolean): DataType = {
+ // TODO this name is slightly misleading. This returns the input
+ // data type we expect to see during serialization.
+ private[catalyst] def dataTypeFor(enc: AgnosticEncoder[_]): DataType = {
// DataType can be native.
if (isNativeEncoder(enc)) {
enc.dataType
- } else if (lenientSerialization) {
- ObjectType(classOf[java.lang.Object])
} else {
ObjectType(enc.clsTag.runtimeClass)
}
@@ -137,7 +123,7 @@ object ScalaReflection extends ScalaReflection {
case NullEncoder => true
case CalendarIntervalEncoder => true
case BinaryEncoder => true
- case _: SparkDecimalEncoder => true
+ case SparkDecimalEncoder => true
case _ => false
}
@@ -169,19 +155,11 @@ object ScalaReflection extends ScalaReflection {
val walkedTypePath =
WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
// Assumes we are deserializing the first column of a row.
val input = GetColumnByOrdinal(0, enc.dataType)
- enc match {
- case RowEncoder(fields) =>
- val children = fields.zipWithIndex.map { case (f, i) =>
- deserializerFor(f.enc, GetStructField(input, i), walkedTypePath)
- }
- CreateExternalRow(children, enc.schema)
- case _ =>
- val deserializer = deserializerFor(
- enc,
- upCastToExpectedType(input, enc.dataType, walkedTypePath),
- walkedTypePath)
- expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
- }
+ val deserializer = deserializerFor(
+ enc,
+ upCastToExpectedType(input, enc.dataType, walkedTypePath),
+ walkedTypePath)
+ expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
}
/**
@@ -200,7 +178,19 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: WalkedTypePath): Expression = enc match {
case _ if isNativeEncoder(enc) =>
path
- case _: BoxedLeafEncoder[_, _] =>
+ case BoxedBooleanEncoder =>
+ createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
+ case BoxedByteEncoder =>
+ createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
+ case BoxedShortEncoder =>
+ createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
+ case BoxedIntEncoder =>
+ createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
+ case BoxedLongEncoder =>
+ createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
+ case BoxedFloatEncoder =>
+ createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
+ case BoxedDoubleEncoder =>
createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
case JavaEnumEncoder(tag) =>
val toString = createDeserializerForString(path, returnNullable = false)
@@ -214,9 +204,9 @@ object ScalaReflection extends ScalaReflection {
returnNullable = false)
case StringEncoder =>
createDeserializerForString(path, returnNullable = false)
- case _: ScalaDecimalEncoder =>
+ case ScalaDecimalEncoder =>
createDeserializerForScalaBigDecimal(path, returnNullable = false)
- case _: JavaDecimalEncoder =>
+ case JavaDecimalEncoder =>
createDeserializerForJavaBigDecimal(path, returnNullable = false)
case ScalaBigIntEncoder =>
createDeserializerForScalaBigInt(path)
@@ -226,13 +216,13 @@ object ScalaReflection extends ScalaReflection {
createDeserializerForDuration(path)
case YearMonthIntervalEncoder =>
createDeserializerForPeriod(path)
- case _: DateEncoder =>
+ case DateEncoder =>
createDeserializerForSqlDate(path)
- case _: LocalDateEncoder =>
+ case LocalDateEncoder =>
createDeserializerForLocalDate(path)
- case _: TimestampEncoder =>
+ case TimestampEncoder =>
createDeserializerForSqlTimestamp(path)
- case _: InstantEncoder =>
+ case InstantEncoder =>
createDeserializerForInstant(path)
case LocalDateTimeEncoder =>
createDeserializerForLocalDateTime(path)
@@ -242,29 +232,39 @@ object ScalaReflection extends ScalaReflection {
case OptionEncoder(valueEnc) =>
val newTypePath =
walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName)
val deserializer = deserializerFor(valueEnc, path, newTypePath)
- WrapOption(deserializer, externalDataTypeFor(valueEnc))
-
- case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) =>
+ WrapOption(deserializer, dataTypeFor(valueEnc))
+
+ case ArrayEncoder(elementEnc: AgnosticEncoder[_]) =>
+ val newTypePath =
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
+ val mapFunction: Expression => Expression = element => {
+ // upcast the array element to the data type the encoder expected.
+ deserializerForWithNullSafetyAndUpcast(
+ element,
+ elementEnc.dataType,
+ nullable = elementEnc.nullable,
+ newTypePath,
+ deserializerFor(elementEnc, _, newTypePath))
+ }
Invoke(
- deserializeArray(
- path,
- elementEnc,
- containsNull,
- None,
- walkedTypePath),
+ UnresolvedMapObjects(mapFunction, path),
toArrayMethodName(elementEnc),
ObjectType(enc.clsTag.runtimeClass),
returnNullable = false)
- case IterableEncoder(clsTag, elementEnc, containsNull, _) =>
- deserializeArray(
- path,
- elementEnc,
- containsNull,
- Option(clsTag.runtimeClass),
- walkedTypePath)
+ case IterableEncoder(clsTag, elementEnc) =>
+ val newTypePath =
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
+ val mapFunction: Expression => Expression = element => {
+ // upcast the array element to the data type the encoder expected.
+ deserializerForWithNullSafetyAndUpcast(
+ element,
+ elementEnc.dataType,
+ nullable = elementEnc.nullable,
+ newTypePath,
+ deserializerFor(elementEnc, _, newTypePath))
+ }
+ UnresolvedMapObjects(mapFunction, path, Some(clsTag.runtimeClass))
- case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
+ case MapEncoder(tag, keyEncoder, valueEncoder) =>
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
valueEncoder.clsTag.runtimeClass.getName)
@@ -298,39 +298,6 @@ object ScalaReflection extends ScalaReflection {
IsNull(path),
expressions.Literal.create(null, dt),
NewInstance(cls, arguments, dt, propagateNull = false))
-
- case RowEncoder(fields) =>
- val convertedFields = fields.zipWithIndex.map { case (f, i) =>
- val newTypePath = walkedTypePath.recordField(
- f.enc.clsTag.runtimeClass.getName,
- f.name)
- exprs.If(
- Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
- exprs.Literal.create(null, externalDataTypeFor(f.enc)),
- deserializerFor(f.enc, GetStructField(path, i), newTypePath))
- }
- exprs.If(IsNull(path),
- exprs.Literal.create(null, externalDataTypeFor(enc)),
- CreateExternalRow(convertedFields, enc.schema))
- }
-
- private def deserializeArray(
- path: Expression,
- elementEnc: AgnosticEncoder[_],
- containsNull: Boolean,
- cls: Option[Class[_]],
- walkedTypePath: WalkedTypePath): Expression = {
- val newTypePath =
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expects.
- deserializerForWithNullSafetyAndUpcast(
- element,
- elementEnc.dataType,
- nullable = containsNull,
- newTypePath,
- deserializerFor(elementEnc, _, newTypePath))
- }
- UnresolvedMapObjects(mapFunction, path, cls)
}
/**
@@ -339,7 +306,7 @@ object ScalaReflection extends ScalaReflection {
* input object is located at ordinal 0 of a row, i.e., `BoundReference(0,
_)`.
*/
def serializerFor(enc: AgnosticEncoder[_]): Expression = {
- val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable =
enc.nullable)
+ val input = BoundReference(0, dataTypeFor(enc), nullable = enc.nullable)
serializerFor(enc, input)
}
@@ -360,32 +327,25 @@ object ScalaReflection extends ScalaReflection {
case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
case StringEncoder => createSerializerForString(input)
- case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt)
- case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input,
dt)
- case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input,
dt)
- case ScalaBigIntEncoder => createSerializerForBigInteger(input)
- case JavaBigIntEncoder => createSerializerForBigInteger(input)
+ case ScalaDecimalEncoder => createSerializerForScalaBigDecimal(input)
+ case JavaDecimalEncoder => createSerializerForJavaBigDecimal(input)
+ case ScalaBigIntEncoder => createSerializerForScalaBigInt(input)
+ case JavaBigIntEncoder => createSerializerForJavaBigInteger(input)
case DayTimeIntervalEncoder => createSerializerForJavaDuration(input)
case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input)
- case DateEncoder(true) | LocalDateEncoder(true) =>
createSerializerForAnyDate(input)
- case DateEncoder(false) => createSerializerForSqlDate(input)
- case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input)
- case TimestampEncoder(true) | InstantEncoder(true) =>
createSerializerForAnyTimestamp(input)
- case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
- case InstantEncoder(false) => createSerializerForJavaInstant(input)
+ case DateEncoder => createSerializerForSqlDate(input)
+ case LocalDateEncoder => createSerializerForJavaLocalDate(input)
+ case TimestampEncoder => createSerializerForSqlTimestamp(input)
+ case InstantEncoder => createSerializerForJavaInstant(input)
case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
case UDTEncoder(udt, udtClass) =>
createSerializerForUserDefinedType(input, udt, udtClass)
case OptionEncoder(valueEnc) =>
- serializerFor(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc),
input))
+ serializerFor(valueEnc, UnwrapOption(dataTypeFor(valueEnc), input))
- case ArrayEncoder(elementEncoder, containsNull) =>
- if (elementEncoder.isPrimitive) {
- createSerializerForPrimitiveArray(input, elementEncoder.dataType)
- } else {
- serializerForArray(elementEncoder, containsNull, input,
lenientSerialization = false)
- }
+ case ArrayEncoder(elementEncoder) =>
+ serializerForArray(isArray = true, elementEncoder, input)
- case IterableEncoder(ctag, elementEncoder, containsNull,
lenientSerialization) =>
+ case IterableEncoder(ctag, elementEncoder) =>
val getter = if
(classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) {
// There's no corresponding Catalyst type for `Set`, we serialize a
`Set` to Catalyst array.
// Note that the property of `Set` is only kept when manipulating the
data as domain object.
@@ -393,19 +353,19 @@ object ScalaReflection extends ScalaReflection {
} else {
input
}
- serializerForArray(elementEncoder, containsNull, getter,
lenientSerialization)
+ serializerForArray(isArray = false, elementEncoder, getter)
- case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) =>
+ case MapEncoder(_, keyEncoder, valueEncoder) =>
createSerializerForMap(
input,
MapElementInformation(
- ObjectType(classOf[AnyRef]),
- nullable = keyEncoder.nullable,
- validateAndSerializeElement(keyEncoder, keyEncoder.nullable)),
+ dataTypeFor(keyEncoder),
+ nullable = !keyEncoder.isPrimitive,
+ serializerFor(keyEncoder, _)),
MapElementInformation(
- ObjectType(classOf[AnyRef]),
- nullable = valueContainsNull,
- validateAndSerializeElement(valueEncoder, valueContainsNull))
+ dataTypeFor(valueEncoder),
+ nullable = !valueEncoder.isPrimitive,
+ serializerFor(valueEncoder, _))
)
case ProductEncoder(_, fields) =>
@@ -417,94 +377,25 @@ object ScalaReflection extends ScalaReflection {
val getter = Invoke(
KnownNotNull(input),
field.name,
- externalDataTypeFor(field.enc),
- returnNullable = field.nullable)
+ dataTypeFor(field.enc),
+ returnNullable = field.enc.nullable)
field.name -> serializerFor(field.enc, getter)
}
createSerializerForObject(input, serializedFields)
-
- case RowEncoder(fields) =>
- val serializedFields = fields.zipWithIndex.map { case (field, index) =>
- val fieldValue = serializerFor(
- field.enc,
- ValidateExternalType(
- GetExternalRowField(input, index, field.name),
- field.enc.dataType,
- lenientExternalDataTypeFor(field.enc)))
-
- val convertedField = if (field.nullable) {
- exprs.If(
- Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) ::
Nil),
- // Because we strip UDTs, `field.dataType` can be different from
`fieldValue.dataType`.
- // We should use `fieldValue.dataType` here.
- exprs.Literal.create(null, fieldValue.dataType),
- fieldValue
- )
- } else {
- AssertNotNull(fieldValue)
- }
- field.name -> convertedField
- }
- createSerializerForObject(input, serializedFields)
}
private def serializerForArray(
+ isArray: Boolean,
elementEnc: AgnosticEncoder[_],
- elementNullable: Boolean,
- input: Expression,
- lenientSerialization: Boolean): Expression = {
- // Default serializer for Seq and generic Arrays. This does not work for
primitive arrays.
- val genericSerializer = createSerializerForMapObjects(
- input,
- ObjectType(classOf[AnyRef]),
- validateAndSerializeElement(elementEnc, elementNullable))
-
- // Check if it is possible the user can pass a primitive array. This is
the only case when it
- // is safe to directly convert to an array (for generic arrays and Seqs
the type and the
- // nullability can be violated). If the user has passed a primitive array
we create a special
- // code path to deal with these.
- val primitiveEncoderOption = elementEnc match {
- case _ if !lenientSerialization => None
- case enc: PrimitiveLeafEncoder[_] => Option(enc)
- case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive)
- case _ => None
+ input: Expression): Expression = {
+ dataTypeFor(elementEnc) match {
+ case dt: ObjectType =>
+ createSerializerForMapObjects(input, dt, serializerFor(elementEnc, _))
+ case dt if isArray && elementEnc.isPrimitive =>
+ createSerializerForPrimitiveArray(input, dt)
+ case dt =>
+ createSerializerForGenericArray(input, dt, elementEnc.nullable)
}
- primitiveEncoderOption match {
- case Some(primitiveEncoder) =>
- val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass
- val check = Invoke(
- targetObject = exprs.Literal.fromObject(primitiveArrayClass),
- functionName = "isInstance",
- BooleanType,
- arguments = input :: Nil,
- propagateNull = false,
- returnNullable = false)
- exprs.If(
- check,
- // TODO replace this with `createSerializerForPrimitiveArray` as
- // soon as Cast support ObjectType casts.
- StaticInvoke(
- classOf[ArrayData],
- ArrayType(elementEnc.dataType, containsNull = false),
- "toArrayData",
- input :: Nil,
- propagateNull = false,
- returnNullable = false),
- genericSerializer)
- case None =>
- genericSerializer
- }
- }
-
- private def validateAndSerializeElement(
- enc: AgnosticEncoder[_],
- nullable: Boolean): Expression => Expression = { input =>
- expressionWithNullSafety(
- serializerFor(
- enc,
- ValidateExternalType(input, enc.dataType,
lenientExternalDataTypeFor(enc))),
- nullable,
- WalkedTypePath())
}
/**
@@ -707,8 +598,8 @@ object ScalaReflection extends ScalaReflection {
case StringType => classOf[UTF8String]
case CalendarIntervalType => classOf[CalendarInterval]
case _: StructType => classOf[InternalRow]
- case _: ArrayType => classOf[ArrayData]
- case _: MapType => classOf[MapData]
+ case _: ArrayType => classOf[ArrayType]
+ case _: MapType => classOf[MapType]
case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
case ObjectType(cls) => cls
case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt,
classOf[java.lang.Object])
@@ -766,11 +657,7 @@ object ScalaReflection extends ScalaReflection {
case NoSymbol => fallbackClass
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
- IterableEncoder(
- ClassTag(targetClass),
- encoder,
- encoder.nullable,
- lenientSerialization = false)
+ IterableEncoder(ClassTag(targetClass), encoder)
}
baseType(tpe) match {
@@ -811,18 +698,18 @@ object ScalaReflection extends ScalaReflection {
// Leaf encoders
case t if isSubtype(t, localTypeOf[String]) => StringEncoder
- case t if isSubtype(t, localTypeOf[Decimal]) =>
DEFAULT_SPARK_DECIMAL_ENCODER
- case t if isSubtype(t, localTypeOf[BigDecimal]) =>
DEFAULT_SCALA_DECIMAL_ENCODER
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
DEFAULT_JAVA_DECIMAL_ENCODER
+ case t if isSubtype(t, localTypeOf[Decimal]) => SparkDecimalEncoder
+ case t if isSubtype(t, localTypeOf[BigDecimal]) => ScalaDecimalEncoder
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
JavaDecimalEncoder
case t if isSubtype(t, localTypeOf[BigInt]) => ScalaBigIntEncoder
case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
JavaBigIntEncoder
case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
CalendarIntervalEncoder
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
DayTimeIntervalEncoder
case t if isSubtype(t, localTypeOf[java.time.Period]) =>
YearMonthIntervalEncoder
- case t if isSubtype(t, localTypeOf[java.sql.Date]) => STRICT_DATE_ENCODER
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
STRICT_LOCAL_DATE_ENCODER
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
STRICT_TIMESTAMP_ENCODER
- case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
STRICT_INSTANT_ENCODER
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) => DateEncoder
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
LocalDateEncoder
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
TimestampEncoder
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) => InstantEncoder
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
LocalDateTimeEncoder
// UDT encoders
@@ -852,7 +739,7 @@ object ScalaReflection extends ScalaReflection {
elementType,
seenTypeSet,
path.recordArray(getClassNameFromType(elementType)))
- ArrayEncoder(encoder, encoder.nullable)
+ ArrayEncoder(encoder)
case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) =>
createIterableEncoder(t, classOf[scala.collection.Seq[_]])
@@ -870,7 +757,7 @@ object ScalaReflection extends ScalaReflection {
valueType,
seenTypeSet,
path.recordValueForMap(getClassNameFromType(valueType)))
- MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder,
valueEncoder.nullable)
+ MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder)
case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
@@ -888,7 +775,7 @@ object ScalaReflection extends ScalaReflection {
fieldType,
seenTypeSet + t,
path.recordField(getClassNameFromType(fieldType), fieldName))
- EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
+ EncoderField(fieldName, encoder)
}
ProductEncoder(ClassTag(getClassFromType(t)), params)
case _ =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 33b0edb0c44..25f6ce520d9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -158,29 +158,20 @@ object SerializerBuildHelper {
returnNullable = false)
}
- def createSerializerForBigDecimal(inputObject: Expression): Expression = {
- createSerializerForBigDecimal(inputObject, DecimalType.SYSTEM_DEFAULT)
- }
-
- def createSerializerForBigDecimal(inputObject: Expression, dt: DecimalType):
Expression = {
+ def createSerializerForJavaBigDecimal(inputObject: Expression): Expression =
{
CheckOverflow(StaticInvoke(
Decimal.getClass,
- dt,
+ DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil,
- returnNullable = false), dt, nullOnOverflow)
+ returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
}
- def createSerializerForAnyDecimal(inputObject: Expression, dt: DecimalType):
Expression = {
- CheckOverflow(StaticInvoke(
- Decimal.getClass,
- dt,
- "fromDecimal",
- inputObject :: Nil,
- returnNullable = false), dt, nullOnOverflow)
+ def createSerializerForScalaBigDecimal(inputObject: Expression): Expression
= {
+ createSerializerForJavaBigDecimal(inputObject)
}
- def createSerializerForBigInteger(inputObject: Expression): Expression = {
+ def createSerializerForJavaBigInteger(inputObject: Expression): Expression =
{
CheckOverflow(StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
@@ -189,6 +180,10 @@ object SerializerBuildHelper {
returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
}
+ def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
+ createSerializerForJavaBigInteger(inputObject)
+ }
+
def createSerializerForPrimitiveArray(
inputObject: Expression,
dataType: DataType): Expression = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index cdc64f2ddb5..6081ac8dc28 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -16,33 +16,28 @@
*/
package org.apache.spark.sql.catalyst.encoders
-import java.{sql => jsql}
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import scala.reflect.{classTag, ClassTag}
-import org.apache.spark.sql.{Encoder, Row}
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
/**
* A non implementation specific encoder. This encoder containers all the
information needed
* to generate an implementation specific encoder (e.g. InternalRow <=> Custom
Object).
- *
- * The input of the serialization does not need to match the external type of
the encoder. This is
- * called lenient serialization. An example of this is lenient date
serialization, in this case both
- * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization
is never lenient; it
- * will always produce instance of the external type.
*/
trait AgnosticEncoder[T] extends Encoder[T] {
def isPrimitive: Boolean
def nullable: Boolean = !isPrimitive
def dataType: DataType
override def schema: StructType = StructType(StructField("value", dataType,
nullable) :: Nil)
- def lenientSerialization: Boolean = false
}
+// TODO check RowEncoder
+// TODO check BeanEncoder
object AgnosticEncoders {
case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
extends AgnosticEncoder[Option[E]] {
@@ -51,48 +46,35 @@ object AgnosticEncoders {
override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]])
}
- case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull:
Boolean)
+ case class ArrayEncoder[E](element: AgnosticEncoder[E])
extends AgnosticEncoder[Array[E]] {
override def isPrimitive: Boolean = false
- override def dataType: DataType = ArrayType(element.dataType, containsNull)
+ override def dataType: DataType = ArrayType(element.dataType,
element.nullable)
override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap
}
- /**
- * Encoder for collections.
- *
- * This encoder can be lenient for [[Row]] encoders. In that case we allow
[[Seq]], primitive
- * array (if any), and generic arrays as input.
- */
- case class IterableEncoder[C, E](
+ case class IterableEncoder[C <: Iterable[E], E](
override val clsTag: ClassTag[C],
- element: AgnosticEncoder[E],
- containsNull: Boolean,
- override val lenientSerialization: Boolean)
+ element: AgnosticEncoder[E])
extends AgnosticEncoder[C] {
override def isPrimitive: Boolean = false
- override val dataType: DataType = ArrayType(element.dataType, containsNull)
+ override val dataType: DataType = ArrayType(element.dataType,
element.nullable)
}
case class MapEncoder[C, K, V](
override val clsTag: ClassTag[C],
keyEncoder: AgnosticEncoder[K],
- valueEncoder: AgnosticEncoder[V],
- valueContainsNull: Boolean)
+ valueEncoder: AgnosticEncoder[V])
extends AgnosticEncoder[C] {
override def isPrimitive: Boolean = false
override val dataType: DataType = MapType(
keyEncoder.dataType,
valueEncoder.dataType,
- valueContainsNull)
+ valueEncoder.nullable)
}
- case class EncoderField(
- name: String,
- enc: AgnosticEncoder[_],
- nullable: Boolean,
- metadata: Metadata) {
- def structField: StructField = StructField(name, enc.dataType, nullable,
metadata)
+ case class EncoderField(name: String, enc: AgnosticEncoder[_]) {
+ def structField: StructField = StructField(name, enc.dataType,
enc.nullable)
}
// This supports both Product and DefinedByConstructorParams
@@ -105,13 +87,6 @@ object AgnosticEncoders {
override def dataType: DataType = schema
}
- case class RowEncoder(fields: Seq[EncoderField]) extends
AgnosticEncoder[Row] {
- override def isPrimitive: Boolean = false
- override val schema: StructType = StructType(fields.map(_.structField))
- override def dataType: DataType = schema
- override def clsTag: ClassTag[Row] = classTag[Row]
- }
-
// This will only work for encoding from/to Sparks' InternalRow format.
// It is here for compatibility.
case class UDTEncoder[E >: Null](
@@ -141,74 +116,39 @@ object AgnosticEncoders {
}
// Primitive encoders
- abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType)
- extends LeafEncoder[E](dataType)
- case object PrimitiveBooleanEncoder extends
PrimitiveLeafEncoder[Boolean](BooleanType)
- case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType)
- case object PrimitiveShortEncoder extends
PrimitiveLeafEncoder[Short](ShortType)
- case object PrimitiveIntEncoder extends
PrimitiveLeafEncoder[Int](IntegerType)
- case object PrimitiveLongEncoder extends PrimitiveLeafEncoder[Long](LongType)
- case object PrimitiveFloatEncoder extends
PrimitiveLeafEncoder[Float](FloatType)
- case object PrimitiveDoubleEncoder extends
PrimitiveLeafEncoder[Double](DoubleType)
+ case object PrimitiveBooleanEncoder extends LeafEncoder[Boolean](BooleanType)
+ case object PrimitiveByteEncoder extends LeafEncoder[Byte](ByteType)
+ case object PrimitiveShortEncoder extends LeafEncoder[Short](ShortType)
+ case object PrimitiveIntEncoder extends LeafEncoder[Int](IntegerType)
+ case object PrimitiveLongEncoder extends LeafEncoder[Long](LongType)
+ case object PrimitiveFloatEncoder extends LeafEncoder[Float](FloatType)
+ case object PrimitiveDoubleEncoder extends LeafEncoder[Double](DoubleType)
// Primitive wrapper encoders.
- abstract class BoxedLeafEncoder[E : ClassTag, P](
- dataType: DataType,
- val primitive: PrimitiveLeafEncoder[P])
- extends LeafEncoder[E](dataType)
- case object BoxedBooleanEncoder
- extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType,
PrimitiveBooleanEncoder)
- case object BoxedByteEncoder
- extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType,
PrimitiveByteEncoder)
- case object BoxedShortEncoder
- extends BoxedLeafEncoder[java.lang.Short, Short](ShortType,
PrimitiveShortEncoder)
- case object BoxedIntEncoder
- extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType,
PrimitiveIntEncoder)
- case object BoxedLongEncoder
- extends BoxedLeafEncoder[java.lang.Long, Long](LongType,
PrimitiveLongEncoder)
- case object BoxedFloatEncoder
- extends BoxedLeafEncoder[java.lang.Float, Float](FloatType,
PrimitiveFloatEncoder)
- case object BoxedDoubleEncoder
- extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType,
PrimitiveDoubleEncoder)
+ case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
+ case object BoxedBooleanEncoder extends
LeafEncoder[java.lang.Boolean](BooleanType)
+ case object BoxedByteEncoder extends LeafEncoder[java.lang.Byte](ByteType)
+ case object BoxedShortEncoder extends LeafEncoder[java.lang.Short](ShortType)
+ case object BoxedIntEncoder extends
LeafEncoder[java.lang.Integer](IntegerType)
+ case object BoxedLongEncoder extends LeafEncoder[java.lang.Long](LongType)
+ case object BoxedFloatEncoder extends LeafEncoder[java.lang.Float](FloatType)
+ case object BoxedDoubleEncoder extends
LeafEncoder[java.lang.Double](DoubleType)
// Nullable leaf encoders
- case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
case object StringEncoder extends LeafEncoder[String](StringType)
case object BinaryEncoder extends LeafEncoder[Array[Byte]](BinaryType)
+ case object SparkDecimalEncoder extends
LeafEncoder[Decimal](DecimalType.SYSTEM_DEFAULT)
+ case object ScalaDecimalEncoder extends
LeafEncoder[BigDecimal](DecimalType.SYSTEM_DEFAULT)
+ case object JavaDecimalEncoder extends
LeafEncoder[JBigDecimal](DecimalType.SYSTEM_DEFAULT)
case object ScalaBigIntEncoder extends
LeafEncoder[BigInt](DecimalType.BigIntDecimal)
case object JavaBigIntEncoder extends
LeafEncoder[JBigInt](DecimalType.BigIntDecimal)
case object CalendarIntervalEncoder extends
LeafEncoder[CalendarInterval](CalendarIntervalType)
case object DayTimeIntervalEncoder extends
LeafEncoder[Duration](DayTimeIntervalType())
case object YearMonthIntervalEncoder extends
LeafEncoder[Period](YearMonthIntervalType())
- case class DateEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[jsql.Date](DateType)
- case class LocalDateEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[LocalDate](DateType)
- case class TimestampEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[jsql.Timestamp](TimestampType)
- case class InstantEncoder(override val lenientSerialization: Boolean)
- extends LeafEncoder[Instant](TimestampType)
+ case object DateEncoder extends LeafEncoder[java.sql.Date](DateType)
+ case object LocalDateEncoder extends LeafEncoder[LocalDate](DateType)
+ case object TimestampEncoder extends
LeafEncoder[java.sql.Timestamp](TimestampType)
+ case object InstantEncoder extends LeafEncoder[Instant](TimestampType)
case object LocalDateTimeEncoder extends
LeafEncoder[LocalDateTime](TimestampNTZType)
-
- case class SparkDecimalEncoder(dt: DecimalType) extends
LeafEncoder[Decimal](dt)
- case class ScalaDecimalEncoder(dt: DecimalType) extends
LeafEncoder[BigDecimal](dt)
- case class JavaDecimalEncoder(dt: DecimalType, override val
lenientSerialization: Boolean)
- extends LeafEncoder[JBigDecimal](dt)
-
- val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization =
false)
- val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder =
LocalDateEncoder(lenientSerialization = false)
- val STRICT_TIMESTAMP_ENCODER: TimestampEncoder =
TimestampEncoder(lenientSerialization = false)
- val STRICT_INSTANT_ENCODER: InstantEncoder =
InstantEncoder(lenientSerialization = false)
- val LENIENT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization =
true)
- val LENIENT_LOCAL_DATE_ENCODER: LocalDateEncoder =
LocalDateEncoder(lenientSerialization = true)
- val LENIENT_TIMESTAMP_ENCODER: TimestampEncoder =
TimestampEncoder(lenientSerialization = true)
- val LENIENT_INSTANT_ENCODER: InstantEncoder =
InstantEncoder(lenientSerialization = true)
-
- val DEFAULT_SPARK_DECIMAL_ENCODER: SparkDecimalEncoder =
- SparkDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
- val DEFAULT_SCALA_DECIMAL_ENCODER: ScalaDecimalEncoder =
- ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
- val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
- JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization =
false)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9ca2fc72ad9..82a6863b5ff 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -47,10 +47,7 @@ import org.apache.spark.util.Utils
object ExpressionEncoder {
def apply[T : TypeTag](): ExpressionEncoder[T] = {
- apply(ScalaReflection.encoderFor[T])
- }
-
- def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = {
+ val enc = ScalaReflection.encoderFor[T]
new ExpressionEncoder[T](
ScalaReflection.serializerFor(enc),
ScalaReflection.deserializerFor(enc),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 78243894544..8eb3475acaa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -17,11 +17,19 @@
package org.apache.spark.sql.catalyst.encoders
-import scala.collection.mutable
-import scala.reflect.classTag
+import scala.annotation.tailrec
+import scala.collection.Map
+import scala.reflect.ClassTag
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder,
TimestampEncoder, UDTEncoder, YearMont [...]
+import org.apache.spark.sql.catalyst.{ScalaReflection, WalkedTypePath}
+import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
+import org.apache.spark.sql.catalyst.SerializerBuildHelper._
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.objects._
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -60,46 +68,27 @@ import org.apache.spark.sql.types._
*/
object RowEncoder {
def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
- ExpressionEncoder(encoderFor(schema, lenient))
+ val cls = classOf[Row]
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val serializer = serializerFor(inputObject, schema, lenient)
+ val deserializer = deserializerFor(GetColumnByOrdinal(0,
serializer.dataType), schema)
+ new ExpressionEncoder[Row](
+ serializer,
+ deserializer,
+ ClassTag(cls))
}
-
def apply(schema: StructType): ExpressionEncoder[Row] = {
apply(schema, lenient = false)
}
- def encoderFor(schema: StructType): AgnosticEncoder[Row] = {
- encoderFor(schema, lenient = false)
- }
+ private def serializerFor(
+ inputObject: Expression,
+ inputType: DataType,
+ lenient: Boolean): Expression = inputType match {
+ case dt if ScalaReflection.isNativeType(dt) => inputObject
- def encoderFor(schema: StructType, lenient: Boolean): AgnosticEncoder[Row] =
{
- encoderForDataType(schema, lenient).asInstanceOf[AgnosticEncoder[Row]]
- }
+ case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType,
lenient)
- private[catalyst] def encoderForDataType(
- dataType: DataType,
- lenient: Boolean): AgnosticEncoder[_] = dataType match {
- case NullType => NullEncoder
- case BooleanType => BoxedBooleanEncoder
- case ByteType => BoxedByteEncoder
- case ShortType => BoxedShortEncoder
- case IntegerType => BoxedIntEncoder
- case LongType => BoxedLongEncoder
- case FloatType => BoxedFloatEncoder
- case DoubleType => BoxedDoubleEncoder
- case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
- case BinaryType => BinaryEncoder
- case StringType => StringEncoder
- case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
InstantEncoder(lenient)
- case TimestampType => TimestampEncoder(lenient)
- case TimestampNTZType => LocalDateTimeEncoder
- case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
LocalDateEncoder(lenient)
- case DateType => DateEncoder(lenient)
- case CalendarIntervalType => CalendarIntervalEncoder
- case _: DayTimeIntervalType => DayTimeIntervalEncoder
- case _: YearMonthIntervalType => YearMonthIntervalEncoder
- case p: PythonUserDefinedType =>
- // TODO check if this works.
- encoderForDataType(p.sqlType, lenient)
case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
@@ -109,26 +98,281 @@ object RowEncoder {
throw
QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
}
}
- UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])
- case ArrayType(elementType, containsNull) =>
- IterableEncoder(
- classTag[mutable.WrappedArray[_]],
- encoderForDataType(elementType, lenient),
- containsNull,
- lenientSerialization = true)
- case MapType(keyType, valueType, valueContainsNull) =>
- MapEncoder(
- classTag[scala.collection.Map[_, _]],
- encoderForDataType(keyType, lenient),
- encoderForDataType(valueType, lenient),
- valueContainsNull)
+ val obj = NewInstance(
+ udtClass,
+ Nil,
+ dataType = ObjectType(udtClass), false)
+ Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
+
+ case TimestampType =>
+ if (lenient) {
+ createSerializerForAnyTimestamp(inputObject)
+ } else if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createSerializerForJavaInstant(inputObject)
+ } else {
+ createSerializerForSqlTimestamp(inputObject)
+ }
+
+ case TimestampNTZType => createSerializerForLocalDateTime(inputObject)
+
+ case DateType =>
+ if (lenient) {
+ createSerializerForAnyDate(inputObject)
+ } else if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createSerializerForJavaLocalDate(inputObject)
+ } else {
+ createSerializerForSqlDate(inputObject)
+ }
+
+ case _: DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
+
+ case _: YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)
+
+ case d: DecimalType =>
+ CheckOverflow(StaticInvoke(
+ Decimal.getClass,
+ d,
+ "fromDecimal",
+ inputObject :: Nil,
+ returnNullable = false), d, !SQLConf.get.ansiEnabled)
+
+ case StringType => createSerializerForString(inputObject)
+
+ case t @ ArrayType(et, containsNull) =>
+ et match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType =>
+ StaticInvoke(
+ classOf[ArrayData],
+ t,
+ "toArrayData",
+ inputObject :: Nil,
+ returnNullable = false)
+
+ case _ =>
+ createSerializerForMapObjects(
+ inputObject,
+ ObjectType(classOf[Object]),
+ element => {
+ val value = serializerFor(ValidateExternalType(element, et,
lenient), et, lenient)
+ expressionWithNullSafety(value, containsNull, WalkedTypePath())
+ })
+ }
+
+ case t @ MapType(kt, vt, valueNullable) =>
+ val keys =
+ Invoke(
+ Invoke(inputObject, "keysIterator",
ObjectType(classOf[scala.collection.Iterator[_]]),
+ returnNullable = false),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
+ val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient)
+
+ val values =
+ Invoke(
+ Invoke(inputObject, "valuesIterator",
ObjectType(classOf[scala.collection.Iterator[_]]),
+ returnNullable = false),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
+ val convertedValues = serializerFor(values, ArrayType(vt,
valueNullable), lenient)
+
+ val nonNullOutput = NewInstance(
+ classOf[ArrayBasedMapData],
+ convertedKeys :: convertedValues :: Nil,
+ dataType = t,
+ propagateNull = false)
+
+ if (inputObject.nullable) {
+ expressionForNullableExpr(inputObject, nonNullOutput)
+ } else {
+ nonNullOutput
+ }
+
case StructType(fields) =>
- AgnosticRowEncoder(fields.map { field =>
- EncoderField(
- field.name,
- encoderForDataType(field.dataType, lenient),
- field.nullable,
- field.metadata)
+ val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case
(field, index) =>
+ val fieldValue = serializerFor(
+ ValidateExternalType(
+ GetExternalRowField(inputObject, index, field.name),
+ field.dataType,
+ lenient),
+ field.dataType,
+ lenient)
+ val convertedField = if (field.nullable) {
+ If(
+ Invoke(inputObject, "isNullAt", BooleanType, Literal(index) ::
Nil),
+ // Because we strip UDTs, `field.dataType` can be different from
`fieldValue.dataType`.
+ // We should use `fieldValue.dataType` here.
+ Literal.create(null, fieldValue.dataType),
+ fieldValue
+ )
+ } else {
+ fieldValue
+ }
+ Literal(field.name) :: convertedField :: Nil
})
+
+ if (inputObject.nullable) {
+ expressionForNullableExpr(inputObject, nonNullOutput)
+ } else {
+ nonNullOutput
+ }
+ // For other data types, return the internal catalyst value as it is.
+ case _ => inputObject
+ }
+
+ /**
+ * Returns the `DataType` that can be used when generating code that
converts input data
+ * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the
`DataType` returned
+ * by this function can be more permissive since multiple external types may
map to a single
+ * internal type. For example, for an input with DecimalType in external
row, its external types
+ * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
+ * `org.apache.spark.sql.types.Decimal`.
+ */
+ def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt
match {
+ // In order to support both Decimal and java/scala BigDecimal in external
row, we make this
+ // as java.lang.Object.
+ case _: DecimalType => ObjectType(classOf[java.lang.Object])
+ // In order to support both Array and Seq in external row, we make this as
java.lang.Object.
+ case _: ArrayType => ObjectType(classOf[java.lang.Object])
+ case _: DateType | _: TimestampType if lenient =>
ObjectType(classOf[java.lang.Object])
+ case _ => externalDataTypeFor(dt)
+ }
+
+ @tailrec
+ def externalDataTypeFor(dt: DataType): DataType = dt match {
+ case _ if ScalaReflection.isNativeType(dt) => dt
+ case TimestampType =>
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ ObjectType(classOf[java.time.Instant])
+ } else {
+ ObjectType(classOf[java.sql.Timestamp])
+ }
+ case TimestampNTZType =>
+ ObjectType(classOf[java.time.LocalDateTime])
+ case DateType =>
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ ObjectType(classOf[java.time.LocalDate])
+ } else {
+ ObjectType(classOf[java.sql.Date])
+ }
+ case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
+ case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period])
+ case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
+ case udt: UserDefinedType[_] => ObjectType(udt.userClass)
+ case _ => dt.physicalDataType match {
+ case _: PhysicalArrayType => ObjectType(classOf[scala.collection.Seq[_]])
+ case _: PhysicalDecimalType => ObjectType(classOf[java.math.BigDecimal])
+ case _: PhysicalMapType => ObjectType(classOf[scala.collection.Map[_,
_]])
+ case PhysicalStringType => ObjectType(classOf[java.lang.String])
+ case _: PhysicalStructType => ObjectType(classOf[Row])
+ // For other data types, return the data type as it is.
+ case _ => dt
+ }
+ }
+
+ private def deserializerFor(input: Expression, schema: StructType):
Expression = {
+ val fields = schema.zipWithIndex.map { case (f, i) =>
+ deserializerFor(GetStructField(input, i))
+ }
+ CreateExternalRow(fields, schema)
+ }
+
+ private def deserializerFor(input: Expression): Expression = {
+ deserializerFor(input, input.dataType)
+ }
+
+ @tailrec
+ private def deserializerFor(input: Expression, dataType: DataType):
Expression = dataType match {
+ case dt if ScalaReflection.isNativeType(dt) => input
+
+ case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)
+
+ case udt: UserDefinedType[_] =>
+ val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
+ val udtClass: Class[_] = if (annotation != null) {
+ annotation.udt()
+ } else {
+ UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
+ throw
QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
+ }
+ }
+ val obj = NewInstance(
+ udtClass,
+ Nil,
+ dataType = ObjectType(udtClass))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
+
+ case TimestampType =>
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createDeserializerForInstant(input)
+ } else {
+ createDeserializerForSqlTimestamp(input)
+ }
+
+ case TimestampNTZType =>
+ createDeserializerForLocalDateTime(input)
+
+ case DateType =>
+ if (SQLConf.get.datetimeJava8ApiEnabled) {
+ createDeserializerForLocalDate(input)
+ } else {
+ createDeserializerForSqlDate(input)
+ }
+
+ case _: DayTimeIntervalType => createDeserializerForDuration(input)
+
+ case _: YearMonthIntervalType => createDeserializerForPeriod(input)
+
+ case _: DecimalType => createDeserializerForJavaBigDecimal(input,
returnNullable = false)
+
+ case StringType => createDeserializerForString(input, returnNullable =
false)
+
+ case ArrayType(et, nullable) =>
+ val arrayData =
+ Invoke(
+ MapObjects(deserializerFor(_), input, et),
+ "array",
+ ObjectType(classOf[Array[_]]), returnNullable = false)
+ // TODO should use `scala.collection.immutable.ArrayDeq.unsafeMake`
method to create
+ // `immutable.Seq` in Scala 2.13 when Scala version compatibility is no
longer required.
+ StaticInvoke(
+ scala.collection.mutable.WrappedArray.getClass,
+ ObjectType(classOf[scala.collection.Seq[_]]),
+ "make",
+ arrayData :: Nil,
+ returnNullable = false)
+
+ case MapType(kt, vt, valueNullable) =>
+ val keyArrayType = ArrayType(kt, false)
+ val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))
+
+ val valueArrayType = ArrayType(vt, valueNullable)
+ val valueData = deserializerFor(Invoke(input, "valueArray",
valueArrayType))
+
+ StaticInvoke(
+ ArrayBasedMapData.getClass,
+ ObjectType(classOf[Map[_, _]]),
+ "toScalaMap",
+ keyData :: valueData :: Nil,
+ returnNullable = false)
+
+ case schema @ StructType(fields) =>
+ val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+ If(
+ Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
+ Literal.create(null, externalDataTypeFor(f.dataType)),
+ deserializerFor(GetStructField(input, i)))
+ }
+ If(IsNull(input),
+ Literal.create(null, externalDataTypeFor(input.dataType)),
+ CreateExternalRow(convertedFields, schema))
+
+ // For other data types, return the internal catalyst value as it is.
+ case _ => input
+ }
+
+ private def expressionForNullableExpr(
+ expr: Expression,
+ newExprWhenNotNull: Expression): Expression = {
+ If(IsNull(expr), Literal.create(null, newExprWhenNotNull.dataType),
newExprWhenNotNull)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index b8313bda069..a644b90a96f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -20,10 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.{Method, Modifier}
import scala.collection.JavaConverters._
-import scala.collection.mutable
import scala.collection.mutable.{Builder, WrappedArray}
import scala.reflect.ClassTag
-import scala.util.Try
+import scala.util.{Properties, Try}
import org.apache.commons.lang3.reflect.MethodUtils
@@ -31,6 +30,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -886,26 +886,10 @@ case class MapObjects private(
_.asInstanceOf[ArrayData].toSeq[Any](et)
}
- private def elementClassTag(): ClassTag[Any] = {
- val clazz = lambdaFunction.dataType match {
- case ObjectType(cls) => cls
- case dt if lambdaFunction.nullable => ScalaReflection.javaBoxedType(dt)
- case dt => ScalaReflection.dataTypeJavaClass(dt)
- }
- ClassTag(clazz).asInstanceOf[ClassTag[Any]]
- }
-
private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
- // The implicit tag is a workaround to deal with a small change in the
- // (scala) signature of ArrayBuilder.make between Scala 2.12 and 2.13.
- implicit val tag: ClassTag[Any] = elementClassTag()
- input => {
- val builder = mutable.ArrayBuilder.make[Any]
- builder.sizeHint(input.size)
- executeFuncOnCollection(input).foreach(builder += _)
- mutable.WrappedArray.make(builder.result())
- }
+ // Scala WrappedArray
+ inputCollection =>
WrappedArray.make(executeFuncOnCollection(inputCollection).toArray)
case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
executeFuncOnCollection(_).toSeq
@@ -1063,20 +1047,44 @@ case class MapObjects private(
val (initCollection, addElement, getResult): (String, String => String,
String) =
customCollectionCls match {
case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
- val tag = ctx.addReferenceObj("tag", elementClassTag())
- val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
- val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
- val builder = ctx.freshName("collectionBuilder")
- (
- s"""
+ def doCodeGenForScala212 = {
+ // WrappedArray in Scala 2.12
+ val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
+ val builder = ctx.freshName("collectionBuilder")
+ (
+ s"""
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
- (genValue: String) => s"$builder.$$plus$$eq($genValue);",
- s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
- s"MODULE$$.make($builder.result());"
- )
+ (genValue: String) => s"$builder.$$plus$$eq($genValue);",
+ s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
+ s"MODULE$$.make(((${classOf[IndexedSeq[_]].getName})$builder" +
+
s".result()).toArray(scala.reflect.ClassTag$$.MODULE$$.Object()));"
+ )
+ }
+
+ def doCodeGenForScala213 = {
+ // In Scala 2.13, WrappedArray is mutable.ArraySeq and newBuilder
method need
+ // a ClassTag type construction parameter
+ val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder(" +
+ s"scala.reflect.ClassTag$$.MODULE$$.Object())"
+ val builder = ctx.freshName("collectionBuilder")
+ (
+ s"""
+ ${classOf[Builder[_, _]].getName} $builder = $getBuilder;
+ $builder.sizeHint($dataLength);
+ """,
+ (genValue: String) => s"$builder.$$plus$$eq($genValue);",
+ s"(${cls.getName})$builder.result();"
+ )
+ }
+ val scalaVersion = Properties.versionNumberString
+ if (scalaVersion.startsWith("2.12")) {
+ doCodeGenForScala212
+ } else {
+ doCodeGenForScala213
+ }
case Some(cls) if
classOf[scala.collection.Seq[_]].isAssignableFrom(cls) ||
classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala sequence or set
@@ -1900,14 +1908,14 @@ case class GetExternalRowField(
* Validates the actual data type of input expression at runtime. If it
doesn't match the
* expectation, throw an exception.
*/
-case class ValidateExternalType(child: Expression, expected: DataType,
externalDataType: DataType)
+case class ValidateExternalType(child: Expression, expected: DataType,
lenient: Boolean)
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(ObjectType(classOf[Object]))
override def nullable: Boolean = child.nullable
- override val dataType: DataType = externalDataType
+ override val dataType: DataType =
RowEncoder.externalDataTypeForInput(expected, lenient)
private lazy val errMsg = s" is not a valid external type for schema of
${expected.simpleString}"
@@ -1919,7 +1927,7 @@ case class ValidateExternalType(child: Expression,
expected: DataType, externalD
}
case _: ArrayType =>
(value: Any) => {
- value.getClass.isArray || value.isInstanceOf[Seq[_]] ||
value.isInstanceOf[Set[_]]
+ value.getClass.isArray || value.isInstanceOf[Seq[_]]
}
case _: DateType =>
(value: Any) => {
@@ -1960,7 +1968,7 @@ case class ValidateExternalType(child: Expression,
expected: DataType, externalD
classOf[scala.math.BigDecimal],
classOf[Decimal]))
case _: ArrayType =>
- s"$obj.getClass().isArray() || ${genCheckTypes(Seq(classOf[Seq[_]],
classOf[Set[_]]))}"
+ s"$obj.getClass().isArray() || $obj instanceof
${classOf[scala.collection.Seq[_]].getName}"
case _: DateType =>
genCheckTypes(Seq(classOf[java.sql.Date],
classOf[java.time.LocalDate]))
case _: TimestampType =>
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index f8ebdfe7676..7e7ce29972b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FooEnum.FooEnum
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct,
Expression, If, SpecificInternalRow, UpCast}
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
MapObjects, NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -388,10 +388,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
}
test("SPARK-15062: Get correct serializer for List[_]") {
+ val list = List(1, 2, 3)
val serializer = serializerFor[List[Int]]
- assert(serializer.isInstanceOf[MapObjects])
- val mapObjects = serializer.asInstanceOf[MapObjects]
- assert(mapObjects.customCollectionCls.isEmpty)
+ assert(serializer.isInstanceOf[NewInstance])
+ assert(serializer.asInstanceOf[NewInstance]
+
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}
test("SPARK 16792: Get correct deserializer for List[_]") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index c6546105231..3a0db1ca121 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -480,8 +480,6 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(ScroogeLikeExample(1),
"SPARK-40385 class with only a companion object constructor")
- encodeDecodeTest(Array(Set(1, 2), Set(2, 3)), "array of sets")
-
productTest(("UDT", new ExamplePoint(0.1, 0.2)))
test("AnyVal class with Any fields") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 2d2e9d1b2bf..c6bddfa5eee 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.encoders
-import scala.collection.mutable
import scala.util.Random
import org.apache.spark.sql.{RandomDataGenerator, Row}
@@ -311,19 +310,6 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(e4.getMessage.contains("java.lang.String is not a valid external
type"))
}
- private def roundTripArray[T](dt: DataType, nullable: Boolean, data:
Array[T]): Unit = {
- val schema = new StructType().add("a", ArrayType(dt, nullable))
- test(s"RowEncoder should return WrappedArray with properly typed array for
$schema") {
- val encoder = RowEncoder(schema).resolveAndBind()
- val result = fromRow(encoder, toRow(encoder,
Row(data))).getAs[mutable.WrappedArray[_]](0)
- assert(result.array.getClass === data.getClass)
- assert(result === data)
- }
- }
-
- roundTripArray(IntegerType, nullable = false, Array(1, 2, 3).map(Int.box))
- roundTripArray(StringType, nullable = true, Array("hello", "world", "!",
null))
-
test("SPARK-25791: Datatype of serializers should be accessible") {
val udtSQLType = new StructType().add("a", IntegerType)
val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT",
"serializedPyClass")
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 265b0eeb8bd..737fcb1bada 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"),
IntegerType,
- IntegerType) :: Nil)
+ lenient = false) :: Nil)
}
test("SPARK-17160: field names are properly escaped by AssertTrue") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 05ab7a65a32..2286b734477 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -496,11 +496,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
- val enc = RowEncoder.encoderForDataType(dt, lenient = false)
val validateType = ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
dt,
- ScalaReflection.lenientExternalDataTypeFor(enc))
+ lenient = false)
checkObjectExprEvaluation(validateType, input,
InternalRow.fromSeq(Seq(Row(input))))
}
@@ -508,7 +507,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
DoubleType,
- DoubleType),
+ lenient = false),
InternalRow.fromSeq(Seq(Row(1))),
"java.lang.Integer is not a valid external type for schema of double")
}
@@ -560,10 +559,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
ExternalMapToCatalyst(
inputObject,
- ScalaReflection.externalDataTypeFor(keyEnc),
+ ScalaReflection.dataTypeFor(keyEnc),
kvSerializerFor(keyEnc),
keyNullable = keyEnc.nullable,
- ScalaReflection.externalDataTypeFor(valueEnc),
+ ScalaReflection.dataTypeFor(valueEnc),
kvSerializerFor(valueEnc),
valueNullable = valueEnc.nullable
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]