hvanhovell commented on code in PR #39517: URL: https://github.com/apache/spark/pull/39517#discussion_r1067435036
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala: ########## @@ -377,27 +408,96 @@ object ScalaReflection extends ScalaReflection { val getter = Invoke( KnownNotNull(input), field.name, - dataTypeFor(field.enc), - returnNullable = field.enc.nullable) + externalDataTypeFor(field.enc), + returnNullable = field.nullable) field.name -> serializerFor(field.enc, getter) } createSerializerForObject(input, serializedFields) + + case RowEncoder(fields) => + val serializedFields = fields.zipWithIndex.map { case (field, index) => + val fieldValue = serializerFor( + field.enc, + ValidateExternalType( + GetExternalRowField(input, index, field.name), + field.enc.dataType, + lenientExternalDataTypeFor(field.enc))) + + val convertedField = if (field.nullable) { + exprs.If( + Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) :: Nil), + // Because we strip UDTs, `field.dataType` can be different from `fieldValue.dataType`. + // We should use `fieldValue.dataType` here. + exprs.Literal.create(null, fieldValue.dataType), + fieldValue + ) + } else { + AssertNotNull(fieldValue) + } + field.name -> convertedField + } + createSerializerForObject(input, serializedFields) } private def serializerForArray( - isArray: Boolean, elementEnc: AgnosticEncoder[_], - input: Expression): Expression = { - dataTypeFor(elementEnc) match { - case dt: ObjectType => - createSerializerForMapObjects(input, dt, serializerFor(elementEnc, _)) - case dt if isArray && elementEnc.isPrimitive => - createSerializerForPrimitiveArray(input, dt) - case dt => - createSerializerForGenericArray(input, dt, elementEnc.nullable) + elementNullable: Boolean, + input: Expression, + lenientSerialization: Boolean): Expression = { + // Default serializer for Seq and generic Arrays. This does not work for primitive arrays. + val genericSerializer = createSerializerForMapObjects( + input, + ObjectType(classOf[AnyRef]), + validateAndSerializeElement(elementEnc, elementNullable)) + + // Check if it is possible the user can pass a primitive array. This is the only case when it + // is safe to directly convert to an array (for generic arrays and Seqs the type and the + // nullability can be violated). If the user has passed a primitive array we create a special + // code path to deal with these. + val primitiveEncoderOption = elementEnc match { + case _ if !lenientSerialization => None + case enc: PrimitiveLeafEncoder[_] => Option(enc) + case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive) + case _ => None + } + primitiveEncoderOption match { + case Some(primitiveEncoder) => + val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass + val check = Invoke( + targetObject = exprs.Literal.fromObject(primitiveArrayClass), + functionName = "isInstance", + BooleanType, + arguments = input :: Nil, + propagateNull = false, + returnNullable = false) + exprs.If( + check, Review Comment: We can widen this to arrays where the element is allowed to be null. In that case we do need to make sure the element type is sound. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org