Repository: spark
Updated Branches:
  refs/heads/master c220cc42a -> ab197308a


[SPARK-25104][SQL] Avro: Validate user specified output schema

## What changes were proposed in this pull request?

With code changes in https://github.com/apache/spark/pull/21847 , Spark can 
write out to Avro file as per user provided output schema.

To make it more robust and user friendly, we should validate the Avro schema 
before tasks launched.

Also we should support output logical decimal type as BYTES (By default we 
output as FIXED)

## How was this patch tested?

Unit test

Closes #22094 from gengliangwang/AvroSerializerMatch.

Authored-by: Gengliang Wang <gengliang.w...@databricks.com>
Signed-off-by: DB Tsai <d_t...@apple.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab197308
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab197308
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab197308

Branch: refs/heads/master
Commit: ab197308a79c74f0a4205a8f60438811b5e0b991
Parents: c220cc4
Author: Gengliang Wang <gengliang.w...@databricks.com>
Authored: Tue Aug 14 04:43:14 2018 +0000
Committer: DB Tsai <d_t...@apple.com>
Committed: Tue Aug 14 04:43:14 2018 +0000

----------------------------------------------------------------------
 .../apache/spark/sql/avro/AvroSerializer.scala  | 108 +++++++++++--------
 .../spark/sql/avro/AvroLogicalTypeSuite.scala   |  40 +++++++
 .../org/apache/spark/sql/avro/AvroSuite.scala   |  57 ++++++++++
 3 files changed, 158 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ab197308/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 3a9544c..f551c83 100644
--- 
a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ 
b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -26,6 +26,7 @@ import org.apache.avro.Conversions.DecimalConversion
 import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
 import org.apache.avro.Schema
 import org.apache.avro.Schema.Type
+import org.apache.avro.Schema.Type._
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
 import org.apache.avro.generic.GenericData.Record
 import org.apache.avro.util.Utf8
@@ -72,62 +73,70 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
   private lazy val decimalConversions = new DecimalConversion()
 
   private def newConverter(catalystType: DataType, avroType: Schema): 
Converter = {
-    catalystType match {
-      case NullType =>
+    (catalystType, avroType.getType) match {
+      case (NullType, NULL) =>
         (getter, ordinal) => null
-      case BooleanType =>
+      case (BooleanType, BOOLEAN) =>
         (getter, ordinal) => getter.getBoolean(ordinal)
-      case ByteType =>
+      case (ByteType, INT) =>
         (getter, ordinal) => getter.getByte(ordinal).toInt
-      case ShortType =>
+      case (ShortType, INT) =>
         (getter, ordinal) => getter.getShort(ordinal).toInt
-      case IntegerType =>
+      case (IntegerType, INT) =>
         (getter, ordinal) => getter.getInt(ordinal)
-      case LongType =>
+      case (LongType, LONG) =>
         (getter, ordinal) => getter.getLong(ordinal)
-      case FloatType =>
+      case (FloatType, FLOAT) =>
         (getter, ordinal) => getter.getFloat(ordinal)
-      case DoubleType =>
+      case (DoubleType, DOUBLE) =>
         (getter, ordinal) => getter.getDouble(ordinal)
-      case d: DecimalType =>
+      case (d: DecimalType, FIXED)
+        if avroType.getLogicalType == LogicalTypes.decimal(d.precision, 
d.scale) =>
         (getter, ordinal) =>
           val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
           decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
             LogicalTypes.decimal(d.precision, d.scale))
 
-      case StringType => avroType.getType match {
-        case Type.ENUM =>
-          import scala.collection.JavaConverters._
-          val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
-          (getter, ordinal) =>
-            val data = getter.getUTF8String(ordinal).toString
-            if (!enumSymbols.contains(data)) {
-              throw new IncompatibleSchemaException(
-                "Cannot write \"" + data + "\" since it's not defined in enum 
\"" +
-                  enumSymbols.mkString("\", \"") + "\"")
-            }
-            new EnumSymbol(avroType, data)
-        case _ =>
-          (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
-      }
-      case BinaryType => avroType.getType match {
-        case Type.FIXED =>
-          val size = avroType.getFixedSize()
-          (getter, ordinal) =>
-            val data: Array[Byte] = getter.getBinary(ordinal)
-            if (data.length != size) {
-              throw new IncompatibleSchemaException(
-                s"Cannot write ${data.length} ${if (data.length > 1) "bytes" 
else "byte"} of " +
-                  "binary data into FIXED Type with size of " +
-                  s"$size ${if (size > 1) "bytes" else "byte"}")
-            }
-            new Fixed(avroType, data)
-        case _ =>
-          (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
-      }
-      case DateType =>
+      case (d: DecimalType, BYTES)
+        if avroType.getLogicalType == LogicalTypes.decimal(d.precision, 
d.scale) =>
+        (getter, ordinal) =>
+          val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
+          decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType,
+            LogicalTypes.decimal(d.precision, d.scale))
+
+      case (StringType, ENUM) =>
+        val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet
+        (getter, ordinal) =>
+          val data = getter.getUTF8String(ordinal).toString
+          if (!enumSymbols.contains(data)) {
+            throw new IncompatibleSchemaException(
+              "Cannot write \"" + data + "\" since it's not defined in enum 
\"" +
+                enumSymbols.mkString("\", \"") + "\"")
+          }
+          new EnumSymbol(avroType, data)
+
+      case (StringType, STRING) =>
+        (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes)
+
+      case (BinaryType, FIXED) =>
+        val size = avroType.getFixedSize()
+        (getter, ordinal) =>
+          val data: Array[Byte] = getter.getBinary(ordinal)
+          if (data.length != size) {
+            throw new IncompatibleSchemaException(
+              s"Cannot write ${data.length} ${if (data.length > 1) "bytes" 
else "byte"} of " +
+                "binary data into FIXED Type with size of " +
+                s"$size ${if (size > 1) "bytes" else "byte"}")
+          }
+          new Fixed(avroType, data)
+
+      case (BinaryType, BYTES) =>
+        (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal))
+
+      case (DateType, INT) =>
         (getter, ordinal) => getter.getInt(ordinal)
