Github user xuanyuanking commented on a diff in the pull request: https://github.com/apache/spark/pull/22878#discussion_r229767694 --- Diff: external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala --- @@ -1374,4 +1377,182 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { |} """.stripMargin) } + + test("generic record converts to row and back") { + val nested = + SchemaBuilder.record("simple_record").fields() + .name("nested1").`type`("int").withDefault(0) + .name("nested2").`type`("string").withDefault("string").endRecord() + val schema = SchemaBuilder.record("record").fields() + .name("boolean").`type`("boolean").withDefault(false) + .name("int").`type`("int").withDefault(0) + .name("long").`type`("long").withDefault(0L) + .name("float").`type`("float").withDefault(0.0F) + .name("double").`type`("double").withDefault(0.0) + .name("string").`type`("string").withDefault("string") + .name("bytes").`type`("bytes").withDefault(java.nio.ByteBuffer.wrap("bytes".getBytes)) + .name("nested").`type`(nested).withDefault(new GenericRecordBuilder(nested).build) + .name("enum").`type`( + SchemaBuilder.enumeration("simple_enums") + .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS")) + .withDefault("SPADES") + .name("int_array").`type`( + SchemaBuilder.array().items().`type`("int")) + .withDefault(java.util.Arrays.asList(1, 2, 3)) + .name("string_array").`type`( + SchemaBuilder.array().items().`type`("string")) + .withDefault(java.util.Arrays.asList("a", "b", "c")) + .name("record_array").`type`( + SchemaBuilder.array.items.`type`(nested)) + .withDefault(java.util.Arrays.asList( + new GenericRecordBuilder(nested).build, + new GenericRecordBuilder(nested).build)) + .name("enum_array").`type`( + SchemaBuilder.array.items.`type`( + SchemaBuilder.enumeration("simple_enums") + .symbols("SPADES", "HEARTS", "CLUBS", "DIAMONDS"))) + .withDefault(java.util.Arrays.asList("SPADES", "HEARTS", "SPADES")) + .name("fixed_array").`type`( + SchemaBuilder.array.items().`type`( + SchemaBuilder.fixed("simple_fixed").size(3))) + .withDefault(java.util.Arrays.asList("foo", "bar", "baz")) + .name("fixed").`type`(SchemaBuilder.fixed("simple_fixed").size(16)) + .withDefault("string_length_16") + .endRecord() + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + assert(record == recordFromRow) + } + + test("encoder resolves union types to rows") { + val schema = SchemaBuilder.record("record").fields() + .name("int_null_union").`type`( + SchemaBuilder.unionOf.`type`("null").and.`type`("int").endUnion) + .withDefault(null) + .name("string_null_union").`type`( + SchemaBuilder.unionOf.`type`("null").and.`type`("string").endUnion) + .withDefault(null) + .name("int_long_union").`type`( + SchemaBuilder.unionOf.`type`("int").and.`type`("long").endUnion) + .withDefault(0) + .name("float_double_union").`type`( + SchemaBuilder.unionOf.`type`("float").and.`type`("double").endUnion) + .withDefault(0.0) + .endRecord + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + val row = expressionEncoder.toRow(record) + val recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + assert(record.get(0) == recordFromRow.get(0)) + assert(record.get(1) == recordFromRow.get(1)) + assert(record.get(2) == recordFromRow.get(2)) + assert(record.get(3) == recordFromRow.get(3)) + record.put(0, 0) + record.put(1, "value") + val updatedRow = expressionEncoder.toRow(record) + val updatedRecordFromRow = expressionEncoder.resolveAndBind().fromRow(updatedRow) + assert(record.get(0) == updatedRecordFromRow.get(0)) + assert(record.get(1) == updatedRecordFromRow.get(1)) + } + + test("encoder resolves complex unions to rows") { + val nested = + SchemaBuilder.record("simple_record").fields() + .name("nested1").`type`("int").withDefault(0) + .name("nested2").`type`("string").withDefault("foo").endRecord() + val schema = SchemaBuilder.record("record").fields() + .name("int_float_string_record").`type`( + SchemaBuilder.unionOf() + .`type`("null").and() + .`type`("int").and() + .`type`("float").and() + .`type`("string").and() + .`type`(nested).endUnion() + ).withDefault(null).endRecord() + + val encoder = AvroEncoder.of[GenericData.Record](schema) + val expressionEncoder = encoder.asInstanceOf[ExpressionEncoder[GenericData.Record]] + val record = new GenericRecordBuilder(schema).build + var row = expressionEncoder.toRow(record) + var recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, 1) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, 1F) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, "bar") + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).getStruct(3, 2) == null) + assert(record == recordFromRow) + + record.put(0, new GenericRecordBuilder(nested).build()) + row = expressionEncoder.toRow(record) + recordFromRow = expressionEncoder.resolveAndBind().fromRow(row) + + assert(row.getStruct(0, 4).get(0, IntegerType) == null) + assert(row.getStruct(0, 4).get(1, FloatType) == null) + assert(row.getStruct(0, 4).get(2, StringType) == null) + assert(record == recordFromRow) + } + + test("create Dataset from GenericRecord") { + // need a spark context with kryo as serializer + val conf = new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.driver.allowMultipleContexts", "true") + .set("spark.master", "local[2]") + .set("spark.app.name", "AvroSuite") + val context = new SparkContext(conf) + + val schema: Schema = + SchemaBuilder + .record("GenericRecordTest") + .namespace("com.databricks.spark.avro") + .fields() + .requiredString("field1") + .name("enumVal").`type`().enumeration("letters").symbols("a", "b", "c").enumDefault("a") + .name("fixedVal").`type`().fixed("MD5").size(16).fixedDefault(ByteBuffer.allocate(16)) + .endRecord() + + implicit val enc = AvroEncoder.of[GenericData.Record](schema) + + val genericRecords = (1 to 10) map { i => + new GenericRecordBuilder(schema) + .set("field1", "field-" + i) + .build() + } + + val rdd: RDD[GenericData.Record] = context.parallelize(genericRecords) + val ds = rdd.toDS() + assert(ds.count() == genericRecords.size) + context.stop() + } } --- End diff -- Yep, actually I test the cases you mentioned self but its need to add lots of generation code by avro. Moreover, IIUC, the testing of `SpecificRecord` just test one more logic of `avroClass.getMethod("getClassSchema")`, I just think no need to add those generation code for this test. If we really want to achieve this maybe add a little simple specific record example based on existing `test.avsc`? Or if we just want to show the usage, maybe add corresponding document is enought. WDYT :)
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org