Repository: spark Updated Branches: refs/heads/master c220cc42a -> ab197308a
[SPARK-25104][SQL] Avro: Validate user specified output schema ## What changes were proposed in this pull request? With code changes in https://github.com/apache/spark/pull/21847 , Spark can write out to Avro file as per user provided output schema. To make it more robust and user friendly, we should validate the Avro schema before tasks launched. Also we should support output logical decimal type as BYTES (By default we output as FIXED) ## How was this patch tested? Unit test Closes #22094 from gengliangwang/AvroSerializerMatch. Authored-by: Gengliang Wang <gengliang.w...@databricks.com> Signed-off-by: DB Tsai <d_t...@apple.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab197308 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab197308 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab197308 Branch: refs/heads/master Commit: ab197308a79c74f0a4205a8f60438811b5e0b991 Parents: c220cc4 Author: Gengliang Wang <gengliang.w...@databricks.com> Authored: Tue Aug 14 04:43:14 2018 +0000 Committer: DB Tsai <d_t...@apple.com> Committed: Tue Aug 14 04:43:14 2018 +0000 ---------------------------------------------------------------------- .../apache/spark/sql/avro/AvroSerializer.scala | 108 +++++++++++-------- .../spark/sql/avro/AvroLogicalTypeSuite.scala | 40 +++++++ .../org/apache/spark/sql/avro/AvroSuite.scala | 57 ++++++++++ 3 files changed, 158 insertions(+), 47 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ab197308/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 3a9544c..f551c83 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -26,6 +26,7 @@ import org.apache.avro.Conversions.DecimalConversion import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema import org.apache.avro.Schema.Type +import org.apache.avro.Schema.Type._ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} import org.apache.avro.generic.GenericData.Record import org.apache.avro.util.Utf8 @@ -72,62 +73,70 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: private lazy val decimalConversions = new DecimalConversion() private def newConverter(catalystType: DataType, avroType: Schema): Converter = { - catalystType match { - case NullType => + (catalystType, avroType.getType) match { + case (NullType, NULL) => (getter, ordinal) => null - case BooleanType => + case (BooleanType, BOOLEAN) => (getter, ordinal) => getter.getBoolean(ordinal) - case ByteType => + case (ByteType, INT) => (getter, ordinal) => getter.getByte(ordinal).toInt - case ShortType => + case (ShortType, INT) => (getter, ordinal) => getter.getShort(ordinal).toInt - case IntegerType => + case (IntegerType, INT) => (getter, ordinal) => getter.getInt(ordinal) - case LongType => + case (LongType, LONG) => (getter, ordinal) => getter.getLong(ordinal) - case FloatType => + case (FloatType, FLOAT) => (getter, ordinal) => getter.getFloat(ordinal) - case DoubleType => + case (DoubleType, DOUBLE) => (getter, ordinal) => getter.getDouble(ordinal) - case d: DecimalType => + case (d: DecimalType, FIXED) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => (getter, ordinal) => val decimal = getter.getDecimal(ordinal, d.precision, d.scale) decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, LogicalTypes.decimal(d.precision, d.scale)) - case StringType => avroType.getType match { - case Type.ENUM => - import scala.collection.JavaConverters._ - val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet - (getter, ordinal) => - val data = getter.getUTF8String(ordinal).toString - if (!enumSymbols.contains(data)) { - throw new IncompatibleSchemaException( - "Cannot write \"" + data + "\" since it's not defined in enum \"" + - enumSymbols.mkString("\", \"") + "\"") - } - new EnumSymbol(avroType, data) - case _ => - (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) - } - case BinaryType => avroType.getType match { - case Type.FIXED => - val size = avroType.getFixedSize() - (getter, ordinal) => - val data: Array[Byte] = getter.getBinary(ordinal) - if (data.length != size) { - throw new IncompatibleSchemaException( - s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + - "binary data into FIXED Type with size of " + - s"$size ${if (size > 1) "bytes" else "byte"}") - } - new Fixed(avroType, data) - case _ => - (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) - } - case DateType => + case (d: DecimalType, BYTES) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (StringType, ENUM) => + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + "Cannot write \"" + data + "\" since it's not defined in enum \"" + + enumSymbols.mkString("\", \"") + "\"") + } + new EnumSymbol(avroType, data) + + case (StringType, STRING) => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + + case (BinaryType, FIXED) => + val size = avroType.getFixedSize() + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + throw new IncompatibleSchemaException( + s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + + "binary data into FIXED Type with size of " + + s"$size ${if (size > 1) "bytes" else "byte"}") + } + new Fixed(avroType, data) + + case (BinaryType, BYTES) => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + + case (DateType, INT) => (getter, ordinal) => getter.getInt(ordinal) - case TimestampType => avroType.getLogicalType match { + + case (TimestampType, LONG) => avroType.getLogicalType match { case _: TimestampMillis => (getter, ordinal) => getter.getLong(ordinal) / 1000 case _: TimestampMicros => (getter, ordinal) => getter.getLong(ordinal) // For backward compatibility, if the Avro type is Long and it is not logical type, @@ -137,7 +146,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: s"Cannot convert Catalyst Timestamp type to Avro logical type ${other}") } - case ArrayType(et, containsNull) => + case (ArrayType(et, containsNull), ARRAY) => val elementConverter = newConverter( et, resolveNullableType(avroType.getElementType, containsNull)) (getter, ordinal) => { @@ -158,12 +167,12 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: java.util.Arrays.asList(result: _*) } - case st: StructType => + case (st: StructType, RECORD) => val structConverter = newStructConverter(st, avroType) val numFields = st.length (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) - case MapType(kt, vt, valueContainsNull) if kt == StringType => + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => val valueConverter = newConverter( vt, resolveNullableType(avroType.getValueType, valueContainsNull)) (getter, ordinal) => @@ -185,12 +194,17 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: result case other => - throw new IncompatibleSchemaException(s"Unexpected type: $other") + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystType to " + + s"Avro type $avroType.") } } private def newStructConverter( catalystStruct: StructType, avroStruct: Schema): InternalRow => Record = { + if (avroStruct.getType != RECORD) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + + s"Avro type $avroStruct.") + } val avroFields = avroStruct.getFields assert(avroFields.size() == catalystStruct.length) val fieldConverters = catalystStruct.zip(avroFields.asScala).map { @@ -212,7 +226,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: } private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { - if (nullable) { + if (nullable && avroType.getType != NULL) { // avro uses union to represent nullable type. val fields = avroType.getTypes.asScala assert(fields.length == 2) http://git-wip-us.apache.org/repos/asf/spark/blob/ab197308/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 24d8c53..ca7eef2 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -267,6 +267,46 @@ class AvroLogicalTypeSuite extends QueryTest with SharedSQLContext with SQLTestU } } + test("Logical type: write Decimal with BYTES type") { + val specifiedSchema = """ + { + "type" : "record", + "name" : "topLevelRecord", + "namespace" : "topLevelRecord", + "fields" : [ { + "name" : "bytes", + "type" : [ { + "type" : "bytes", + "namespace" : "topLevelRecord.bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + }, { + "name" : "fixed", + "type" : [ { + "type" : "bytes", + "logicalType" : "decimal", + "precision" : 4, + "scale" : 2 + }, "null" ] + } ] + } + """ + withTempDir { dir => + val (avroSchema, avroFile) = decimalSchemaAndFile(dir.getAbsolutePath) + assert(specifiedSchema != avroSchema) + val expected = + decimalInputData.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } + val df = spark.read.format("avro").load(avroFile) + + withTempPath { path => + df.write.format("avro").option("avroSchema", specifiedSchema).save(path.toString) + checkAnswer(spark.read.format("avro").load(path.toString), expected) + } + } + } + test("Logical type: Decimal with too large precision") { withTempDir { dir => val schema = new Schema.Parser().parse("""{ http://git-wip-us.apache.org/repos/asf/spark/blob/ab197308/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index b07b146..c4f4d8e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.avro.Schema import org.apache.avro.Schema.{Field, Type} +import org.apache.avro.Schema.Type._ import org.apache.avro.file.{DataFileReader, DataFileWriter} import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} @@ -850,6 +851,62 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("throw exception if unable to write with user provided Avro schema") { + val input: Seq[(DataType, Schema.Type)] = Seq( + (NullType, NULL), + (BooleanType, BOOLEAN), + (ByteType, INT), + (ShortType, INT), + (IntegerType, INT), + (LongType, LONG), + (FloatType, FLOAT), + (DoubleType, DOUBLE), + (BinaryType, BYTES), + (DateType, INT), + (TimestampType, LONG), + (DecimalType(4, 2), BYTES) + ) + def assertException(f: () => AvroSerializer) { + val message = intercept[org.apache.spark.sql.avro.IncompatibleSchemaException] { + f() + }.getMessage + assert(message.contains("Cannot convert Catalyst type")) + } + + def resolveNullable(schema: Schema, nullable: Boolean): Schema = { + if (nullable && schema.getType != NULL) { + Schema.createUnion(schema, Schema.create(NULL)) + } else { + schema + } + } + for { + i <- input + j <- input + nullable <- Seq(true, false) + } if (i._2 != j._2) { + val avroType = resolveNullable(Schema.create(j._2), nullable) + val avroArrayType = resolveNullable(Schema.createArray(avroType), nullable) + val avroMapType = resolveNullable(Schema.createMap(avroType), nullable) + val name = "foo" + val avroField = new Field(name, avroType, "", null) + val recordSchema = Schema.createRecord("name", "doc", "space", true, Seq(avroField).asJava) + val avroRecordType = resolveNullable(recordSchema, nullable) + + val catalystType = i._1 + val catalystArrayType = ArrayType(catalystType, nullable) + val catalystMapType = MapType(StringType, catalystType, nullable) + val catalystStructType = StructType(Seq(StructField(name, catalystType, nullable))) + + for { + avro <- Seq(avroType, avroArrayType, avroMapType, avroRecordType) + catalyst <- Seq(catalystType, catalystArrayType, catalystMapType, catalystStructType) + } { + assertException(() => new AvroSerializer(catalyst, avro, nullable)) + } + } + } + test("reading from invalid path throws exception") { // Directory given has no avro files --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org