This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d66667f56d74 [SPARK-50906][SQL][SS] Add nullability check for if 
inputs of to_avro align with schema
d66667f56d74 is described below

commit d66667f56d747910c958b3ff53ce28946a81db6c
Author: fanyue-xia <[email protected]>
AuthorDate: Tue Jan 28 19:40:46 2025 -0800

    [SPARK-50906][SQL][SS] Add nullability check for if inputs of to_avro align 
with schema
    
    ### What changes were proposed in this pull request?
    
    Previously, we don't explicitly check when input of `to_avro` is `null` but 
the schema does not allow `null`. As a result, a NPE will be raised in this 
situation. This PR adds the check during serialization before writing to avro 
and raises user-facing error if above occurs.
    
    ### Why are the changes needed?
    
    It makes it easier for the user to understand and face the error.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #49590 from fanyue-xia/to_avro_improve_NPE.
    
    Lead-authored-by: fanyue-xia <[email protected]>
    Co-authored-by: fanyue-xia <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  7 ++
 .../org/apache/spark/sql/avro/AvroSuite.scala      | 78 ++++++++++++++++++++--
 .../org/apache/spark/sql/avro/AvroSerializer.scala | 11 +++
 3 files changed, 89 insertions(+), 7 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 3bac6638f706..92da13df5ff1 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -117,6 +117,13 @@
     },
     "sqlState" : "42604"
   },
+  "AVRO_CANNOT_WRITE_NULL_FIELD" : {
+    "message" : [
+      "Cannot write null value for field <name> defined as non-null Avro data 
type <dataType>.",
+      "To allow null value for this field, specify its avro schema as a union 
type with \"null\" using `avroSchema` option."
+    ],
+    "sqlState" : "22004"
+  },
   "AVRO_INCOMPATIBLE_READ_TYPE" : {
     "message" : [
       "Cannot convert Avro <avroPath> to SQL <sqlPath> because the original 
encoded data type is <avroType>, however you're trying to read the field as 
<sqlType>, which would lead to an incorrect answer.",
diff --git 
a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala 
b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 8342ca4e8427..4ddf6503d99e 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -25,7 +25,7 @@ import java.util.UUID
 
 import scala.jdk.CollectionConverters._
 
-import org.apache.avro.{AvroTypeException, Schema, SchemaBuilder, 
SchemaFormatter}
+import org.apache.avro.{Schema, SchemaBuilder, SchemaFormatter}
 import org.apache.avro.Schema.{Field, Type}
 import org.apache.avro.Schema.Type._
 import org.apache.avro.file.{DataFileReader, DataFileWriter}
@@ -33,7 +33,7 @@ import org.apache.avro.generic.{GenericData, 
GenericDatumReader, GenericDatumWri
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
 import org.apache.commons.io.FileUtils
 
-import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, 
SparkUpgradeException}
+import org.apache.spark.{SPARK_VERSION_SHORT, SparkConf, SparkException, 
SparkRuntimeException, SparkThrowable, SparkUpgradeException}
 import org.apache.spark.TestUtils.assertExceptionMsg
 import org.apache.spark.sql._
 import org.apache.spark.sql.TestingUDT.IntervalData
@@ -45,7 +45,7 @@ import 
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone
 import org.apache.spark.sql.execution.{FormattedMode, SparkPlan}
 import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, 
DataSource, FilePartition}
 import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.LegacyBehaviorPolicy
 import org.apache.spark.sql.internal.LegacyBehaviorPolicy._
 import org.apache.spark.sql.internal.SQLConf
@@ -100,6 +100,14 @@ abstract class AvroSuite
     SchemaFormatter.format(AvroUtils.JSON_INLINE_FORMAT, schema)
   }
 
+  private def getRootCause(ex: Throwable): Throwable = {
+    var rootCause = ex
+    while (rootCause.getCause != null) {
+      rootCause = rootCause.getCause
+    }
+    rootCause
+  }
+
   // Check whether an Avro schema of union type is converted to SQL in an 
expected way, when the
   // stable ID option is on.
   //
@@ -1317,7 +1325,16 @@ abstract class AvroSuite
         dfWithNull.write.format("avro")
           .option("avroSchema", 
avroSchema).save(s"$tempDir/${UUID.randomUUID()}")
       }
