This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new a691e7658f1f [SPARK-53130][SQL][PYTHON] Fix `toJson` behavior of 
collated string types
a691e7658f1f is described below

commit a691e7658f1fbb506a371efc8bb3c97ae5cf5509
Author: Stefan Kandic <stefan.kan...@databricks.com>
AuthorDate: Wed Aug 6 22:29:59 2025 +0800

    [SPARK-53130][SQL][PYTHON] Fix `toJson` behavior of collated string types
    
    ### What changes were proposed in this pull request?
    Changing the behavior of collated string types to return their collation in 
the `toJson` methods and to still keep backwards compatibility with older 
engine versions reading tables with collations by propagating this fix upstream 
in `StructField` where the collation will be removed from the type but still 
kept in the metadata.
    
    ### Why are the changes needed?
    Old way of handling `toJson` meant that collated string types will not be 
able to be serialized and deserialized correctly unless they are a part of 
`StructField`. Initially, we thought that this is not a big deal, but then 
later we faced some issues regarding this, especially in pyspark which uses 
json primarily to parse types back and forth.
    This could avoid hacky changes in future like the one in 
https://github.com/apache/spark/pull/51688 without changing any behavior for 
how tables/schemas work.
    
    ### Does this PR introduce _any_ user-facing change?
    Technically yes, but it is a small change that should not impact any 
queries, just how StringType is represented when not in a StructField object.
    
    ### How was this patch tested?
    New and existing unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #51850 from stefankandic/fixStringJson.
    
    Authored-by: Stefan Kandic <stefan.kan...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 19ea6ff95ef334cf9655c6b29b5a0123ebf58d87)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 python/pyspark/sql/tests/test_types.py             | 15 +++++++++
 python/pyspark/sql/types.py                        | 39 +++++++++++++++++++---
 .../org/apache/spark/sql/types/DataType.scala      |  2 ++
 .../org/apache/spark/sql/types/StringType.scala    |  7 ----
 .../org/apache/spark/sql/types/StructField.scala   | 21 +++++++++++-
 .../org/apache/spark/sql/types/DataTypeSuite.scala | 13 ++++++++
 6 files changed, 84 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 15247b97664d..f1c6ad11707e 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -623,6 +623,17 @@ class TypesTestsMixin:
         from pyspark.sql.types import _parse_datatype_json_string
 
         unicode_collation = "UNICODE"
+        utf8_lcase_collation = "UTF8_LCASE"
+
+        standalone_string = StringType(unicode_collation)
+
+        standalone_array = ArrayType(StringType(unicode_collation))
+
+        standalone_map = MapType(StringType(utf8_lcase_collation), 
StringType(unicode_collation))
+
+        standalone_nested = ArrayType(
+            MapType(StringType(utf8_lcase_collation), 
ArrayType(StringType(unicode_collation)))
+        )
 
         simple_struct = StructType([StructField("c1", 
StringType(unicode_collation))])
 
