Repository: spark Updated Branches: refs/heads/master bdd27961c -> 0cea9e3cd
[SPARK-24855][SQL][EXTERNAL] Built-in AVRO support should support specified schema on write ## What changes were proposed in this pull request? Allows `avroSchema` option to be specified on write, allowing a user to specify a schema in cases where this is required. A trivial use case is reading in an avro dataset, making some small adjustment to a column or columns and writing out using the same schema. Implicit schema creation from SQL Struct results in a schema that while for the most part, is functionally similar, is not necessarily compatible. Allows `fixed` Field type to be utilized for records of specified `avroSchema` ## How was this patch tested? Unit tests in AvroSuite are extended to test this with enum and fixed types. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #21847 from lindblombr/specify_schema_on_write. Lead-authored-by: Brian Lindblom <blindb...@apple.com> Co-authored-by: DB Tsai <d_t...@apple.com> Signed-off-by: DB Tsai <d_t...@apple.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0cea9e3c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0cea9e3c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0cea9e3c Branch: refs/heads/master Commit: 0cea9e3cd0a92799bdcc0f9bc2cf96259c343a30 Parents: bdd2796 Author: Brian Lindblom <blindb...@apple.com> Authored: Fri Aug 10 03:35:29 2018 +0000 Committer: DB Tsai <d_t...@apple.com> Committed: Fri Aug 10 03:35:29 2018 +0000 ---------------------------------------------------------------------- .../apache/spark/sql/avro/AvroFileFormat.scala | 6 +- .../apache/spark/sql/avro/AvroSerializer.scala | 40 +++- .../org/apache/spark/sql/avro/AvroSuite.scala | 228 ++++++++++++++++++- 3 files changed, 257 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0cea9e3c/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 6ffcf37..6df23c9 100755 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -113,8 +113,10 @@ private[avro] class AvroFileFormat extends FileFormat options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { val parsedOptions = new AvroOptions(options, spark.sessionState.newHadoopConf()) - val outputAvroSchema = SchemaConverters.toAvroType(dataSchema, nullable = false, - parsedOptions.recordName, parsedOptions.recordNamespace, parsedOptions.outputTimestampType) + val outputAvroSchema: Schema = parsedOptions.schema + .map(new Schema.Parser().parse) + .getOrElse(SchemaConverters.toAvroType(dataSchema, nullable = false, + parsedOptions.recordName, parsedOptions.recordNamespace)) AvroJob.setOutputKeySchema(job, outputAvroSchema) http://git-wip-us.apache.org/repos/asf/spark/blob/0cea9e3c/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 9885826..216c52a 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -23,8 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis} import org.apache.avro.Schema -import org.apache.avro.Schema.Type.NULL -import org.apache.avro.generic.GenericData.Record +import org.apache.avro.Schema.Type +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} import org.apache.avro.util.Utf8 import org.apache.spark.sql.catalyst.InternalRow @@ -87,10 +87,36 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: (getter, ordinal) => getter.getDouble(ordinal) case d: DecimalType => (getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString - case StringType => - (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) - case BinaryType => - (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + case StringType => avroType.getType match { + case Type.ENUM => + import scala.collection.JavaConverters._ + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + "Cannot write \"" + data + "\" since it's not defined in enum \"" + + enumSymbols.mkString("\", \"") + "\"") + } + new EnumSymbol(avroType, data) + case _ => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + } + case BinaryType => avroType.getType match { + case Type.FIXED => + val size = avroType.getFixedSize() + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + throw new IncompatibleSchemaException( + s"Cannot write ${data.length} ${if (data.length > 1) "bytes" else "byte"} of " + + "binary data into FIXED Type with size of " + + s"$size ${if (size > 1) "bytes" else "byte"}") + } + new Fixed(avroType, data) + case _ => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + } case DateType => (getter, ordinal) => getter.getInt(ordinal) case TimestampType => avroType.getLogicalType match { @@ -182,7 +208,7 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: // avro uses union to represent nullable type. val fields = avroType.getTypes.asScala assert(fields.length == 2) - val actualType = fields.filter(_.getType != NULL) + val actualType = fields.filter(_.getType != Type.NULL) assert(actualType.length == 1) actualType.head } else { http://git-wip-us.apache.org/repos/asf/spark/blob/0cea9e3c/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 47995bb..ada9980 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -32,6 +32,7 @@ import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWri import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils +import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -100,6 +101,25 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkAnswer(newEntries, originalEntries) } + def checkAvroSchemaEquals(avroSchema: String, expectedAvroSchema: String): Unit = { + assert(new Schema.Parser().parse(avroSchema) == + new Schema.Parser().parse(expectedAvroSchema)) + } + + def getAvroSchemaStringFromFiles(filePath: String): String = { + new DataFileReader({ + val file = new File(filePath) + if (file.isFile) { + file + } else { + file.listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("avro")) + .head + } + }, new GenericDatumReader[Any]()).getSchema.toString(false) + } + test("resolve avro data source") { Seq("avro", "com.databricks.spark.avro").foreach { provider => assert(DataSource.lookupDataSource(provider, spark.sessionState.conf) === @@ -471,7 +491,6 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } """ val df = spark.read.format("avro").option("avroSchema", avroSchema).load(timestampAvro) - checkAnswer(df, expected) } @@ -773,6 +792,205 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result === Row("foo")) } + test("support user provided avro schema for writing nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": [{ "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing data not in the enum will throw an exception + val message = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing non-nullable enum type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "enum", + | "type": { "type": "enum", + | "name": "Suit", + | "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"] + | } + | }] + |} + """.stripMargin + + val dfWithNull = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row(null), Row("HEARTS"), Row("DIAMONDS"), + Row(null), Row("CLUBS"), Row("HEARTS"), Row("SPADES"))), + StructType(Seq(StructField("Suit", StringType, true)))) + + val df = spark.createDataFrame(dfWithNull.na.drop().rdd, + StructType(Seq(StructField("Suit", StringType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing nulls without using avro union type will + // throw an exception as avro uses union type to handle null. + val message1 = intercept[SparkException] { + dfWithNull.write.format("avro") + .option("avroSchema", avroSchema).save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.avro.AvroRuntimeException: Not a union:")) + + // Writing df containing data not in the enum will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row("SPADES"), Row("NOT-IN-ENUM"), Row("HEARTS"), Row("DIAMONDS"))), + StructType(Seq(StructField("Suit", StringType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write \"NOT-IN-ENUM\" since it's not defined in enum")) + } + } + + test("support user provided avro schema for writing nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": [{ "type": "fixed", + | "size": 2, + | "name": "fixed2" + | }, "null"] + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(null))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, true)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + + test("support user provided avro schema for writing non-nullable fixed type") { + withTempPath { tempDir => + val avroSchema = + """ + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [{ + | "name": "fixed2", + | "type": { "type": "fixed", + | "size": 2, + | "name": "fixed2" + | } + | }] + |} + """.stripMargin + + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168).map(_.toByte)), Row(Array(1, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + + val tempSaveDir = s"$tempDir/save/" + + df.write.format("avro").option("avroSchema", avroSchema).save(tempSaveDir) + + checkAnswer(df, spark.read.format("avro").load(tempSaveDir)) + checkAvroSchemaEquals(avroSchema, getAvroSchemaStringFromFiles(tempSaveDir)) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message1 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192, 168, 1).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message1.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 3 bytes of binary data into FIXED Type with size of 2 bytes")) + + // Writing df containing binary data that doesn't fit FIXED size will throw an exception + val message2 = intercept[SparkException] { + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(Array(192).map(_.toByte)))), + StructType(Seq(StructField("fixed2", BinaryType, false)))) + .write.format("avro").option("avroSchema", avroSchema) + .save(s"$tempDir/${UUID.randomUUID()}") + }.getCause.getMessage + assert(message2.contains("org.apache.spark.sql.avro.IncompatibleSchemaException: " + + "Cannot write 1 byte of binary data into FIXED Type with size of 2 bytes")) + } + } + test("reading from invalid path throws exception") { // Directory given has no avro files @@ -936,13 +1154,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { withTempPath { dir => val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) writeDf.write.format("avro").save(dir.toString) - val file = new File(dir.toString) - .listFiles() - .filter(_.isFile) - .filter(_.getName.endsWith("avro")) - .head - val reader = new DataFileReader(file, new GenericDatumReader[Any]()) - val schema = reader.getSchema.toString() + val schema = getAvroSchemaStringFromFiles(dir.toString) assert(schema.contains("\"namespace\":\"topLevelRecord\"")) assert(schema.contains("\"namespace\":\"topLevelRecord.data\"")) assert(schema.contains("\"namespace\":\"topLevelRecord.data.data\"")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org