This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d78199ae2a67 [SPARK-49406][SQL][TESTS] Add three test cases for 
`from_avro/to_avro` from the perspective of `GenericRecord`
d78199ae2a67 is described below

commit d78199ae2a67b1590903015323059baf59fae88b
Author: panbingkun <[email protected]>
AuthorDate: Thu Aug 29 12:59:19 2024 +0200

    [SPARK-49406][SQL][TESTS] Add three test cases for `from_avro/to_avro` from 
the perspective of `GenericRecord`
    
    ### What changes were proposed in this pull request?
    The pr aims to add 3 test cases for `from_avro/to_avro` from the 
perspective of `GenericRecord`.
    
    ### Why are the changes needed?
    Just to add test cases and help better understand these two functions: 
`from_avro` & `to_avro`.
    
    ### Does this PR introduce _any_ user-facing change?
    No, only just supplementary test cases.
    
    ### How was this patch tested?
    Update existed UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47888 from panbingkun/SPARK-49406.
    
    Authored-by: panbingkun <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../apache/spark/sql/avro/AvroFunctionsSuite.scala | 222 ++++++++++++++++++++-
 1 file changed, 219 insertions(+), 3 deletions(-)

diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
index 7001fa96deb8..432c3fa9be3a 100644
--- 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
+++ 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala
@@ -22,16 +22,18 @@ import java.io.ByteArrayOutputStream
 import scala.jdk.CollectionConverters._
 
 import org.apache.avro.{Schema, SchemaBuilder}
-import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, 
GenericRecordBuilder}
-import org.apache.avro.io.EncoderFactory
+import org.apache.avro.file.SeekableByteArrayInput
+import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, 
GenericRecord, GenericRecordBuilder}
+import org.apache.avro.io.{DecoderFactory, EncoderFactory}
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.avro.{functions => Fns}
 import org.apache.spark.sql.execution.LocalTableScanExec
 import org.apache.spark.sql.functions.{col, lit, struct}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{BinaryType, StructType}
 
 class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
@@ -371,4 +373,218 @@ class AvroFunctionsSuite extends QueryTest with 
SharedSparkSession {
           stop = 138)))
     }
   }
