Github user MaxGekk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21847#discussion_r205561746
  
    --- Diff: 
external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala ---
    @@ -165,16 +183,100 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
           result
       }
     
    -  private def resolveNullableType(avroType: Schema, nullable: Boolean): 
Schema = {
    -    if (nullable) {
    +  private def resolveNullableType(avroType: Schema, catalystType: DataType,
    +                                  nullable: Boolean): Schema = {
    +    if (nullable && avroType.getType == Type.UNION) {
           // avro uses union to represent nullable type.
    -      val fields = avroType.getTypes.asScala
    -      assert(fields.length == 2)
    -      val actualType = fields.filter(_.getType != NULL)
    +      val fieldTypes = avroType.getTypes.asScala
    +
    +      // If we're nullable, we need to have at least two types.  Cases 
with more than two types are
    +      // captured in test("read read-write, read-write w/ schema, read") 
w/ test.avro input
    +      assert(fieldTypes.length >= 2)
    +
    +      val actualType = catalystType match {
    +        case NullType => fieldTypes.filter(_.getType == Type.NULL)
    +        case BooleanType => fieldTypes.filter(_.getType == Type.BOOLEAN)
    +        case ByteType => fieldTypes.filter(_.getType == Type.INT)
    +        case BinaryType =>
    +          val at = fieldTypes.filter(x => x.getType == Type.BYTES || 
x.getType == Type.FIXED)
    +          if (at.length > 1) {
    +            throw new IncompatibleSchemaException (
    +              s"Cannot resolve schema of ${catalystType} against union 
${avroType.toString}")
    +          } else {
    +            at
    +          }
    +        case ShortType | IntegerType => fieldTypes.filter(_.getType == 
Type.INT)
    +        case LongType => fieldTypes.filter(_.getType == Type.LONG)
    +        case FloatType => fieldTypes.filter(_.getType == Type.FLOAT)
    +        case DoubleType => fieldTypes.filter(_.getType == Type.DOUBLE)
    +        case d: DecimalType => fieldTypes.filter(_.getType == Type.STRING)
    +        case StringType => fieldTypes
    +          .filter(x => x.getType == Type.STRING || x.getType == Type.ENUM)
    +        case DateType => fieldTypes.filter(x => x.getType == Type.INT || 
x.getType == Type.LONG)
    +        case TimestampType => fieldTypes.filter(_.getType == Type.LONG)
    +        case ArrayType(et, containsNull) =>
    +          // Find array that matches the type
    +          fieldTypes.filter(x => x.getType == Type.ARRAY && 
typeMatchesSchema(et, x.getElementType))
    +        case st: StructType => // Find the matching record!
    +          val recordTypes = fieldTypes.filter(x => x.getType == 
Type.RECORD)
    +          if (recordTypes.length > 1) {
    +            throw new IncompatibleSchemaException(
    +              "Unions of multiple record types are NOT supported with 
user-specified schema")
    +          }
    +          recordTypes
    +        case MapType(kt, vt, valueContainsNull) =>
    +          // Find the map that matches the type
    +          fieldTypes.filter(x => x.getType == Type.MAP && 
typeMatchesSchema(vt, x.getValueType))
    +        case other =>
    +          throw new IncompatibleSchemaException(s"Unexpected type: $other")
    +      }
    +
           assert(actualType.length == 1)
           actualType.head
         } else {
           avroType
         }
       }
    +
    +  // Given a Schema and a DataType, do they match?
    +  private def typeMatchesSchema(catalystType: DataType, avroSchema: 
Schema): Boolean = {
    +    if (catalystType.isInstanceOf[StructType]) {
    +      val avroFields = resolveNullableType(avroSchema, catalystType,
    +        avroSchema.getType == Type.UNION)
    +        .getFields
    +      assert(avroFields.size() == 
catalystType.asInstanceOf[StructType].length)
    +      catalystType.asInstanceOf[StructType].zip(avroFields.asScala).map {
    +        case (f1, f2) => typeMatchesSchema(f1.dataType, f2.schema)
    +      }.foldLeft(true)(_ && _)
    +    } else {
    +      val isTypeCompatible = (a: Schema, b: DataType, c: Type) =>
    +        resolveNullableType(a, b, a.getType == Type.UNION).getType == c
    +
    +      catalystType match {
    +        case ByteType | ShortType | IntegerType =>
    +          isTypeCompatible(avroSchema, catalystType, Type.INT)
    +        case BooleanType => isTypeCompatible(avroSchema, catalystType, 
Type.BOOLEAN)
    +        case BinaryType => isTypeCompatible(avroSchema, catalystType, 
Type.BYTES)
    +        case LongType | TimestampType => isTypeCompatible(avroSchema, 
catalystType, Type.LONG)
    +        case FloatType => isTypeCompatible(avroSchema, catalystType, 
Type.FLOAT)
    +        case DoubleType => isTypeCompatible(avroSchema, catalystType, 
Type.DOUBLE)
    +        case d: DecimalType => isTypeCompatible(avroSchema, catalystType, 
Type.STRING)
    --- End diff --
    
    Could you add a comment why for `DecimalType` you pass `Type.STRING`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to