-      assertExceptionMsg[AvroTypeException](e1, "value null is not a 
SuitEnumType")
+
+      val expectedDatatype = "{\"type\":\"enum\",\"name\":\"SuitEnumType\"," +
+        "\"symbols\":[\"SPADES\",\"HEARTS\",\"DIAMONDS\",\"CLUBS\"]}"
+
+      checkError(
+        getRootCause(e1).asInstanceOf[SparkThrowable],
+        condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+        parameters = Map(
+          "name" -> "`Suit`",
+          "dataType" -> expectedDatatype))
 
       // Writing df containing data not in the enum will throw an exception
       val e2 = intercept[SparkException] {
@@ -1332,6 +1349,50 @@ abstract class AvroSuite
     }
   }
 
+  test("to_avro nested struct schema nullability mismatch") {
+    Seq((true, false), (false, true)).foreach {
+      case (innerNull, outerNull) =>
+        val innerSchema = StructType(Seq(StructField("field1", IntegerType, 
innerNull)))
+        val outerSchema = StructType(Seq(StructField("innerStruct", 
innerSchema, outerNull)))
+        val nestedSchema = StructType(Seq(StructField("outerStruct", 
outerSchema, false)))
+
+        val rowWithNull = if (innerNull) Row(Row(null)) else Row(null)
+        val data = Seq(Row(Row(Row(1))), Row(rowWithNull), Row(Row(Row(3))))
+        val df = spark.createDataFrame(spark.sparkContext.parallelize(data), 
nestedSchema)
+
+        val avroTypeStruct = s"""{
+          |  "type": "record",
+          |  "name": "outerStruct",
+          |  "fields": [
+          |    {
+          |      "name": "innerStruct",
+          |      "type": {
+          |        "type": "record",
+          |        "name": "innerStruct",
+          |        "fields": [
+          |          {"name": "field1", "type": "int"}
+          |        ]
+          |      }
+          |    }
+          |  ]
+          |}
+        """.stripMargin // nullability mismatch for innerStruct
+
+        val expectedErrorName = if (outerNull) "`innerStruct`" else "`field1`"
+        val expectedErrorSchema = if (outerNull) 
"{\"type\":\"record\",\"name\":\"innerStruct\"" +
+          ",\"fields\":[{\"name\":\"field1\",\"type\":\"int\"}]}" else 
"\"int\""
+
+        checkError(
+          exception = intercept[SparkRuntimeException] {
+            df.select(avro.functions.to_avro($"outerStruct", 
avroTypeStruct)).collect()
+          },
+          condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+          parameters = Map(
+            "name" -> expectedErrorName,
+            "dataType" -> expectedErrorSchema))
+    }
+  }
+
   test("support user provided avro schema for writing nullable fixed type") {
     withTempPath { tempDir =>
       val avroSchema =
@@ -1517,9 +1578,12 @@ abstract class AvroSuite
           .save(s"$tempDir/${UUID.randomUUID()}")
       }
       assert(ex.getCondition == "TASK_WRITE_FAILED")
-      assert(ex.getCause.isInstanceOf[java.lang.NullPointerException])
-      assert(ex.getCause.getMessage.contains(
-        "null value for (non-nullable) string at test_schema.Name"))
+      checkError(
+        ex.getCause.asInstanceOf[SparkThrowable],
+        condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+        parameters = Map(
+          "name" -> "`Name`",
+          "dataType" -> "\"string\""))
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 814a28e24f52..1d83a46a278f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -29,11 +29,13 @@ import org.apache.avro.Schema.Type._
 import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
 import org.apache.avro.util.Utf8
 
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.avro.AvroUtils.{nonNullUnionBranches, toFieldStr, 
AvroMatchedField}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, 
SpecificInternalRow}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
 import org.apache.spark.sql.execution.datasources.DataSourceUtils
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 import org.apache.spark.sql.types._
@@ -282,11 +284,20 @@ private[sql] class AvroSerializer(
     }.toArray.unzip
 
     val numFields = catalystStruct.length
+    val avroFields = avroStruct.getFields()
+    val isSchemaNullable = avroFields.asScala.map(_.schema().isNullable)
     row: InternalRow =>
       val result = new Record(avroStruct)
       var i = 0
       while (i < numFields) {
         if (row.isNullAt(i)) {
+          if (!isSchemaNullable(i)) {
+            throw new SparkRuntimeException(
+              errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
+              messageParameters = Map(
+                "name" -> toSQLId(avroFields.get(i).name),
+                "dataType" -> avroFields.get(i).schema().toString))
+          }
           result.put(avroIndices(i), null)
         } else {
           result.put(avroIndices(i), fieldConverters(i).apply(row, i))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to