This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0bc38acc615a [SPARK-48675][SQL] Fix cache table with collated column
0bc38acc615a is described below

commit 0bc38acc615ad411a97779c6a1ff43d4391c0c3d
Author: Nikola Mandic <[email protected]>
AuthorDate: Fri Jun 21 22:45:47 2024 +0800

    [SPARK-48675][SQL] Fix cache table with collated column
    
    ### What changes were proposed in this pull request?
    
    Following sequence of queries produces the error:
    ```
    > cache lazy table t as select col from values ('a' collate utf8_lcase) as 
(col);
    > select col from t;
    org.apache.spark.SparkException: not support type: 
org.apache.spark.sql.types.StringType1.
            at 
org.apache.spark.sql.errors.QueryExecutionErrors$.notSupportTypeError(QueryExecutionErrors.scala:1069)
            at 
org.apache.spark.sql.execution.columnar.ColumnBuilder$.apply(ColumnBuilder.scala:200)
            at 
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.$anonfun$next$1(InMemoryRelation.scala:85)
            at scala.collection.immutable.List.map(List.scala:247)
            at scala.collection.immutable.List.map(List.scala:79)
            at 
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:84)
            at 
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:82)
            at 
org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:296)
            at 
org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:293)
    ...
    ```
    This is also the problem on non-lazy cached tables.
    
    It turns out that the problem happens to occur during the execution of 
`InMemoryTableScanExec` where we need to update `ColumnAccessor`, 
`ColumnBuilder`, `ColumnType` and `ColumnStats`.
    
    ### Why are the changes needed?
    
    To fix the described error.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the described sequence of queries should produce valid results after 
these changes are applied instead of throwing error.
    
    ### How was this patch tested?
    
    Added checks to columnar suites for the mentioned classes and integration 
