This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0c9b072e1808 [SPARK-48833][SQL][VARIANT] Support variant in
`InMemoryTableScan`
0c9b072e1808 is described below
commit 0c9b072e1808e180d8670e4100f3344f039cb072
Author: Richard Chen <[email protected]>
AuthorDate: Wed Jul 24 18:02:38 2024 +0800
[SPARK-48833][SQL][VARIANT] Support variant in `InMemoryTableScan`
### What changes were proposed in this pull request?
adds support for variant type in `InMemoryTableScan`, or `df.cache()` by
supporting writing variant values to an inmemory buffer.
### Why are the changes needed?
prior to this PR, calling `df.cache()` on a df that has a variant would
fail because `InMemoryTableScan` does not support reading variant types.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added UTs
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #47252 from richardc-db/variant_dfcache_support.
Authored-by: Richard Chen <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/execution/columnar/ColumnAccessor.scala | 6 +-
.../sql/execution/columnar/ColumnBuilder.scala | 4 ++
.../spark/sql/execution/columnar/ColumnStats.scala | 15 +++++
.../spark/sql/execution/columnar/ColumnType.scala | 43 ++++++++++++-
.../columnar/GenerateColumnAccessor.scala | 3 +-
.../scala/org/apache/spark/sql/VariantSuite.scala | 72 ++++++++++++++++++++++
6 files changed, 139 insertions(+), 4 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 9652a48e5270..2074649cc986 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import
org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
/**
* An `Iterator` like trait used to extract values from columnar byte buffer.
When a value is
@@ -111,6 +111,10 @@ private[columnar] class IntervalColumnAccessor(buffer:
ByteBuffer)
extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL)
with NullableColumnAccessor
+private[columnar] class VariantColumnAccessor(buffer: ByteBuffer)
+ extends BasicColumnAccessor[VariantVal](buffer, VARIANT)
+ with NullableColumnAccessor
+
private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer,
dataType: DecimalType)
extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
index 9fafdb794841..b65ef12f12d5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
@@ -131,6 +131,9 @@ class BinaryColumnBuilder extends ComplexColumnBuilder(new
BinaryColumnStats, BI
private[columnar]
class IntervalColumnBuilder extends ComplexColumnBuilder(new
IntervalColumnStats, CALENDAR_INTERVAL)
+private[columnar]
+class VariantColumnBuilder extends ComplexColumnBuilder(new
VariantColumnStats, VARIANT)
+
private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType)
extends NativeColumnBuilder(new DecimalColumnStats(dataType),
COMPACT_DECIMAL(dataType))
@@ -189,6 +192,7 @@ private[columnar] object ColumnBuilder {
case s: StringType => new StringColumnBuilder(s)
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
+ case VariantType => new VariantColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnBuilder(dt)
case dt: DecimalType => new DecimalColumnBuilder(dt)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 45f489cb13c2..4e4b3667fa24 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -297,6 +297,21 @@ private[columnar] final class BinaryColumnStats extends
ColumnStats {
Array[Any](null, null, nullCount, count, sizeInBytes)
}
+private[columnar] final class VariantColumnStats extends ColumnStats {
+ override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+ if (!row.isNullAt(ordinal)) {
+ val size = VARIANT.actualSize(row, ordinal)
+ sizeInBytes += size
+ count += 1
+ } else {
+ gatherNullStats()
+ }
+ }
+
+ override def collectedStatistics: Array[Any] =
+ Array[Any](null, null, nullCount, count, sizeInBytes)
+}
+
private[columnar] final class IntervalColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index b8e63294f3cd..5cc3a3d83d4c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -24,11 +24,11 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.types.{PhysicalArrayType,
PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType,
PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType,
PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType,
PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType,
PhysicalStructType}
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
/**
@@ -815,6 +815,45 @@ private[columnar] object CALENDAR_INTERVAL extends
ColumnType[CalendarInterval]
}
}
+/**
+ * Used to append/extract Java VariantVals into/from the underlying
[[ByteBuffer]] of a column.
+ *
+ * Variants are encoded in `append` as:
+ * | value size | metadata size | value binary | metadata binary |
+ * and are only expected to be decoded in `extract`.
+ */
+private[columnar] object VARIANT
+ extends ColumnType[VariantVal] with DirectCopyColumnType[VariantVal] {
+ override def dataType: PhysicalDataType = PhysicalVariantType
+
+ /** Chosen to match the default size set in `VariantType`. */
+ override def defaultSize: Int = 2048
+
+ override def getField(row: InternalRow, ordinal: Int): VariantVal =
row.getVariant(ordinal)
+
+ override def setField(row: InternalRow, ordinal: Int, value: VariantVal):
Unit =
+ row.update(ordinal, value)
+
+ override def append(v: VariantVal, buffer: ByteBuffer): Unit = {
+ val valueSize = v.getValue().length
+ val metadataSize = v.getMetadata().length
+ ByteBufferHelper.putInt(buffer, valueSize)
+ ByteBufferHelper.putInt(buffer, metadataSize)
+ ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getValue()), buffer,
valueSize)
+ ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getMetadata()), buffer,
metadataSize)
+ }
+
+ override def extract(buffer: ByteBuffer): VariantVal = {
+ val valueSize = ByteBufferHelper.getInt(buffer)
+ val metadataSize = ByteBufferHelper.getInt(buffer)
+ val valueBuffer = ByteBuffer.allocate(valueSize)
+ ByteBufferHelper.copyMemory(buffer, valueBuffer, valueSize)
+ val metadataBuffer = ByteBuffer.allocate(metadataSize)
+ ByteBufferHelper.copyMemory(buffer, metadataBuffer, metadataSize)
+ new VariantVal(valueBuffer.array(), metadataBuffer.array())
+ }
+}
+
private[columnar] object ColumnType {
@tailrec
def apply(dataType: DataType): ColumnType[_] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 75416b878914..d07ebeb843bb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -89,6 +89,7 @@ object GenerateColumnAccessor extends
CodeGenerator[Seq[DataType], ColumnarItera
case _: StringType => classOf[StringColumnAccessor].getName
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
+ case VariantType => classOf[VariantColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
classOf[CompactDecimalColumnAccessor].getName
case dt: DecimalType => classOf[DecimalColumnAccessor].getName
@@ -101,7 +102,7 @@ object GenerateColumnAccessor extends
CodeGenerator[Seq[DataType], ColumnarItera
val createCode = dt match {
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
- case NullType | BinaryType | CalendarIntervalType =>
+ case NullType | BinaryType | CalendarIntervalType | VariantType =>
s"$accessorName = new
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case other =>
s"""$accessorName = new
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index de1e4330c564..ce2643f9e239 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -639,6 +639,78 @@ class VariantSuite extends QueryTest with
SharedSparkSession with ExpressionEval
}
}
+ test("variant in a cached row-based df") {
+ val query = """select
+ parse_json(format_string('{\"a\": %s}', id)) v,
+ cast(null as variant) as null_v,
+ case when id % 2 = 0 then parse_json(cast(id as string)) else null end
as some_null
+ from range(0, 10)"""
+ val df = spark.sql(query)
+ df.cache()
+
+ val expected = spark.sql(query)
+ checkAnswer(df, expected.collect())
+ }
+
+ test("struct of variant in a cached row-based df") {
+ val query = """select named_struct(
+ 'v', parse_json(format_string('{\"a\": %s}', id)),
+ 'null_v', cast(null as variant),
+ 'some_null', case when id % 2 = 0 then parse_json(cast(id as string))
else null end
+ ) v
+ from range(0, 10)"""
+ val df = spark.sql(query)
+ df.cache()
+
+ val expected = spark.sql(query)
+ checkAnswer(df, expected.collect())
+ }
+
+ test("array of variant in a cached row-based df") {
+ val query = """select array(
+ parse_json(cast(id as string)),
+ parse_json(format_string('{\"a\": %s}', id)),
+ null,
+ case when id % 2 = 0 then parse_json(cast(id as string)) else null end) v
+ from range(0, 10)"""
+ val df = spark.sql(query)
+ df.cache()
+
+ val expected = spark.sql(query)
+ checkAnswer(df, expected.collect())
+ }
+
+ test("map of variant in a cached row-based df") {
+ val query = """select map(
+ 'v', parse_json(format_string('{\"a\": %s}', id)),
+ 'null_v', cast(null as variant),
+ 'some_null', case when id % 2 = 0 then parse_json(cast(id as string))
else null end
+ ) v
+ from range(0, 10)"""
+ val df = spark.sql(query)
+ df.cache()
+
+ val expected = spark.sql(query)
+ checkAnswer(df, expected.collect())
+ }
+
+ test("variant in a cached column-based df") {
+ withTable("t") {
+ val query = """select named_struct(
+ 'v', parse_json(format_string('{\"a\": %s}', id)),
+ 'null_v', cast(null as variant),
+ 'some_null', case when id % 2 = 0 then parse_json(cast(id as string))
else null end
+ ) v
+ from range(0, 10)"""
+
spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t")
+ val df = spark.sql("select * from t")
+ df.cache()
+
+ val expected = spark.sql(query)
+ checkAnswer(df, expected.collect())
+ }
+ }
+
test("variant_get size") {
val largeKey = "x" * 1000
val df = Seq(s"""{ "$largeKey": {"a" : 1 },
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]