This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-4.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push: new a691e7658f1f [SPARK-53130][SQL][PYTHON] Fix `toJson` behavior of collated string types a691e7658f1f is described below commit a691e7658f1fbb506a371efc8bb3c97ae5cf5509 Author: Stefan Kandic <stefan.kan...@databricks.com> AuthorDate: Wed Aug 6 22:29:59 2025 +0800 [SPARK-53130][SQL][PYTHON] Fix `toJson` behavior of collated string types ### What changes were proposed in this pull request? Changing the behavior of collated string types to return their collation in the `toJson` methods and to still keep backwards compatibility with older engine versions reading tables with collations by propagating this fix upstream in `StructField` where the collation will be removed from the type but still kept in the metadata. ### Why are the changes needed? Old way of handling `toJson` meant that collated string types will not be able to be serialized and deserialized correctly unless they are a part of `StructField`. Initially, we thought that this is not a big deal, but then later we faced some issues regarding this, especially in pyspark which uses json primarily to parse types back and forth. This could avoid hacky changes in future like the one in https://github.com/apache/spark/pull/51688 without changing any behavior for how tables/schemas work. ### Does this PR introduce _any_ user-facing change? Technically yes, but it is a small change that should not impact any queries, just how StringType is represented when not in a StructField object. ### How was this patch tested? New and existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51850 from stefankandic/fixStringJson. Authored-by: Stefan Kandic <stefan.kan...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 19ea6ff95ef334cf9655c6b29b5a0123ebf58d87) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- python/pyspark/sql/tests/test_types.py | 15 +++++++++ python/pyspark/sql/types.py | 39 +++++++++++++++++++--- .../org/apache/spark/sql/types/DataType.scala | 2 ++ .../org/apache/spark/sql/types/StringType.scala | 7 ---- .../org/apache/spark/sql/types/StructField.scala | 21 +++++++++++- .../org/apache/spark/sql/types/DataTypeSuite.scala | 13 ++++++++ 6 files changed, 84 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 15247b97664d..f1c6ad11707e 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -623,6 +623,17 @@ class TypesTestsMixin: from pyspark.sql.types import _parse_datatype_json_string unicode_collation = "UNICODE" + utf8_lcase_collation = "UTF8_LCASE" + + standalone_string = StringType(unicode_collation) + + standalone_array = ArrayType(StringType(unicode_collation)) + + standalone_map = MapType(StringType(utf8_lcase_collation), StringType(unicode_collation)) + + standalone_nested = ArrayType( + MapType(StringType(utf8_lcase_collation), ArrayType(StringType(unicode_collation))) + ) simple_struct = StructType([StructField("c1", StringType(unicode_collation))]) @@ -694,6 +705,10 @@ class TypesTestsMixin: ) schemas = [ + standalone_string, + standalone_array, + standalone_map, + standalone_nested, simple_struct, nested_struct, array_in_schema, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 451a80e2a53a..e6c48084f98d 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -296,11 +296,8 @@ class StringType(AtomicType): return f"string collate {self.collation}" - # For backwards compatibility and compatibility with other readers all string types - # are serialized in json as regular strings and the collation info is written to - # struct field metadata def jsonValue(self) -> str: - return "string" + return self.simpleString() def __repr__(self) -> str: return ( @@ -1010,11 +1007,39 @@ class StructField(DataType): return { "name": self.name, - "type": self.dataType.jsonValue(), + "type": self._dataTypeJsonValue(collationMetadata), "nullable": self.nullable, "metadata": metadata, } + def _dataTypeJsonValue(self, collationMetadata: Dict[str, str]) -> Union[str, Dict[str, Any]]: + if not collationMetadata: + return self.dataType.jsonValue() + + def removeCollations(dt: DataType) -> DataType: + # Only recurse into map and array types as any child struct type + # will have already been processed. + if isinstance(dt, ArrayType): + return ArrayType(removeCollations(dt.elementType), dt.containsNull) + elif isinstance(dt, MapType): + return MapType( + removeCollations(dt.keyType), + removeCollations(dt.valueType), + dt.valueContainsNull, + ) + elif isinstance(dt, StringType): + return StringType() + elif isinstance(dt, VarcharType): + return VarcharType(dt.length) + elif isinstance(dt, CharType): + return CharType(dt.length) + else: + return dt + + # As we want to be backwards compatible we should remove all collations information from the + # json and only keep that information in the metadata. + return removeCollations(self.dataType).jsonValue() + @classmethod def fromJson(cls, json: Dict[str, Any]) -> "StructField": metadata = json.get("metadata") @@ -1843,6 +1868,7 @@ _all_mappable_types: Dict[str, Type[DataType]] = { _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)") _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") +_STRING_WITH_COLLATION = re.compile(r"string\s+collate\s+(\w+)") _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") @@ -2003,6 +2029,9 @@ def _parse_datatype_json_value( if first_field is not None and second_field is None: return YearMonthIntervalType(first_field) return YearMonthIntervalType(first_field, second_field) + elif _STRING_WITH_COLLATION.match(json_value): + m = _STRING_WITH_COLLATION.match(json_value) + return StringType(m.group(1)) # type: ignore[union-attr] elif _LENGTH_CHAR.match(json_value): m = _LENGTH_CHAR.match(json_value) return CharType(int(m.group(1))) # type: ignore[union-attr] diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 58ea5ef2a673..d5d85c316741 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -126,6 +126,7 @@ object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r + private val STRING_WITH_COLLATION = """string\s+collate\s+(\w+)""".r val COLLATIONS_METADATA_KEY = "__COLLATIONS" @@ -214,6 +215,7 @@ object DataType { case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) case CHAR_TYPE(length) => CharType(length.toInt) case VARCHAR_TYPE(length) => VarcharType(length.toInt) + case STRING_WITH_COLLATION(collation) => StringType(collation) // For backwards compatibility, previously the type name of NullType is "null" case "null" => NullType case "timestamp_ltz" => TimestampType diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 5fec578b0358..787730f77508 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import org.json4s.JsonAST.{JString, JValue} - import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.internal.SqlApiConf @@ -90,11 +88,6 @@ class StringType private[sql] ( private[sql] def collationName: String = CollationFactory.fetchCollation(collationId).collationName - // Due to backwards compatibility and compatibility with other readers - // all string types are serialized in json as regular strings and - // the collation information is written to struct field metadata - override def jsonValue: JValue = JString("string") - override def equals(obj: Any): Boolean = { obj match { case s: StringType => s.collationId == collationId && s.constraint == constraint diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala index 60362ec46f53..87b02c2a2926 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -70,11 +70,30 @@ case class StructField( private[sql] def jsonValue: JValue = { ("name" -> name) ~ - ("type" -> dataType.jsonValue) ~ + ("type" -> dataTypeJsonValue) ~ ("nullable" -> nullable) ~ ("metadata" -> metadataJson) } + private[sql] def dataTypeJsonValue: JValue = { + if (collationMetadata.isEmpty) return dataType.jsonValue + + def removeCollations(dt: DataType): DataType = dt match { + // Only recurse into map and array types as any child struct type + // will have already been processed. + case ArrayType(et, nullable) => + ArrayType(removeCollations(et), nullable) + case MapType(kt, vt, nullable) => + MapType(removeCollations(kt), removeCollations(vt), nullable) + case st: StringType => StringHelper.removeCollation(st) + case _ => dt + } + + // As we want to be backwards compatible we should remove all collations information from the + // json and only keep that information in the metadata. + removeCollations(dataType).jsonValue + } + private def metadataJson: JValue = { val metadataJsonValue = metadata.jsonValue metadataJsonValue match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c4287df94d61..34fe8e7abefc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM class DataTypeSuite extends SparkFunSuite { private val UNICODE_COLLATION_ID = CollationFactory.collationNameToId("UNICODE") + private val UTF8_LCASE_COLLATION_ID = CollationFactory.collationNameToId("UTF8_LCASE") test("construct an ArrayType") { val array = ArrayType(StringType) @@ -1143,6 +1144,17 @@ class DataTypeSuite extends SparkFunSuite { } test("schema with collation should not change during ser/de") { + val standaloneString = StringType(UNICODE_COLLATION_ID) + + val standaloneArray = ArrayType(StringType(UNICODE_COLLATION_ID)) + + val standaloneMap = MapType(StringType(UNICODE_COLLATION_ID), + StringType(UTF8_LCASE_COLLATION_ID)) + + val standaloneNested = ArrayType(MapType( + StringType(UNICODE_COLLATION_ID), + ArrayType(StringType(UTF8_LCASE_COLLATION_ID)))) + val simpleStruct = StructType( StructField("c1", StringType(UNICODE_COLLATION_ID)) :: Nil) @@ -1183,6 +1195,7 @@ class DataTypeSuite extends SparkFunSuite { mapWithKeyInNameInSchema ++ arrayInMapInNestedSchema.fields ++ nestedArrayInMap.fields) Seq( + standaloneString, standaloneArray, standaloneMap, standaloneNested, simpleStruct, caseInsensitiveNames, specialCharsInName, nestedStruct, arrayInSchema, mapInSchema, mapWithKeyInNameInSchema, nestedArrayInMap, arrayInMapInNestedSchema, schemaWithMultipleFields) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org