This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 804b2a416781 [SPARK-47506][SQL] Add support to all file source formats for collated data types 804b2a416781 is described below commit 804b2a4167813ac33f5d2e61898483a66c389059 Author: Stefan Kandic <stefan.kan...@databricks.com> AuthorDate: Mon Mar 25 10:05:01 2024 +0500 [SPARK-47506][SQL] Add support to all file source formats for collated data types ### What changes were proposed in this pull request? Adding support and tests for collated types in all the file sources currently supported by Spark, including: - parquet - json - csv - orc - text Important to note is that collations metadata will only be preserved if these file sources are specified via the [CREATE TABLE USING DATA_SOURCE](https://spark.apache.org/docs/latest/sql-ref-syntax-ddl-create-table-datasource.html) api. Just using the dataframe api to directly write to a file will not preserve collation metadata (except in the case of parquet because it saves the schema in the file itself). ### Why are the changes needed? To have collations be compatible with all file sources users can choose from. ### Does this PR introduce _any_ user-facing change? Yes, users can now create tables with collations using all supported file sources. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45641 from stefankandic/fileSources. Authored-by: Stefan Kandic <stefan.kan...@databricks.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/json/JacksonGenerator.scala | 2 +- .../execution/datasources/orc/OrcSerializer.scala | 2 +- .../sql/execution/datasources/orc/OrcUtils.scala | 4 ++ .../datasources/text/TextFileFormat.scala | 2 +- .../org/apache/spark/sql/CollationSuite.scala | 49 ++++++++++++++++++---- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index e01457ff1025..c2c6117e1e3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -137,7 +137,7 @@ class JacksonGenerator( (row: SpecializedGetters, ordinal: Int) => gen.writeNumber(row.getDouble(ordinal)) - case StringType => + case _: StringType => (row: SpecializedGetters, ordinal: Int) => gen.writeString(row.getUTF8String(ordinal).toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 5ed73c3f78b1..75e3e13b0f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -130,7 +130,7 @@ class OrcSerializer(dataSchema: StructType) { // Don't reuse the result object for string and binary as it would cause extra data copy. - case StringType => (getter, ordinal) => + case _: StringType => (getter, ordinal) => new Text(getter.getUTF8String(ordinal).getBytes) case BinaryType => (getter, ordinal) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index 15fa2f88e128..24943b37d059 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -305,6 +305,10 @@ object OrcUtils extends Logging { val typeDesc = new TypeDescription(TypeDescription.Category.TIMESTAMP) typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, t.typeName) Some(typeDesc) + case _: StringType => + val typeDesc = new TypeDescription(TypeDescription.Category.STRING) + typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, StringType.typeName) + Some(typeDesc) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index e675f70e2a0d..caa4e3ed386b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -138,6 +138,6 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { } override def supportDataType(dataType: DataType): Boolean = - dataType == StringType + dataType.isInstanceOf[StringType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 146ba63cf402..f0b51a5b2c19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -36,6 +36,10 @@ import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName + private val collationPreservingSources = Seq("parquet") + private val collationNonPreservingSources = Seq("orc", "csv", "json", "text") + private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources + test("collate returns proper type") { Seq("utf8_binary", "utf8_binary_lcase", "unicode", "unicode_ci").foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) @@ -424,22 +428,49 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("create table with collation") { - val tableName = "parquet_dummy_tbl" + val tableName = "dummy_tbl" val collationName = "UTF8_BINARY_LCASE" val collationId = CollationFactory.collationNameToId(collationName) - withTable(tableName) { - sql( + allFileBasedDataSources.foreach { format => + withTable(tableName) { + sql( s""" - |CREATE TABLE $tableName (c1 STRING COLLATE $collationName) - |USING PARQUET + |CREATE TABLE $tableName ( + | c1 STRING COLLATE $collationName + |) + |USING $format |""".stripMargin) - sql(s"INSERT INTO $tableName VALUES ('aaa')") - sql(s"INSERT INTO $tableName VALUES ('AAA')") + sql(s"INSERT INTO $tableName VALUES ('aaa')") + sql(s"INSERT INTO $tableName VALUES ('AAA')") - checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName))) - assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId)) + checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName))) + assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId)) + } + } + } + + test("write collated data to different data sources with dataframe api") { + val collationName = "UNICODE_CI" + + allFileBasedDataSources.foreach { format => + withTempPath { path => + val df = sql(s"SELECT c COLLATE $collationName AS c FROM VALUES ('aaa') AS data(c)") + df.write.format(format).save(path.getAbsolutePath) + + val readback = spark.read.format(format).load(path.getAbsolutePath) + val readbackCollation = if (collationPreservingSources.contains(format)) { + collationName + } else { + "UTF8_BINARY" + } + + checkAnswer(readback, Row("aaa")) + checkAnswer( + readback.selectExpr(s"collation(${readback.columns.head})"), + Row(readbackCollation)) + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org