Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/22749#discussion_r226301402
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
---
@@ -212,21 +183,88 @@ object ExpressionEncoder {
* A generic encoder for JVM objects that uses Catalyst Expressions for a
`serializer`
* and a `deserializer`.
*
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param serializer A set of expressions, one for each top-level field
that can be used to
- * extract the values from a raw object into an
[[InternalRow]].
- * @param deserializer An expression that will construct an object given
an [[InternalRow]].
+ * @param objSerializer An expression that can be used to encode a raw
object to corresponding
+ * Spark SQL representation that can be a primitive
column, array, map or a
+ * struct. This represents how Spark SQL generally
serializes an object of
+ * type `T`.
+ * @param objDeserializer An expression that will construct an object
given a Spark SQL
+ * representation. This represents how Spark SQL
generally deserializes
+ * a serialized value in Spark SQL representation
back to an object of
+ * type `T`.
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
- schema: StructType,
- flat: Boolean,
- serializer: Seq[Expression],
- deserializer: Expression,
+ objSerializer: Expression,
+ objDeserializer: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(serializer.size == 1)
+ /**
+ * A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]]:
+ * 1. If `serializer` encodes a raw object to a struct, we directly use
the `serializer`.
+ * 2. For other cases, we create a struct to wrap the `serializer`.
+ */
+ val serializer: Seq[NamedExpression] = {
+ val serializedAsStruct =
objSerializer.dataType.isInstanceOf[StructType]
+ val clsName = Utils.getSimpleName(clsTag.runtimeClass)
+
+ if (serializedAsStruct) {
+ val nullSafeSerializer = objSerializer.transformUp {
+ case r: BoundReference =>
+ // For input object of Product type, we can't encode it to row
if it's null, as Spark SQL
+ // doesn't allow top-level row to be null, only its columns can
be null.
+ AssertNotNull(r, Seq("top level Product or row object"))
+ }
+ nullSafeSerializer match {
+ case If(_, _, s: CreateNamedStruct) => s
+ case s: CreateNamedStruct => s
+ case _ =>
+ throw new RuntimeException(s"class $clsName has unexpected
serializer: $objSerializer")
+ }
+ } else {
+ // For other input objects like primitive, array, map, etc., we
construct a struct to wrap
+ // the serializer which is a column of an row.
+ CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
+ }
+ }.flatten
+
+ /**
+ * Returns an expression that can be used to deserialize an input row to
an object of type `T`
+ * with a compatible schema. Fields of the row will be extracted using
`UnresolvedAttribute`.
+ * of the same name as the constructor arguments.
+ *
+ * For complex objects that are encoded to structs, Fields of the struct
will be extracted using
+ * `GetColumnByOrdinal` with corresponding ordinal.
+ */
+ val deserializer: Expression = {
+ val serializedAsStruct =
objSerializer.dataType.isInstanceOf[StructType]
+
+ if (serializedAsStruct) {
+ // We serialized this kind of objects to root-level row. The input
of general deserializer
+ // is a `GetColumnByOrdinal(0)` expression to extract first column
of a row. We need to
+ // transform attributes accessors.
+ objDeserializer.transform {
+ case UnresolvedExtractValue(GetColumnByOrdinal(0, _),
+ Literal(part: UTF8String, StringType)) =>
+ UnresolvedAttribute.quoted(part.toString)
+ case GetStructField(GetColumnByOrdinal(0, dt), ordinal, _) =>
+ GetColumnByOrdinal(ordinal, dt)
+ case If(IsNull(GetColumnByOrdinal(0, _)), _, n: NewInstance) => n
+ case If(IsNull(GetColumnByOrdinal(0, _)), _, i:
InitializeJavaBean) => i
+ }
+ } else {
+ // For other input objects like primitive, array, map, etc., we
deserialize the first column
+ // of a row to the object.
+ objDeserializer
+ }
+ }
+
+ // The schema after converting `T` to a Spark SQL row. This schema is
dependent on the given
+ // serialier.
+ val schema: StructType = StructType(serializer.map { s =>
+ StructField(s.name, s.dataType, s.nullable)
--- End diff --
nvm, serializer don't need analysis
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]