xkrogen commented on a change in pull request #31597:
URL: https://github.com/apache/spark/pull/31597#discussion_r581197032



##########
File path: 
core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala
##########
@@ -22,64 +22,88 @@ import java.nio.ByteBuffer
 
 import com.esotericsoftware.kryo.io.{Input, Output}
 import org.apache.avro.{Schema, SchemaBuilder}
-import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.generic.GenericData.{Array => AvroArray, EnumSymbol, 
Fixed, Record}
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
 import org.apache.spark.internal.config.SERIALIZER
 
 class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext 
{
   conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
 
-  val schema : Schema = SchemaBuilder
+  val recordSchema : Schema = SchemaBuilder
     .record("testRecord").fields()
     .requiredString("data")
     .endRecord()
-  val record = new Record(schema)
-  record.put("data", "test data")
+  val recordDatum = new Record(recordSchema)
+  recordDatum.put("data", "test data")
 
-  test("schema compression and decompression") {
-    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
-    assert(schema === 
genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
-  }
+  val arraySchema = SchemaBuilder.array().items().`type`(recordSchema)
+  val arrayDatum = new AvroArray[Record](1, arraySchema)
+  arrayDatum.add(recordDatum)
 
-  test("record serialization and deserialization") {
-    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+  val enumSchema = SchemaBuilder.enumeration("enum").symbols("A", "B")
+  val enumDatum = new EnumSymbol(enumSchema, "A")
 
-    val outputStream = new ByteArrayOutputStream()
-    val output = new Output(outputStream)
-    genericSer.serializeDatum(record, output)
-    output.flush()
-    output.close()
+  val fixedSchema = SchemaBuilder.fixed("fixed").size(4)
+  val fixedDatum = new Fixed(fixedSchema, "ABCD".getBytes)
 
-    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
-    assert(genericSer.deserializeDatum(input) === record)
+  test("schema compression and decompression") {
+    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+    assert(recordSchema ===
+      
genericSer.decompress(ByteBuffer.wrap(genericSer.compress(recordSchema))))
   }
 
   test("uses schema fingerprint to decrease message size") {
-    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)
+    val genericSerFull = new GenericAvroSerializer[Record](conf.getAvroSchema)
 
     val output = new Output(new ByteArrayOutputStream())
 
     val beginningNormalPosition = output.total()
-    genericSerFull.serializeDatum(record, output)
+    genericSerFull.serializeDatum(recordDatum, output)
     output.flush()
     val normalLength = output.total - beginningNormalPosition
 
-    conf.registerAvroSchemas(schema)
-    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
+    conf.registerAvroSchemas(recordSchema)
+    val genericSerFinger = new 
GenericAvroSerializer[Record](conf.getAvroSchema)
     val beginningFingerprintPosition = output.total()
-    genericSerFinger.serializeDatum(record, output)
+    genericSerFinger.serializeDatum(recordDatum, output)
     val fingerprintLength = output.total - beginningFingerprintPosition
 
     assert(fingerprintLength < normalLength)
   }
 
   test("caches previously seen schemas") {
     val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
-    val compressedSchema = genericSer.compress(schema)
+    val compressedSchema = genericSer.compress(recordSchema)
     val decompressedSchema = 
genericSer.decompress(ByteBuffer.wrap(compressedSchema))
 
-    assert(compressedSchema.eq(genericSer.compress(schema)))
+    assert(compressedSchema.eq(genericSer.compress(recordSchema)))
     
assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
   }
+
+  Seq(
+    ("GenericData.Record", recordDatum),
+    ("GenericData.Array", arrayDatum),
+    ("GenericData.EnumSymbol", enumDatum),
+    ("GenericData.Fixed", fixedDatum)
+  ).foreach { case (name, datum) =>
+    test(s"SPARK-34477: $name serialization and deserialization") {

Review comment:
       Minor nit suggestion: take `GenericData` prefix out of the sequence 
above and instead update this line to be `GenericData.$name serialization ... ` 
? 

##########
File path: 
core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
##########
@@ -44,8 +44,8 @@ import org.apache.spark.util.Utils
  *                string representation of the Avro schema, used to decrease 
the amount of data
  *                that needs to be serialized.

Review comment:
       Can we add a `@tparam` here?

##########
File path: 
core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala
##########
@@ -22,64 +22,88 @@ import java.nio.ByteBuffer
 
 import com.esotericsoftware.kryo.io.{Input, Output}
 import org.apache.avro.{Schema, SchemaBuilder}
-import org.apache.avro.generic.GenericData.Record
+import org.apache.avro.generic.GenericData.{Array => AvroArray, EnumSymbol, 
Fixed, Record}
 
 import org.apache.spark.{SharedSparkContext, SparkFunSuite}
 import org.apache.spark.internal.config.SERIALIZER
 
 class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext 
{
   conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
 
-  val schema : Schema = SchemaBuilder
+  val recordSchema : Schema = SchemaBuilder
     .record("testRecord").fields()
     .requiredString("data")
     .endRecord()
-  val record = new Record(schema)
-  record.put("data", "test data")
+  val recordDatum = new Record(recordSchema)
+  recordDatum.put("data", "test data")
 
-  test("schema compression and decompression") {
-    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
-    assert(schema === 
genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
-  }
+  val arraySchema = SchemaBuilder.array().items().`type`(recordSchema)
+  val arrayDatum = new AvroArray[Record](1, arraySchema)
+  arrayDatum.add(recordDatum)
 
-  test("record serialization and deserialization") {
-    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+  val enumSchema = SchemaBuilder.enumeration("enum").symbols("A", "B")
+  val enumDatum = new EnumSymbol(enumSchema, "A")
 
-    val outputStream = new ByteArrayOutputStream()
-    val output = new Output(outputStream)
-    genericSer.serializeDatum(record, output)
-    output.flush()
-    output.close()
+  val fixedSchema = SchemaBuilder.fixed("fixed").size(4)
+  val fixedDatum = new Fixed(fixedSchema, "ABCD".getBytes)
 
-    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
-    assert(genericSer.deserializeDatum(input) === record)
+  test("schema compression and decompression") {
+    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
+    assert(recordSchema ===
+      
genericSer.decompress(ByteBuffer.wrap(genericSer.compress(recordSchema))))
   }
 
   test("uses schema fingerprint to decrease message size") {
-    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)
+    val genericSerFull = new GenericAvroSerializer[Record](conf.getAvroSchema)
 
     val output = new Output(new ByteArrayOutputStream())
 
     val beginningNormalPosition = output.total()
-    genericSerFull.serializeDatum(record, output)
+    genericSerFull.serializeDatum(recordDatum, output)
     output.flush()
     val normalLength = output.total - beginningNormalPosition
 
-    conf.registerAvroSchemas(schema)
-    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
+    conf.registerAvroSchemas(recordSchema)
+    val genericSerFinger = new 
GenericAvroSerializer[Record](conf.getAvroSchema)
     val beginningFingerprintPosition = output.total()
-    genericSerFinger.serializeDatum(record, output)
+    genericSerFinger.serializeDatum(recordDatum, output)
     val fingerprintLength = output.total - beginningFingerprintPosition
 
     assert(fingerprintLength < normalLength)
   }
 
   test("caches previously seen schemas") {
     val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
-    val compressedSchema = genericSer.compress(schema)
+    val compressedSchema = genericSer.compress(recordSchema)
     val decompressedSchema = 
genericSer.decompress(ByteBuffer.wrap(compressedSchema))
 
-    assert(compressedSchema.eq(genericSer.compress(schema)))
+    assert(compressedSchema.eq(genericSer.compress(recordSchema)))
     
assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
   }
+
+  Seq(
+    ("GenericData.Record", recordDatum),
+    ("GenericData.Array", arrayDatum),
+    ("GenericData.EnumSymbol", enumDatum),
+    ("GenericData.Fixed", fixedDatum)
+  ).foreach { case (name, datum) =>
+    test(s"SPARK-34477: $name serialization and deserialization") {
+      val genericSer = new 
GenericAvroSerializer[datum.type](conf.getAvroSchema)
+
+      val outputStream = new ByteArrayOutputStream()
+      val output = new Output(outputStream)
+      genericSer.serializeDatum(datum, output)
+      output.flush()
+      output.close()
+
+      val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
+      assert(genericSer.deserializeDatum(input) === datum)
+    }
+
+    test(s"SPARK-34477: $name serialization and deserialization through 
KryoSerializer ") {
+      require(conf.get(SERIALIZER) == 
"org.apache.spark.serializer.KryoSerializer")

Review comment:
       Do we need the `require` ? At the top of the class we're explicitly 
setting this value

##########
File path: core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
##########
@@ -153,8 +153,15 @@ class KryoSerializer(conf: SparkConf)
     kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
     kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
 
-    kryo.register(classOf[GenericRecord], new 
GenericAvroSerializer(avroSchemas))
-    kryo.register(classOf[GenericData.Record], new 
GenericAvroSerializer(avroSchemas))
+    kryo.register(classOf[GenericRecord], new 
GenericAvroSerializer[GenericRecord](avroSchemas))

Review comment:
       Can we get a little more DRY on these lines like:
   ```
       def registerAvro[T]: Unit = kryo.register(classOf[T], new 
GenericAvroSerializer[T](avroSchemas))
       registerAvro[GenericRecord]
       registerAvro[GenericData.Record]
       // ...
   ```

##########
File path: 
core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
##########
@@ -44,8 +44,8 @@ import org.apache.spark.util.Utils
  *                string representation of the Avro schema, used to decrease 
the amount of data
  *                that needs to be serialized.
  */
-private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
-  extends KSerializer[GenericRecord] {
+private[serializer] class GenericAvroSerializer[D >: Null <: GenericContainer]

Review comment:
       Why do we need the lower bound on `Null` here?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to