+
+  private def serialize(record: GenericRecord, avroSchema: String): 
Array[Byte] = {
+    val schema = new Schema.Parser().parse(avroSchema)
+    val datumWriter = new GenericDatumWriter[GenericRecord](schema)
+    var outputStream: ByteArrayOutputStream = null
+    var bytes: Array[Byte] = null
+    try {
+      outputStream = new ByteArrayOutputStream()
+      val encoder = EncoderFactory.get.binaryEncoder(outputStream, null)
+      datumWriter.write(record, encoder)
+      encoder.flush()
+      bytes = outputStream.toByteArray
+    } finally {
+      if (outputStream != null) {
+        outputStream.close()
+      }
+    }
+    bytes
+  }
+
+  private def deserialize(bytes: Array[Byte], avroSchema: String): 
GenericRecord = {
+    val schema = new Schema.Parser().parse(avroSchema)
+    val datumReader = new GenericDatumReader[GenericRecord](schema)
+    var inputStream: SeekableByteArrayInput = null
+    var record: GenericRecord = null
+    try {
+      inputStream = new SeekableByteArrayInput(bytes)
+      val decoder = DecoderFactory.get.binaryDecoder(inputStream, null)
+      record = datumReader.read(null, decoder)
+    } finally {
+      if (inputStream != null) {
+        inputStream.close()
+      }
+    }
+    record
+  }
+
+  // write: `GenericRecord` -> binary (by `serialize`) -> dataframe
+  // read: dataframe -> binary -> `GenericRecord` (by `deserialize`)
+  test("roundtrip in serialize and deserialize - GenericRecord") {
+    val avroSchema =
+      """
+        |{
+        |  "type": "record",
+        |  "name": "person",
+        |  "fields": [
+        |    {"name": "name", "type": "string"},
+        |    {"name": "age", "type": "int"},
+        |    {"name": "country", "type": "string"}
+        |  ]
+        |}
+        |""".stripMargin
+    val testTable = "test_avro"
+    withTable(testTable) {
+      val schema = new Schema.Parser().parse(avroSchema)
+      val person1 = new GenericRecordBuilder(schema)
+        .set("name", "sparkA")
+        .set("age", 18)
+        .set("country", "usa")
+        .build()
+      val person2 = new GenericRecordBuilder(schema)
+        .set("name", "sparkB")
+        .set("age", 19)
+        .set("country", "usb")
+        .build()
+      Seq(person1, person2)
+        .map(p => serialize(p, avroSchema))
+        .toDF("data")
+        .repartition(1)
+        .writeTo(testTable)
+        .create()
+
+      val expectedSchema = new StructType().add("data", BinaryType)
+      assert(spark.table(testTable).schema === expectedSchema)
+
+      // Note that what is returned here is `Row[Array[Byte]]`
+      val avroDF = sql(s"SELECT data FROM $testTable")
+      val readbacks = avroDF
+        .collect()
+        .map(row => deserialize(row.get(0).asInstanceOf[Array[Byte]], 
avroSchema))
+
+      val readbackPerson1 = readbacks.head
+      assert(readbackPerson1.get(0).toString === person1.get(0))
+      assert(readbackPerson1.get(1).asInstanceOf[Int] === 
person1.get(1).asInstanceOf[Int])
+      assert(readbackPerson1.get(2).toString === person1.get(2))
+
+      val readbackPerson2 = readbacks(1)
+      assert(readbackPerson2.get(0).toString === person2.get(0))
+      assert(readbackPerson2.get(1).asInstanceOf[Int] === 
person2.get(1).asInstanceOf[Int])
+      assert(readbackPerson2.get(2).toString === person2.get(2))
+    }
+  }
+
+  // write: `GenericRecord` -> binary (by `serialize`) -> dataframe
+  // read: dataframe -> binary -> struct (by `from_avro`) -> `GenericRecord`
+  test("use `serialize` to write GenericRecord and `from_avro` to read 
GenericRecord") {
+    val avroSchema =
+      """
+        |{
+        |  "type": "record",
+        |  "name": "person",
+        |  "fields": [
+        |    {"name": "name", "type": "string"},
+        |    {"name": "age", "type": "int"},
+        |    {"name": "country", "type": "string"}
+        |  ]
+        |}
+        |""".stripMargin
+    val testTable = "test_avro"
+    withTable(testTable) {
+      val schema = new Schema.Parser().parse(avroSchema)
+      val person1 = new GenericRecordBuilder(schema)
+        .set("name", "sparkA")
+        .set("age", 18)
+        .set("country", "usa")
+        .build()
+      val person2 = new GenericRecordBuilder(schema)
+        .set("name", "sparkB")
+        .set("age", 19)
+        .set("country", "usb")
+        .build()
+      Seq(person1, person2)
+        .map(p => serialize(p, avroSchema))
+        .toDF("data")
+        .repartition(1)
+        .writeTo(testTable)
+        .create()
+
+      val expectedSchema = new StructType().add("data", BinaryType)
+      assert(spark.table(testTable).schema === expectedSchema)
+
+      // Note that what is returned here is `Row[Struct]`
+      val avroDF = sql(s"SELECT from_avro(data, '$avroSchema', map()) FROM 
$testTable")
+      val readbacks = avroDF
+        .collect()
+        .map(row =>
+          new GenericRecordBuilder(schema)
+            .set("name", row.getStruct(0).getString(0))
+            .set("age", row.getStruct(0).getInt(1))
+            .set("country", row.getStruct(0).getString(2))
+            .build())
+
+      val readbackPerson1 = readbacks.head
+      assert(readbackPerson1.get(0) === person1.get(0))
+      assert(readbackPerson1.get(1).asInstanceOf[Int] === 
person1.get(1).asInstanceOf[Int])
+      assert(readbackPerson1.get(2) === person1.get(2))
+
+      val readbackPerson2 = readbacks(1)
+      assert(readbackPerson2.get(0) === person2.get(0))
+      assert(readbackPerson2.get(1).asInstanceOf[Int] === 
person2.get(1).asInstanceOf[Int])
+      assert(readbackPerson2.get(2) === person2.get(2))
+    }
+  }
+
+  // write: `GenericRecord` (to `struct`) -> binary (by `to_avro`) -> dataframe
+  // read: dataframe -> binary -> `GenericRecord` (by `deserialize`)
+  test("use `to_avro` to write GenericRecord and `deserialize` to read 
GenericRecord") {
+    val avroSchema =
+      """
+        |{
+        |  "type": "record",
+        |  "name": "person",
+        |  "fields": [
+        |    {"name": "name", "type": "string"},
+        |    {"name": "age", "type": "int"},
+        |    {"name": "country", "type": "string"}
+        |  ]
+        |}
+        |""".stripMargin
+    val testTable = "test_avro"
+    withTable(testTable) {
+      val schema = new Schema.Parser().parse(avroSchema)
+      val person1 = new GenericRecordBuilder(schema)
+        .set("name", "sparkA")
+        .set("age", 18)
+        .set("country", "usa")
+        .build()
+      val person2 = new GenericRecordBuilder(schema)
+        .set("name", "sparkB")
+        .set("age", 19)
+        .set("country", "usb")
+        .build()
+      Seq(person1, person2)
+        .map(p => (
+          p.get(0).asInstanceOf[String],
+          p.get(1).asInstanceOf[Int],
+          p.get(2).asInstanceOf[String]))
+        .toDF("name", "age", "country")
+        .select(Fns.to_avro(struct($"name", $"age", $"country"), 
avroSchema).as("data"))
+        .repartition(1)
+        .writeTo(testTable)
+        .create()
+
+      val expectedSchema = new StructType().add("data", BinaryType)
+      assert(spark.table(testTable).schema === expectedSchema)
+
+      // Note that what is returned here is `Row[Array[Byte]]`
+      val avroDF = sql(s"select data from $testTable")
+      val readbacks = avroDF
+        .collect()
+        .map(row => row.get(0).asInstanceOf[Array[Byte]])
+        .map(bytes => deserialize(bytes, avroSchema))
+
+      val readbackPerson1 = readbacks.head
+      assert(readbackPerson1.get(0).toString === person1.get(0))
+      assert(readbackPerson1.get(1).asInstanceOf[Int] === 
person1.get(1).asInstanceOf[Int])
+      assert(readbackPerson1.get(2).toString === person1.get(2))
+
+      val readbackPerson2 = readbacks(1)
+      assert(readbackPerson2.get(0).toString === person2.get(0))
+      assert(readbackPerson2.get(1).asInstanceOf[Int] === 
person2.get(1).asInstanceOf[Int])
+      assert(readbackPerson2.get(2).toString === person2.get(2))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to