shardulm94 commented on a change in pull request #31597:
URL: https://github.com/apache/spark/pull/31597#discussion_r581249774



##########
File path: 
core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala
##########
@@ -22,64 +22,88 @@ import java.nio.ByteBuffer
 
 import com.esotericsoftware.kryo.io.{Input, Output}
 import org.apache.avro.{Schema, SchemaBuilder}
-import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.generic.GenericData.{Array => AvroArray, EnumSymbol, 
Fixed, Record}
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
 import org.apache.spark.internal.config.SERIALIZER
 
 class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext 
{
   conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
 
-  val schema : Schema = SchemaBuilder
+  val recordSchema : Schema = SchemaBuilder
     .record("testRecord").fields()
     .requiredString("data")
     .endRecord()
-  val record = new Record(schema)
-  record.put("data", "test data")
+  val recordDatum = new Record(recordSchema)
+  recordDatum.put("data", "test data")
 
-  test("schema compression and decompression") {
-    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
-    assert(schema === 
genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
-  }
+  val arraySchema = SchemaBuilder.array().items().`type`(recordSchema)
+  val arrayDatum = new AvroArray[Record](1, arraySchema)
+  arrayDatum.add(recordDatum)
 
-  test("record serialization and deserialization") {
-    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+  val enumSchema = SchemaBuilder.enumeration("enum").symbols("A", "B")
+  val enumDatum = new EnumSymbol(enumSchema, "A")
 
-    val outputStream = new ByteArrayOutputStream()
-    val output = new Output(outputStream)
-    genericSer.serializeDatum(record, output)
-    output.flush()
-    output.close()
+  val fixedSchema = SchemaBuilder.fixed("fixed").size(4)
+  val fixedDatum = new Fixed(fixedSchema, "ABCD".getBytes)
 
-    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
-    assert(genericSer.deserializeDatum(input) === record)
+  test("schema compression and decompression") {
+    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+    assert(recordSchema ===
+      
genericSer.decompress(ByteBuffer.wrap(genericSer.compress(recordSchema))))
   }
 
   test("uses schema fingerprint to decrease message size") {
-    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)
+    val genericSerFull = new GenericAvroSerializer[Record](conf.getAvroSchema)
 
     val output = new Output(new ByteArrayOutputStream())
 
     val beginningNormalPosition = output.total()
-    genericSerFull.serializeDatum(record, output)
+    genericSerFull.serializeDatum(recordDatum, output)
     output.flush()
     val normalLength = output.total - beginningNormalPosition
 
-    conf.registerAvroSchemas(schema)
-    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
+    conf.registerAvroSchemas(recordSchema)
+    val genericSerFinger = new 
GenericAvroSerializer[Record](conf.getAvroSchema)
     val beginningFingerprintPosition = output.total()
-    genericSerFinger.serializeDatum(record, output)
+    genericSerFinger.serializeDatum(recordDatum, output)
     val fingerprintLength = output.total - beginningFingerprintPosition
 
     assert(fingerprintLength < normalLength)
   }
 
   test("caches previously seen schemas") {
     val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
-    val compressedSchema = genericSer.compress(schema)
+    val compressedSchema = genericSer.compress(recordSchema)
     val decompressedSchema = 
genericSer.decompress(ByteBuffer.wrap(compressedSchema))
 
-    assert(compressedSchema.eq(genericSer.compress(schema)))
+    assert(compressedSchema.eq(genericSer.compress(recordSchema)))
     
assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
   }
+
+  Seq(
+    ("GenericData.Record", recordDatum),
+    ("GenericData.Array", arrayDatum),
+    ("GenericData.EnumSymbol", enumDatum),
+    ("GenericData.Fixed", fixedDatum)
+  ).foreach { case (name, datum) =>
+    test(s"SPARK-34477: $name serialization and deserialization") {
+      val genericSer = new 
GenericAvroSerializer[datum.type](conf.getAvroSchema)
+
+      val outputStream = new ByteArrayOutputStream()
+      val output = new Output(outputStream)
+      genericSer.serializeDatum(datum, output)
+      output.flush()
+      output.close()
+
+      val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
+      assert(genericSer.deserializeDatum(input) === datum)
+    }
+
+    test(s"SPARK-34477: $name serialization and deserialization through 
KryoSerializer ") {
+      require(conf.get(SERIALIZER) == 
"org.apache.spark.serializer.KryoSerializer")

Review comment:
       I decided to keep this for convenience. Since this test case tests for 
Kryo specifically, and won't work without it, the test failure would be more 
obvious if someone tries to remove the conf at the top of the file.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to