Github user dbtsai commented on a diff in the pull request:
https://github.com/apache/spark/pull/21847#discussion_r206358703
--- Diff:
external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala ---
@@ -165,16 +182,118 @@ class AvroSerializer(rootCatalystType: DataType,
rootAvroType: Schema, nullable:
result
}
- private def resolveNullableType(avroType: Schema, nullable: Boolean):
Schema = {
- if (nullable) {
+ // Resolve an Avro union against a supplied DataType, i.e. a LongType
compared against
+ // a ["null", "long"] should return a schema of type Schema.Type.LONG
+ // This function also handles resolving a DataType against unions of 2
or more types, i.e.
+ // an IntType resolves against a ["int", "long", "null"] will correctly
return a schema of
+ // type Schema.Type.LONG
+ private def resolveUnionType(avroType: Schema, catalystType: DataType,
+ nullable: Boolean): Schema = {
+ if (avroType.getType == Type.UNION) {
// avro uses union to represent nullable type.
- val fields = avroType.getTypes.asScala
- assert(fields.length == 2)
- val actualType = fields.filter(_.getType != NULL)
- assert(actualType.length == 1)
+ val fieldTypes = avroType.getTypes.asScala
+
+ // If we're nullable, we need to have at least two types. Cases
with more than two types
+ // are captured in test("read read-write, read-write w/ schema,
read") w/ test.avro input
+ if (nullable && fieldTypes.length < 2) {
+ throw new IncompatibleSchemaException(
+ s"Cannot resolve nullable ${catalystType} against union type
${avroType}")
+ }
+
+ val actualType = catalystType match {
+ case NullType => fieldTypes.filter(_.getType == Type.NULL)
+ case BooleanType => fieldTypes.filter(_.getType == Type.BOOLEAN)
+ case ByteType => fieldTypes.filter(_.getType == Type.INT)
+ case BinaryType =>
+ val at = fieldTypes.filter(x => x.getType == Type.BYTES ||
x.getType == Type.FIXED)
+ if (at.length > 1) {
+ throw new IncompatibleSchemaException(
+ s"Cannot resolve schema of ${catalystType} against union
${avroType.toString}")
+ } else {
+ at
+ }
+ case ShortType | IntegerType => fieldTypes.filter(_.getType ==
Type.INT)
+ case LongType => fieldTypes.filter(_.getType == Type.LONG)
+ case FloatType => fieldTypes.filter(_.getType == Type.FLOAT)
+ case DoubleType => fieldTypes.filter(_.getType == Type.DOUBLE)
+ case d: DecimalType => fieldTypes.filter(_.getType == Type.STRING)
+ case StringType => fieldTypes
+ .filter(x => x.getType == Type.STRING || x.getType == Type.ENUM)
+ case DateType => fieldTypes.filter(x => x.getType == Type.INT ||
x.getType == Type.LONG)
+ case TimestampType => fieldTypes.filter(_.getType == Type.LONG)
+ case ArrayType(et, containsNull) =>
+ // Find array that matches the element type specified
+ fieldTypes.filter(x => x.getType == Type.ARRAY
+ && typeMatchesSchema(et, x.getElementType))
+ case st: StructType => // Find the matching record!
+ val recordTypes = fieldTypes.filter(x => x.getType ==
Type.RECORD)
+ if (recordTypes.length > 1) {
+ throw new IncompatibleSchemaException(
+ "Unions of multiple record types are NOT supported with
user-specified schema")
+ }
+ recordTypes
+ case MapType(kt, vt, valueContainsNull) =>
+ // Find the map that matches the value type. Maps in Avro are
always key type string
+ fieldTypes.filter(x => x.getType == Type.MAP &&
typeMatchesSchema(vt, x.getValueType))
+ case other =>
+ throw new IncompatibleSchemaException(s"Unexpected type: $other")
+ }
+
+ if (actualType.length != 1) {
+ throw new IncompatibleSchemaException(
+ s"Failed to resolve ${catalystType} against ambiguous schema
${avroType}")
+ }
actualType.head
} else {
avroType
}
}
+
+ // Given a Schema and a DataType, do they match?
+ private def typeMatchesSchema(catalystType: DataType, avroSchema:
Schema): Boolean = {
+ if (catalystType.isInstanceOf[StructType]) {
+ val avroFields = resolveUnionType(avroSchema, catalystType,
+ avroSchema.getType == Type.UNION)
+ .getFields
+ if (avroFields.size() ==
catalystType.asInstanceOf[StructType].length) {
+
catalystType.asInstanceOf[StructType].zip(avroFields.asScala).forall {
+ case (f1, f2) => typeMatchesSchema(f1.dataType, f2.schema)
+ }
+ } else {
+ false
+ }
+ } else {
+ val isTypeCompatible = (a: Schema, b: DataType, c: Type) =>
+ resolveUnionType(a, b, a.getType == Type.UNION).getType == c
+
+ catalystType match {
+ case ByteType | ShortType | IntegerType =>
+ isTypeCompatible(avroSchema, catalystType, Type.INT)
+ case BooleanType => isTypeCompatible(avroSchema, catalystType,
Type.BOOLEAN)
+ case BinaryType => isTypeCompatible(avroSchema, catalystType,
Type.BYTES)
+ case LongType | TimestampType => isTypeCompatible(avroSchema,
catalystType, Type.LONG)
+ case FloatType => isTypeCompatible(avroSchema, catalystType,
Type.FLOAT)
+ case DoubleType => isTypeCompatible(avroSchema, catalystType,
Type.DOUBLE)
+ case d: DecimalType =>
+ // newConverter always returns a string representation for
DecimalType, so we honor
+ // that here, since we don't yet support Avro's logical types
+ isTypeCompatible(avroSchema, catalystType, Type.STRING)
+ case StringType => isTypeCompatible(avroSchema, catalystType,
Type.STRING) ||
+ isTypeCompatible(avroSchema, catalystType, Type.ENUM)
+ case DateType => isTypeCompatible(avroSchema, catalystType,
Type.INT) ||
+ isTypeCompatible(avroSchema, catalystType, Type.LONG)
+ case ArrayType(et, containsNull) =>
+ isTypeCompatible(avroSchema, catalystType, Type.ARRAY) &&
+ typeMatchesSchema(et,
+ resolveUnionType(avroSchema, catalystType,
avroSchema.getType == Type.UNION)
+ .getElementType)
+ case MapType(kt, vt, valueContainsNull) =>
+ isTypeCompatible(avroSchema, catalystType, Type.MAP) &&
+ typeMatchesSchema(vt,
+ resolveUnionType(avroSchema, catalystType,
avroSchema.getType == Type.UNION)
+ .getValueType)
+ }
+ }
+
--- End diff --
remove the extra line
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]