This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c9b072e1808 [SPARK-48833][SQL][VARIANT] Support variant in 
`InMemoryTableScan`
0c9b072e1808 is described below

commit 0c9b072e1808e180d8670e4100f3344f039cb072
Author: Richard Chen <[email protected]>
AuthorDate: Wed Jul 24 18:02:38 2024 +0800

    [SPARK-48833][SQL][VARIANT] Support variant in `InMemoryTableScan`
    
    ### What changes were proposed in this pull request?
    
    adds support for variant type in `InMemoryTableScan`, or `df.cache()` by 
supporting writing variant values to an inmemory buffer.
    
    ### Why are the changes needed?
    
    prior to this PR, calling `df.cache()` on a df that has a variant would 
fail because `InMemoryTableScan` does not support reading variant types.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    ### How was this patch tested?
    
    added UTs
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #47252 from richardc-db/variant_dfcache_support.
    
    Authored-by: Richard Chen <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/execution/columnar/ColumnAccessor.scala    |  6 +-
 .../sql/execution/columnar/ColumnBuilder.scala     |  4 ++
 .../spark/sql/execution/columnar/ColumnStats.scala | 15 +++++
 .../spark/sql/execution/columnar/ColumnType.scala  | 43 ++++++++++++-
 .../columnar/GenerateColumnAccessor.scala          |  3 +-
 .../scala/org/apache/spark/sql/VariantSuite.scala  | 72 ++++++++++++++++++++++
 6 files changed, 139 insertions(+), 4 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 9652a48e5270..2074649cc986 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import 
