This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new b5c4bb76db0 [SPARK-41535][SQL] Set null correctly for calendar 
interval fields in `InterpretedUnsafeProjection` and 
`InterpretedMutableProjection`
b5c4bb76db0 is described below

commit b5c4bb76db06623e4252e02f76f235d9a476059b
Author: Bruce Robbins <[email protected]>
AuthorDate: Tue Dec 20 09:29:18 2022 +0900

    [SPARK-41535][SQL] Set null correctly for calendar interval fields in 
`InterpretedUnsafeProjection` and `InterpretedMutableProjection`
    
    In `InterpretedUnsafeProjection`, use `UnsafeWriter.write`, rather than 
`UnsafeWriter.setNullAt`, to set null for interval fields. Also, in 
`InterpretedMutableProjection`, use `InternalRow.setInterval`, rather than 
`InternalRow.setNullAt`, to set null for interval fields.
    
    This returns the wrong answer:
    ```
    set spark.sql.codegen.wholeStage=false;
    set spark.sql.codegen.factoryMode=NO_CODEGEN;
    
    select first(col1), last(col2) from values
    (make_interval(0, 0, 0, 7, 0, 0, 0), make_interval(17, 0, 0, 2, 0, 0, 0))
    as data(col1, col2);
    
    +---------------+---------------+
    |first(col1)    |last(col2)     |
    +---------------+---------------+
    |16 years 2 days|16 years 2 days|
    +---------------+---------------+
    ```
    In the above case, `TungstenAggregationIterator` uses 
