Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/22749#discussion_r226298803 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala --- @@ -212,21 +183,88 @@ object ExpressionEncoder { * A generic encoder for JVM objects that uses Catalyst Expressions for a `serializer` * and a `deserializer`. * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param serializer A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param deserializer An expression that will construct an object given an [[InternalRow]]. + * @param objSerializer An expression that can be used to encode a raw object to corresponding + * Spark SQL representation that can be a primitive column, array, map or a + * struct. This represents how Spark SQL generally serializes an object of + * type `T`. + * @param objDeserializer An expression that will construct an object given a Spark SQL + * representation. This represents how Spark SQL generally deserializes + * a serialized value in Spark SQL representation back to an object of + * type `T`. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( - schema: StructType, - flat: Boolean, - serializer: Seq[Expression], - deserializer: Expression, + objSerializer: Expression, + objDeserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(serializer.size == 1) + /** + * A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]: + * 1. If `serializer` encodes a raw object to a struct, we directly use the `serializer`. + * 2. For other cases, we create a struct to wrap the `serializer`. + */ + val serializer: Seq[NamedExpression] = { + val serializedAsStruct = objSerializer.dataType.isInstanceOf[StructType] + val clsName = Utils.getSimpleName(clsTag.runtimeClass) + + if (serializedAsStruct) { + val nullSafeSerializer = objSerializer.transformUp { + case r: BoundReference => + // For input object of Product type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(r, Seq("top level Product or row object")) + } + nullSafeSerializer match { + case If(_, _, s: CreateNamedStruct) => s --- End diff -- let's also make sure the if condition is `IsNull`, which better explains why we strip it(it can't be null)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org