org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor
 import org.apache.spark.sql.execution.vectorized.WritableColumnVector
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
 
 /**
  * An `Iterator` like trait used to extract values from columnar byte buffer. 
When a value is
@@ -111,6 +111,10 @@ private[columnar] class IntervalColumnAccessor(buffer: 
ByteBuffer)
   extends BasicColumnAccessor[CalendarInterval](buffer, CALENDAR_INTERVAL)
   with NullableColumnAccessor
 
+private[columnar] class VariantColumnAccessor(buffer: ByteBuffer)
+  extends BasicColumnAccessor[VariantVal](buffer, VARIANT)
+  with NullableColumnAccessor
+
 private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, 
dataType: DecimalType)
   extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType))
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
index 9fafdb794841..b65ef12f12d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
@@ -131,6 +131,9 @@ class BinaryColumnBuilder extends ComplexColumnBuilder(new 
BinaryColumnStats, BI
 private[columnar]
 class IntervalColumnBuilder extends ComplexColumnBuilder(new 
IntervalColumnStats, CALENDAR_INTERVAL)
 
+private[columnar]
+class VariantColumnBuilder extends ComplexColumnBuilder(new 
VariantColumnStats, VARIANT)
+
 private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType)
   extends NativeColumnBuilder(new DecimalColumnStats(dataType), 
COMPACT_DECIMAL(dataType))
 
@@ -189,6 +192,7 @@ private[columnar] object ColumnBuilder {
       case s: StringType => new StringColumnBuilder(s)
       case BinaryType => new BinaryColumnBuilder
       case CalendarIntervalType => new IntervalColumnBuilder
+      case VariantType => new VariantColumnBuilder
       case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
         new CompactDecimalColumnBuilder(dt)
       case dt: DecimalType => new DecimalColumnBuilder(dt)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 45f489cb13c2..4e4b3667fa24 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -297,6 +297,21 @@ private[columnar] final class BinaryColumnStats extends 
ColumnStats {
     Array[Any](null, null, nullCount, count, sizeInBytes)
 }
 
+private[columnar] final class VariantColumnStats extends ColumnStats {
+  override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
+    if (!row.isNullAt(ordinal)) {
+      val size = VARIANT.actualSize(row, ordinal)
+      sizeInBytes += size
+      count += 1
+    } else {
+      gatherNullStats()
+    }
+  }
+
+  override def collectedStatistics: Array[Any] =
+    Array[Any](null, null, nullCount, count, sizeInBytes)
+}
+
 private[columnar] final class IntervalColumnStats extends ColumnStats {
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index b8e63294f3cd..5cc3a3d83d4c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -24,11 +24,11 @@ import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, 
PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, 
PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, 
PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, 
PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, 
PhysicalStructType}
+import org.apache.spark.sql.catalyst.types._
 import org.apache.spark.sql.errors.ExecutionErrors
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
 
 
 /**
@@ -815,6 +815,45 @@ private[columnar] object CALENDAR_INTERVAL extends 
ColumnType[CalendarInterval]
   }
 }
 
+/**
+ * Used to append/extract Java VariantVals into/from the underlying 
[[ByteBuffer]] of a column.
+ *
+ * Variants are encoded in `append` as:
+ * | value size | metadata size | value binary | metadata binary |
+ * and are only expected to be decoded in `extract`.
+ */
+private[columnar] object VARIANT
+  extends ColumnType[VariantVal] with DirectCopyColumnType[VariantVal] {
+  override def dataType: PhysicalDataType = PhysicalVariantType
+
+  /** Chosen to match the default size set in `VariantType`. */
+  override def defaultSize: Int = 2048
+
+  override def getField(row: InternalRow, ordinal: Int): VariantVal = 
row.getVariant(ordinal)
+
+  override def setField(row: InternalRow, ordinal: Int, value: VariantVal): 
Unit =
+    row.update(ordinal, value)
+
+  override def append(v: VariantVal, buffer: ByteBuffer): Unit = {
+    val valueSize = v.getValue().length
+    val metadataSize = v.getMetadata().length
+    ByteBufferHelper.putInt(buffer, valueSize)
+    ByteBufferHelper.putInt(buffer, metadataSize)
+    ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getValue()), buffer, 
valueSize)
+    ByteBufferHelper.copyMemory(ByteBuffer.wrap(v.getMetadata()), buffer, 
metadataSize)
+  }
+
+  override def extract(buffer: ByteBuffer): VariantVal = {
+    val valueSize = ByteBufferHelper.getInt(buffer)
+    val metadataSize = ByteBufferHelper.getInt(buffer)
+    val valueBuffer = ByteBuffer.allocate(valueSize)
+    ByteBufferHelper.copyMemory(buffer, valueBuffer, valueSize)
+    val metadataBuffer = ByteBuffer.allocate(metadataSize)
+    ByteBufferHelper.copyMemory(buffer, metadataBuffer, metadataSize)
+    new VariantVal(valueBuffer.array(), metadataBuffer.array())
+  }
+}
+
 private[columnar] object ColumnType {
   @tailrec
   def apply(dataType: DataType): ColumnType[_] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 75416b878914..d07ebeb843bb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -89,6 +89,7 @@ object GenerateColumnAccessor extends 
CodeGenerator[Seq[DataType], ColumnarItera
         case _: StringType => classOf[StringColumnAccessor].getName
         case BinaryType => classOf[BinaryColumnAccessor].getName
         case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
+        case VariantType => classOf[VariantColumnAccessor].getName
         case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
           classOf[CompactDecimalColumnAccessor].getName
         case dt: DecimalType => classOf[DecimalColumnAccessor].getName
@@ -101,7 +102,7 @@ object GenerateColumnAccessor extends 
CodeGenerator[Seq[DataType], ColumnarItera
       val createCode = dt match {
         case t if CodeGenerator.isPrimitiveType(dt) =>
           s"$accessorName = new 
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
-        case NullType | BinaryType | CalendarIntervalType =>
+        case NullType | BinaryType | CalendarIntervalType | VariantType =>
           s"$accessorName = new 
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
         case other =>
           s"""$accessorName = new 
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
index de1e4330c564..ce2643f9e239 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -639,6 +639,78 @@ class VariantSuite extends QueryTest with 
SharedSparkSession with ExpressionEval
     }
   }
 
+  test("variant in a cached row-based df") {
+    val query = """select
+      parse_json(format_string('{\"a\": %s}', id)) v,
+      cast(null as variant) as null_v,
+      case when id % 2 = 0 then parse_json(cast(id as string)) else null end 
as some_null
+    from range(0, 10)"""
+    val df = spark.sql(query)
+    df.cache()
+
+    val expected = spark.sql(query)
+    checkAnswer(df, expected.collect())
+  }
+
+  test("struct of variant in a cached row-based df") {
+    val query = """select named_struct(
+      'v', parse_json(format_string('{\"a\": %s}', id)),
+      'null_v', cast(null as variant),
+      'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) 
else null end
+    ) v
+    from range(0, 10)"""
+    val df = spark.sql(query)
+    df.cache()
+
+    val expected = spark.sql(query)
+    checkAnswer(df, expected.collect())
+  }
+
+  test("array of variant in a cached row-based df") {
+    val query = """select array(
+      parse_json(cast(id as string)),
+      parse_json(format_string('{\"a\": %s}', id)),
+      null,
+      case when id % 2 = 0 then parse_json(cast(id as string)) else null end) v
+    from range(0, 10)"""
+    val df = spark.sql(query)
+    df.cache()
+
+    val expected = spark.sql(query)
+    checkAnswer(df, expected.collect())
+  }
+
+  test("map of variant in a cached row-based df") {
+    val query = """select map(
+      'v', parse_json(format_string('{\"a\": %s}', id)),
+      'null_v', cast(null as variant),
+      'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) 
else null end
+    ) v
+    from range(0, 10)"""
+    val df = spark.sql(query)
+    df.cache()
+
+    val expected = spark.sql(query)
+    checkAnswer(df, expected.collect())
+  }
+
+  test("variant in a cached column-based df") {
+    withTable("t") {
+      val query = """select named_struct(
+        'v', parse_json(format_string('{\"a\": %s}', id)),
+        'null_v', cast(null as variant),
+        'some_null', case when id % 2 = 0 then parse_json(cast(id as string)) 
else null end
+      ) v
+      from range(0, 10)"""
+      
spark.sql(query).write.format("parquet").mode("overwrite").saveAsTable("t")
+      val df = spark.sql("select * from t")
+      df.cache()
+
+      val expected = spark.sql(query)
+      checkAnswer(df, expected.collect())
+    }
+  }
+
   test("variant_get size") {
     val largeKey = "x" * 1000
     val df = Seq(s"""{ "$largeKey": {"a" : 1 },


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to