alexeykudinkin commented on a change in pull request #4955:
URL: https://github.com/apache/hudi/pull/4955#discussion_r831597205



##########
File path: 
hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
##########
@@ -37,10 +37,16 @@ import scala.collection.mutable.ArrayBuffer
 /**
  * A deserializer to deserialize data in avro format to data in catalyst 
format.
  *
- * NOTE: This is a version of {@code AvroDeserializer} impl from Spark 2.4.4 
w/ the fix for SPARK-30267
+ * NOTE: This code is borrowed from Spark 2.4.4
+ *       This code is borrowed, so that we can better control compatibility 
w/in Spark minor
+ *       branches (3.2.x, 3.1.x, etc)
+ *
+ *       PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY 
NECESSARY
+ *
+ * NOTE: This is a version of [[AvroDeserializer]] impl from Spark 2.4.4 w/ 
the fix for SPARK-30267

Review comment:
       @vinothchandar this is the diff against Spark

##########
File path: 
hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala
##########
@@ -0,0 +1,503 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.math.BigDecimal
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic._
+import org.apache.avro.util.Utf8
+import org.apache.spark.sql.avro.AvroDeserializer.{createDateRebaseFuncInRead, 
createTimestampRebaseFuncInRead}
+import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr}
+import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
+import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, 
UnsafeArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
DateTimeUtils, GenericArrayData, RebaseDateTime}
+import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY
+import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A deserializer to deserialize data in avro format to data in catalyst 
format.
+ *
+ * NOTE: This code is borrowed from Spark 3.2.1
+ *       This code is borrowed, so that we can better control compatibility 
w/in Spark minor
+ *       branches (3.2.x, 3.1.x, etc)
+ *
+ *       PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY 
NECESSARY
+ */
+private[sql] class AvroDeserializer(rootAvroType: Schema,
+                                    rootCatalystType: DataType,
+                                    positionalFieldMatch: Boolean,
+                                    datetimeRebaseSpec: RebaseSpec,
+                                    filters: StructFilters) {
+
+  def this(rootAvroType: Schema,
+           rootCatalystType: DataType,
+           datetimeRebaseMode: String) = {
+    this(
+      rootAvroType,
+      rootCatalystType,
+      positionalFieldMatch = false,
+      RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)),
+      new NoopFilters)
+  }
+
+  private lazy val decimalConversions = new DecimalConversion()
+
+  private val dateRebaseFunc = 
createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro")
+
+  private val timestampRebaseFunc = 
createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro")
+
+  private val converter: Any => Option[Any] = try {
+    rootCatalystType match {
+      // A shortcut for empty schema.
+      case st: StructType if st.isEmpty =>
+        (_: Any) => Some(InternalRow.empty)
+
+      case st: StructType =>
+        val resultRow = new SpecificInternalRow(st.map(_.dataType))
+        val fieldUpdater = new RowUpdater(resultRow)
+        val applyFilters = filters.skipRow(resultRow, _)
+        val writer = getRecordWriter(rootAvroType, st, Nil, Nil, applyFilters)
+        (data: Any) => {
+          val record = data.asInstanceOf[GenericRecord]
+          val skipRow = writer(fieldUpdater, record)
+          if (skipRow) None else Some(resultRow)
+        }
+
+      case _ =>
+        val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+        val fieldUpdater = new RowUpdater(tmpRow)
+        val writer = newWriter(rootAvroType, rootCatalystType, Nil, Nil)
+        (data: Any) => {
+          writer(fieldUpdater, 0, data)
+          Some(tmpRow.get(0, rootCatalystType))
+        }
+    }
+  } catch {
+    case ise: IncompatibleSchemaException => throw new 
IncompatibleSchemaException(
+      s"Cannot convert Avro type $rootAvroType to SQL type 
${rootCatalystType.sql}.", ise)
+  }
+
+  def deserialize(data: Any): Option[Any] = converter(data)
+
+  /**
+   * Creates a writer to write avro values to Catalyst values at the given 
ordinal with the given
+   * updater.
+   */
+  private def newWriter(avroType: Schema,
+                        catalystType: DataType,
+                        avroPath: Seq[String],
+                        catalystPath: Seq[String]): (CatalystDataUpdater, Int, 
Any) => Unit = {
+    val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " +
+      s"SQL ${toFieldStr(catalystPath)} because "
+    val incompatibleMsg = errorPrefix +
+      s"schema is incompatible (avroType = $avroType, sqlType = 
${catalystType.sql})"
+
+    (avroType.getType, catalystType) match {
+      case (NULL, NullType) => (updater, ordinal, _) =>
+        updater.setNullAt(ordinal)
+
+      // TODO: we can avoid boxing if future version of avro provide primitive 
accessors.
+      case (BOOLEAN, BooleanType) => (updater, ordinal, value) =>
+        updater.setBoolean(ordinal, value.asInstanceOf[Boolean])
+
+      case (INT, IntegerType) => (updater, ordinal, value) =>
+        updater.setInt(ordinal, value.asInstanceOf[Int])
+
+      case (INT, DateType) => (updater, ordinal, value) =>
+        updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int]))
+
+      case (LONG, LongType) => (updater, ordinal, value) =>
+        updater.setLong(ordinal, value.asInstanceOf[Long])
+
+      case (LONG, TimestampType) => avroType.getLogicalType match {
+        // For backward compatibility, if the Avro type is Long and it is not 
logical type
+        // (the `null` case), the value is processed as timestamp type with 
millisecond precision.
+        case null | _: TimestampMillis => (updater, ordinal, value) =>
+          val millis = value.asInstanceOf[Long]
+          val micros = DateTimeUtils.millisToMicros(millis)
+          updater.setLong(ordinal, timestampRebaseFunc(micros))
+        case _: TimestampMicros => (updater, ordinal, value) =>
+          val micros = value.asInstanceOf[Long]
+          updater.setLong(ordinal, timestampRebaseFunc(micros))
+        case other => throw new IncompatibleSchemaException(errorPrefix +
+          s"Avro logical type $other cannot be converted to SQL type 
${TimestampType.sql}.")
+      }
+
+      // Before we upgrade Avro to 1.8 for logical type support, spark-avro 
converts Long to Date.
+      // For backward compatibility, we still keep this conversion.
+      case (LONG, DateType) => (updater, ordinal, value) =>
+        updater.setInt(ordinal, (value.asInstanceOf[Long] / 
MILLIS_PER_DAY).toInt)
+
+      case (FLOAT, FloatType) => (updater, ordinal, value) =>
+        updater.setFloat(ordinal, value.asInstanceOf[Float])
+
+      case (DOUBLE, DoubleType) => (updater, ordinal, value) =>
+        updater.setDouble(ordinal, value.asInstanceOf[Double])
+
+      case (STRING, StringType) => (updater, ordinal, value) =>
+        val str = value match {
+          case s: String => UTF8String.fromString(s)
+          case s: Utf8 =>
+            val bytes = new Array[Byte](s.getByteLength)
+            System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength)
+            UTF8String.fromBytes(bytes)
+        }
+        updater.set(ordinal, str)
+
+      case (ENUM, StringType) => (updater, ordinal, value) =>
+        updater.set(ordinal, UTF8String.fromString(value.toString))
+
+      case (FIXED, BinaryType) => (updater, ordinal, value) =>
+        updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone())
+
+      case (BYTES, BinaryType) => (updater, ordinal, value) =>
+        val bytes = value match {
+          case b: ByteBuffer =>
+            val bytes = new Array[Byte](b.remaining)
+            b.get(bytes)
+            bytes
+          case b: Array[Byte] => b
+          case other =>
+            throw new RuntimeException(errorPrefix + s"$other is not a valid 
avro binary.")
+        }
+        updater.set(ordinal, bytes)
+
+      case (FIXED, _: DecimalType) => (updater, ordinal, value) =>
+        val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
+        val bigDecimal = 
decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d)
+        val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
+        updater.setDecimal(ordinal, decimal)
+
+      case (BYTES, _: DecimalType) => (updater, ordinal, value) =>
+        val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal]
+        val bigDecimal = 
decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d)
+        val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale)
+        updater.setDecimal(ordinal, decimal)
+
+      case (RECORD, st: StructType) =>
+        // Avro datasource doesn't accept filters with nested attributes. See 
SPARK-32328.
+        // We can always return `false` from `applyFilters` for nested records.
+        val writeRecord =
+          getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = 
_ => false)
+        (updater, ordinal, value) =>
+          val row = new SpecificInternalRow(st)
+          writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord])
+          updater.set(ordinal, row)
+
+      case (ARRAY, ArrayType(elementType, containsNull)) =>
+        val avroElementPath = avroPath :+ "element"
+        val elementWriter = newWriter(avroType.getElementType, elementType,
+          avroElementPath, catalystPath :+ "element")
+        (updater, ordinal, value) =>
+          val collection = value.asInstanceOf[java.util.Collection[Any]]
+          val result = createArrayData(elementType, collection.size())
+          val elementUpdater = new ArrayDataUpdater(result)
+
+          var i = 0
+          val iter = collection.iterator()
+          while (iter.hasNext) {
+            val element = iter.next()
+            if (element == null) {
+              if (!containsNull) {
+                throw new RuntimeException(
+                  s"Array value at path ${toFieldStr(avroElementPath)} is not 
allowed to be null")
+              } else {
+                elementUpdater.setNullAt(i)
+              }
+            } else {
+              elementWriter(elementUpdater, i, element)
+            }
+            i += 1
+          }
+
+          updater.set(ordinal, result)
+
+      case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == 
StringType =>
+        val keyWriter = newWriter(SchemaBuilder.builder().stringType(), 
StringType,
+          avroPath :+ "key", catalystPath :+ "key")
+        val valueWriter = newWriter(avroType.getValueType, valueType,
+          avroPath :+ "value", catalystPath :+ "value")
+        (updater, ordinal, value) =>
+          val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]]
+          val keyArray = createArrayData(keyType, map.size())
+          val keyUpdater = new ArrayDataUpdater(keyArray)
+          val valueArray = createArrayData(valueType, map.size())
+          val valueUpdater = new ArrayDataUpdater(valueArray)
+          val iter = map.entrySet().iterator()
+          var i = 0
+          while (iter.hasNext) {
+            val entry = iter.next()
+            assert(entry.getKey != null)
+            keyWriter(keyUpdater, i, entry.getKey)
+            if (entry.getValue == null) {
+              if (!valueContainsNull) {
+                throw new RuntimeException(
+                  s"Map value at path ${toFieldStr(avroPath :+ "value")} is 
not allowed to be null")
+              } else {
+                valueUpdater.setNullAt(i)
+              }
+            } else {
+              valueWriter(valueUpdater, i, entry.getValue)
+            }
+            i += 1
+          }
+
+          // The Avro map will never have null or duplicated map keys, it's 
safe to create a
+          // ArrayBasedMapData directly here.
+          updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))
+
+      case (UNION, _) =>
+        val allTypes = avroType.getTypes.asScala
+        val nonNullTypes = allTypes.filter(_.getType != NULL)
+        val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava)
+        if (nonNullTypes.nonEmpty) {
+          if (nonNullTypes.length == 1) {
+            newWriter(nonNullTypes.head, catalystType, avroPath, catalystPath)
+          } else {
+            nonNullTypes.map(_.getType).toSeq match {
+              case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == 
LongType =>
+                (updater, ordinal, value) =>
+                  value match {
+                    case null => updater.setNullAt(ordinal)
+                    case l: java.lang.Long => updater.setLong(ordinal, l)
+                    case i: java.lang.Integer => updater.setLong(ordinal, 
i.longValue())
+                  }
+
+              case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && 
catalystType == DoubleType =>
+                (updater, ordinal, value) =>
+                  value match {
+                    case null => updater.setNullAt(ordinal)
+                    case d: java.lang.Double => updater.setDouble(ordinal, d)
+                    case f: java.lang.Float => updater.setDouble(ordinal, 
f.doubleValue())
+                  }
+
+              case _ =>
+                catalystType match {
+                  case st: StructType if st.length == nonNullTypes.size =>
+                    val fieldWriters = nonNullTypes.zip(st.fields).map {
+                      case (schema, field) =>
+                        newWriter(schema, field.dataType, avroPath, 
catalystPath :+ field.name)
+                    }.toArray
+                    (updater, ordinal, value) => {
+                      val row = new SpecificInternalRow(st)
+                      val fieldUpdater = new RowUpdater(row)
+                      val i = GenericData.get().resolveUnion(nonNullAvroType, 
value)
+                      fieldWriters(i)(fieldUpdater, i, value)
+                      updater.set(ordinal, row)
+                    }
+
+                  case _ => throw new 
IncompatibleSchemaException(incompatibleMsg)
+                }
+            }
+          }
+        } else {
+          (updater, ordinal, _) => updater.setNullAt(ordinal)
+        }
+
+      case _ => throw new IncompatibleSchemaException(incompatibleMsg)
+    }
+  }
+
+  // TODO: move the following method in Decimal object on creating Decimal 
from BigDecimal?
+  private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): 
Decimal = {
+    if (precision <= Decimal.MAX_LONG_DIGITS) {
+      // Constructs a `Decimal` with an unscaled `Long` value if possible.
+      Decimal(decimal.unscaledValue().longValue(), precision, scale)
+    } else {
+      // Otherwise, resorts to an unscaled `BigInteger` instead.
+      Decimal(decimal, precision, scale)
+    }
+  }
+
+  private def getRecordWriter(avroType: Schema,
+                              catalystType: StructType,
+                              avroPath: Seq[String],
+                              catalystPath: Seq[String],
+                              applyFilters: Int => Boolean): 
(CatalystDataUpdater, GenericRecord) => Boolean = {
+    val validFieldIndexes = ArrayBuffer.empty[Int]
+    val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit]
+
+    val avroSchemaHelper =
+      new AvroUtils.AvroSchemaHelper(avroType, avroPath, positionalFieldMatch)
+    val length = catalystType.length
+    var i = 0
+    while (i < length) {
+      val catalystField = catalystType.fields(i)
+      avroSchemaHelper.getAvroField(catalystField.name, i) match {
+        case Some(avroField) =>
+          validFieldIndexes += avroField.pos()
+
+          val baseWriter = newWriter(avroField.schema(), 
catalystField.dataType,
+            avroPath :+ avroField.name, catalystPath :+ catalystField.name)
+          val ordinal = i
+          val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => 
{
+            if (value == null) {
+              fieldUpdater.setNullAt(ordinal)
+            } else {
+              baseWriter(fieldUpdater, ordinal, value)
+            }
+          }
+          fieldWriters += fieldWriter
+        case None if !catalystField.nullable =>
+          val fieldDescription =
+            toFieldDescription(catalystPath :+ catalystField.name, i, 
positionalFieldMatch)
+          throw new IncompatibleSchemaException(
+            s"Cannot find non-nullable $fieldDescription in Avro schema.")
+        case _ => // nothing to do
+      }
+      i += 1
+    }
+
+    (fieldUpdater, record) => {
+      var i = 0
+      var skipRow = false
+      while (i < validFieldIndexes.length && !skipRow) {
+        fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i)))
+        skipRow = applyFilters(i)
+        i += 1
+      }
+      skipRow
+    }
+  }
+
+  private def createArrayData(elementType: DataType, length: Int): ArrayData = 
elementType match {
+    case BooleanType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Boolean](length))
+    case ByteType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Byte](length))
+    case ShortType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Short](length))
+    case IntegerType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Int](length))
+    case LongType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Long](length))
+    case FloatType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Float](length))
+    case DoubleType => UnsafeArrayData.fromPrimitiveArray(new 
Array[Double](length))
+    case _ => new GenericArrayData(new Array[Any](length))
+  }
+
+  /**
+   * A base interface for updating values inside catalyst data structure like 
`InternalRow` and
+   * `ArrayData`.
+   */
+  sealed trait CatalystDataUpdater {
+    def set(ordinal: Int, value: Any): Unit
+
+    def setNullAt(ordinal: Int): Unit = set(ordinal, null)
+
+    def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value)
+
+    def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value)
+
+    def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value)
+
+    def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value)
+
+    def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
+
+    def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
+
+    def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
+
+    def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
+  }
+
+  final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
+    override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, 
value)
+
+    override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal)
+
+    override def setBoolean(ordinal: Int, value: Boolean): Unit = 
row.setBoolean(ordinal, value)
+
+    override def setByte(ordinal: Int, value: Byte): Unit = 
row.setByte(ordinal, value)
+
+    override def setShort(ordinal: Int, value: Short): Unit = 
row.setShort(ordinal, value)
+
+    override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, 
value)
+
+    override def setLong(ordinal: Int, value: Long): Unit = 
row.setLong(ordinal, value)
+
+    override def setDouble(ordinal: Int, value: Double): Unit = 
row.setDouble(ordinal, value)
+
+    override def setFloat(ordinal: Int, value: Float): Unit = 
row.setFloat(ordinal, value)
+
+    override def setDecimal(ordinal: Int, value: Decimal): Unit =
+      row.setDecimal(ordinal, value, value.precision)
+  }
+
+  final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
+    override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, 
value)
+
+    override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal)
+
+    override def setBoolean(ordinal: Int, value: Boolean): Unit = 
array.setBoolean(ordinal, value)
+
+    override def setByte(ordinal: Int, value: Byte): Unit = 
array.setByte(ordinal, value)
+
+    override def setShort(ordinal: Int, value: Short): Unit = 
array.setShort(ordinal, value)
+
+    override def setInt(ordinal: Int, value: Int): Unit = 
array.setInt(ordinal, value)
+
+    override def setLong(ordinal: Int, value: Long): Unit = 
array.setLong(ordinal, value)
+
+    override def setDouble(ordinal: Int, value: Double): Unit = 
array.setDouble(ordinal, value)
+
+    override def setFloat(ordinal: Int, value: Float): Unit = 
array.setFloat(ordinal, value)
+
+    override def setDecimal(ordinal: Int, value: Decimal): Unit = 
array.update(ordinal, value)
+  }
+}
+
+object AvroDeserializer {
+
+  // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making 
[[AvroDeserializer]] implementation

Review comment:
       @vinothchandar this is the diff against Spark

##########
File path: 
hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
##########
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.avro
+
+import java.nio.ByteBuffer
+import scala.collection.JavaConverters._
+import org.apache.avro.Conversions.DecimalConversion
+import org.apache.avro.LogicalTypes
+import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
+import org.apache.avro.Schema
+import org.apache.avro.Schema.Type
+import org.apache.avro.Schema.Type._
+import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
+import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.util.Utf8
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, 
createTimestampRebaseFuncInWrite}
+import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, 
SpecificInternalRow}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime}
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
+import org.apache.spark.sql.types._
+
+/**
+ * A serializer to serialize data in catalyst format to data in avro format.
+ *
+ * NOTE: This code is borrowed from Spark 3.2.1
+ * This code is borrowed, so that we can better control compatibility w/in 
Spark minor
+ * branches (3.2.x, 3.1.x, etc)
+ *
+ * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
+ */
+private[sql] class AvroSerializer(rootCatalystType: DataType,
+                                  rootAvroType: Schema,
+                                  nullable: Boolean,
+                                  positionalFieldMatch: Boolean,
+                                  datetimeRebaseMode: 
LegacyBehaviorPolicy.Value) extends Logging {
+
+  def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: 
Boolean) = {
+    this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = 
false,
+      
LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)))
+  }
+
+  def serialize(catalystData: Any): Any = {
+    converter.apply(catalystData)
+  }
+
+  private val dateRebaseFunc = createDateRebaseFuncInWrite(
+    datetimeRebaseMode, "Avro")
+
+  private val timestampRebaseFunc = createTimestampRebaseFuncInWrite(
+    datetimeRebaseMode, "Avro")
+
+  private val converter: Any => Any = {
+    val actualAvroType = resolveNullableType(rootAvroType, nullable)
+    val baseConverter = try {
+      rootCatalystType match {
+        case st: StructType =>
+          newStructConverter(st, actualAvroType, Nil, Nil).asInstanceOf[Any => 
Any]
+        case _ =>
+          val tmpRow = new SpecificInternalRow(Seq(rootCatalystType))
+          val converter = newConverter(rootCatalystType, actualAvroType, Nil, 
Nil)
+          (data: Any) =>
+            tmpRow.update(0, data)
+            converter.apply(tmpRow, 0)
+      }
+    } catch {
+      case ise: IncompatibleSchemaException => throw new 
IncompatibleSchemaException(
+        s"Cannot convert SQL type ${rootCatalystType.sql} to Avro type 
$rootAvroType.", ise)
+    }
+    if (nullable) {
+      (data: Any) =>
+        if (data == null) {
+          null
+        } else {
+          baseConverter.apply(data)
+        }
+    } else {
+      baseConverter
+    }
+  }
+
+  private type Converter = (SpecializedGetters, Int) => Any
+
+  private lazy val decimalConversions = new DecimalConversion()
+
+  private def newConverter(catalystType: DataType,
+                           avroType: Schema,
+                           catalystPath: Seq[String],
+                           avroPath: Seq[String]): Converter = {
+    val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
+      s"to Avro ${toFieldStr(avroPath)} because "
+    (catalystType, avroType.getType) match {
+      case (NullType, NULL) =>
+        (getter, ordinal) => null
+      case (BooleanType, BOOLEAN) =>
+        (getter, ordinal) => getter.getBoolean(ordinal)
+      case (ByteType, INT) =>
+        (getter, ordinal) => getter.getByte(ordinal).toInt
+      case (ShortType, INT) =>
+        (getter, ordinal) => getter.getShort(ordinal).toInt
+      case (IntegerType, INT) =>
+        (getter, ordinal) => getter.getInt(ordinal)
+      case (LongType, LONG) =>
+        (getter, ordinal) => getter.getLong(ordinal)
+      case (FloatType, FLOAT) =>
+        (getter, ordinal) => getter.getFloat(ordinal)
+      case (DoubleType, DOUBLE) =>
+        (getter, ordinal) => getter.getDouble(ordinal)
+      case (d: DecimalType, FIXED)
+        if avroType.getLogicalType == LogicalTypes.decimal(d.precision, 
d.scale) =>
+        (getter, ordinal) =>
+          val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+          decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
+            LogicalTypes.decimal(d.precision, d.scale))
+
+      case (d: DecimalType, BYTES)
+        if avroType.getLogicalType == LogicalTypes.decimal(d.precision, 
d.scale) =>
+        (getter, ordinal) =>
+          val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+          decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
+            LogicalTypes.decimal(d.precision, d.scale))
+
+      case (StringType, ENUM) =>
+        val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
+        (getter, ordinal) =>
+          val data = getter.getUTF8String(ordinal).toString
+          if (!enumSymbols.contains(data)) {
+            throw new IncompatibleSchemaException(errorPrefix +
+              s""""$data" cannot be written since it's not defined in enum """ 
+
+              enumSymbols.mkString("\"", "\", \"", "\""))
+          }
+          new EnumSymbol(avroType, data)
+
+      case (StringType, STRING) =>
+        (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+
+      case (BinaryType, FIXED) =>
+        val size = avroType.getFixedSize
+        (getter, ordinal) =>
+          val data: Array[Byte] = getter.getBinary(ordinal)
+          if (data.length != size) {
+            def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else 
"byte"}"
+
+            throw new IncompatibleSchemaException(errorPrefix + 
len2str(data.length) +
+              " of binary data cannot be written into FIXED type with size of 
" + len2str(size))
+          }
+          new Fixed(avroType, data)
+
+      case (BinaryType, BYTES) =>
+        (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+
+      case (DateType, INT) =>
+        (getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal))
+
+      case (TimestampType, LONG) => avroType.getLogicalType match {
+        // For backward compatibility, if the Avro type is Long and it is not 
logical type
+        // (the `null` case), output the timestamp value as with millisecond 
precision.
+        case null | _: TimestampMillis => (getter, ordinal) =>
+          
DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal)))
+        case _: TimestampMicros => (getter, ordinal) =>
+          timestampRebaseFunc(getter.getLong(ordinal))
+        case other => throw new IncompatibleSchemaException(errorPrefix +
+          s"SQL type ${TimestampType.sql} cannot be converted to Avro logical 
type $other")
+      }
+
+      case (ArrayType(et, containsNull), ARRAY) =>
+        val elementConverter = newConverter(
+          et, resolveNullableType(avroType.getElementType, containsNull),
+          catalystPath :+ "element", avroPath :+ "element")
+        (getter, ordinal) => {
+          val arrayData = getter.getArray(ordinal)
+          val len = arrayData.numElements()
+          val result = new Array[Any](len)
+          var i = 0
+          while (i < len) {
+            if (containsNull && arrayData.isNullAt(i)) {
+              result(i) = null
+            } else {
+              result(i) = elementConverter(arrayData, i)
+            }
+            i += 1
+          }
+          // avro writer is expecting a Java Collection, so we convert it into
+          // `ArrayList` backed by the specified array without data copying.
+          java.util.Arrays.asList(result: _*)
+        }
+
+      case (st: StructType, RECORD) =>
+        val structConverter = newStructConverter(st, avroType, catalystPath, 
avroPath)
+        val numFields = st.length
+        (getter, ordinal) => structConverter(getter.getStruct(ordinal, 
numFields))
+
+      case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
+        val valueConverter = newConverter(
+          vt, resolveNullableType(avroType.getValueType, valueContainsNull),
+          catalystPath :+ "value", avroPath :+ "value")
+        (getter, ordinal) =>
+          val mapData = getter.getMap(ordinal)
+          val len = mapData.numElements()
+          val result = new java.util.HashMap[String, Any](len)
+          val keyArray = mapData.keyArray()
+          val valueArray = mapData.valueArray()
+          var i = 0
+          while (i < len) {
+            val key = keyArray.getUTF8String(i).toString
+            if (valueContainsNull && valueArray.isNullAt(i)) {
+              result.put(key, null)
+            } else {
+              result.put(key, valueConverter(valueArray, i))
+            }
+            i += 1
+          }
+          result
+
+      case _ =>
+        throw new IncompatibleSchemaException(errorPrefix +
+          s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = 
$avroType)")
+    }
+  }
+
+  private def newStructConverter(catalystStruct: StructType,
+                                 avroStruct: Schema,
+                                 catalystPath: Seq[String],
+                                 avroPath: Seq[String]): InternalRow => Record 
= {
+
+    val avroPathStr = toFieldStr(avroPath)
+    if (avroStruct.getType != RECORD) {
+      throw new IncompatibleSchemaException(s"$avroPathStr was not a RECORD")
+    }
+    val avroFields = avroStruct.getFields.asScala
+    if (avroFields.size != catalystStruct.length) {
+      throw new IncompatibleSchemaException(
+        s"Avro $avroPathStr schema length (${avroFields.size}) doesn't match " 
+
+          s"SQL ${toFieldStr(catalystPath)} schema length 
(${catalystStruct.length})")
+    }
+    val avroSchemaHelper =
+      new AvroUtils.AvroSchemaHelper(avroStruct, avroPath, 
positionalFieldMatch)
+
+    val (avroIndices: Array[Int], fieldConverters: Array[Converter]) =
+      catalystStruct.zipWithIndex.map { case (catalystField, catalystPos) =>
+        val avroField = avroSchemaHelper.getAvroField(catalystField.name, 
catalystPos) match {
+          case Some(f) => f
+          case None =>
+            val fieldDescription = toFieldDescription(
+              catalystPath :+ catalystField.name, catalystPos, 
positionalFieldMatch)
+            throw new IncompatibleSchemaException(
+              s"Cannot find $fieldDescription in Avro schema at $avroPathStr")
+        }
+        val converter = newConverter(catalystField.dataType,
+          resolveNullableType(avroField.schema(), catalystField.nullable),
+          catalystPath :+ catalystField.name, avroPath :+ avroField.name)
+        (avroField.pos(), converter)
+      }.toArray.unzip
+
+    val numFields = catalystStruct.length
+    row: InternalRow =>
+      val result = new Record(avroStruct)
+      var i = 0
+      while (i < numFields) {
+        if (row.isNullAt(i)) {
+          result.put(avroIndices(i), null)
+        } else {
+          result.put(avroIndices(i), fieldConverters(i).apply(row, i))
+        }
+        i += 1
+      }
+      result
+  }
+
+  /**
+   * Resolve a possibly nullable Avro Type.
+   *
+   * An Avro type is nullable when it is a [[UNION]] of two types: one null 
type and another
+   * non-null type. This method will check the nullability of the input Avro 
type and return the
+   * non-null type within when it is nullable. Otherwise it will return the 
input Avro type
+   * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the 
input Avro type is an
+   * unsupported nullable type.
+   *
+   * It will also log a warning message if the nullability for Avro and 
catalyst types are
+   * different.
+   */
+  private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema 
= {
+    val (avroNullable, resolvedAvroType) = resolveAvroType(avroType)
+    warnNullabilityDifference(avroNullable, nullable)
+    resolvedAvroType
+  }
+
+  /**
+   * Check the nullability of the input Avro type and resolve it when it is 
nullable. The first
+   * return value is a [[Boolean]] indicating if the input Avro type is 
nullable. The second
+   * return value is the possibly resolved type.
+   */
+  private def resolveAvroType(avroType: Schema): (Boolean, Schema) = {
+    if (avroType.getType == Type.UNION) {
+      val fields = avroType.getTypes.asScala
+      val actualType = fields.filter(_.getType != Type.NULL)
+      if (fields.length != 2 || actualType.length != 1) {
+        throw new UnsupportedAvroTypeException(
+          s"Unsupported Avro UNION type $avroType: Only UNION of a null type 
and a non-null " +
+            "type is supported")
+      }
+      (true, actualType.head)
+    } else {
+      (false, avroType)
+    }
+  }
+
+  /**
+   * log a warning message if the nullability for Avro and catalyst types are 
different.
+   */
+  private def warnNullabilityDifference(avroNullable: Boolean, 
catalystNullable: Boolean): Unit = {
+    if (avroNullable && !catalystNullable) {
+      logWarning("Writing Avro files with nullable Avro schema and 
non-nullable catalyst schema.")
+    }
+    if (!avroNullable && catalystNullable) {
+      logWarning("Writing Avro files with non-nullable Avro schema and 
nullable catalyst " +
+        "schema will throw runtime exception if there is a record with null 
value.")
+    }
+  }
+}
+
+object AvroSerializer {
+
+  // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making 
[[AvroSerializer]] implementation

Review comment:
       @vinothchandar this is the diff against Spark




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to