test to `CollationSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47045 from nikolamand-db/SPARK-48675.
    
    Authored-by: Nikola Mandic <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/execution/columnar/ColumnAccessor.scala    |  6 +--
 .../sql/execution/columnar/ColumnBuilder.scala     |  5 +-
 .../spark/sql/execution/columnar/ColumnStats.scala | 10 ++--
 .../spark/sql/execution/columnar/ColumnType.scala  | 12 +++--
 .../columnar/GenerateColumnAccessor.scala          |  4 +-
 .../columnar/compression/compressionSchemes.scala  |  4 +-
 .../org/apache/spark/sql/CollationSuite.scala      | 34 +++++++++++++
 .../sql/execution/columnar/ColumnStatsSuite.scala  | 59 +++++++++++++++++++++-
 .../sql/execution/columnar/ColumnTypeSuite.scala   | 33 +++++++++---
 .../sql/execution/columnar/ColumnarTestUtils.scala |  2 +-
 .../columnar/NullableColumnAccessorSuite.scala     | 23 +++++++--
 .../columnar/NullableColumnBuilderSuite.scala      | 23 +++++++--
 .../compression/CompressionSchemeBenchmark.scala   |  5 +-
 .../compression/DictionaryEncodingSuite.scala      | 14 +++--
 .../compression/RunLengthEncodingSuite.scala       | 14 +++--
 15 files changed, 205 insertions(+), 43 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 4a922dcb062e..9652a48e5270 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -100,8 +100,8 @@ private[columnar] class FloatColumnAccessor(buffer: 
ByteBuffer)
 private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer)
   extends NativeColumnAccessor(buffer, DOUBLE)
 
-private[columnar] class StringColumnAccessor(buffer: ByteBuffer)
-  extends NativeColumnAccessor(buffer, STRING)
+private[columnar] class StringColumnAccessor(buffer: ByteBuffer, dataType: 
StringType)
+  extends NativeColumnAccessor(buffer, STRING(dataType))
 
 private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer)
   extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
@@ -147,7 +147,7 @@ private[sql] object ColumnAccessor {
         new LongColumnAccessor(buf)
       case FloatType => new FloatColumnAccessor(buf)
       case DoubleType => new DoubleColumnAccessor(buf)
-      case StringType => new StringColumnAccessor(buf)
+      case s: StringType => new StringColumnAccessor(buf, s)
       case BinaryType => new BinaryColumnAccessor(buf)
       case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
         new CompactDecimalColumnAccessor(buf, dt)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
index 367547155bee..9fafdb794841 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
@@ -122,7 +122,8 @@ private[columnar]
 class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, 
DOUBLE)
 
 private[columnar]
-class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, 
STRING)
+class StringColumnBuilder(dataType: StringType)
+  extends NativeColumnBuilder(new StringColumnStats(dataType), 
STRING(dataType))
 
 private[columnar]
 class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, 
BINARY)
@@ -185,7 +186,7 @@ private[columnar] object ColumnBuilder {
         new LongColumnBuilder
       case FloatType => new FloatColumnBuilder
       case DoubleType => new DoubleColumnBuilder
-      case StringType => new StringColumnBuilder
+      case s: StringType => new StringColumnBuilder(s)
       case BinaryType => new BinaryColumnBuilder
       case CalendarIntervalType => new IntervalColumnBuilder
       case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 18ef84262aad..45f489cb13c2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -255,14 +255,16 @@ private[columnar] final class DoubleColumnStats extends 
ColumnStats {
     Array[Any](lower, upper, nullCount, count, sizeInBytes)
 }
 
-private[columnar] final class StringColumnStats extends ColumnStats {
+private[columnar] final class StringColumnStats(collationId: Int) extends 
ColumnStats {
+  def this(dt: StringType) = this(dt.collationId)
+
   protected var upper: UTF8String = null
   protected var lower: UTF8String = null
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     if (!row.isNullAt(ordinal)) {
       val value = row.getUTF8String(ordinal)
-      val size = STRING.actualSize(row, ordinal)
+      val size = STRING(collationId).actualSize(row, ordinal)
       gatherValueStats(value, size)
     } else {
       gatherNullStats()
@@ -270,8 +272,8 @@ private[columnar] final class StringColumnStats extends 
ColumnStats {
   }
 
   def gatherValueStats(value: UTF8String, size: Int): Unit = {
-    if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone()
-    if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone()
+    if (upper == null || value.semanticCompare(upper, collationId) > 0) upper 
= value.clone()
+    if (lower == null || value.semanticCompare(lower, collationId) < 0) lower 
= value.clone()
     sizeInBytes += size
     count += 1
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index ee1f9b413302..b8e63294f3cd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -491,8 +491,8 @@ private[columnar] trait DirectCopyColumnType[JvmType] 
extends ColumnType[JvmType
   }
 }
 
-private[columnar] object STRING
-  extends NativeColumnType(PhysicalStringType(StringType.collationId), 8)
+private[columnar] case class STRING(collationId: Int)
+  extends NativeColumnType(PhysicalStringType(collationId), 8)
     with DirectCopyColumnType[UTF8String] {
 
   override def actualSize(row: InternalRow, ordinal: Int): Int = {
@@ -532,6 +532,12 @@ private[columnar] object STRING
   override def clone(v: UTF8String): UTF8String = v.clone()
 }
 
+private[columnar] object STRING {
+  def apply(dt: StringType): STRING = {
+    STRING(dt.collationId)
+  }
+}
+
 private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int)
   extends NativeColumnType(PhysicalDecimalType(precision, scale), 8) {
 
@@ -821,7 +827,7 @@ private[columnar] object ColumnType {
       case LongType | TimestampType | TimestampNTZType | _: 
DayTimeIntervalType => LONG
       case FloatType => FLOAT
       case DoubleType => DOUBLE
-      case StringType => STRING
+      case s: StringType => STRING(s)
       case BinaryType => BINARY
       case i: CalendarIntervalType => CALENDAR_INTERVAL
       case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => 
COMPACT_DECIMAL(dt)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 5eadc7d47c92..75416b878914 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -86,7 +86,7 @@ object GenerateColumnAccessor extends 
CodeGenerator[Seq[DataType], ColumnarItera
           classOf[LongColumnAccessor].getName
         case FloatType => classOf[FloatColumnAccessor].getName
         case DoubleType => classOf[DoubleColumnAccessor].getName
-        case StringType => classOf[StringColumnAccessor].getName
+        case _: StringType => classOf[StringColumnAccessor].getName
         case BinaryType => classOf[BinaryColumnAccessor].getName
         case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
         case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
@@ -101,7 +101,7 @@ object GenerateColumnAccessor extends 
CodeGenerator[Seq[DataType], ColumnarItera
       val createCode = dt match {
         case t if CodeGenerator.isPrimitiveType(dt) =>
           s"$accessorName = new 
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
-        case NullType | StringType | BinaryType | CalendarIntervalType =>
+        case NullType | BinaryType | CalendarIntervalType =>
           s"$accessorName = new 
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
         case other =>
           s"""$accessorName = new 
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
index 46044f6919d1..86d76856e12b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
@@ -176,7 +176,7 @@ private[columnar] case object RunLengthEncoding extends 
CompressionScheme {
   }
 
   override def supports(columnType: ColumnType[_]): Boolean = columnType match 
{
-    case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
+    case INT | LONG | SHORT | BYTE | _: STRING | BOOLEAN => true
     case _ => false
   }
 
@@ -373,7 +373,7 @@ private[columnar] case object DictionaryEncoding extends 
CompressionScheme {
   }
 
   override def supports(columnType: ColumnType[_]): Boolean = columnType match 
{
-    case INT | LONG | STRING => true
+    case INT | LONG | _: STRING => true
     case _ => false
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index e2a11fc137c3..c4eaedfb215e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -31,6 +31,7 @@ import 
org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
 import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.joins._
 import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
 import org.apache.spark.sql.types.{MapType, StringType, StructField, 
StructType}
@@ -1431,4 +1432,37 @@ class CollationSuite extends DatasourceV2SQLBase with 
AdaptiveSparkPlanHelper {
     })
   }
 
+  test("cache table with collated columns") {
+    val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI")
+    val lazyOptions = Seq(false, true)
+
+    for (
+      collation <- collations;
+      lazyTable <- lazyOptions
+    ) {
+      val lazyStr = if (lazyTable) "LAZY" else ""
+
+      def checkCacheTable(values: String): Unit = {
+        sql(s"CACHE $lazyStr TABLE tbl AS SELECT col FROM VALUES ($values) AS 
(col)")
+        // Checks in-memory fetching code path.
+        val all = sql("SELECT col FROM tbl")
+        assert(all.queryExecution.executedPlan.collectFirst {
+          case _: InMemoryTableScanExec => true
+        }.nonEmpty)
+        checkAnswer(all, Row("a"))
+        // Checks column stats code path.
+        checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Row("a"))
+        checkAnswer(sql("SELECT col FROM tbl WHERE col = 'b'"), Seq.empty)
+      }
+
+      withTable("tbl") {
+        checkCacheTable(s"'a' COLLATE $collation")
+      }
+      withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) {
+        withTable("tbl") {
+          checkCacheTable("'a'")
+        }
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
index f39057013e64..bdb118b91fa2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.sql.types.StringType
 
 class ColumnStatsSuite extends SparkFunSuite {
   testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
@@ -28,9 +29,9 @@ class ColumnStatsSuite extends SparkFunSuite {
   testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, 
Long.MinValue, 0))
   testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, 
Float.MinValue, 0))
   testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, 
Double.MinValue, 0))
-  testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
   testDecimalColumnStats(Array(null, null, 0))
   testIntervalColumnStats(Array(null, null, 0))
+  testStringColumnStats(Array(null, null, 0))
 
   def testColumnStats[T <: PhysicalDataType, U <: ColumnStats](
       columnStatsClass: Class[U],
@@ -141,4 +142,60 @@ class ColumnStatsSuite extends SparkFunSuite {
       }
     }
   }
+
+  def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats](
+      initialStatistics: Array[Any]): Unit = {
+
+    Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", 
"UNICODE_CI").foreach(collation => {
+      val columnType = STRING(StringType(collation))
+
+      test(s"STRING($collation): empty") {
+        val columnStats = new 
StringColumnStats(StringType(collation).collationId)
+        columnStats.collectedStatistics.zip(initialStatistics).foreach {
+          case (actual, expected) => assert(actual === expected)
+        }
+      }
+
+      test(s"STRING($collation): non-empty") {
+        import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
+
+        val columnStats = new 
StringColumnStats(StringType(collation).collationId)
+        val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ 
Seq.fill(10)(makeNullRow(1))
+        rows.foreach(columnStats.gatherStats(_, 0))
+
+        val values = rows.take(10).map(_.get(0,
+          ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
+        val ordering = PhysicalDataType.ordering(
+          ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
+        val stats = columnStats.collectedStatistics
+
+        assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+        assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+        assertResult(10, "Wrong null count")(stats(2))
+        assertResult(20, "Wrong row count")(stats(3))
+        assertResult(stats(4), "Wrong size in bytes") {
+          rows.map { row =>
+            if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
+          }.sum
+        }
+      }
+    })
+
+    test("STRING(UTF8_LCASE): collation-defined ordering") {
+      import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+      import org.apache.spark.unsafe.types.UTF8String
+
+      val columnStats = new 
StringColumnStats(StringType("UTF8_LCASE").collationId)
+      val rows = Seq("b", "a", "C", "A").map(str => {
+        val row = new GenericInternalRow(1)
+        row(0) = UTF8String.fromString(str)
+        row
+      })
+      rows.foreach(columnStats.gatherStats(_, 0))
+
+      val stats = columnStats.collectedStatistics
+      assertResult(UTF8String.fromString("a"), "Wrong lower bound")(stats(0))
+      assertResult(UTF8String.fromString("C"), "Wrong upper bound")(stats(1))
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
index d79ac8dc3545..a95bda9bf71d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
 import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, 
PhysicalDataType, PhysicalMapType, PhysicalStructType}
+import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
@@ -40,7 +41,9 @@ class ColumnTypeSuite extends SparkFunSuite {
     val checks = Map(
       NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
       FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20, 
10) -> 12,
-      STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE 
-> 68,
+      STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8,
+      STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) -> 
8,
+      BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
       CALENDAR_INTERVAL -> 16)
 
     checks.foreach { case (columnType, expectedSize) =>
@@ -73,7 +76,12 @@ class ColumnTypeSuite extends SparkFunSuite {
     checkActualSize(LONG, Long.MaxValue, 8)
     checkActualSize(FLOAT, Float.MaxValue, 4)
     checkActualSize(DOUBLE, Double.MaxValue, 8)
-    checkActualSize(STRING, "hello", 4 + 
"hello".getBytes(StandardCharsets.UTF_8).length)
+    Seq(
+      "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
+    ).foreach(collation => {
+      checkActualSize(STRING(StringType(collation)),
+        "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
+    })
     checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
     checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
     checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
@@ -93,7 +101,10 @@ class ColumnTypeSuite extends SparkFunSuite {
   testNativeColumnType(FLOAT)
   testNativeColumnType(DOUBLE)
   testNativeColumnType(COMPACT_DECIMAL(15, 10))
-  testNativeColumnType(STRING)
+  testNativeColumnType(STRING(StringType)) // UTF8_BINARY
+  testNativeColumnType(STRING(StringType("UTF8_LCASE")))
+  testNativeColumnType(STRING(StringType("UNICODE")))
+  testNativeColumnType(STRING(StringType("UNICODE_CI")))
 
   testColumnType(NULL)
   testColumnType(BINARY)
@@ -104,11 +115,18 @@ class ColumnTypeSuite extends SparkFunSuite {
   testColumnType(CALENDAR_INTERVAL)
 
   def testNativeColumnType[T <: PhysicalDataType](columnType: 
NativeColumnType[T]): Unit = {
-    testColumnType[T#InternalType](columnType)
+    val typeName = columnType match {
+      case s: STRING =>
+        val collation = 
CollationFactory.fetchCollation(s.collationId).collationName
+        Some(if (collation == "UTF8_BINARY") "STRING" else 
s"STRING($collation)")
+      case _ => None
+    }
+    testColumnType[T#InternalType](columnType, typeName)
   }
 
-  def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {
-
+  def testColumnType[JvmType](
+      columnType: ColumnType[JvmType],
+      typeName: Option[String] = None): Unit = {
     val proj = UnsafeProjection.create(
       
Array[DataType](ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
     val converter = CatalystTypeConverters.createToScalaConverter(
@@ -116,8 +134,9 @@ class ColumnTypeSuite extends SparkFunSuite {
     val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy())
     val totalSize = seq.map(_.getSizeInBytes).sum
     val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize)
+    val testName = typeName.getOrElse(columnType.toString)
 
-    test(s"$columnType append/extract") {
+    test(s"$testName append/extract") {
       val buffer = 
ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder())
       seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer))
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
index e7b509c087b7..d08c34056f56 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
@@ -50,7 +50,7 @@ object ColumnarTestUtils {
       case LONG => Random.nextLong()
       case FLOAT => Random.nextFloat()
       case DOUBLE => Random.nextDouble()
-      case STRING => 
UTF8String.fromString(Random.nextString(Random.nextInt(32)))
+      case _: STRING => 
UTF8String.fromString(Random.nextString(Random.nextInt(32)))
       case BINARY => randomBytes(Random.nextInt(32))
       case CALENDAR_INTERVAL =>
         new CalendarInterval(Random.nextInt(), Random.nextInt(), 
Random.nextLong())
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
index 169d9356c00c..ee622793ee0a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
 import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, 
PhysicalMapType, PhysicalStructType}
+import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.types._
 
 class TestNullableColumnAccessor[JvmType](
@@ -41,21 +42,33 @@ object TestNullableColumnAccessor {
 class NullableColumnAccessorSuite extends SparkFunSuite {
   import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
 
-  Seq(
+  val stringTypes = Seq(
+    STRING(StringType), // UTF8_BINARY
+    STRING(StringType("UTF8_LCASE")),
+    STRING(StringType("UNICODE")),
+    STRING(StringType("UNICODE_CI")))
+  val otherTypes = Seq(
     NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
-    STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
+    BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
     STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
     ARRAY(PhysicalArrayType(IntegerType, true)),
     MAP(PhysicalMapType(IntegerType, StringType, true)),
     CALENDAR_INTERVAL)
-    .foreach {
+
+  stringTypes.foreach(s => {
+    val collation = 
CollationFactory.fetchCollation(s.collationId).collationName
+    val typeName = if (collation == "UTF8_BINARY") "STRING" else 
s"STRING($collation)"
+    testNullableColumnAccessor(s, Some(typeName))
+  })
+  otherTypes.foreach {
     testNullableColumnAccessor(_)
   }
 
   def testNullableColumnAccessor[JvmType](
-      columnType: ColumnType[JvmType]): Unit = {
+      columnType: ColumnType[JvmType],
+      testTypeName: Option[String] = None): Unit = {
 
-    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+    val typeName = 
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
     val nullRow = makeNullRow(1)
 
     test(s"Nullable $typeName column accessor: empty column") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
index 22f557e49ded..609212c95e98 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
 import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, 
PhysicalMapType, PhysicalStructType}
+import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.types._
 
 class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
@@ -39,21 +40,33 @@ object TestNullableColumnBuilder {
 class NullableColumnBuilderSuite extends SparkFunSuite {
   import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
 
-  Seq(
+  val stringTypes = Seq(
+    STRING(StringType), // UTF8_BINARY
+    STRING(StringType("UTF8_LCASE")),
+    STRING(StringType("UNICODE")),
+    STRING(StringType("UNICODE_CI")))
+  val otherTypes = Seq(
     BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
-    STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
+    BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
     STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
     ARRAY(PhysicalArrayType(IntegerType, true)),
     MAP(PhysicalMapType(IntegerType, StringType, true)),
     CALENDAR_INTERVAL)
-    .foreach {
+
+  stringTypes.foreach(s => {
+    val collation = 
CollationFactory.fetchCollation(s.collationId).collationName
+    val typeName = if (collation == "UTF8_BINARY") "STRING" else 
s"STRING($collation)"
+    testNullableColumnBuilder(s, Some(typeName))
+  })
+  otherTypes.foreach {
     testNullableColumnBuilder(_)
   }
 
   def testNullableColumnBuilder[JvmType](
-      columnType: ColumnType[JvmType]): Unit = {
+      columnType: ColumnType[JvmType],
+      testTypeName: Option[String] = None): Unit = {
 
-    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+    val typeName = 
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
     val dataType = columnType.dataType
     val proj = UnsafeProjection.create(Array[DataType](
       ColumnarDataTypeUtils.toLogicalDataType(dataType)))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
index 2da0adf439da..05ae57530529 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
@@ -27,6 +27,7 @@ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.catalyst.types.PhysicalDataType
 import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG, 
NativeColumnType, SHORT, STRING}
+import org.apache.spark.sql.types.StringType
 import org.apache.spark.util.Utils._
 
 /**
@@ -231,8 +232,8 @@ object CompressionSchemeBenchmark extends BenchmarkBase 
with AllCompressionSchem
     }
     testData.rewind()
 
-    runEncodeBenchmark("STRING Encode", iters, count, STRING, testData)
-    runDecodeBenchmark("STRING Decode", iters, count, STRING, testData)
+    runEncodeBenchmark("STRING Encode", iters, count, STRING(StringType), 
testData)
+    runDecodeBenchmark("STRING Decode", iters, count, STRING(StringType), 
testData)
   }
 
   override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
index 10d5e8a0eb9a..2b2bc7e76136 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
@@ -25,19 +25,27 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.types.StringType
 
 class DictionaryEncodingSuite extends SparkFunSuite {
   val nullValue = -1
   testDictionaryEncoding(new IntColumnStats, INT)
   testDictionaryEncoding(new LongColumnStats, LONG)
-  testDictionaryEncoding(new StringColumnStats, STRING, false)
+  Seq(
+    "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
+  ).foreach(collation => {
+    val dt = StringType(collation)
+    val typeName = if (collation == "UTF8_BINARY") "STRING" else 
s"STRING($collation)"
+    testDictionaryEncoding(new StringColumnStats(dt), STRING(dt), false, 
Some(typeName))
+  })
 
   def testDictionaryEncoding[T <: PhysicalDataType](
       columnStats: ColumnStats,
       columnType: NativeColumnType[T],
-      testDecompress: Boolean = true): Unit = {
+      testDecompress: Boolean = true,
+      testTypeName: Option[String] = None): Unit = {
 
-    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+    val typeName = 
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
 
     def buildDictionary(buffer: ByteBuffer) = {
       (0 until buffer.getInt()).map(columnType.extract(buffer) -> 
_.toShort).toMap
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
index 00f242a6b9c4..9b0067fd2983 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.types.StringType
 
 class RunLengthEncodingSuite extends SparkFunSuite {
   val nullValue = -1
@@ -31,14 +32,21 @@ class RunLengthEncodingSuite extends SparkFunSuite {
   testRunLengthEncoding(new ShortColumnStats, SHORT)
   testRunLengthEncoding(new IntColumnStats, INT)
   testRunLengthEncoding(new LongColumnStats, LONG)
-  testRunLengthEncoding(new StringColumnStats, STRING, false)
+  Seq(
+    "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
+  ).foreach(collation => {
+    val dt = StringType(collation)
+    val typeName = if (collation == "UTF8_BINARY") "STRING" else 
s"STRING($collation)"
+    testRunLengthEncoding(new StringColumnStats(dt), STRING(dt), false, 
Some(typeName))
+  })
 
   def testRunLengthEncoding[T <: PhysicalDataType](
       columnStats: ColumnStats,
       columnType: NativeColumnType[T],
-      testDecompress: Boolean = true): Unit = {
+      testDecompress: Boolean = true,
+      testTypeName: Option[String] = None): Unit = {
 
-    val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+    val typeName = 
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
 
     def skeleton(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]): Unit = {
       // -------------


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to