This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0bc38acc615a [SPARK-48675][SQL] Fix cache table with collated column
0bc38acc615a is described below
commit 0bc38acc615ad411a97779c6a1ff43d4391c0c3d
Author: Nikola Mandic <[email protected]>
AuthorDate: Fri Jun 21 22:45:47 2024 +0800
[SPARK-48675][SQL] Fix cache table with collated column
### What changes were proposed in this pull request?
Following sequence of queries produces the error:
```
> cache lazy table t as select col from values ('a' collate utf8_lcase) as
(col);
> select col from t;
org.apache.spark.SparkException: not support type:
org.apache.spark.sql.types.StringType1.
at
org.apache.spark.sql.errors.QueryExecutionErrors$.notSupportTypeError(QueryExecutionErrors.scala:1069)
at
org.apache.spark.sql.execution.columnar.ColumnBuilder$.apply(ColumnBuilder.scala:200)
at
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.$anonfun$next$1(InMemoryRelation.scala:85)
at scala.collection.immutable.List.map(List.scala:247)
at scala.collection.immutable.List.map(List.scala:79)
at
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:84)
at
org.apache.spark.sql.execution.columnar.DefaultCachedBatchSerializer$$anon$1.next(InMemoryRelation.scala:82)
at
org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:296)
at
org.apache.spark.sql.execution.columnar.CachedRDDBuilder$$anon$2.next(InMemoryRelation.scala:293)
...
```
This is also the problem on non-lazy cached tables.
It turns out that the problem happens to occur during the execution of
`InMemoryTableScanExec` where we need to update `ColumnAccessor`,
`ColumnBuilder`, `ColumnType` and `ColumnStats`.
### Why are the changes needed?
To fix the described error.
### Does this PR introduce _any_ user-facing change?
Yes, the described sequence of queries should produce valid results after
these changes are applied instead of throwing error.
### How was this patch tested?
Added checks to columnar suites for the mentioned classes and integration
test to `CollationSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47045 from nikolamand-db/SPARK-48675.
Authored-by: Nikola Mandic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/execution/columnar/ColumnAccessor.scala | 6 +--
.../sql/execution/columnar/ColumnBuilder.scala | 5 +-
.../spark/sql/execution/columnar/ColumnStats.scala | 10 ++--
.../spark/sql/execution/columnar/ColumnType.scala | 12 +++--
.../columnar/GenerateColumnAccessor.scala | 4 +-
.../columnar/compression/compressionSchemes.scala | 4 +-
.../org/apache/spark/sql/CollationSuite.scala | 34 +++++++++++++
.../sql/execution/columnar/ColumnStatsSuite.scala | 59 +++++++++++++++++++++-
.../sql/execution/columnar/ColumnTypeSuite.scala | 33 +++++++++---
.../sql/execution/columnar/ColumnarTestUtils.scala | 2 +-
.../columnar/NullableColumnAccessorSuite.scala | 23 +++++++--
.../columnar/NullableColumnBuilderSuite.scala | 23 +++++++--
.../compression/CompressionSchemeBenchmark.scala | 5 +-
.../compression/DictionaryEncodingSuite.scala | 14 +++--
.../compression/RunLengthEncodingSuite.scala | 14 +++--
15 files changed, 205 insertions(+), 43 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 4a922dcb062e..9652a48e5270 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -100,8 +100,8 @@ private[columnar] class FloatColumnAccessor(buffer:
ByteBuffer)
private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, DOUBLE)
-private[columnar] class StringColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, STRING)
+private[columnar] class StringColumnAccessor(buffer: ByteBuffer, dataType:
StringType)
+ extends NativeColumnAccessor(buffer, STRING(dataType))
private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer)
extends BasicColumnAccessor[Array[Byte]](buffer, BINARY)
@@ -147,7 +147,7 @@ private[sql] object ColumnAccessor {
new LongColumnAccessor(buf)
case FloatType => new FloatColumnAccessor(buf)
case DoubleType => new DoubleColumnAccessor(buf)
- case StringType => new StringColumnAccessor(buf)
+ case s: StringType => new StringColumnAccessor(buf, s)
case BinaryType => new BinaryColumnAccessor(buf)
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
new CompactDecimalColumnAccessor(buf, dt)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
index 367547155bee..9fafdb794841 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
@@ -122,7 +122,8 @@ private[columnar]
class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats,
DOUBLE)
private[columnar]
-class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats,
STRING)
+class StringColumnBuilder(dataType: StringType)
+ extends NativeColumnBuilder(new StringColumnStats(dataType),
STRING(dataType))
private[columnar]
class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats,
BINARY)
@@ -185,7 +186,7 @@ private[columnar] object ColumnBuilder {
new LongColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder
- case StringType => new StringColumnBuilder
+ case s: StringType => new StringColumnBuilder(s)
case BinaryType => new BinaryColumnBuilder
case CalendarIntervalType => new IntervalColumnBuilder
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 18ef84262aad..45f489cb13c2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -255,14 +255,16 @@ private[columnar] final class DoubleColumnStats extends
ColumnStats {
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
-private[columnar] final class StringColumnStats extends ColumnStats {
+private[columnar] final class StringColumnStats(collationId: Int) extends
ColumnStats {
+ def this(dt: StringType) = this(dt.collationId)
+
protected var upper: UTF8String = null
protected var lower: UTF8String = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
- val size = STRING.actualSize(row, ordinal)
+ val size = STRING(collationId).actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats()
@@ -270,8 +272,8 @@ private[columnar] final class StringColumnStats extends
ColumnStats {
}
def gatherValueStats(value: UTF8String, size: Int): Unit = {
- if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone()
- if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone()
+ if (upper == null || value.semanticCompare(upper, collationId) > 0) upper
= value.clone()
+ if (lower == null || value.semanticCompare(lower, collationId) < 0) lower
= value.clone()
sizeInBytes += size
count += 1
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index ee1f9b413302..b8e63294f3cd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -491,8 +491,8 @@ private[columnar] trait DirectCopyColumnType[JvmType]
extends ColumnType[JvmType
}
}
-private[columnar] object STRING
- extends NativeColumnType(PhysicalStringType(StringType.collationId), 8)
+private[columnar] case class STRING(collationId: Int)
+ extends NativeColumnType(PhysicalStringType(collationId), 8)
with DirectCopyColumnType[UTF8String] {
override def actualSize(row: InternalRow, ordinal: Int): Int = {
@@ -532,6 +532,12 @@ private[columnar] object STRING
override def clone(v: UTF8String): UTF8String = v.clone()
}
+private[columnar] object STRING {
+ def apply(dt: StringType): STRING = {
+ STRING(dt.collationId)
+ }
+}
+
private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(PhysicalDecimalType(precision, scale), 8) {
@@ -821,7 +827,7 @@ private[columnar] object ColumnType {
case LongType | TimestampType | TimestampNTZType | _:
DayTimeIntervalType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
- case StringType => STRING
+ case s: StringType => STRING(s)
case BinaryType => BINARY
case i: CalendarIntervalType => CALENDAR_INTERVAL
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
COMPACT_DECIMAL(dt)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index 5eadc7d47c92..75416b878914 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -86,7 +86,7 @@ object GenerateColumnAccessor extends
CodeGenerator[Seq[DataType], ColumnarItera
classOf[LongColumnAccessor].getName
case FloatType => classOf[FloatColumnAccessor].getName
case DoubleType => classOf[DoubleColumnAccessor].getName
- case StringType => classOf[StringColumnAccessor].getName
+ case _: StringType => classOf[StringColumnAccessor].getName
case BinaryType => classOf[BinaryColumnAccessor].getName
case CalendarIntervalType => classOf[IntervalColumnAccessor].getName
case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
@@ -101,7 +101,7 @@ object GenerateColumnAccessor extends
CodeGenerator[Seq[DataType], ColumnarItera
val createCode = dt match {
case t if CodeGenerator.isPrimitiveType(dt) =>
s"$accessorName = new
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
- case NullType | StringType | BinaryType | CalendarIntervalType =>
+ case NullType | BinaryType | CalendarIntervalType =>
s"$accessorName = new
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));"
case other =>
s"""$accessorName = new
$accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder),
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
index 46044f6919d1..86d76856e12b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala
@@ -176,7 +176,7 @@ private[columnar] case object RunLengthEncoding extends
CompressionScheme {
}
override def supports(columnType: ColumnType[_]): Boolean = columnType match
{
- case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
+ case INT | LONG | SHORT | BYTE | _: STRING | BOOLEAN => true
case _ => false
}
@@ -373,7 +373,7 @@ private[columnar] case object DictionaryEncoding extends
CompressionScheme {
}
override def supports(columnType: ColumnType[_]): Boolean = columnType match
{
- case INT | LONG | STRING => true
+ case INT | LONG | _: STRING => true
case _ => false
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index e2a11fc137c3..c4eaedfb215e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -31,6 +31,7 @@ import
org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
ObjectHashAggregateExec}
+import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.types.{MapType, StringType, StructField,
StructType}
@@ -1431,4 +1432,37 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
})
}
+ test("cache table with collated columns") {
+ val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI")
+ val lazyOptions = Seq(false, true)
+
+ for (
+ collation <- collations;
+ lazyTable <- lazyOptions
+ ) {
+ val lazyStr = if (lazyTable) "LAZY" else ""
+
+ def checkCacheTable(values: String): Unit = {
+ sql(s"CACHE $lazyStr TABLE tbl AS SELECT col FROM VALUES ($values) AS
(col)")
+ // Checks in-memory fetching code path.
+ val all = sql("SELECT col FROM tbl")
+ assert(all.queryExecution.executedPlan.collectFirst {
+ case _: InMemoryTableScanExec => true
+ }.nonEmpty)
+ checkAnswer(all, Row("a"))
+ // Checks column stats code path.
+ checkAnswer(sql("SELECT col FROM tbl WHERE col = 'a'"), Row("a"))
+ checkAnswer(sql("SELECT col FROM tbl WHERE col = 'b'"), Seq.empty)
+ }
+
+ withTable("tbl") {
+ checkCacheTable(s"'a' COLLATE $collation")
+ }
+ withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) {
+ withTable("tbl") {
+ checkCacheTable("'a'")
+ }
+ }
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
index f39057013e64..bdb118b91fa2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.types.PhysicalDataType
+import org.apache.spark.sql.types.StringType
class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
@@ -28,9 +29,9 @@ class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue,
Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue,
Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue,
Double.MinValue, 0))
- testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
testIntervalColumnStats(Array(null, null, 0))
+ testStringColumnStats(Array(null, null, 0))
def testColumnStats[T <: PhysicalDataType, U <: ColumnStats](
columnStatsClass: Class[U],
@@ -141,4 +142,60 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}
+
+ def testStringColumnStats[T <: PhysicalDataType, U <: ColumnStats](
+ initialStatistics: Array[Any]): Unit = {
+
+ Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE",
"UNICODE_CI").foreach(collation => {
+ val columnType = STRING(StringType(collation))
+
+ test(s"STRING($collation): empty") {
+ val columnStats = new
StringColumnStats(StringType(collation).collationId)
+ columnStats.collectedStatistics.zip(initialStatistics).foreach {
+ case (actual, expected) => assert(actual === expected)
+ }
+ }
+
+ test(s"STRING($collation): non-empty") {
+ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
+
+ val columnStats = new
StringColumnStats(StringType(collation).collationId)
+ val rows = Seq.fill(10)(makeRandomRow(columnType)) ++
Seq.fill(10)(makeNullRow(1))
+ rows.foreach(columnStats.gatherStats(_, 0))
+
+ val values = rows.take(10).map(_.get(0,
+ ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
+ val ordering = PhysicalDataType.ordering(
+ ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType))
+ val stats = columnStats.collectedStatistics
+
+ assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
+ assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
+ assertResult(10, "Wrong null count")(stats(2))
+ assertResult(20, "Wrong row count")(stats(3))
+ assertResult(stats(4), "Wrong size in bytes") {
+ rows.map { row =>
+ if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
+ }.sum
+ }
+ }
+ })
+
+ test("STRING(UTF8_LCASE): collation-defined ordering") {
+ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+ import org.apache.spark.unsafe.types.UTF8String
+
+ val columnStats = new
StringColumnStats(StringType("UTF8_LCASE").collationId)
+ val rows = Seq("b", "a", "C", "A").map(str => {
+ val row = new GenericInternalRow(1)
+ row(0) = UTF8String.fromString(str)
+ row
+ })
+ rows.foreach(columnStats.gatherStats(_, 0))
+
+ val stats = columnStats.collectedStatistics
+ assertResult(UTF8String.fromString("a"), "Wrong lower bound")(stats(0))
+ assertResult(UTF8String.fromString("C"), "Wrong upper bound")(stats(1))
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
index d79ac8dc3545..a95bda9bf71d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType,
PhysicalDataType, PhysicalMapType, PhysicalStructType}
+import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -40,7 +41,9 @@ class ColumnTypeSuite extends SparkFunSuite {
val checks = Map(
NULL -> 0, BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, LONG -> 8,
FLOAT -> 4, DOUBLE -> 8, COMPACT_DECIMAL(15, 10) -> 8, LARGE_DECIMAL(20,
10) -> 12,
- STRING -> 8, BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE
-> 68,
+ STRING(StringType) -> 8, STRING(StringType("UTF8_LCASE")) -> 8,
+ STRING(StringType("UNICODE")) -> 8, STRING(StringType("UNICODE_CI")) ->
8,
+ BINARY -> 16, STRUCT_TYPE -> 20, ARRAY_TYPE -> 28, MAP_TYPE -> 68,
CALENDAR_INTERVAL -> 16)
checks.foreach { case (columnType, expectedSize) =>
@@ -73,7 +76,12 @@ class ColumnTypeSuite extends SparkFunSuite {
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(DOUBLE, Double.MaxValue, 8)
- checkActualSize(STRING, "hello", 4 +
"hello".getBytes(StandardCharsets.UTF_8).length)
+ Seq(
+ "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
+ ).foreach(collation => {
+ checkActualSize(STRING(StringType(collation)),
+ "hello", 4 + "hello".getBytes(StandardCharsets.UTF_8).length)
+ })
checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4)
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
@@ -93,7 +101,10 @@ class ColumnTypeSuite extends SparkFunSuite {
testNativeColumnType(FLOAT)
testNativeColumnType(DOUBLE)
testNativeColumnType(COMPACT_DECIMAL(15, 10))
- testNativeColumnType(STRING)
+ testNativeColumnType(STRING(StringType)) // UTF8_BINARY
+ testNativeColumnType(STRING(StringType("UTF8_LCASE")))
+ testNativeColumnType(STRING(StringType("UNICODE")))
+ testNativeColumnType(STRING(StringType("UNICODE_CI")))
testColumnType(NULL)
testColumnType(BINARY)
@@ -104,11 +115,18 @@ class ColumnTypeSuite extends SparkFunSuite {
testColumnType(CALENDAR_INTERVAL)
def testNativeColumnType[T <: PhysicalDataType](columnType:
NativeColumnType[T]): Unit = {
- testColumnType[T#InternalType](columnType)
+ val typeName = columnType match {
+ case s: STRING =>
+ val collation =
CollationFactory.fetchCollation(s.collationId).collationName
+ Some(if (collation == "UTF8_BINARY") "STRING" else
s"STRING($collation)")
+ case _ => None
+ }
+ testColumnType[T#InternalType](columnType, typeName)
}
- def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = {
-
+ def testColumnType[JvmType](
+ columnType: ColumnType[JvmType],
+ typeName: Option[String] = None): Unit = {
val proj = UnsafeProjection.create(
Array[DataType](ColumnarDataTypeUtils.toLogicalDataType(columnType.dataType)))
val converter = CatalystTypeConverters.createToScalaConverter(
@@ -116,8 +134,9 @@ class ColumnTypeSuite extends SparkFunSuite {
val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy())
val totalSize = seq.map(_.getSizeInBytes).sum
val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize)
+ val testName = typeName.getOrElse(columnType.toString)
- test(s"$columnType append/extract") {
+ test(s"$testName append/extract") {
val buffer =
ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder())
seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
index e7b509c087b7..d08c34056f56 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala
@@ -50,7 +50,7 @@ object ColumnarTestUtils {
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
- case STRING =>
UTF8String.fromString(Random.nextString(Random.nextInt(32)))
+ case _: STRING =>
UTF8String.fromString(Random.nextString(Random.nextInt(32)))
case BINARY => randomBytes(Random.nextInt(32))
case CALENDAR_INTERVAL =>
new CalendarInterval(Random.nextInt(), Random.nextInt(),
Random.nextLong())
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
index 169d9356c00c..ee622793ee0a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType,
PhysicalMapType, PhysicalStructType}
+import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._
class TestNullableColumnAccessor[JvmType](
@@ -41,21 +42,33 @@ object TestNullableColumnAccessor {
class NullableColumnAccessorSuite extends SparkFunSuite {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
- Seq(
+ val stringTypes = Seq(
+ STRING(StringType), // UTF8_BINARY
+ STRING(StringType("UTF8_LCASE")),
+ STRING(StringType("UNICODE")),
+ STRING(StringType("UNICODE_CI")))
+ val otherTypes = Seq(
NULL, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
- STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
+ BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
ARRAY(PhysicalArrayType(IntegerType, true)),
MAP(PhysicalMapType(IntegerType, StringType, true)),
CALENDAR_INTERVAL)
- .foreach {
+
+ stringTypes.foreach(s => {
+ val collation =
CollationFactory.fetchCollation(s.collationId).collationName
+ val typeName = if (collation == "UTF8_BINARY") "STRING" else
s"STRING($collation)"
+ testNullableColumnAccessor(s, Some(typeName))
+ })
+ otherTypes.foreach {
testNullableColumnAccessor(_)
}
def testNullableColumnAccessor[JvmType](
- columnType: ColumnType[JvmType]): Unit = {
+ columnType: ColumnType[JvmType],
+ testTypeName: Option[String] = None): Unit = {
- val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val typeName =
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
val nullRow = makeNullRow(1)
test(s"Nullable $typeName column accessor: empty column") {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
index 22f557e49ded..609212c95e98 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeProjection}
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType,
PhysicalMapType, PhysicalStructType}
+import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._
class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType])
@@ -39,21 +40,33 @@ object TestNullableColumnBuilder {
class NullableColumnBuilderSuite extends SparkFunSuite {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
- Seq(
+ val stringTypes = Seq(
+ STRING(StringType), // UTF8_BINARY
+ STRING(StringType("UTF8_LCASE")),
+ STRING(StringType("UNICODE")),
+ STRING(StringType("UNICODE_CI")))
+ val otherTypes = Seq(
BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE,
- STRING, BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
+ BINARY, COMPACT_DECIMAL(15, 10), LARGE_DECIMAL(20, 10),
STRUCT(PhysicalStructType(Array(StructField("a", StringType)))),
ARRAY(PhysicalArrayType(IntegerType, true)),
MAP(PhysicalMapType(IntegerType, StringType, true)),
CALENDAR_INTERVAL)
- .foreach {
+
+ stringTypes.foreach(s => {
+ val collation =
CollationFactory.fetchCollation(s.collationId).collationName
+ val typeName = if (collation == "UTF8_BINARY") "STRING" else
s"STRING($collation)"
+ testNullableColumnBuilder(s, Some(typeName))
+ })
+ otherTypes.foreach {
testNullableColumnBuilder(_)
}
def testNullableColumnBuilder[JvmType](
- columnType: ColumnType[JvmType]): Unit = {
+ columnType: ColumnType[JvmType],
+ testTypeName: Option[String] = None): Unit = {
- val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val typeName =
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
val dataType = columnType.dataType
val proj = UnsafeProjection.create(Array[DataType](
ColumnarDataTypeUtils.toLogicalDataType(dataType)))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
index 2da0adf439da..05ae57530529 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala
@@ -27,6 +27,7 @@ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.execution.columnar.{BOOLEAN, INT, LONG,
NativeColumnType, SHORT, STRING}
+import org.apache.spark.sql.types.StringType
import org.apache.spark.util.Utils._
/**
@@ -231,8 +232,8 @@ object CompressionSchemeBenchmark extends BenchmarkBase
with AllCompressionSchem
}
testData.rewind()
- runEncodeBenchmark("STRING Encode", iters, count, STRING, testData)
- runDecodeBenchmark("STRING Decode", iters, count, STRING, testData)
+ runEncodeBenchmark("STRING Encode", iters, count, STRING(StringType),
testData)
+ runDecodeBenchmark("STRING Decode", iters, count, STRING(StringType),
testData)
}
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
index 10d5e8a0eb9a..2b2bc7e76136 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala
@@ -25,19 +25,27 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.types.StringType
class DictionaryEncodingSuite extends SparkFunSuite {
val nullValue = -1
testDictionaryEncoding(new IntColumnStats, INT)
testDictionaryEncoding(new LongColumnStats, LONG)
- testDictionaryEncoding(new StringColumnStats, STRING, false)
+ Seq(
+ "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
+ ).foreach(collation => {
+ val dt = StringType(collation)
+ val typeName = if (collation == "UTF8_BINARY") "STRING" else
s"STRING($collation)"
+ testDictionaryEncoding(new StringColumnStats(dt), STRING(dt), false,
Some(typeName))
+ })
def testDictionaryEncoding[T <: PhysicalDataType](
columnStats: ColumnStats,
columnType: NativeColumnType[T],
- testDecompress: Boolean = true): Unit = {
+ testDecompress: Boolean = true,
+ testTypeName: Option[String] = None): Unit = {
- val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val typeName =
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
def buildDictionary(buffer: ByteBuffer) = {
(0 until buffer.getInt()).map(columnType.extract(buffer) ->
_.toShort).toMap
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
index 00f242a6b9c4..9b0067fd2983 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.types.StringType
class RunLengthEncodingSuite extends SparkFunSuite {
val nullValue = -1
@@ -31,14 +32,21 @@ class RunLengthEncodingSuite extends SparkFunSuite {
testRunLengthEncoding(new ShortColumnStats, SHORT)
testRunLengthEncoding(new IntColumnStats, INT)
testRunLengthEncoding(new LongColumnStats, LONG)
- testRunLengthEncoding(new StringColumnStats, STRING, false)
+ Seq(
+ "UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI"
+ ).foreach(collation => {
+ val dt = StringType(collation)
+ val typeName = if (collation == "UTF8_BINARY") "STRING" else
s"STRING($collation)"
+ testRunLengthEncoding(new StringColumnStats(dt), STRING(dt), false,
Some(typeName))
+ })
def testRunLengthEncoding[T <: PhysicalDataType](
columnStats: ColumnStats,
columnType: NativeColumnType[T],
- testDecompress: Boolean = true): Unit = {
+ testDecompress: Boolean = true,
+ testTypeName: Option[String] = None): Unit = {
- val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val typeName =
testTypeName.getOrElse(columnType.getClass.getSimpleName.stripSuffix("$"))
def skeleton(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]): Unit = {
// -------------
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]