@@ -694,6 +705,10 @@ class TypesTestsMixin:
         )
 
         schemas = [
+            standalone_string,
+            standalone_array,
+            standalone_map,
+            standalone_nested,
             simple_struct,
             nested_struct,
             array_in_schema,
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 451a80e2a53a..e6c48084f98d 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -296,11 +296,8 @@ class StringType(AtomicType):
 
         return f"string collate {self.collation}"
 
-    # For backwards compatibility and compatibility with other readers all 
string types
-    # are serialized in json as regular strings and the collation info is 
written to
-    # struct field metadata
     def jsonValue(self) -> str:
-        return "string"
+        return self.simpleString()
 
     def __repr__(self) -> str:
         return (
@@ -1010,11 +1007,39 @@ class StructField(DataType):
 
         return {
             "name": self.name,
-            "type": self.dataType.jsonValue(),
+            "type": self._dataTypeJsonValue(collationMetadata),
             "nullable": self.nullable,
             "metadata": metadata,
         }
 
+    def _dataTypeJsonValue(self, collationMetadata: Dict[str, str]) -> 
Union[str, Dict[str, Any]]:
+        if not collationMetadata:
+            return self.dataType.jsonValue()
+
+        def removeCollations(dt: DataType) -> DataType:
+            # Only recurse into map and array types as any child struct type
+            # will have already been processed.
+            if isinstance(dt, ArrayType):
+                return ArrayType(removeCollations(dt.elementType), 
dt.containsNull)
+            elif isinstance(dt, MapType):
+                return MapType(
+                    removeCollations(dt.keyType),
+                    removeCollations(dt.valueType),
+                    dt.valueContainsNull,
+                )
+            elif isinstance(dt, StringType):
+                return StringType()
+            elif isinstance(dt, VarcharType):
+                return VarcharType(dt.length)
+            elif isinstance(dt, CharType):
+                return CharType(dt.length)
+            else:
+                return dt
+
+        # As we want to be backwards compatible we should remove all 
collations information from the
+        # json and only keep that information in the metadata.
+        return removeCollations(self.dataType).jsonValue()
+
     @classmethod
     def fromJson(cls, json: Dict[str, Any]) -> "StructField":
         metadata = json.get("metadata")
@@ -1843,6 +1868,7 @@ _all_mappable_types: Dict[str, Type[DataType]] = {
 
 _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)")
 _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
+_STRING_WITH_COLLATION = re.compile(r"string\s+collate\s+(\w+)")
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
 _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to 
(day|hour|minute|second))?")
 _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
@@ -2003,6 +2029,9 @@ def _parse_datatype_json_value(
             if first_field is not None and second_field is None:
                 return YearMonthIntervalType(first_field)
             return YearMonthIntervalType(first_field, second_field)
+        elif _STRING_WITH_COLLATION.match(json_value):
+            m = _STRING_WITH_COLLATION.match(json_value)
+            return StringType(m.group(1))  # type: ignore[union-attr]
         elif _LENGTH_CHAR.match(json_value):
             m = _LENGTH_CHAR.match(json_value)
             return CharType(int(m.group(1)))  # type: ignore[union-attr]
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 58ea5ef2a673..d5d85c316741 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -126,6 +126,7 @@ object DataType {
   private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
   private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r
   private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r
+  private val STRING_WITH_COLLATION = """string\s+collate\s+(\w+)""".r
 
   val COLLATIONS_METADATA_KEY = "__COLLATIONS"
 
@@ -214,6 +215,7 @@ object DataType {
       case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, 
scale.toInt)
       case CHAR_TYPE(length) => CharType(length.toInt)
       case VARCHAR_TYPE(length) => VarcharType(length.toInt)
+      case STRING_WITH_COLLATION(collation) => StringType(collation)
       // For backwards compatibility, previously the type name of NullType is 
"null"
       case "null" => NullType
       case "timestamp_ltz" => TimestampType
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index 5fec578b0358..787730f77508 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.types
 
-import org.json4s.JsonAST.{JString, JValue}
-
 import org.apache.spark.annotation.Stable
 import org.apache.spark.sql.catalyst.util.CollationFactory
 import org.apache.spark.sql.internal.SqlApiConf
@@ -90,11 +88,6 @@ class StringType private[sql] (
   private[sql] def collationName: String =
     CollationFactory.fetchCollation(collationId).collationName
 
-  // Due to backwards compatibility and compatibility with other readers
-  // all string types are serialized in json as regular strings and
-  // the collation information is written to struct field metadata
-  override def jsonValue: JValue = JString("string")
-
   override def equals(obj: Any): Boolean = {
     obj match {
       case s: StringType => s.collationId == collationId && s.constraint == 
constraint
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
index 60362ec46f53..87b02c2a2926 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala
@@ -70,11 +70,30 @@ case class StructField(
 
   private[sql] def jsonValue: JValue = {
     ("name" -> name) ~
-      ("type" -> dataType.jsonValue) ~
+      ("type" -> dataTypeJsonValue) ~
       ("nullable" -> nullable) ~
       ("metadata" -> metadataJson)
   }
 
+  private[sql] def dataTypeJsonValue: JValue = {
+    if (collationMetadata.isEmpty) return dataType.jsonValue
+
+    def removeCollations(dt: DataType): DataType = dt match {
+      // Only recurse into map and array types as any child struct type
+      // will have already been processed.
+      case ArrayType(et, nullable) =>
+        ArrayType(removeCollations(et), nullable)
+      case MapType(kt, vt, nullable) =>
+        MapType(removeCollations(kt), removeCollations(vt), nullable)
+      case st: StringType => StringHelper.removeCollation(st)
+      case _ => dt
+    }
+
+    // As we want to be backwards compatible we should remove all collations 
information from the
+    // json and only keep that information in the metadata.
+    removeCollations(dataType).jsonValue
+  }
+
   private def metadataJson: JValue = {
     val metadataJsonValue = metadata.jsonValue
     metadataJsonValue match {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index c4287df94d61..34fe8e7abefc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -30,6 +30,7 @@ import 
org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM
 class DataTypeSuite extends SparkFunSuite {
 
   private val UNICODE_COLLATION_ID = 
CollationFactory.collationNameToId("UNICODE")
+  private val UTF8_LCASE_COLLATION_ID = 
CollationFactory.collationNameToId("UTF8_LCASE")
 
   test("construct an ArrayType") {
     val array = ArrayType(StringType)
@@ -1143,6 +1144,17 @@ class DataTypeSuite extends SparkFunSuite {
   }
 
   test("schema with collation should not change during ser/de") {
+    val standaloneString = StringType(UNICODE_COLLATION_ID)
+
+    val standaloneArray = ArrayType(StringType(UNICODE_COLLATION_ID))
+
+    val standaloneMap = MapType(StringType(UNICODE_COLLATION_ID),
+      StringType(UTF8_LCASE_COLLATION_ID))
+
+    val standaloneNested = ArrayType(MapType(
+      StringType(UNICODE_COLLATION_ID),
+      ArrayType(StringType(UTF8_LCASE_COLLATION_ID))))
+
     val simpleStruct = StructType(
       StructField("c1", StringType(UNICODE_COLLATION_ID)) :: Nil)
 
@@ -1183,6 +1195,7 @@ class DataTypeSuite extends SparkFunSuite {
         mapWithKeyInNameInSchema ++ arrayInMapInNestedSchema.fields ++ 
nestedArrayInMap.fields)
 
     Seq(
+      standaloneString, standaloneArray, standaloneMap, standaloneNested,
       simpleStruct, caseInsensitiveNames, specialCharsInName, nestedStruct, 
arrayInSchema,
       mapInSchema, mapWithKeyInNameInSchema, nestedArrayInMap, 
arrayInMapInNestedSchema,
       schemaWithMultipleFields)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to