This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 3042b2c7c5 [HUDI-4525] Fixing Spark 3.3 `AvroSerializer` 
implementation (#6279)
3042b2c7c5 is described below

commit 3042b2c7c54c91578dba69a1a814563fb00718d5
Author: Alexey Kudinkin <[email protected]>
AuthorDate: Wed Aug 3 14:27:21 2022 -0700

    [HUDI-4525] Fixing Spark 3.3 `AvroSerializer` implementation (#6279)
---
 .github/workflows/bot.yml                          |   2 +-
 .../org/apache/hudi/io/HoodieAppendHandle.java     |   6 +-
 .../scala/org/apache/hudi/HoodieSparkUtils.scala   |   1 +
 .../TestConvertFilterToCatalystExpression.scala    |   4 +-
 .../org/apache/hudi/TestHoodieSparkSqlWriter.scala |  34 +++++-
 .../org/apache/spark/sql/avro/AvroSerializer.scala |  27 ++++-
 .../apache/spark/sql/avro/AvroDeserializer.scala   |  35 +++---
 .../org/apache/spark/sql/avro/AvroSerializer.scala | 121 ++++++++++++++++-----
 8 files changed, 172 insertions(+), 58 deletions(-)

diff --git a/.github/workflows/bot.yml b/.github/workflows/bot.yml
index 26c07b96bf..3aa9bdbcc6 100644
--- a/.github/workflows/bot.yml
+++ b/.github/workflows/bot.yml
@@ -69,4 +69,4 @@ jobs:
           FLINK_PROFILE: ${{ matrix.flinkProfile }}
         if: ${{ !endsWith(env.SPARK_PROFILE, '2.4') }} # skip test spark 2.4 
as it's covered by Azure CI
         run:
-          mvn test -Punit-tests -D"$SCALA_PROFILE" -D"$SPARK_PROFILE" 
-D"$FLINK_PROFILE" '-Dtest=org.apache.spark.sql.hudi.Test*' -pl 
hudi-spark-datasource/hudi-spark
+          mvn test -Punit-tests -D"$SCALA_PROFILE" -D"$SPARK_PROFILE" 
-D"$FLINK_PROFILE" '-Dtest=Test*' -pl hudi-spark-datasource/hudi-spark
diff --git 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
index 426e20f83b..e0d40642a6 100644
--- 
a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
+++ 
b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/io/HoodieAppendHandle.java
@@ -471,10 +471,12 @@ public class HoodieAppendHandle<T extends 
HoodieRecordPayload, I, K, O> extends
 
     return HoodieLogFormat.newWriterBuilder()
         
.onParentPath(FSUtils.getPartitionPath(hoodieTable.getMetaClient().getBasePath(),
 partitionPath))
-        .withFileId(fileId).overBaseCommit(baseCommitTime)
+        .withFileId(fileId)
+        .overBaseCommit(baseCommitTime)
         
.withLogVersion(latestLogFile.map(HoodieLogFile::getLogVersion).orElse(HoodieLogFile.LOGFILE_BASE_VERSION))
         .withFileSize(latestLogFile.map(HoodieLogFile::getFileSize).orElse(0L))
-        .withSizeThreshold(config.getLogFileMaxSize()).withFs(fs)
+        .withSizeThreshold(config.getLogFileMaxSize())
+        .withFs(fs)
         .withRolloverLogWriteToken(writeToken)
         .withLogWriteToken(latestLogFile.map(x -> 
FSUtils.getWriteTokenFromLogPath(x.getPath())).orElse(writeToken))
         .withFileExtension(HoodieLogFile.DELTA_EXTENSION).build();
diff --git 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
index 97bbe3e79b..a2f5d1ce97 100644
--- 
a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
+++ 
b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala
@@ -54,6 +54,7 @@ private[hudi] trait SparkVersionsSupport {
   def isSpark3_2: Boolean = getSparkVersion.startsWith("3.2")
   def isSpark3_3: Boolean = getSparkVersion.startsWith("3.3")
 
+  def gteqSpark3_0: Boolean = getSparkVersion >= "3.0"
   def gteqSpark3_1: Boolean = getSparkVersion >= "3.1"
   def gteqSpark3_1_3: Boolean = getSparkVersion >= "3.1.3"
   def gteqSpark3_2: Boolean = getSparkVersion >= "3.2"
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
index 8aa47ffc2f..2d4498ac28 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestConvertFilterToCatalystExpression.scala
@@ -69,7 +69,7 @@ class TestConvertFilterToCatalystExpression {
   private def checkConvertFilter(filter: Filter, expectExpression: String): 
Unit = {
     // [SPARK-25769][SPARK-34636][SPARK-34626][SQL] sql method in 
UnresolvedAttribute,
     // AttributeReference and Alias don't quote qualified names properly
-    val removeQuotesIfNeed = if (expectExpression != null && 
HoodieSparkUtils.isSpark3_2) {
+    val removeQuotesIfNeed = if (expectExpression != null && 
HoodieSparkUtils.gteqSpark3_2) {
       expectExpression.replace("`", "")
     } else {
       expectExpression
@@ -86,7 +86,7 @@ class TestConvertFilterToCatalystExpression {
   private def checkConvertFilters(filters: Array[Filter], expectExpression: 
String): Unit = {
     // [SPARK-25769][SPARK-34636][SPARK-34626][SQL] sql method in 
UnresolvedAttribute,
     // AttributeReference and Alias don't quote qualified names properly
-    val removeQuotesIfNeed = if (expectExpression != null && 
HoodieSparkUtils.isSpark3_2) {
+    val removeQuotesIfNeed = if (expectExpression != null && 
HoodieSparkUtils.gteqSpark3_2) {
       expectExpression.replace("`", "")
     } else {
       expectExpression
diff --git 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
index 4829c44932..93469f2796 100644
--- 
a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
+++ 
b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala
@@ -22,6 +22,7 @@ import java.time.Instant
 import java.util.{Collections, Date, UUID}
 import org.apache.commons.io.FileUtils
 import org.apache.hudi.DataSourceWriteOptions._
+import org.apache.hudi.HoodieSparkUtils.gteqSpark3_0
 import org.apache.hudi.client.SparkRDDWriteClient
 import org.apache.hudi.common.model._
 import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient, 
TableSchemaResolver}
@@ -41,7 +42,8 @@ import org.apache.spark.{SparkConf, SparkContext}
 import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, 
assertTrue, fail}
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.{CsvSource, EnumSource, ValueSource}
+import org.junit.jupiter.params.provider.Arguments.arguments
+import org.junit.jupiter.params.provider.{Arguments, CsvSource, EnumSource, 
MethodSource, ValueSource}
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito.{spy, times, verify}
 import org.scalatest.Assertions.assertThrows
@@ -485,11 +487,8 @@ class TestHoodieSparkSqlWriter {
    * @param populateMetaFields Flag for populating meta fields
    */
   @ParameterizedTest
-  @CsvSource(
-    Array("COPY_ON_WRITE,parquet,true", "COPY_ON_WRITE,parquet,false", 
"MERGE_ON_READ,parquet,true", "MERGE_ON_READ,parquet,false",
-      "COPY_ON_WRITE,orc,true", "COPY_ON_WRITE,orc,false", 
"MERGE_ON_READ,orc,true", "MERGE_ON_READ,orc,false"
-    ))
-  def testDatasourceInsertForTableTypeBaseFileMetaFields(tableType: String, 
baseFileFormat: String, populateMetaFields: Boolean): Unit = {
+  @MethodSource(Array("testDatasourceInsert"))
+  def testDatasourceInsertForTableTypeBaseFileMetaFields(tableType: String, 
populateMetaFields: Boolean, baseFileFormat: String): Unit = {
     val hoodieFooTableName = "hoodie_foo_tbl"
     val fooTableModifier = Map("path" -> tempBasePath,
       HoodieWriteConfig.TBL_NAME.key -> hoodieFooTableName,
@@ -1069,3 +1068,26 @@ class TestHoodieSparkSqlWriter {
     assertTrue(kg2 == classOf[SimpleKeyGenerator].getName)
   }
 }
+
+object TestHoodieSparkSqlWriter {
+  def testDatasourceInsert: java.util.stream.Stream[Arguments] = {
+    val scenarios = Array(
+      Seq("COPY_ON_WRITE", true),
+      Seq("COPY_ON_WRITE", false),
+      Seq("MERGE_ON_READ", true),
+      Seq("MERGE_ON_READ", false)
+    )
+
+    val parquetScenarios = scenarios.map { _ :+ "parquet" }
+    val orcScenarios = scenarios.map { _ :+ "orc" }
+
+    // TODO(HUDI-4496) Fix Orc support in Spark 3.x
+    val targetScenarios = if (gteqSpark3_0) {
+      parquetScenarios
+    } else {
+      parquetScenarios ++ orcScenarios
+    }
+
+    java.util.Arrays.stream(targetScenarios.map(as => 
arguments(as.map(_.asInstanceOf[AnyRef]):_*)))
+  }
+}
diff --git 
a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
 
b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 73267f4147..ba9812b026 100644
--- 
a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ 
b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -45,8 +45,13 @@ import java.util.TimeZone
  * A serializer to serialize data in catalyst format to data in avro format.
  *
  * NOTE: This code is borrowed from Spark 3.2.1
- * This code is borrowed, so that we can better control compatibility w/in 
Spark minor
- * branches (3.2.x, 3.1.x, etc)
+ *       This code is borrowed, so that we can better control compatibility 
w/in Spark minor
+ *       branches (3.2.x, 3.1.x, etc)
+ *
+ * NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH 
THE MODIFICATION
+ *       BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND 
PROPERLY ALL THE
+ *       MODIFICATIONS.
+ *
  *
  * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
  */
@@ -211,11 +216,20 @@ private[sql] class AvroSerializer(rootCatalystType: 
DataType,
         val numFields = st.length
         (getter, ordinal) => structConverter(getter.getStruct(ordinal, 
numFields))
 
+      
////////////////////////////////////////////////////////////////////////////////////////////
+      // Following section is amended to the original (Spark's) implementation
+      // >>> BEGINS
+      
////////////////////////////////////////////////////////////////////////////////////////////
+
       case (st: StructType, UNION) =>
         val unionConverter = newUnionConverter(st, avroType, catalystPath, 
avroPath)
         val numFields = st.length
         (getter, ordinal) => unionConverter(getter.getStruct(ordinal, 
numFields))
 
+      
////////////////////////////////////////////////////////////////////////////////////////////
+      // <<< ENDS
+      
////////////////////////////////////////////////////////////////////////////////////////////
+
       case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
         val valueConverter = newConverter(
           vt, resolveNullableType(avroType.getValueType, valueContainsNull),
@@ -293,6 +307,11 @@ private[sql] class AvroSerializer(rootCatalystType: 
DataType,
       result
   }
 
+  
////////////////////////////////////////////////////////////////////////////////////////////
+  // Following section is amended to the original (Spark's) implementation
+  // >>> BEGINS
+  
////////////////////////////////////////////////////////////////////////////////////////////
+
   private def newUnionConverter(catalystStruct: StructType,
                                 avroUnion: Schema,
                                 catalystPath: Seq[String],
@@ -337,6 +356,10 @@ private[sql] class AvroSerializer(rootCatalystType: 
DataType,
       avroStruct.getTypes.size() - 1 == catalystStruct.length) || 
avroStruct.getTypes.size() == catalystStruct.length
   }
 
+  
////////////////////////////////////////////////////////////////////////////////////////////
+  // <<< ENDS
+  
////////////////////////////////////////////////////////////////////////////////////////////
+
   /**
    * Resolve a possibly nullable Avro Type.
    *
diff --git 
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
 
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
index fbefb36ddc..5e7bab3e51 100644
--- 
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
+++ 
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
@@ -48,17 +48,15 @@ import java.util.TimeZone
  *
  * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
  */
-private[sql] class AvroDeserializer(
-                                     rootAvroType: Schema,
-                                     rootCatalystType: DataType,
-                                     positionalFieldMatch: Boolean,
-                                     datetimeRebaseSpec: RebaseSpec,
-                                     filters: StructFilters) {
-
-  def this(
-            rootAvroType: Schema,
-            rootCatalystType: DataType,
-            datetimeRebaseMode: String) = {
+private[sql] class AvroDeserializer(rootAvroType: Schema,
+                                    rootCatalystType: DataType,
+                                    positionalFieldMatch: Boolean,
+                                    datetimeRebaseSpec: RebaseSpec,
+                                    filters: StructFilters) {
+
+  def this(rootAvroType: Schema,
+           rootCatalystType: DataType,
+           datetimeRebaseMode: String) = {
     this(
       rootAvroType,
       rootCatalystType,
@@ -69,11 +67,9 @@ private[sql] class AvroDeserializer(
 
   private lazy val decimalConversions = new DecimalConversion()
 
-  private val dateRebaseFunc = createDateRebaseFuncInRead(
-    datetimeRebaseSpec.mode, "Avro")
+  private val dateRebaseFunc = 
createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro")
 
-  private val timestampRebaseFunc = createTimestampRebaseFuncInRead(
-    datetimeRebaseSpec, "Avro")
+  private val timestampRebaseFunc = 
createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro")
 
   private val converter: Any => Option[Any] = try {
     rootCatalystType match {
@@ -112,11 +108,10 @@ private[sql] class AvroDeserializer(
    * Creates a writer to write avro values to Catalyst values at the given 
ordinal with the given
    * updater.
    */
-  private def newWriter(
-                         avroType: Schema,
-                         catalystType: DataType,
-                         avroPath: Seq[String],
-                         catalystPath: Seq[String]): (CatalystDataUpdater, 
Int, Any) => Unit = {
+  private def newWriter(avroType: Schema,
+                        catalystType: DataType,
+                        avroPath: Seq[String],
+                        catalystPath: Seq[String]): (CatalystDataUpdater, Int, 
Any) => Unit = {
     val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " +
       s"SQL ${toFieldStr(catalystPath)} because "
     val incompatibleMsg = errorPrefix +
diff --git 
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
 
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 73d245d42d..450d9d7346 100644
--- 
a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ 
b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -29,6 +29,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
 import org.apache.avro.generic.GenericData.Record
 import org.apache.avro.util.Utf8
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, 
createTimestampRebaseFuncInWrite}
 import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, 
SpecificInternalRow}
@@ -44,17 +45,20 @@ import java.util.TimeZone
  * A serializer to serialize data in catalyst format to data in avro format.
  *
  * NOTE: This code is borrowed from Spark 3.3.0
- * This code is borrowed, so that we can better control compatibility w/in 
Spark minor
- * branches (3.2.x, 3.1.x, etc)
+ *       This code is borrowed, so that we can better control compatibility 
w/in Spark minor
+ *       branches (3.2.x, 3.1.x, etc)
+ *
+ * NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH 
THE MODIFICATION
+ *       BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND 
PROPERLY ALL THE
+ *       MODIFICATIONS.
  *
  * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
  */
-private[sql] class AvroSerializer(
-                                   rootCatalystType: DataType,
-                                   rootAvroType: Schema,
-                                   nullable: Boolean,
-                                   positionalFieldMatch: Boolean,
-                                   datetimeRebaseMode: 
LegacyBehaviorPolicy.Value) extends Logging {
+private[sql] class AvroSerializer(rootCatalystType: DataType,
+                                  rootAvroType: Schema,
+                                  nullable: Boolean,
+                                  positionalFieldMatch: Boolean,
+                                  datetimeRebaseMode: 
LegacyBehaviorPolicy.Value) extends Logging {
 
   def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: 
Boolean) = {
     this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = 
false,
@@ -65,10 +69,10 @@ private[sql] class AvroSerializer(
     converter.apply(catalystData)
   }
 
-  private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInWrite(
+  private val dateRebaseFunc = createDateRebaseFuncInWrite(
     datetimeRebaseMode, "Avro")
 
-  private val timestampRebaseFunc = 
DataSourceUtils.createTimestampRebaseFuncInWrite(
+  private val timestampRebaseFunc = createTimestampRebaseFuncInWrite(
     datetimeRebaseMode, "Avro")
 
   private val converter: Any => Any = {
@@ -104,11 +108,10 @@ private[sql] class AvroSerializer(
 
   private lazy val decimalConversions = new DecimalConversion()
 
-  private def newConverter(
-                            catalystType: DataType,
-                            avroType: Schema,
-                            catalystPath: Seq[String],
-                            avroPath: Seq[String]): Converter = {
+  private def newConverter(catalystType: DataType,
+                           avroType: Schema,
+                           catalystPath: Seq[String],
+                           avroPath: Seq[String]): Converter = {
     val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
       s"to Avro ${toFieldStr(avroPath)} because "
     (catalystType, avroType.getType) match {
@@ -162,6 +165,7 @@ private[sql] class AvroSerializer(
           val data: Array[Byte] = getter.getBinary(ordinal)
           if (data.length != size) {
             def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else 
"byte"}"
+
             throw new IncompatibleSchemaException(errorPrefix + 
len2str(data.length) +
               " of binary data cannot be written into FIXED type with size of 
" + len2str(size))
           }
@@ -223,6 +227,20 @@ private[sql] class AvroSerializer(
         val numFields = st.length
         (getter, ordinal) => structConverter(getter.getStruct(ordinal, 
numFields))
 
+      
////////////////////////////////////////////////////////////////////////////////////////////
+      // Following section is amended to the original (Spark's) implementation
+      // >>> BEGINS
+      
////////////////////////////////////////////////////////////////////////////////////////////
+
+      case (st: StructType, UNION) =>
+        val unionConverter = newUnionConverter(st, avroType, catalystPath, 
avroPath)
+        val numFields = st.length
+        (getter, ordinal) => unionConverter(getter.getStruct(ordinal, 
numFields))
+
+      
////////////////////////////////////////////////////////////////////////////////////////////
+      // <<< ENDS
+      
////////////////////////////////////////////////////////////////////////////////////////////
+
       case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
         val valueConverter = newConverter(
           vt, resolveNullableType(avroType.getValueType, valueContainsNull),
@@ -257,11 +275,10 @@ private[sql] class AvroSerializer(
     }
   }
 
-  private def newStructConverter(
-                                  catalystStruct: StructType,
-                                  avroStruct: Schema,
-                                  catalystPath: Seq[String],
-                                  avroPath: Seq[String]): InternalRow => 
Record = {
+  private def newStructConverter(catalystStruct: StructType,
+                                 avroStruct: Schema,
+                                 catalystPath: Seq[String],
+                                 avroPath: Seq[String]): InternalRow => Record 
= {
 
     val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
       avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch)
@@ -292,6 +309,60 @@ private[sql] class AvroSerializer(
       result
   }
 
+  
////////////////////////////////////////////////////////////////////////////////////////////
+  // Following section is amended to the original (Spark's) implementation
+  // >>> BEGINS
+  
////////////////////////////////////////////////////////////////////////////////////////////
+
+  private def newUnionConverter(catalystStruct: StructType,
+                                avroUnion: Schema,
+                                catalystPath: Seq[String],
+                                avroPath: Seq[String]): InternalRow => Any = {
+    if (avroUnion.getType != UNION || !canMapUnion(catalystStruct, avroUnion)) 
{
+      throw new IncompatibleSchemaException(s"Cannot convert Catalyst type 
$catalystStruct to " +
+        s"Avro type $avroUnion.")
+    }
+    val nullable = avroUnion.getTypes.size() > 0 && 
avroUnion.getTypes.get(0).getType == Type.NULL
+    val avroInnerTypes = if (nullable) {
+      avroUnion.getTypes.asScala.tail
+    } else {
+      avroUnion.getTypes.asScala
+    }
+    val fieldConverters = catalystStruct.zip(avroInnerTypes).map {
+      case (f1, f2) => newConverter(f1.dataType, f2, catalystPath, avroPath)
+    }
+    val numFields = catalystStruct.length
+    (row: InternalRow) =>
+      var i = 0
+      var result: Any = null
+      while (i < numFields) {
+        if (!row.isNullAt(i)) {
+          if (result != null) {
+            throw new IncompatibleSchemaException(s"Cannot convert Catalyst 
record $catalystStruct to " +
+              s"Avro union $avroUnion. Record has more than one optional 
values set")
+          }
+          result = fieldConverters(i).apply(row, i)
+        }
+        i += 1
+      }
+      if (!nullable && result == null) {
+        throw new IncompatibleSchemaException(s"Cannot convert Catalyst record 
$catalystStruct to " +
+          s"Avro union $avroUnion. Record has no values set, while should have 
exactly one")
+      }
+      result
+  }
+
+  private def canMapUnion(catalystStruct: StructType, avroStruct: Schema): 
Boolean = {
+    (avroStruct.getTypes.size() > 0 &&
+      avroStruct.getTypes.get(0).getType == Type.NULL &&
+      avroStruct.getTypes.size() - 1 == catalystStruct.length) || 
avroStruct.getTypes.size() == catalystStruct.length
+  }
+
+  
////////////////////////////////////////////////////////////////////////////////////////////
+  // <<< ENDS
+  
////////////////////////////////////////////////////////////////////////////////////////////
+
+
   /**
    * Resolve a possibly nullable Avro Type.
    *
@@ -319,12 +390,12 @@ private[sql] class AvroSerializer(
     if (avroType.getType == Type.UNION) {
       val fields = avroType.getTypes.asScala
       val actualType = fields.filter(_.getType != Type.NULL)
-      if (fields.length != 2 || actualType.length != 1) {
-        throw new UnsupportedAvroTypeException(
-          s"Unsupported Avro UNION type $avroType: Only UNION of a null type 
and a non-null " +
-            "type is supported")
+      if (fields.length == 2 && actualType.length == 1) {
+        (true, actualType.head)
+      } else {
+        // This is just a normal union, not used to designate nullability
+        (false, avroType)
       }
-      (true, actualType.head)
     } else {
       (false, avroType)
     }

Reply via email to