`InterpretedUnsafeProjection` to create the aggregation buffer and to 
initialize all the fields to null. `InterpretedUnsafeProjection` incorrectly 
calls `UnsafeRowWriter#setNullAt`, rather than `unsafeRowWriter#write`, for the 
two calendar interval fields. As a result, the writer never allocates memory 
from the variable length region for the two intervals, and the pointers in the 
fixed region get left as zero. Later, when `Interpre [...]
    
    Even after one fixes the above bug in `InterpretedUnsafeProjection` so that 
the buffer is created correctly, `InterpretedMutableProjection` has a similar 
bug to SPARK-41395, except this time for calendar interval data:
    ```
    set spark.sql.codegen.wholeStage=false;
    set spark.sql.codegen.factoryMode=NO_CODEGEN;
    
    select first(col1), last(col2), max(col3) from values
    (null, null, 1),
    (make_interval(0, 0, 0, 7, 0, 0, 0), make_interval(17, 0, 0, 2, 0, 0, 0), 3)
    as data(col1, col2, col3);
    
    +---------------+---------------+---------+
    |first(col1)    |last(col2)     |max(col3)|
    +---------------+---------------+---------+
    |16 years 2 days|16 years 2 days|3        |
    +---------------+---------------+---------+
    ```
    These two bugs could get exercised during codegen fallback.
    
    No.
    
    New unit tests.
    
    Closes #39117 from bersprockets/unsafe_interval_issue.
    
    Authored-by: Bruce Robbins <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 7f153842041d66e9cf0465262f4458cfffda4f43)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../expressions/codegen/UnsafeArrayWriter.java     | 14 +++++++
 .../catalyst/expressions/codegen/UnsafeWriter.java |  2 +-
 .../expressions/InterpretedMutableProjection.scala |  4 +-
 .../expressions/InterpretedUnsafeProjection.scala  |  3 ++
 .../expressions/codegen/CodeGenerator.scala        |  5 +--
 .../spark/sql/catalyst/util/UnsafeRowUtils.scala   | 26 ++++++++++++
 .../expressions/MutableProjectionSuite.scala       | 29 ++++++++++++-
 .../expressions/UnsafeRowConverterSuite.scala      | 49 ++++++++++++++++++++++
 8 files changed, 125 insertions(+), 7 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index bf6792313ae..e5b941f9e60 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -21,6 +21,7 @@ import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
 
 import static 
org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
 
@@ -182,4 +183,17 @@ public final class UnsafeArrayWriter extends UnsafeWriter {
       setNull(ordinal);
     }
   }
+
+  @Override
+  public void write(int ordinal, CalendarInterval input) {
+    assertIndexIsValid(ordinal);
+    // the UnsafeWriter version of write(int, CalendarInterval) doesn't handle
+    // null intervals appropriately when the container is an array, so we 
handle
+    // that case here.
+    if (input == null) {
+      setNull(ordinal);
+    } else {
+      super.write(ordinal, input);
+    }
+  }
 }
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 84b2b294794..8d4e187d01a 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -131,7 +131,7 @@ public abstract class UnsafeWriter {
     increaseCursor(roundedSize);
   }
 
-  public final void write(int ordinal, CalendarInterval input) {
+  public void write(int ordinal, CalendarInterval input) {
     // grow the global buffer before writing data.
     grow(16);
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 4e129e96d1c..5d95ac71be8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils.avoidSetNullAt
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.DecimalType
 
 
 /**
@@ -73,7 +73,7 @@ class InterpretedMutableProjection(expressions: 
Seq[Expression]) extends Mutable
 
   private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case 
(e, i) =>
     val writer = InternalRow.getWriter(i, e.dataType)
-    if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) {
+    if (!e.nullable || avoidSetNullAt(e.dataType)) {
       (v: Any) => writer(mutableRow, v)
     } else {
       (v: Any) => {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index d02d1e8b55b..3281280f648 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -266,6 +266,9 @@ object InterpretedUnsafeProjection {
         // We can't call setNullAt() for DecimalType with precision larger 
than 18, we call write
         // directly. We can use the unwrapped writer directly.
         unsafeWriter
+      case CalendarIntervalType =>
+        // We can't call setNullAt() for CalendarIntervalType, we call write 
directly.
+        unsafeWriter
       case BooleanType | ByteType =>
         (v, i) => {
           if (!v.isNullAt(i)) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 132bb259770..3abb717da5b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -38,7 +38,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, SQLOrderingUtil}
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, 
SQLOrderingUtil, UnsafeRowUtils}
 import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -1725,8 +1725,7 @@ object CodeGenerator extends Logging {
     if (nullable) {
       // Can't call setNullAt on DecimalType/CalendarIntervalType, because we 
need to keep the
       // offset
-      if (!isVectorized && (dataType.isInstanceOf[DecimalType] ||
-        dataType.isInstanceOf[CalendarIntervalType])) {
+      if (!isVectorized && UnsafeRowUtils.avoidSetNullAt(dataType)) {
         s"""
            |if (!${ev.isNull}) {
            |  ${setColumn(row, dataType, ordinal, ev.value)};
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
index 48db0c7d971..2791f404813 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
@@ -113,4 +113,30 @@ object UnsafeRowUtils {
     val size = offsetAndSize.toInt
     (offset, size)
   }
+
+  /**
+   * Returns a Boolean indicating whether one should avoid calling
+   * UnsafeRow.setNullAt for a field of the given data type.
+   * Fields of type DecimalType (with precision
+   * greater than Decimal.MAX_LONG_DIGITS) and CalendarIntervalType use
+   * pointers into the variable length region, and those pointers should
+   * never get zeroed out (setNullAt will zero out those pointers) because 
UnsafeRow
+   * may do in-place update for these 2 types even though they are not 
primitive.
+   *
+   * When avoidSetNullAt returns true, callers should not use
+   * UnsafeRow#setNullAt for fields of that data type, but instead pass
+   * a null value to the appropriate set method, e.g.:
+   *
+   *   row.setDecimal(ordinal, null, precision)
+   *
+   * Even though only UnsafeRow has this limitation, it's safe to extend this 
rule
+   * to all subclasses of InternalRow, since you don't always know the 
concrete type
+   * of the row you are dealing with, and all subclasses of InternalRow will
+   * handle a null value appropriately.
+   */
+  def avoidSetNullAt(dt: DataType): Boolean = dt match {
+    case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => true
+    case CalendarIntervalType => true
+    case _ => false
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
index e3f11283816..b79df0e40e9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, 
yearMonthIntervalTypes}
 import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -127,6 +127,33 @@ class MutableProjectionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     testRows(bufferSchema, buffer, scalaRows)
   }
 
+  testBothCodegenAndInterpreted("SPARK-41535: unsafe buffer with null 
intervals") {
+    val bufferSchema = StructType(Array(
+      StructField("intv1", CalendarIntervalType, nullable = true),
+      StructField("intv2", CalendarIntervalType, nullable = true)))
+    val buffer = UnsafeProjection.create(bufferSchema)
+      .apply(new GenericInternalRow(bufferSchema.length))
+    val scalaRows = Seq(
+      Seq(null, null),
+      Seq(
+        new CalendarInterval(0, 7, 0L),
+        new CalendarInterval(12*17, 2, 0L)))
+    testRows(bufferSchema, buffer, scalaRows)
+  }
+
+  testBothCodegenAndInterpreted("SPARK-41535: generic buffer with null 
intervals") {
+    val bufferSchema = StructType(Array(
+      StructField("intv1", CalendarIntervalType, nullable = true),
+      StructField("intv2", CalendarIntervalType, nullable = true)))
+    val buffer = new GenericInternalRow(bufferSchema.length)
+    val scalaRows = Seq(
+      Seq(null, null),
+      Seq(
+        new CalendarInterval(0, 7, 0L),
+        new CalendarInterval(12*17, 2, 0L)))
+    testRows(bufferSchema, buffer, scalaRows)
+  }
+
   testBothCodegenAndInterpreted("variable-length types") {
     val proj = createMutableProjection(variableLengthTypes)
     val scalaValues = Seq("abc", BigDecimal(10),
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 220728fcaa2..83dc8127828 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -277,6 +277,48 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers with PlanTestB
     // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
   }
 
+  testBothCodegenAndInterpreted("SPARK-41535: intervals initialized as null") {
+    val factory = UnsafeProjection
+    val fieldTypes: Array[DataType] = Array(CalendarIntervalType, 
CalendarIntervalType)
+    val converter = factory.create(fieldTypes)
+
+    val row = new SpecificInternalRow(fieldTypes)
+    for (i <- 0 until row.numFields) {
+      row.setInterval(i, null)
+    }
+
+    val nullAtCreation = converter.apply(row)
+
+    for (i <- 0 until row.numFields) {
+      assert(nullAtCreation.isNullAt(i))
+    }
+
+    val intervals = Array(
+      new CalendarInterval(0, 7, 0L),
+      new CalendarInterval(12*17, 2, 0L)
+    )
+    // set interval values into previously null columns
+    for (i <- intervals.indices) {
+      nullAtCreation.setInterval(i, intervals(i))
+    }
+
+    for (i <- intervals.indices) {
+      assert(nullAtCreation.getInterval(i) == intervals(i))
+    }
+  }
+
+  testBothCodegenAndInterpreted("SPARK-41535: interval array containing 
nulls") {
+    val factory = UnsafeProjection
+    val fieldTypes: Array[DataType] = Array(ArrayType(CalendarIntervalType))
+    val converter = factory.create(fieldTypes)
+
+    val row = new SpecificInternalRow(fieldTypes)
+    val values = Array(new CalendarInterval(0, 7, 0L), null)
+    row.update(0, createArray(values: _*))
+    val unsafeRow: UnsafeRow = converter.apply(row)
+    testArrayInterval(unsafeRow.getArray(0), values)
+  }
+
   testBothCodegenAndInterpreted("basic conversion with struct type") {
     val factory = UnsafeProjection
     val fieldTypes: Array[DataType] = Array(
@@ -330,6 +372,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with 
Matchers with PlanTestB
     }
   }
 
+  private def testArrayInterval(array: UnsafeArrayData, values: 
Seq[CalendarInterval]): Unit = {
+    assert(array.numElements == values.length)
+    values.zipWithIndex.foreach {
+      case (value, index) => assert(array.getInterval(index) == value)
+    }
+  }
+
   private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values: 
Seq[Int]): Unit = {
     assert(keys.length == values.length)
     assert(map.numElements == keys.length)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to