Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/22749#discussion_r226507104
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
---
@@ -212,21 +183,88 @@ object ExpressionEncoder {
* A generic encoder for JVM objects that uses Catalyst Expressions for a
`serializer`
* and a `deserializer`.
*
- * @param schema The schema after converting `T` to a Spark SQL row.
- * @param serializer A set of expressions, one for each top-level field
that can be used to
- * extract the values from a raw object into an
[[InternalRow]].
- * @param deserializer An expression that will construct an object given
an [[InternalRow]].
+ * @param objSerializer An expression that can be used to encode a raw
object to corresponding
+ * Spark SQL representation that can be a primitive
column, array, map or a
+ * struct. This represents how Spark SQL generally
serializes an object of
+ * type `T`.
+ * @param objDeserializer An expression that will construct an object
given a Spark SQL
+ * representation. This represents how Spark SQL
generally deserializes
+ * a serialized value in Spark SQL representation
back to an object of
+ * type `T`.
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
- schema: StructType,
- flat: Boolean,
- serializer: Seq[Expression],
- deserializer: Expression,
+ objSerializer: Expression,
+ objDeserializer: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(serializer.size == 1)
+ /**
+ * A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]]:
+ * 1. If `serializer` encodes a raw object to a struct, we directly use
the `serializer`.
+ * 2. For other cases, we create a struct to wrap the `serializer`.
+ */
+ val serializer: Seq[NamedExpression] = {
+ val serializedAsStruct =
objSerializer.dataType.isInstanceOf[StructType]
+ val clsName = Utils.getSimpleName(clsTag.runtimeClass)
+
+ if (serializedAsStruct) {
+ val nullSafeSerializer = objSerializer.transformUp {
+ case r: BoundReference =>
+ // For input object of Product type, we can't encode it to row
if it's null, as Spark SQL
+ // doesn't allow top-level row to be null, only its columns can
be null.
+ AssertNotNull(r, Seq("top level Product or row object"))
+ }
+ nullSafeSerializer match {
+ case If(_, _, s: CreateNamedStruct) => s
--- End diff --
ok.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]