This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fec210b36be [SPARK-41395][SQL] `InterpretedMutableProjection` should 
use `setDecimal` to set null values for decimals in an unsafe row
fec210b36be is described below

commit fec210b36be22f187b51b67970960692f75ac31f
Author: Bruce Robbins <bersprock...@gmail.com>
AuthorDate: Fri Dec 9 21:44:45 2022 +0900

    [SPARK-41395][SQL] `InterpretedMutableProjection` should use `setDecimal` 
to set null values for decimals in an unsafe row
    
    ### What changes were proposed in this pull request?
    
    Change `InterpretedMutableProjection` to use `setDecimal` rather than 
`setNullAt` to set null values for decimals in unsafe rows.
    
    ### Why are the changes needed?
    
    The following returns the wrong answer:
    
    ```
    set spark.sql.codegen.wholeStage=false;
    set spark.sql.codegen.factoryMode=NO_CODEGEN;
    
    select max(col1), max(col2) from values
    (cast(null  as decimal(27,2)), cast(null   as decimal(27,2))),
    (cast(77.77 as decimal(27,2)), cast(245.00 as decimal(27,2)))
    as data(col1, col2);
    
    +---------+---------+
    |max(col1)|max(col2)|
    +---------+---------+
    |null     |239.88   |
    +---------+---------+
    ```
    This is because `InterpretedMutableProjection` inappropriately uses 
`InternalRow#setNullAt` on unsafe rows to set null for decimal types with 
precision > `Decimal.MAX_LONG_DIGITS`.
    
    When `setNullAt` is used, the pointer to the decimal's storage area in the 
variable length region gets zeroed out. Later, when 
`InterpretedMutableProjection` calls `setDecimal` on that field, 
`UnsafeRow#setDecimal` picks up the zero pointer and stores decimal data on top 
of the null-tracking bit set. Later updates to the null-tracking bit set (e.g., 
calls to `setNotNullAt`) further corrupt the decimal data (turning 245.00 into 
239.88, for example). The stomping of the null-tracking bi [...]
    
    This bug can manifest for end-users after codegen fallback (say, if an 
expression's generated code fails to compile).
    
    [Codegen for mutable 
projection](https://github.com/apache/spark/blob/89b2ee27d258dec8fe265fa862846e800a374d8e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala#L1729)
 uses `mutableRow.setDecimal` for null decimal values regardless of precision 
or the type for `mutableRow`, so this PR does the same.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #38923 from bersprockets/unsafe_decimal_issue.
    
    Authored-by: Bruce Robbins <bersprock...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../expressions/InterpretedMutableProjection.scala |  3 +-
 .../expressions/MutableProjectionSuite.scala       | 62 ++++++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 91c9457af7d..4e129e96d1c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.DecimalType
 
 
 /**
@@ -72,7 +73,7 @@ class InterpretedMutableProjection(expressions: 
Seq[Expression]) extends Mutable
 
   private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case 
(e, i) =>
     val writer = InternalRow.getWriter(i, e.dataType)
-    if (!e.nullable) {
+    if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) {
       (v: Any) => writer(mutableRow, v)
     } else {
       (v: Any) => {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
index 0f01bfbb894..e3f11283816 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -65,6 +65,68 @@ class MutableProjectionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(SafeProjection.create(fixedLengthTypes)(projUnsafeRow) === inputRow)
   }
 
+  def testRows(
+      bufferSchema: StructType,
+      buffer: InternalRow,
+      scalaRows: Seq[Seq[Any]]): Unit = {
+    val bufferTypes = bufferSchema.map(_.dataType).toArray
+    val proj = createMutableProjection(bufferTypes)
+
+    scalaRows.foreach { scalaRow =>
+      val inputRow = InternalRow.fromSeq(scalaRow.zip(bufferTypes).map {
+        case (v, dataType) => 
CatalystTypeConverters.createToCatalystConverter(dataType)(v)
+      })
+      val projRow = proj.target(buffer)(inputRow)
+      assert(SafeProjection.create(bufferTypes)(projRow) === inputRow)
+    }
+  }
+
+  testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal 
(high precision)") {
+    val bufferSchema = StructType(Array(
+      StructField("dec1", DecimalType(27, 2), nullable = true),
+      StructField("dec2", DecimalType(27, 2), nullable = true)))
+    val buffer = UnsafeProjection.create(bufferSchema)
+      .apply(new GenericInternalRow(bufferSchema.length))
+    val scalaRows = Seq(
+      Seq(null, null),
+      Seq(BigDecimal(77.77), BigDecimal(245.00)))
+    testRows(bufferSchema, buffer, scalaRows)
+  }
+
+  testBothCodegenAndInterpreted("SPARK-41395: unsafe buffer with null decimal 
(low precision)") {
+    val bufferSchema = StructType(Array(
+      StructField("dec1", DecimalType(10, 2), nullable = true),
+      StructField("dec2", DecimalType(10, 2), nullable = true)))
+    val buffer = UnsafeProjection.create(bufferSchema)
+      .apply(new GenericInternalRow(bufferSchema.length))
+    val scalaRows = Seq(
+      Seq(null, null),
+      Seq(BigDecimal(77.77), BigDecimal(245.00)))
+    testRows(bufferSchema, buffer, scalaRows)
+  }
+
+  testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal 
(high precision)") {
+    val bufferSchema = StructType(Array(
+      StructField("dec1", DecimalType(27, 2), nullable = true),
+      StructField("dec2", DecimalType(27, 2), nullable = true)))
+    val buffer = new GenericInternalRow(bufferSchema.length)
+    val scalaRows = Seq(
+      Seq(null, null),
+      Seq(BigDecimal(77.77), BigDecimal(245.00)))
+    testRows(bufferSchema, buffer, scalaRows)
+  }
+
+  testBothCodegenAndInterpreted("SPARK-41395: generic buffer with null decimal 
(low precision)") {
+    val bufferSchema = StructType(Array(
+      StructField("dec1", DecimalType(10, 2), nullable = true),
+      StructField("dec2", DecimalType(10, 2), nullable = true)))
+    val buffer = new GenericInternalRow(bufferSchema.length)
+    val scalaRows = Seq(
+      Seq(null, null),
+      Seq(BigDecimal(77.77), BigDecimal(245.00)))
+    testRows(bufferSchema, buffer, scalaRows)
+  }
+
   testBothCodegenAndInterpreted("variable-length types") {
     val proj = createMutableProjection(variableLengthTypes)
     val scalaValues = Seq("abc", BigDecimal(10),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to