This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 56dd1f7c101e [SPARK-46634][SQL] literal validation should not drill 
down to null fields
56dd1f7c101e is described below

commit 56dd1f7c101ed0db7a6fcb7ac2f6f06136ac3d37
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Tue Jan 9 08:58:54 2024 -0800

    [SPARK-46634][SQL] literal validation should not drill down to null fields
    
    ### What changes were proposed in this pull request?
    
    This fixes a minor bug in literal validation. The contract of `InternalRow` 
is people should call `isNullAt` instead of relying on the `get` function to 
return null. `InternalRow` is an abstract class and it's not guaranteed that 
the `get` function can work for null field. This PR fixes the literal 
validation to check `isNullAt` before getting the field value.
    
    ### Why are the changes needed?
    
    Fix bugs for specific `InternalRow` implementations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44640 from cloud-fan/literal.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/expressions/literals.scala      |  4 +++-
 .../catalyst/expressions/LiteralExpressionSuite.scala  | 18 ++++++++++++++++++
 2 files changed, 21 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 79b2985adc1d..6c72afae91e9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -243,7 +243,9 @@ object Literal {
           v.isInstanceOf[InternalRow] && {
             val row = v.asInstanceOf[InternalRow]
             st.fields.map(_.dataType).zipWithIndex.forall {
-              case (fieldDataType, i) => doValidate(row.get(i, fieldDataType), 
fieldDataType)
+              case (fieldDataType, i) =>
+                // Do not need to validate null values.
+                row.isNullAt(i) || doValidate(row.get(i, fieldDataType), 
fieldDataType)
             }
           }
         case _ => false
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index f63b60f5ebba..d42e0b7d681d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -478,6 +478,24 @@ class LiteralExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       UTF8String.fromString("Spark SQL"))
   }
 
+  // A generic internal row that throws exception when accessing null values
+  class NullAccessForbiddenGenericInternalRow(override val values: Array[Any])
+    extends GenericInternalRow(values) {
+    override def get(ordinal: Int, dataType: DataType): AnyRef = {
+      if (values(ordinal) == null) {
+        throw new RuntimeException(s"Should not access null field at 
$ordinal!")
+      }
+      super.get(ordinal, dataType)
+    }
+  }
+
+  test("SPARK-46634: literal validation should not drill down to null fields") 
{
+    val exceptionInternalRow = new 
NullAccessForbiddenGenericInternalRow(Array(null, 1))
+    val schema = StructType.fromDDL("id INT, age INT")
+    // This should not fail because it should check whether the field is null 
before drilling down
+    Literal.validateLiteralValue(exceptionInternalRow, schema)
+  }
+
   test("SPARK-46604: Literal support immutable ArraySeq") {
     import org.apache.spark.util.ArrayImplicits._
     val immArraySeq = Array(1.0, 4.0).toImmutableArraySeq


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to