xkrogen commented on code in PR #36506:
URL: https://github.com/apache/spark/pull/36506#discussion_r948281127
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -313,13 +364,12 @@ private[sql] class AvroSerializer(
private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
if (avroType.getType == Type.UNION) {
val fields = avroType.getTypes.asScala
- val actualType = fields.filter(_.getType != Type.NULL)
- if (fields.length != 2 || actualType.length != 1) {
- throw new UnsupportedAvroTypeException(
- s"Unsupported Avro UNION type $avroType: Only UNION of a null type
and a non-null " +
- "type is supported")
+ val nonNullTypes = fields.filter(_.getType != Type.NULL)
+ if (nonNullTypes.length == 1) {
+ (true, nonNullTypes.head)
+ } else {
+ (false, avroType)
}
Review Comment:
We could unpack this using a match statement to be a little more
Scala-idiomatic:
```suggestion
avroType.getTypes.asScala.filter(_.getType != Type.NULL).toSeq match {
case Seq() => throw new UnsupportedAvroTypeException(
"""seems this situation wasn't handled, but `type: ["null"]` is a
valid avro schema""")
case Seq(singleType) => (true, singleType)
case _ => (false, avroType)
}
```
Either way works, though I will note that the current code doesn't properly
handle the situation where there is a union with a single type that is `NULL`.
Either style we use, we should make sure we throw an error for that case.
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -287,14 +298,54 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0,
member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when
converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = avroType.getTypes.asScala.filter(_.getType !=
NULL).toSeq
+ validateComplexUnionMembers(catalystStruct, nonNullTypes, catalystPath,
avroPath)
+
+ val fieldConverters = nonNullTypes.zipWithIndex.map { case (avroField, i)
=>
+ val cf = catalystStruct.fields(i)
+ newConverter(cf.dataType, resolveNullableType(avroField, nullable =
true),
+ catalystPath :+ cf.name, avroPath :+ cf.name)
Review Comment:
Avro union branches are identified by their type, so it seems like
technically we should use `avroPath :+ avroField.getFullName` here instead of
`avroPath :+ cf.name`. The full name could have periods in it so we should
probably wrap it like `avroPath :+ s"[${avroField.getFullName}]"`
##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala:
##########
@@ -327,11 +329,19 @@ abstract class AvroSuite
dataFileWriter.flush()
dataFileWriter.close()
- val df = spark.sqlContext.read.format("avro").load(s"$dir.avro")
+ val df = spark.sqlContext.read.format("avro").load(nativeWriterPath)
assertResult(field1)(df.selectExpr("field1.member0").first().get(0))
assertResult(field2)(df.selectExpr("field2.member1").first().get(0))
assertResult(field3)(df.selectExpr("field3.member2").first().get(0))
assertResult(field4)(df.selectExpr("field4.member3").first().get(0))
+
+ df.write.format("avro").option("avroSchema",
schema.toString).save(sparkWriterPath)
+
+ val df2 = spark.sqlContext.read.format("avro").load(nativeWriterPath)
+ assertResult(field1)(df2.selectExpr("field1.member0").first().get(0))
+ assertResult(field2)(df2.selectExpr("field2.member1").first().get(0))
+ assertResult(field3)(df2.selectExpr("field3.member2").first().get(0))
+ assertResult(field4)(df2.selectExpr("field4.member3").first().get(0))
Review Comment:
This is good, but we only validate that Spark is able to re-read the written
output. So for example if we had implemented `AvroSerializer` to actually write
out a record type with fields `memberN`, instead of writing out a union type,
this test would still pass.
It would be better to _also_ validate the output using native Avro reader,
and check that we actually wrote a valid union type.
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -218,6 +218,17 @@ private[sql] class AvroSerializer(
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal,
numFields))
+ case (st: StructType, UNION) =>
+ val unionConvertor = newComplexUnionConverter(st, avroType,
catalystPath, avroPath)
+ val numFields = st.length
+ (getter, ordinal) => unionConvertor(getter.getStruct(ordinal,
numFields))
+
+ case (DoubleType, UNION) if nonNullUnionTypes(avroType) == Set(FLOAT,
DOUBLE) =>
+ (getter, ordinal) => getter.getDouble(ordinal)
+
+ case (LongType, UNION) if nonNullUnionTypes(avroType) == Set(INT, LONG)
=>
+ (getter, ordinal) => getter.getLong(ordinal)
Review Comment:
This does imply some loss of information if you were to do a round-trip from
Avro to SQL to Avro, since all records written out would have double/long
values even if the input was a float/int.
The representation chosen on the read path is inherently lossy, since we
discard the information about which union branch the datum originated from, so
I don't think there's anything we can do here to avoid this behavior. Just
wanted to point it out.
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -287,14 +298,54 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0,
member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when
converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = avroType.getTypes.asScala.filter(_.getType !=
NULL).toSeq
+ validateComplexUnionMembers(catalystStruct, nonNullTypes, catalystPath,
avroPath)
+
+ val fieldConverters = nonNullTypes.zipWithIndex.map { case (avroField, i)
=>
+ val cf = catalystStruct.fields(i)
+ newConverter(cf.dataType, resolveNullableType(avroField, nullable =
true),
+ catalystPath :+ cf.name, avroPath :+ cf.name)
+ }.toArray
+
+ val numFields = catalystStruct.length
+ row: InternalRow =>
+ (0 until numFields).dropWhile(row.isNullAt).headOption match {
+ case Some(i) => fieldConverters(i).apply(row, i)
+ case None => null
+ }
+ }
+
+ def validateComplexUnionMembers(
+ catalystStruct: StructType,
+ unionTypes: Seq[Schema],
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): Unit = {
+ val expectedFieldNames = unionTypes.indices.map(i => s"member$i")
+ if (catalystStruct.fieldNames.toSeq != expectedFieldNames) {
+ throw new IncompatibleSchemaException(s"Generic Avro union at
${toFieldStr(avroPath)} " +
+ s"does not match the SQL schema at ${toFieldStr(catalystPath)}. It
expected the " +
+ s"following members ${expectedFieldNames.mkString("(", ", ", ")")} but
got " +
+ s"${catalystStruct.fieldNames.mkString("(", ", ", ")")}")
+ }
+ }
Review Comment:
Can we just inline this into `newComplexUnionConverter()`? The method
definition/parameter list is as long as the body :)
If we do need to keep it, it should be made private.
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -287,14 +298,54 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0,
member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when
converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = avroType.getTypes.asScala.filter(_.getType !=
NULL).toSeq
+ validateComplexUnionMembers(catalystStruct, nonNullTypes, catalystPath,
avroPath)
+
+ val fieldConverters = nonNullTypes.zipWithIndex.map { case (avroField, i)
=>
+ val cf = catalystStruct.fields(i)
+ newConverter(cf.dataType, resolveNullableType(avroField, nullable =
true),
Review Comment:
Avro doesn't allow nested unions so when we're iterating over the union
branches here, we'll never see a union. Thus there is no point in
`resolveNullableType()`, since all it does is extract nullability from unions.
We can just pass `avroField` directly here.
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -287,14 +298,54 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0,
member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when
converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = avroType.getTypes.asScala.filter(_.getType !=
NULL).toSeq
+ validateComplexUnionMembers(catalystStruct, nonNullTypes, catalystPath,
avroPath)
+
+ val fieldConverters = nonNullTypes.zipWithIndex.map { case (avroField, i)
=>
+ val cf = catalystStruct.fields(i)
+ newConverter(cf.dataType, resolveNullableType(avroField, nullable =
true),
+ catalystPath :+ cf.name, avroPath :+ cf.name)
+ }.toArray
+
+ val numFields = catalystStruct.length
+ row: InternalRow =>
+ (0 until numFields).dropWhile(row.isNullAt).headOption match {
+ case Some(i) => fieldConverters(i).apply(row, i)
+ case None => null
+ }
Review Comment:
Since this is a performance-critical section (executed on a per-row basis),
it's better to avoid Scala collections, which can be much less performant due
to creation of lots of temporary intermediate objects. We can instead use a
while-loop (note that `return` and `break` also have bad performance in Scala
since they are implemented using exceptions, so we avoid it by using a control
variable):
```suggestion
row: InternalRow => {
var idx = 0
var retVal: Any = null
while (idx < numFields && retVal == null) {
if (!row.isNullAt(idx)) {
retVal = fieldConverters(idx).apply(row, idx)
}
idx += 1
}
}
```
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -337,4 +387,8 @@ private[sql] class AvroSerializer(
"schema will throw runtime exception if there is a record with null
value.")
}
}
+
+ private def nonNullUnionTypes(avroType: Schema): Set[Type] = {
+ avroType.getTypes.asScala.map(_.getType).filter(_ != NULL).toSet
+ }
Review Comment:
It's about time we had a utility method for this, given how many places we
do it here and in `AvroDeserializer`, but I would suggest that we:
1. Make this return `Set[Schema]` (i.e. remove the `map(_.getType)`) or even
`Seq[Schema]` to make it more generally useful.
2. Move it into `AvroUtils` to make it more accessible
For the newly added union check the usage becomes slightly less concise:
```
case (DoubleType, UNION) if nonNullUnionTypes(avroType).map(_.getType)
== Set(FLOAT, DOUBLE) =>
...
case (LongType, UNION) if nonNullUnionTypes(avroType).map(_.getType)
== Set(INT, LONG) =>
...
```
But now we can use it in a bunch of other places:
`AvroSerializer#resolveAvroType()`,
`AvroSerializer#newComplexUnionConverter()`,
`SchemaConverters#toSqlTypeHelper()`, and `AvroDeserializer#newWriter()`
##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala:
##########
@@ -219,7 +219,7 @@ class AvroFunctionsSuite extends QueryTest with
SharedSparkSession {
functions.from_avro($"avro", avroTypeStruct)), df)
}
- test("to_avro with unsupported nullable Avro schema") {
+ test("to_avro with complex union Avro schema") {
val df = spark.range(10).select(struct($"id",
$"id".cast("string").as("str")).as("struct"))
for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int",
"long"]""")) {
Review Comment:
This is only checking the `(INT, LONG)` case. Can we also check the more
generic case of different types, like `(INT, STRING)`?
##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala:
##########
@@ -1144,32 +1154,87 @@ abstract class AvroSuite
}
}
- test("unsupported nullable avro type") {
+ test("int/long double/float conversion") {
val catalystSchema =
StructType(Seq(
- StructField("Age", IntegerType, nullable = false),
- StructField("Name", StringType, nullable = false)))
+ StructField("Age", LongType),
+ StructField("Length", DoubleType),
+ StructField("Name", StringType)))
- for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int",
"long"]""")) {
+ for (optionalNull <- Seq(""""null",""", "")) {
val avroSchema = s"""
|{
| "type" : "record",
| "name" : "test_schema",
| "fields" : [
- | {"name": "Age", "type": $unsupportedAvroType},
+ | {"name": "Age", "type": [$optionalNull "int", "long"]},
+ | {"name": "Length", "type": [$optionalNull "float", "double"]},
| {"name": "Name", "type": ["null", "string"]}
| ]
|}
""".stripMargin
val df = spark.createDataFrame(
- spark.sparkContext.parallelize(Seq(Row(2, "Aurora"))), catalystSchema)
+ spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"))),
catalystSchema)
Review Comment:
Can we also test with a record that has the non-upcasted types (int/float)?
```suggestion
spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"), Row(1,
0.9F, null))), catalystSchema)
```
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -287,14 +298,54 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0,
member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when
converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = avroType.getTypes.asScala.filter(_.getType !=
NULL).toSeq
+ validateComplexUnionMembers(catalystStruct, nonNullTypes, catalystPath,
avroPath)
+
+ val fieldConverters = nonNullTypes.zipWithIndex.map { case (avroField, i)
=>
Review Comment:
`avroField` seems like a bit of a misleading name, since this is a `Schema`
rather than a `Schema.Field`, and it's not actually a field at all -- it's a
union branch. Maybe `avroBranch` or `avroBranchType`?
##########
connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala:
##########
@@ -287,14 +298,54 @@ private[sql] class AvroSerializer(
result
}
+ /**
+ * Complex unions map to struct types where field names are member0,
member1, etc.
+ * This is consistent with the behavior in [[SchemaConverters]] and when
converting between Avro
+ * and Parquet.
+ */
+ private def newComplexUnionConverter(
+ catalystStruct: StructType,
+ avroType: Schema,
+ catalystPath: Seq[String],
+ avroPath: Seq[String]): InternalRow => Any = {
+ val nonNullTypes = avroType.getTypes.asScala.filter(_.getType !=
NULL).toSeq
+ validateComplexUnionMembers(catalystStruct, nonNullTypes, catalystPath,
avroPath)
+
+ val fieldConverters = nonNullTypes.zipWithIndex.map { case (avroField, i)
=>
+ val cf = catalystStruct.fields(i)
Review Comment:
`zip` instead of `zipWithIndex` ?
```suggestion
val fieldConverters = nonNullTypes.zip(catalystStruct).map { case
(avroField, cf) =>
```
##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala:
##########
@@ -327,11 +329,19 @@ abstract class AvroSuite
dataFileWriter.flush()
dataFileWriter.close()
- val df = spark.sqlContext.read.format("avro").load(s"$dir.avro")
+ val df = spark.sqlContext.read.format("avro").load(nativeWriterPath)
assertResult(field1)(df.selectExpr("field1.member0").first().get(0))
assertResult(field2)(df.selectExpr("field2.member1").first().get(0))
assertResult(field3)(df.selectExpr("field3.member2").first().get(0))
assertResult(field4)(df.selectExpr("field4.member3").first().get(0))
+
+ df.write.format("avro").option("avroSchema",
schema.toString).save(sparkWriterPath)
+
+ val df2 = spark.sqlContext.read.format("avro").load(nativeWriterPath)
+ assertResult(field1)(df2.selectExpr("field1.member0").first().get(0))
+ assertResult(field2)(df2.selectExpr("field2.member1").first().get(0))
+ assertResult(field3)(df2.selectExpr("field3.member2").first().get(0))
+ assertResult(field4)(df2.selectExpr("field4.member3").first().get(0))
Review Comment:
One other thing which is not tested currently is a top-level null for the
union. I've seen this get handled improperly in some areas in the past; it
would be nice to see it covered by the testing here since it is a special case.
##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala:
##########
@@ -1144,32 +1154,87 @@ abstract class AvroSuite
}
}
- test("unsupported nullable avro type") {
+ test("int/long double/float conversion") {
val catalystSchema =
StructType(Seq(
- StructField("Age", IntegerType, nullable = false),
- StructField("Name", StringType, nullable = false)))
+ StructField("Age", LongType),
+ StructField("Length", DoubleType),
+ StructField("Name", StringType)))
- for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int",
"long"]""")) {
+ for (optionalNull <- Seq(""""null",""", "")) {
val avroSchema = s"""
|{
| "type" : "record",
| "name" : "test_schema",
| "fields" : [
- | {"name": "Age", "type": $unsupportedAvroType},
+ | {"name": "Age", "type": [$optionalNull "int", "long"]},
+ | {"name": "Length", "type": [$optionalNull "float", "double"]},
| {"name": "Name", "type": ["null", "string"]}
| ]
|}
""".stripMargin
val df = spark.createDataFrame(
- spark.sparkContext.parallelize(Seq(Row(2, "Aurora"))), catalystSchema)
+ spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"))),
catalystSchema)
withTempPath { tempDir =>
- val message = intercept[SparkException] {
+ df.write.format("avro").option("avroSchema",
avroSchema).save(tempDir.getPath)
+ checkAnswer(
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .load(tempDir.getPath),
+ df)
+ }
+ }
+ }
+
+ test("non-matching complex union types") {
+ val catalystSchema =
+ StructType(Seq(
+ StructField("Union", StructType(Seq(
+ StructField("member0", IntegerType),
+ StructField("member1", StructType(Seq(StructField("f1", StringType,
nullable = false))))
+ )))))
Review Comment:
You can use the builder-style `new StructType().add(...)` to be more concise
here:
```suggestion
val catalystSchema = new StructType().add("Union", new StructType()
.add("member0", IntegerType)
.add("member1", new StructType().add("f1", StringType, nullable =
false))
```
(I prefer this format since it's less verbose, but either way is okay)
##########
connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala:
##########
@@ -1144,32 +1154,87 @@ abstract class AvroSuite
}
}
- test("unsupported nullable avro type") {
+ test("int/long double/float conversion") {
val catalystSchema =
StructType(Seq(
- StructField("Age", IntegerType, nullable = false),
- StructField("Name", StringType, nullable = false)))
+ StructField("Age", LongType),
+ StructField("Length", DoubleType),
+ StructField("Name", StringType)))
- for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int",
"long"]""")) {
+ for (optionalNull <- Seq(""""null",""", "")) {
val avroSchema = s"""
|{
| "type" : "record",
| "name" : "test_schema",
| "fields" : [
- | {"name": "Age", "type": $unsupportedAvroType},
+ | {"name": "Age", "type": [$optionalNull "int", "long"]},
+ | {"name": "Length", "type": [$optionalNull "float", "double"]},
| {"name": "Name", "type": ["null", "string"]}
| ]
|}
""".stripMargin
val df = spark.createDataFrame(
- spark.sparkContext.parallelize(Seq(Row(2, "Aurora"))), catalystSchema)
+ spark.sparkContext.parallelize(Seq(Row(2L, 1.8D, "Aurora"))),
catalystSchema)
withTempPath { tempDir =>
- val message = intercept[SparkException] {
+ df.write.format("avro").option("avroSchema",
avroSchema).save(tempDir.getPath)
+ checkAnswer(
+ spark.read
+ .format("avro")
+ .option("avroSchema", avroSchema)
+ .load(tempDir.getPath),
+ df)
+ }
+ }
+ }
+
+ test("non-matching complex union types") {
+ val catalystSchema =
+ StructType(Seq(
+ StructField("Union", StructType(Seq(
+ StructField("member0", IntegerType),
+ StructField("member1", StructType(Seq(StructField("f1", StringType,
nullable = false))))
+ )))))
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(Seq(Row(Row(1, null)))), catalystSchema)
+
+ val recordSchema =
"""{"type":"record","name":"r","fields":[{"name":"f1","type":"string"}]}"""
+ for ((unionSchema, compatible) <- Seq(
+ (""""null","int",""" + recordSchema, true),
+ (""""int","null",""" + recordSchema, true),
+ (""""int",""" + recordSchema + ""","null"""", true),
+ (""""int",""" + recordSchema, true),
+ (""""null",""" + recordSchema + ""","int"""", false),
+ (""""null",""" + recordSchema, false),
+
(""""null","int",{"type":"record","name":"r","fields":[{"name":"f2","type":"string"}]}""",
+ false)
+ )) {
+ val avroSchema = s"""
+ |{
+ | "type" : "record",
+ | "name" : "test_schema",
+ | "fields" : [
+ | {"name": "Union", "type": [$unionSchema]}
+ | ]
+ |}
+ """.stripMargin
Review Comment:
Maybe use `SchemaBuilder` here? It was hard for me to read the unionSchema
examples with so many quotes
```suggestion
val recordS =
SchemaBuilder.record("r").fields().requiredString("f1").endRecord()
val intS = Schema.create(Schema.Type.INT)
val nullS = Schema.create(Schema.Type.NULL)
for ((unionTypes, compatible) <- Seq(
(Seq(nullS, intS, recordS), true),
(Seq(intS, nullS, recordS), true),
(Seq(intS, recordS, nullS), true),
(Seq(intS, recordS), true),
(Seq(nullS, recordS, intS), false),
(Seq(nullS, recordS), false),
(Seq(nullS,
SchemaBuilder.record("r").fields().requiredString("f2").endRecord()), false)
)) {
val avroSchema = SchemaBuilder.record("test_schema").fields()
.name("union").`type`(Schema.createUnion(unionTypes: _*)).noDefault()
.endRecord().toString()
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]