cloud-fan commented on code in PR #39186:
URL: https://github.com/apache/spark/pull/39186#discussion_r1056264625
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala:
##########
@@ -186,237 +167,129 @@ object ScalaReflection extends ScalaReflection {
* @param walkedTypePath The paths from top to bottom to access current
field when deserializing.
*/
private def deserializerFor(
- tpe: `Type`,
- walkedTypePath: WalkedTypePath): Expression => Expression =
cleanUpReflectionObjects {
- baseType(tpe) match {
- case t if !dataTypeFor(t).isInstanceOf[ObjectType] => identity
-
- case t if isSubtype(t, localTypeOf[Option[_]]) =>
- val TypeRef(_, _, Seq(optType)) = t
- val className = getClassNameFromType(optType)
- val newTypePath = walkedTypePath.recordOption(className)
- val dataType = dataTypeFor(optType)
- val deserializerFunc = deserializerFor(optType, newTypePath)
- path => WrapOption(deserializerFunc(path), dataType)
-
- case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Integer])
-
- case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Long])
-
- case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Double])
-
- case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Float])
-
- case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Short])
-
- case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Byte])
-
- case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Boolean])
-
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
- createDeserializerForLocalDate
-
- case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
- createDeserializerForSqlDate
-
- case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
- createDeserializerForInstant
-
- case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val cls = getClassFromType(t)
- path => createDeserializerForTypesSupportValueOf(
- Invoke(path, "toString", ObjectType(classOf[String]), returnNullable
= false), cls)
-
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
- createDeserializerForSqlTimestamp
-
- case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
- createDeserializerForLocalDateTime
-
- case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
- createDeserializerForDuration
-
- case t if isSubtype(t, localTypeOf[java.time.Period]) =>
- createDeserializerForPeriod
-
- case t if isSubtype(t, localTypeOf[java.lang.String]) =>
- createDeserializerForString(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
- createDeserializerForJavaBigDecimal(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[BigDecimal]) =>
- createDeserializerForScalaBigDecimal(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
- createDeserializerForJavaBigInteger(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
- createDeserializerForScalaBigInt
-
- case t if isSubtype(t, localTypeOf[Array[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, elementNullable) = schemaFor(elementType)
- val className = getClassNameFromType(elementType)
- val newTypePath = walkedTypePath.recordArray(className)
- val deserializerFunc = deserializerFor(elementType, newTypePath)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- deserializerFunc)
- }
-
- val arrayCls = arrayClassFor(elementType)
-
- val methodName = elementType match {
- case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
- case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
- case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
- case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
- case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
- case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
- case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
- // non-primitive
- case _ => "array"
- }
- path => {
- val arrayData = UnresolvedMapObjects(mapFunction, path)
- Invoke(arrayData, methodName, arrayCls, returnNullable = false)
- }
-
- // We serialize a `Set` to Catalyst array. When we deserialize a
Catalyst array
- // to a `Set`, if there are duplicated elements, the elements will be
de-duplicated.
- case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) ||
- isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, elementNullable) = schemaFor(elementType)
- val className = getClassNameFromType(elementType)
- val newTypePath = walkedTypePath.recordArray(className)
- val deserializerFunc = deserializerFor(elementType, newTypePath)
- val mapFunction: Expression => Expression = element => {
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- deserializerFunc)
- }
-
- val companion = t.dealias.typeSymbol.companion.typeSignature
- val cls = companion.member(TermName("newBuilder")) match {
- case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]]
- case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]])
=>
- classOf[scala.collection.Set[_]]
- case _ => mirror.runtimeClass(t.typeSymbol.asClass)
- }
- UnresolvedMapObjects(mapFunction, _, Some(cls))
-
- case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val classNameForKey = getClassNameFromType(keyType)
- val classNameForValue = getClassNameFromType(valueType)
-
- val newTypePath = walkedTypePath.recordMap(classNameForKey,
classNameForValue)
-
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val keyDeserializerFunc = deserializerFor(keyType, newTypePath)
- val valueDeserializerFunc = deserializerFor(valueType, newTypePath)
- val cls = mirror.runtimeClass(t.typeSymbol.asClass)
- UnresolvedCatalystToExternalMap(_, keyDeserializerFunc,
valueDeserializerFunc, cls)
-
- case t if t.typeSymbol.annotations.exists(_.tree.tpe =:=
typeOf[SQLUserDefinedType]) =>
- val udt =
getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
- getConstructor().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType =
ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val cls = udt.userClass
- path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt =
UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val cls = udt.userClass
- path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
-
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
-
- val cls = getClassFromType(tpe)
-
- val arguDeserializerFuncs = params.zipWithIndex.map { case
((fieldName, fieldType), i) =>
- val Schema(dataType, nullable) = schemaFor(fieldType)
- val clsName = getClassNameFromType(fieldType)
- val newTypePath = walkedTypePath.recordField(clsName, fieldName)
-
- // For tuples, we based grab the inner fields by ordinal instead of
name.
- val newPathFunc = if (cls.getName startsWith "scala.Tuple") {
- addToPathOrdinal(_, i, dataType, newTypePath)
+ enc: AgnosticEncoder[_],
+ input: Expression,
+ typePath: WalkedTypePath): Expression = enc match {
+ case _ if isNativeEncoder(enc) =>
+ input
+ case BooleanEncoder =>
Review Comment:
nvm, we have `PrimitiveBooleanEncoder` and `BooleanEncoder`
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala:
##########
@@ -186,237 +167,129 @@ object ScalaReflection extends ScalaReflection {
* @param walkedTypePath The paths from top to bottom to access current
field when deserializing.
*/
private def deserializerFor(
- tpe: `Type`,
- walkedTypePath: WalkedTypePath): Expression => Expression =
cleanUpReflectionObjects {
- baseType(tpe) match {
- case t if !dataTypeFor(t).isInstanceOf[ObjectType] => identity
-
- case t if isSubtype(t, localTypeOf[Option[_]]) =>
- val TypeRef(_, _, Seq(optType)) = t
- val className = getClassNameFromType(optType)
- val newTypePath = walkedTypePath.recordOption(className)
- val dataType = dataTypeFor(optType)
- val deserializerFunc = deserializerFor(optType, newTypePath)
- path => WrapOption(deserializerFunc(path), dataType)
-
- case t if isSubtype(t, localTypeOf[java.lang.Integer]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Integer])
-
- case t if isSubtype(t, localTypeOf[java.lang.Long]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Long])
-
- case t if isSubtype(t, localTypeOf[java.lang.Double]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Double])
-
- case t if isSubtype(t, localTypeOf[java.lang.Float]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Float])
-
- case t if isSubtype(t, localTypeOf[java.lang.Short]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Short])
-
- case t if isSubtype(t, localTypeOf[java.lang.Byte]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Byte])
-
- case t if isSubtype(t, localTypeOf[java.lang.Boolean]) =>
- createDeserializerForTypesSupportValueOf(_, classOf[java.lang.Boolean])
-
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
- createDeserializerForLocalDate
-
- case t if isSubtype(t, localTypeOf[java.sql.Date]) =>
- createDeserializerForSqlDate
-
- case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
- createDeserializerForInstant
-
- case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val cls = getClassFromType(t)
- path => createDeserializerForTypesSupportValueOf(
- Invoke(path, "toString", ObjectType(classOf[String]), returnNullable
= false), cls)
-
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
- createDeserializerForSqlTimestamp
-
- case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
- createDeserializerForLocalDateTime
-
- case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
- createDeserializerForDuration
-
- case t if isSubtype(t, localTypeOf[java.time.Period]) =>
- createDeserializerForPeriod
-
- case t if isSubtype(t, localTypeOf[java.lang.String]) =>
- createDeserializerForString(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
- createDeserializerForJavaBigDecimal(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[BigDecimal]) =>
- createDeserializerForScalaBigDecimal(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
- createDeserializerForJavaBigInteger(_, returnNullable = false)
-
- case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
- createDeserializerForScalaBigInt
-
- case t if isSubtype(t, localTypeOf[Array[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, elementNullable) = schemaFor(elementType)
- val className = getClassNameFromType(elementType)
- val newTypePath = walkedTypePath.recordArray(className)
- val deserializerFunc = deserializerFor(elementType, newTypePath)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- deserializerFunc)
- }
-
- val arrayCls = arrayClassFor(elementType)
-
- val methodName = elementType match {
- case t if isSubtype(t, definitions.IntTpe) => "toIntArray"
- case t if isSubtype(t, definitions.LongTpe) => "toLongArray"
- case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray"
- case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray"
- case t if isSubtype(t, definitions.ShortTpe) => "toShortArray"
- case t if isSubtype(t, definitions.ByteTpe) => "toByteArray"
- case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray"
- // non-primitive
- case _ => "array"
- }
- path => {
- val arrayData = UnresolvedMapObjects(mapFunction, path)
- Invoke(arrayData, methodName, arrayCls, returnNullable = false)
- }
-
- // We serialize a `Set` to Catalyst array. When we deserialize a
Catalyst array
- // to a `Set`, if there are duplicated elements, the elements will be
de-duplicated.
- case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) ||
- isSubtype(t, localTypeOf[scala.collection.Set[_]]) =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, elementNullable) = schemaFor(elementType)
- val className = getClassNameFromType(elementType)
- val newTypePath = walkedTypePath.recordArray(className)
- val deserializerFunc = deserializerFor(elementType, newTypePath)
- val mapFunction: Expression => Expression = element => {
- deserializerForWithNullSafetyAndUpcast(
- element,
- dataType,
- nullable = elementNullable,
- newTypePath,
- deserializerFunc)
- }
-
- val companion = t.dealias.typeSymbol.companion.typeSignature
- val cls = companion.member(TermName("newBuilder")) match {
- case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]]
- case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]])
=>
- classOf[scala.collection.Set[_]]
- case _ => mirror.runtimeClass(t.typeSymbol.asClass)
- }
- UnresolvedMapObjects(mapFunction, _, Some(cls))
-
- case t if isSubtype(t, localTypeOf[Map[_, _]]) =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val classNameForKey = getClassNameFromType(keyType)
- val classNameForValue = getClassNameFromType(valueType)
-
- val newTypePath = walkedTypePath.recordMap(classNameForKey,
classNameForValue)
-
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val keyDeserializerFunc = deserializerFor(keyType, newTypePath)
- val valueDeserializerFunc = deserializerFor(valueType, newTypePath)
- val cls = mirror.runtimeClass(t.typeSymbol.asClass)
- UnresolvedCatalystToExternalMap(_, keyDeserializerFunc,
valueDeserializerFunc, cls)
-
- case t if t.typeSymbol.annotations.exists(_.tree.tpe =:=
typeOf[SQLUserDefinedType]) =>
- val udt =
getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().
- getConstructor().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType =
ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val cls = udt.userClass
- path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt =
UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor().
- newInstance().asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- // Code touching Scala Reflection should be called outside the
returned function to allow
- // caching the Scala Reflection result
- val cls = udt.userClass
- path => Invoke(obj, "deserialize", ObjectType(cls), Seq(path))
-
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
-
- val cls = getClassFromType(tpe)
-
- val arguDeserializerFuncs = params.zipWithIndex.map { case
((fieldName, fieldType), i) =>
- val Schema(dataType, nullable) = schemaFor(fieldType)
- val clsName = getClassNameFromType(fieldType)
- val newTypePath = walkedTypePath.recordField(clsName, fieldName)
-
- // For tuples, we based grab the inner fields by ordinal instead of
name.
- val newPathFunc = if (cls.getName startsWith "scala.Tuple") {
- addToPathOrdinal(_, i, dataType, newTypePath)
+ enc: AgnosticEncoder[_],
+ input: Expression,
+ typePath: WalkedTypePath): Expression = enc match {
+ case _ if isNativeEncoder(enc) =>
+ input
+ case BooleanEncoder =>
Review Comment:
maybe `BoxedBooleanEncoder` is a better name
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]