This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 804b2a416781 [SPARK-47506][SQL] Add support to all file source formats
for collated data types
804b2a416781 is described below
commit 804b2a4167813ac33f5d2e61898483a66c389059
Author: Stefan Kandic <[email protected]>
AuthorDate: Mon Mar 25 10:05:01 2024 +0500
[SPARK-47506][SQL] Add support to all file source formats for collated data
types
### What changes were proposed in this pull request?
Adding support and tests for collated types in all the file sources
currently supported by Spark, including:
- parquet
- json
- csv
- orc
- text
Important to note is that collations metadata will only be preserved if
these file sources are specified via the [CREATE TABLE USING
DATA_SOURCE](https://spark.apache.org/docs/latest/sql-ref-syntax-ddl-create-table-datasource.html)
api. Just using the dataframe api to directly write to a file will not
preserve collation metadata (except in the case of parquet because it saves the
schema in the file itself).
### Why are the changes needed?
To have collations be compatible with all file sources users can choose
from.
### Does this PR introduce _any_ user-facing change?
Yes, users can now create tables with collations using all supported file
sources.
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45641 from stefankandic/fileSources.
Authored-by: Stefan Kandic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../spark/sql/catalyst/json/JacksonGenerator.scala | 2 +-
.../execution/datasources/orc/OrcSerializer.scala | 2 +-
.../sql/execution/datasources/orc/OrcUtils.scala | 4 ++
.../datasources/text/TextFileFormat.scala | 2 +-
.../org/apache/spark/sql/CollationSuite.scala | 49 ++++++++++++++++++----
5 files changed, 47 insertions(+), 12 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index e01457ff1025..c2c6117e1e3a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -137,7 +137,7 @@ class JacksonGenerator(
(row: SpecializedGetters, ordinal: Int) =>
gen.writeNumber(row.getDouble(ordinal))
- case StringType =>
+ case _: StringType =>
(row: SpecializedGetters, ordinal: Int) =>
gen.writeString(row.getUTF8String(ordinal).toString)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
index 5ed73c3f78b1..75e3e13b0f7e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala
@@ -130,7 +130,7 @@ class OrcSerializer(dataSchema: StructType) {
// Don't reuse the result object for string and binary as it would cause
extra data copy.
- case StringType => (getter, ordinal) =>
+ case _: StringType => (getter, ordinal) =>
new Text(getter.getUTF8String(ordinal).getBytes)
case BinaryType => (getter, ordinal) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index 15fa2f88e128..24943b37d059 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -305,6 +305,10 @@ object OrcUtils extends Logging {
val typeDesc = new
TypeDescription(TypeDescription.Category.TIMESTAMP)
typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, t.typeName)
Some(typeDesc)
+ case _: StringType =>
+ val typeDesc = new TypeDescription(TypeDescription.Category.STRING)
+ typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME,
StringType.typeName)
+ Some(typeDesc)
case _ => None
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index e675f70e2a0d..caa4e3ed386b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -138,6 +138,6 @@ class TextFileFormat extends TextBasedFileFormat with
DataSourceRegister {
}
override def supportDataType(dataType: DataType): Boolean =
- dataType == StringType
+ dataType.isInstanceOf[StringType]
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
index 146ba63cf402..f0b51a5b2c19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -36,6 +36,10 @@ import org.apache.spark.sql.types.{MapType, StringType,
StructField, StructType}
class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
+ private val collationPreservingSources = Seq("parquet")
+ private val collationNonPreservingSources = Seq("orc", "csv", "json", "text")
+ private val allFileBasedDataSources = collationPreservingSources ++
collationNonPreservingSources
+
test("collate returns proper type") {
Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach {
collationName =>
checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa"))
@@ -424,22 +428,49 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
test("create table with collation") {
- val tableName = "parquet_dummy_tbl"
+ val tableName = "dummy_tbl"
val collationName = "UTF8_BINARY_LCASE"
val collationId = CollationFactory.collationNameToId(collationName)
- withTable(tableName) {
- sql(
+ allFileBasedDataSources.foreach { format =>
+ withTable(tableName) {
+ sql(
s"""
- |CREATE TABLE $tableName (c1 STRING COLLATE $collationName)
- |USING PARQUET
+ |CREATE TABLE $tableName (
+ | c1 STRING COLLATE $collationName
+ |)
+ |USING $format
|""".stripMargin)
- sql(s"INSERT INTO $tableName VALUES ('aaa')")
- sql(s"INSERT INTO $tableName VALUES ('AAA')")
+ sql(s"INSERT INTO $tableName VALUES ('aaa')")
+ sql(s"INSERT INTO $tableName VALUES ('AAA')")
- checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
Seq(Row(collationName)))
- assert(sql(s"select c1 FROM $tableName").schema.head.dataType ==
StringType(collationId))
+ checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
Seq(Row(collationName)))
+ assert(sql(s"select c1 FROM $tableName").schema.head.dataType ==
StringType(collationId))
+ }
+ }
+ }
+
+ test("write collated data to different data sources with dataframe api") {
+ val collationName = "UNICODE_CI"
+
+ allFileBasedDataSources.foreach { format =>
+ withTempPath { path =>
+ val df = sql(s"SELECT c COLLATE $collationName AS c FROM VALUES
('aaa') AS data(c)")
+ df.write.format(format).save(path.getAbsolutePath)
+
+ val readback = spark.read.format(format).load(path.getAbsolutePath)
+ val readbackCollation = if
(collationPreservingSources.contains(format)) {
+ collationName
+ } else {
+ "UTF8_BINARY"
+ }
+
+ checkAnswer(readback, Row("aaa"))
+ checkAnswer(
+ readback.selectExpr(s"collation(${readback.columns.head})"),
+ Row(readbackCollation))
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]