-      case TimestampType => avroType.getLogicalType match {
+
+      case (TimestampType, LONG) => avroType.getLogicalType match {
           case _: TimestampMillis => (getter, ordinal) => 
getter.getLong(ordinal) / 1000
           case _: TimestampMicros => (getter, ordinal) => 
getter.getLong(ordinal)
           // For backward compatibility, if the Avro type is Long and it is 
not logical type,
@@ -137,7 +146,7 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
             s"Cannot convert Catalyst Timestamp type to Avro logical type 
${other}")
         }
 
-      case ArrayType(et, containsNull) =>
+      case (ArrayType(et, containsNull), ARRAY) =>
         val elementConverter = newConverter(
           et, resolveNullableType(avroType.getElementType, containsNull))
         (getter, ordinal) => {
@@ -158,12 +167,12 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
           java.util.Arrays.asList(result: _*)
         }
 
-      case st: StructType =>
+      case (st: StructType, RECORD) =>
         val structConverter = newStructConverter(st, avroType)
         val numFields = st.length
         (getter, ordinal) => structConverter(getter.getStruct(ordinal, 
numFields))
 
-      case MapType(kt, vt, valueContainsNull) if kt == StringType =>
+      case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
         val valueConverter = newConverter(
           vt, resolveNullableType(avroType.getValueType, valueContainsNull))
         (getter, ordinal) =>
@@ -185,12 +194,17 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
           result
 
       case other =>
-        throw new IncompatibleSchemaException(s"Unexpected type: $other")
+        throw new IncompatibleSchemaException(s"Cannot convert Catalyst type 
$catalystType to " +
+          s"Avro type $avroType.")
     }
   }
 
   private def newStructConverter(
       catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = 
{
+    if (avroStruct.getType != RECORD) {
+      throw new IncompatibleSchemaException(s"Cannot convert Catalyst type 
$catalystStruct to " +
+        s"Avro type $avroStruct.")
+    }
     val avroFields = avroStruct.getFields
     assert(avroFields.size() == catalystStruct.length)
     val fieldConverters = catalystStruct.zip(avroFields.asScala).map {
@@ -212,7 +226,7 @@ class AvroSerializer(rootCatalystType: DataType, 
rootAvroType: Schema, nullable:
   }
 
   private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema 
= {
-    if (nullable) {
+    if (nullable && avroType.getType != NULL) {
       // avro uses union to represent nullable type.
       val fields = avroType.getTypes.asScala
       assert(fields.length == 2)

http://git-wip-us.apache.org/repos/asf/spark/blob/ab197308/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
 
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
index 24d8c53..ca7eef2 100644
--- 
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
+++ 
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala
@@ -267,6 +267,46 @@ class AvroLogicalTypeSuite extends QueryTest with 
SharedSQLContext with SQLTestU
     }
   }
 
+  test("Logical type: write Decimal with BYTES type") {
+    val specifiedSchema = """
+      {
+        "type" : "record",
+        "name" : "topLevelRecord",
+        "namespace" : "topLevelRecord",
+        "fields" : [ {
+          "name" : "bytes",
+          "type" : [ {
+            "type" : "bytes",
+            "namespace" : "topLevelRecord.bytes",
+            "logicalType" : "decimal",
+            "precision" : 4,
+            "scale" : 2
+          }, "null" ]
+        }, {
+          "name" : "fixed",
+          "type" : [ {
+            "type" : "bytes",
+            "logicalType" : "decimal",
+            "precision" : 4,
+            "scale" : 2
+          }, "null" ]
+        } ]
+      }
+    """
+    withTempDir { dir =>
+      val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath)
+      assert(specifiedSchema != avroSchema)
+      val expected =
+        decimalInputData.map { x => Row(new java.math.BigDecimal(x), new 
java.math.BigDecimal(x)) }
+      val df = spark.read.format("avro").load(avroFile)
+
+      withTempPath { path =>
+        df.write.format("avro").option("avroSchema", 
specifiedSchema).save(path.toString)
+        checkAnswer(spark.read.format("avro").load(path.toString), expected)
+      }
+    }
+  }
+
   test("Logical type: Decimal with too large precision") {
     withTempDir { dir =>
       val schema = new Schema.Parser().parse("""{

http://git-wip-us.apache.org/repos/asf/spark/blob/ab197308/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index b07b146..c4f4d8e 100644
--- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.avro.Schema
 import org.apache.avro.Schema.{Field, Type}
+import org.apache.avro.Schema.Type._
 import org.apache.avro.file.{DataFileReader, DataFileWriter}
 import org.apache.avro.generic.{GenericData, GenericDatumReader, 
GenericDatumWriter, GenericRecord}
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
@@ -850,6 +851,62 @@ class AvroSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
     }
   }
 
+  test("throw exception if unable to write with user provided Avro schema") {
+    val input: Seq[(DataType, Schema.Type)] = Seq(
+      (NullType, NULL),
+      (BooleanType, BOOLEAN),
+      (ByteType, INT),
+      (ShortType, INT),
+      (IntegerType, INT),
+      (LongType, LONG),
+      (FloatType, FLOAT),
+      (DoubleType, DOUBLE),
+      (BinaryType, BYTES),
+      (DateType, INT),
+      (TimestampType, LONG),
+      (DecimalType(4, 2), BYTES)
+    )
+    def assertException(f: () => AvroSerializer) {
+      val message = 
intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] {
+        f()
+      }.getMessage
+      assert(message.contains("Cannot convert Catalyst type"))
+    }
+
+    def resolveNullable(schema: Schema, nullable: Boolean): Schema = {
+      if (nullable && schema.getType != NULL) {
+        Schema.createUnion(schema, Schema.create(NULL))
+      } else {
+        schema
+      }
+    }
+    for {
+      i <- input
+      j <- input
+      nullable <- Seq(true, false)
+    } if (i._2 != j._2) {
+      val avroType = resolveNullable(Schema.create(j._2), nullable)
+      val avroArrayType = resolveNullable(Schema.createArray(avroType), 
nullable)
+      val avroMapType = resolveNullable(Schema.createMap(avroType), nullable)
+      val name = "foo"
+      val avroField = new Field(name, avroType, "", null)
+      val recordSchema = Schema.createRecord("name", "doc", "space", true, 
Seq(avroField).asJava)
+      val avroRecordType = resolveNullable(recordSchema, nullable)
+
+      val catalystType = i._1
+      val catalystArrayType = ArrayType(catalystType, nullable)
+      val catalystMapType = MapType(StringType, catalystType, nullable)
+      val catalystStructType = StructType(Seq(StructField(name, catalystType, 
nullable)))
+
+      for {
+        avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType)
+        catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, 
catalystStructType)
+      } {
+        assertException(() => new AvroSerializer(catalyst, avro, nullable))
+      }
+    }
+  }
+
   test("reading from invalid path throws exception") {
 
     // Directory given has no avro files


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to