This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cb869328ea7 [SPARK-41804][SQL] Choose correct element size in
`InterpretedUnsafeProjection` for array of UDTs
cb869328ea7 is described below
commit cb869328ea7fcf95a4178a0db19a6fa821ce3f15
Author: Bruce Robbins <[email protected]>
AuthorDate: Tue Jan 3 10:22:35 2023 +0900
[SPARK-41804][SQL] Choose correct element size in
`InterpretedUnsafeProjection` for array of UDTs
### What changes were proposed in this pull request?
Change `InterpretedUnsafeProjection#getElementSize` to choose the
appropriate element size for the underlying SQL type of a UDT, rather than
simply using the the default size of the underlying SQL type.
### Why are the changes needed?
Consider this query:
```
// create a file of vector data
import org.apache.spark.ml.linalg.{DenseVector, Vector}
case class TestRow(varr: Array[Vector])
val values = Array(0.1d, 0.2d, 0.3d)
val dv = new DenseVector(values).asInstanceOf[Vector]
val ds = Seq(TestRow(Array(dv, dv))).toDS
ds.coalesce(1).write.mode("overwrite").format("parquet").save("vector_data")
// this works
spark.read.format("parquet").load("vector_data").collect
sql("set spark.sql.codegen.wholeStage=false")
sql("set spark.sql.codegen.factoryMode=NO_CODEGEN")
// this will get an error
spark.read.format("parquet").load("vector_data").collect
```
The failures vary, incuding
* `VectorUDT` attempting to deserialize to a `SparseVector` (rather than a
`DenseVector`)
* negative array size (for one of the nested arrays)
* JVM crash (SIGBUS error).
This is because `InterpretedUnsafeProjection` initializes the outer-most
array writer with an element size of 17 (the size of the UDT's underlying
struct), rather than an element size of 8, which would be appropriate for an
array of structs.
When the outer-most array is later accessed, `UnsafeArrayData` assumes an
element size of 8, so it picks up a garbage offset/size tuple for the second
element.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit test.
Closes #39349 from bersprockets/udt_issue.
Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../expressions/InterpretedUnsafeProjection.scala | 2 ++
.../catalyst/expressions/UnsafeRowConverterSuite.scala | 16 ++++++++++++++++
2 files changed, 18 insertions(+)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index d87c0c006cf..9108a045c09 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -294,6 +294,8 @@ object InterpretedUnsafeProjection {
private def getElementSize(dataType: DataType): Int = dataType match {
case NullType | StringType | BinaryType | CalendarIntervalType |
_: DecimalType | _: StructType | _: ArrayType | _: MapType => 8
+ case udt: UserDefinedType[_] =>
+ getElementSize(udt.sqlType)
case _ => dataType.defaultSize
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 83dc8127828..cbab8894cb5 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -687,4 +687,20 @@ class UnsafeRowConverterSuite extends SparkFunSuite with
Matchers with PlanTestB
val fields5 = Array[DataType](udt)
assert(convertBackToInternalRow(udtRow, fields5) === udtRow)
}
+
+ testBothCodegenAndInterpreted("SPARK-41804: Array of UDTs") {
+ val udt = new ExampleBaseTypeUDT
+ val objs = Seq(
+ udt.serialize(new ExampleSubClass(1)),
+ udt.serialize(new ExampleSubClass(2)))
+ val arr = new GenericArrayData(objs)
+ val row = new GenericInternalRow(Array[Any](arr))
+ val unsafeProj = UnsafeProjection.create(Array[DataType](ArrayType(udt)))
+ val unsafeRow = unsafeProj.apply(row)
+ val unsafeOuterArray = unsafeRow.getArray(0)
+ // get second element from unsafe array
+ val unsafeStruct = unsafeOuterArray.getStruct(1, 1)
+ val result = unsafeStruct.getInt(0)
+ assert(result == 2)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]