This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new b5c4bb76db0 [SPARK-41535][SQL] Set null correctly for calendar
interval fields in `InterpretedUnsafeProjection` and
`InterpretedMutableProjection`
b5c4bb76db0 is described below
commit b5c4bb76db06623e4252e02f76f235d9a476059b
Author: Bruce Robbins <[email protected]>
AuthorDate: Tue Dec 20 09:29:18 2022 +0900
[SPARK-41535][SQL] Set null correctly for calendar interval fields in
`InterpretedUnsafeProjection` and `InterpretedMutableProjection`
In `InterpretedUnsafeProjection`, use `UnsafeWriter.write`, rather than
`UnsafeWriter.setNullAt`, to set null for interval fields. Also, in
`InterpretedMutableProjection`, use `InternalRow.setInterval`, rather than
`InternalRow.setNullAt`, to set null for interval fields.
This returns the wrong answer:
```
set spark.sql.codegen.wholeStage=false;
set spark.sql.codegen.factoryMode=NO_CODEGEN;
select first(col1), last(col2) from values
(make_interval(0, 0, 0, 7, 0, 0, 0), make_interval(17, 0, 0, 2, 0, 0, 0))
as data(col1, col2);
+---------------+---------------+
|first(col1) |last(col2) |
+---------------+---------------+
|16 years 2 days|16 years 2 days|
+---------------+---------------+
```
In the above case, `TungstenAggregationIterator` uses
`InterpretedUnsafeProjection` to create the aggregation buffer and to
initialize all the fields to null. `InterpretedUnsafeProjection` incorrectly
calls `UnsafeRowWriter#setNullAt`, rather than `unsafeRowWriter#write`, for the
two calendar interval fields. As a result, the writer never allocates memory
from the variable length region for the two intervals, and the pointers in the
fixed region get left as zero. Later, when `Interpre [...]
Even after one fixes the above bug in `InterpretedUnsafeProjection` so that
the buffer is created correctly, `InterpretedMutableProjection` has a similar
bug to SPARK-41395, except this time for calendar interval data:
```
set spark.sql.codegen.wholeStage=false;
set spark.sql.codegen.factoryMode=NO_CODEGEN;
select first(col1), last(col2), max(col3) from values
(null, null, 1),
(make_interval(0, 0, 0, 7, 0, 0, 0), make_interval(17, 0, 0, 2, 0, 0, 0), 3)
as data(col1, col2, col3);
+---------------+---------------+---------+
|first(col1) |last(col2) |max(col3)|
+---------------+---------------+---------+
|16 years 2 days|16 years 2 days|3 |
+---------------+---------------+---------+
```
These two bugs could get exercised during codegen fallback.
No.
New unit tests.
Closes #39117 from bersprockets/unsafe_interval_issue.
Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 7f153842041d66e9cf0465262f4458cfffda4f43)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../expressions/codegen/UnsafeArrayWriter.java | 14 +++++++
.../catalyst/expressions/codegen/UnsafeWriter.java | 2 +-
.../expressions/InterpretedMutableProjection.scala | 4 +-
.../expressions/InterpretedUnsafeProjection.scala | 3 ++
.../expressions/codegen/CodeGenerator.scala | 5 +--
.../spark/sql/catalyst/util/UnsafeRowUtils.scala | 26 ++++++++++++
.../expressions/MutableProjectionSuite.scala | 29 ++++++++++++-
.../expressions/UnsafeRowConverterSuite.scala | 49 ++++++++++++++++++++++
8 files changed, 125 insertions(+), 7 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index bf6792313ae..e5b941f9e60 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -21,6 +21,7 @@ import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
import static
org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
@@ -182,4 +183,17 @@ public final class UnsafeArrayWriter extends UnsafeWriter {
setNull(ordinal);
}
}
+
+ @Override
+ public void write(int ordinal, CalendarInterval input) {
+ assertIndexIsValid(ordinal);
+ // the UnsafeWriter version of write(int, CalendarInterval) doesn't handle
+ // null intervals appropriately when the container is an array, so we
handle
+ // that case here.
+ if (input == null) {
+ setNull(ordinal);
+ } else {
+ super.write(ordinal, input);
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 84b2b294794..8d4e187d01a 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -131,7 +131,7 @@ public abstract class UnsafeWriter {
increaseCursor(roundedSize);
}
- public final void write(int ordinal, CalendarInterval input) {
+ public void write(int ordinal, CalendarInterval input) {
// grow the global buffer before writing data.
grow(16);
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 4e129e96d1c..5d95ac71be8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils.avoidSetNullAt
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.DecimalType
/**
@@ -73,7 +73,7 @@ class InterpretedMutableProjection(expressions:
Seq[Expression]) extends Mutable
private[this] val fieldWriters: Array[Any => Unit] = validExprs.map { case
(e, i) =>
val writer = InternalRow.getWriter(i, e.dataType)
- if (!e.nullable || e.dataType.isInstanceOf[DecimalType]) {
+ if (!e.nullable || avoidSetNullAt(e.dataType)) {
(v: Any) => writer(mutableRow, v)
} else {
(v: Any) => {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index d02d1e8b55b..3281280f648 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -266,6 +266,9 @@ object InterpretedUnsafeProjection {
// We can't call setNullAt() for DecimalType with precision larger
than 18, we call write
// directly. We can use the unwrapped writer directly.
unsafeWriter
+ case CalendarIntervalType =>
+ // We can't call setNullAt() for CalendarIntervalType, we call write
directly.
+ unsafeWriter
case BooleanType | ByteType =>
(v, i) => {
if (!v.isNullAt(i)) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 132bb259770..3abb717da5b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -38,7 +38,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, SQLOrderingUtil}
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData,
SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -1725,8 +1725,7 @@ object CodeGenerator extends Logging {
if (nullable) {
// Can't call setNullAt on DecimalType/CalendarIntervalType, because we
need to keep the
// offset
- if (!isVectorized && (dataType.isInstanceOf[DecimalType] ||
- dataType.isInstanceOf[CalendarIntervalType])) {
+ if (!isVectorized && UnsafeRowUtils.avoidSetNullAt(dataType)) {
s"""
|if (!${ev.isNull}) {
| ${setColumn(row, dataType, ordinal, ev.value)};
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
index 48db0c7d971..2791f404813 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
@@ -113,4 +113,30 @@ object UnsafeRowUtils {
val size = offsetAndSize.toInt
(offset, size)
}
+
+ /**
+ * Returns a Boolean indicating whether one should avoid calling
+ * UnsafeRow.setNullAt for a field of the given data type.
+ * Fields of type DecimalType (with precision
+ * greater than Decimal.MAX_LONG_DIGITS) and CalendarIntervalType use
+ * pointers into the variable length region, and those pointers should
+ * never get zeroed out (setNullAt will zero out those pointers) because
UnsafeRow
+ * may do in-place update for these 2 types even though they are not
primitive.
+ *
+ * When avoidSetNullAt returns true, callers should not use
+ * UnsafeRow#setNullAt for fields of that data type, but instead pass
+ * a null value to the appropriate set method, e.g.:
+ *
+ * row.setDecimal(ordinal, null, precision)
+ *
+ * Even though only UnsafeRow has this limitation, it's safe to extend this
rule
+ * to all subclasses of InternalRow, since you don't always know the
concrete type
+ * of the row you are dealing with, and all subclasses of InternalRow will
+ * handle a null value appropriately.
+ */
+ def avoidSetNullAt(dt: DataType): Boolean = dt match {
+ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => true
+ case CalendarIntervalType => true
+ case _ => false
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
index e3f11283816..b79df0e40e9 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MutableProjectionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes,
yearMonthIntervalTypes}
import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class MutableProjectionSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -127,6 +127,33 @@ class MutableProjectionSuite extends SparkFunSuite with
ExpressionEvalHelper {
testRows(bufferSchema, buffer, scalaRows)
}
+ testBothCodegenAndInterpreted("SPARK-41535: unsafe buffer with null
intervals") {
+ val bufferSchema = StructType(Array(
+ StructField("intv1", CalendarIntervalType, nullable = true),
+ StructField("intv2", CalendarIntervalType, nullable = true)))
+ val buffer = UnsafeProjection.create(bufferSchema)
+ .apply(new GenericInternalRow(bufferSchema.length))
+ val scalaRows = Seq(
+ Seq(null, null),
+ Seq(
+ new CalendarInterval(0, 7, 0L),
+ new CalendarInterval(12*17, 2, 0L)))
+ testRows(bufferSchema, buffer, scalaRows)
+ }
+
+ testBothCodegenAndInterpreted("SPARK-41535: generic buffer with null
intervals") {
+ val bufferSchema = StructType(Array(
+ StructField("intv1", CalendarIntervalType, nullable = true),
+ StructField("intv2", CalendarIntervalType, nullable = true)))
+ val buffer = new GenericInternalRow(bufferSchema.length)
+ val scalaRows = Seq(
+ Seq(null, null),
+ Seq(
+ new CalendarInterval(0, 7, 0L),
+ new CalendarInterval(12*17, 2, 0L)))
+ testRows(bufferSchema, buffer, scalaRows)
+ }
+
testBothCodegenAndInterpreted("variable-length types") {
val proj = createMutableProjection(variableLengthTypes)
val scalaValues = Seq("abc", BigDecimal(10),
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 220728fcaa2..83dc8127828 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -277,6 +277,48 @@ class UnsafeRowConverterSuite extends SparkFunSuite with
Matchers with PlanTestB
// assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
+ testBothCodegenAndInterpreted("SPARK-41535: intervals initialized as null") {
+ val factory = UnsafeProjection
+ val fieldTypes: Array[DataType] = Array(CalendarIntervalType,
CalendarIntervalType)
+ val converter = factory.create(fieldTypes)
+
+ val row = new SpecificInternalRow(fieldTypes)
+ for (i <- 0 until row.numFields) {
+ row.setInterval(i, null)
+ }
+
+ val nullAtCreation = converter.apply(row)
+
+ for (i <- 0 until row.numFields) {
+ assert(nullAtCreation.isNullAt(i))
+ }
+
+ val intervals = Array(
+ new CalendarInterval(0, 7, 0L),
+ new CalendarInterval(12*17, 2, 0L)
+ )
+ // set interval values into previously null columns
+ for (i <- intervals.indices) {
+ nullAtCreation.setInterval(i, intervals(i))
+ }
+
+ for (i <- intervals.indices) {
+ assert(nullAtCreation.getInterval(i) == intervals(i))
+ }
+ }
+
+ testBothCodegenAndInterpreted("SPARK-41535: interval array containing
nulls") {
+ val factory = UnsafeProjection
+ val fieldTypes: Array[DataType] = Array(ArrayType(CalendarIntervalType))
+ val converter = factory.create(fieldTypes)
+
+ val row = new SpecificInternalRow(fieldTypes)
+ val values = Array(new CalendarInterval(0, 7, 0L), null)
+ row.update(0, createArray(values: _*))
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ testArrayInterval(unsafeRow.getArray(0), values)
+ }
+
testBothCodegenAndInterpreted("basic conversion with struct type") {
val factory = UnsafeProjection
val fieldTypes: Array[DataType] = Array(
@@ -330,6 +372,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with
Matchers with PlanTestB
}
}
+ private def testArrayInterval(array: UnsafeArrayData, values:
Seq[CalendarInterval]): Unit = {
+ assert(array.numElements == values.length)
+ values.zipWithIndex.foreach {
+ case (value, index) => assert(array.getInterval(index) == value)
+ }
+ }
+
private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values:
Seq[Int]): Unit = {
assert(keys.length == values.length)
assert(map.numElements == keys.length)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]