cstavr commented on code in PR #46280:
URL: https://github.com/apache/spark/pull/46280#discussion_r1597674391


##########
python/pyspark/sql/types.py:
##########
@@ -263,21 +263,23 @@ def __init__(self, collation: Optional[str] = None):
     def fromCollationId(self, collationId: int) -> "StringType":
         return StringType(StringType.collationNames[collationId])
 
-    def collationIdToName(self) -> str:
-        if self.collationId == 0:
-            return ""
-        else:
-            return " collate %s" % StringType.collationNames[self.collationId]
+    @classmethod
+    def collationIdToName(cls, collationId: int) -> str:
+        return StringType.collationNames[collationId]
 
     @classmethod
     def collationNameToId(cls, collationName: str) -> int:
         return StringType.collationNames.index(collationName)
 
     def simpleString(self) -> str:
-        return "string" + self.collationIdToName()
+        return (
+            "string"
+            if self.isUTF8BinaryCollation()
+            else "string collate" + self.collationIdToName(self.collationId)

Review Comment:
   ```suggestion
               else "string collate " + self.collationIdToName(self.collationId)
   ```



##########
python/pyspark/sql/tests/test_types.py:
##########
@@ -549,6 +549,76 @@ def test_convert_list_to_str(self):
         self.assertEqual(df.count(), 1)
         self.assertEqual(df.head(), Row(name="[123]", income=120))
 
+    def test_schema_with_collations_json_ser_de(self):
+        from pyspark.sql.types import _parse_datatype_json_string
+
+        unicode_collation = "UNICODE"
+
+        simple_struct = StructType([StructField("c1", 
StringType(unicode_collation))])
+
+        nested_struct = StructType([StructField("nested", simple_struct)])
+
+        array_in_schema = StructType(
+            [StructField("array", ArrayType(StringType(unicode_collation)))]
+        )
+
+        map_in_schema = StructType(
+            [
+                StructField(
+                    "map", MapType(StringType(unicode_collation), 
StringType(unicode_collation))
+                )
+            ]
+        )
+
+        array_in_map_in_nested_schema = StructType(

Review Comment:
   What does in "in nested schema" mean here?
   
   ```suggestion
           array_in_map = StructType(
   ```



##########
python/pyspark/sql/tests/test_types.py:
##########
@@ -549,6 +549,76 @@ def test_convert_list_to_str(self):
         self.assertEqual(df.count(), 1)
         self.assertEqual(df.head(), Row(name="[123]", income=120))
 
+    def test_schema_with_collations_json_ser_de(self):
+        from pyspark.sql.types import _parse_datatype_json_string
+
+        unicode_collation = "UNICODE"
+
+        simple_struct = StructType([StructField("c1", 
StringType(unicode_collation))])
+
+        nested_struct = StructType([StructField("nested", simple_struct)])
+
+        array_in_schema = StructType(
+            [StructField("array", ArrayType(StringType(unicode_collation)))]
+        )
+
+        map_in_schema = StructType(
+            [
+                StructField(
+                    "map", MapType(StringType(unicode_collation), 
StringType(unicode_collation))
+                )
+            ]
+        )
+
+        array_in_map_in_nested_schema = StructType(
+            [
+                StructField(
+                    "arrInMap",
+                    MapType(
+                        StringType(unicode_collation), 
ArrayType(StringType(unicode_collation))
+                    ),
+                )
+            ]
+        )
+
+        nested_array_in_map = StructType(
+            [
+                StructField(
+                    "nestedArrayInMap",
+                    ArrayType(
+                        MapType(
+                            StringType(unicode_collation),
+                            
ArrayType(ArrayType(StringType(unicode_collation))),
+                        )
+                    ),
+                )
+            ]
+        )
+
+        schema_with_multiple_fields = StructType(
+            simple_struct.fields
+            + nested_struct.fields
+            + array_in_schema.fields
+            + map_in_schema.fields
+            + array_in_map_in_nested_schema.fields
+            + nested_array_in_map.fields
+        )
+
+        schemas = [
+            simple_struct,
+            nested_struct,
+            array_in_schema,
+            map_in_schema,
+            nested_array_in_map,
+            array_in_map_in_nested_schema,
+            schema_with_multiple_fields,
+        ]
+
+        for schema in schemas:
+            scala_datatype = 
self.spark._jsparkSession.parseDataType(schema.json())
+            python_datatype = 
_parse_datatype_json_string(scala_datatype.json())
+            assert schema == python_datatype

Review Comment:
   ```suggestion
               scala_datatype = 
self.spark._jsparkSession.parseDataType(schema.json())
               assert schema == scala_datatype
               
               python_datatype = _parse_datatype_json_string(schema.json())
               assert schema == python_datatype
   ```



##########
sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala:
##########
@@ -63,7 +66,60 @@ case class StructField(
     ("name" -> name) ~
       ("type" -> dataType.jsonValue) ~
       ("nullable" -> nullable) ~
-      ("metadata" -> metadata.jsonValue)
+      ("metadata" -> metadataJson)
+  }
+
+  private def metadataJson: JValue = {
+    val metadataJsonValue = metadata.jsonValue
+    metadataJsonValue match {
+      case JObject(fields) if collationMetadata.nonEmpty =>
+        val collationFields = collationMetadata.map(kv => kv._1 -> 
JString(kv._2)).toList
+        JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY -> 
JObject(collationFields)))
+
+      case _ => metadataJsonValue
+    }
+  }
+
+  /** Map of field path to collation name. */
+  private lazy val collationMetadata: mutable.Map[String, String] = {

Review Comment:
   ```suggestion
     private lazy val collationMetadata: .Map[String, String] = {
   ```
   Is this modified somewhere?



##########
sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala:
##########
@@ -274,6 +288,39 @@ object DataType {
       messageParameters = Map("other" -> compact(render(other))))
   }
 
+  /**
+   * Checks if the current field is in the collation map, and if it is it 
returns
+   * a StringType with the given collation. Otherwise, it further parses its 
type.
+   */
+  private def resolveType(

Review Comment:
   ```suggestion
     private def parseDataTypeWithCollation(
   ```



##########
sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala:
##########
@@ -61,6 +63,8 @@ class StringType private(val collationId: Int) extends 
AtomicType with Serializa
     if (isUTF8BinaryCollation) "string"
     else s"string collate 
${CollationFactory.fetchCollation(collationId).collationName}"
 
+  override def jsonValue: JValue = JString("string")

Review Comment:
   Please add a comment here and explain that collations are not included.



##########
sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala:
##########
@@ -63,7 +66,60 @@ case class StructField(
     ("name" -> name) ~
       ("type" -> dataType.jsonValue) ~
       ("nullable" -> nullable) ~
-      ("metadata" -> metadata.jsonValue)
+      ("metadata" -> metadataJson)
+  }
+
+  private def metadataJson: JValue = {
+    val metadataJsonValue = metadata.jsonValue
+    metadataJsonValue match {
+      case JObject(fields) if collationMetadata.nonEmpty =>
+        val collationFields = collationMetadata.map(kv => kv._1 -> 
JString(kv._2)).toList
+        JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY -> 
JObject(collationFields)))
+
+      case _ => metadataJsonValue
+    }
+  }
+
+  /** Map of field path to collation name. */
+  private lazy val collationMetadata: mutable.Map[String, String] = {
+    val fieldToCollationMap = mutable.Map[String, String]()
+
+    def visitRecursively(dt: DataType, path: String): Unit = dt match {
+      case at: ArrayType =>
+        processDataType(at.elementType, path + ".element")
+
+      case mt: MapType =>
+        processDataType(mt.keyType, path + ".key")
+        processDataType(mt.valueType, path + ".value")
+
+      case st: StringType if isCollatedString(st) =>
+        fieldToCollationMap(path) = collationName(st)
+
+      case _ =>
+    }
+
+    def processDataType(dt: DataType, path: String): Unit = {
+      if (isCollatedString(dt)) {
+        fieldToCollationMap(path) = collationName(dt)
+      } else {
+        visitRecursively(dt, path)
+      }
+    }
+
+    visitRecursively(dataType, name)
+    fieldToCollationMap
+  }
+
+  private def isCollatedString(dt: DataType): Boolean = dt match {
+    case st: StringType => !st.isUTF8BinaryCollation
+    case _ => false
+  }
+
+  private def collationName(dt: DataType): String = dt match {
+    case st: StringType =>

Review Comment:
   This can a result in weird non-match errors. Can we make this method take a 
StringType as argument?



##########
python/pyspark/sql/types.py:
##########
@@ -876,30 +894,86 @@ def __init__(
         self.dataType = dataType
         self.nullable = nullable
         self.metadata = metadata or {}
+        self._collationMetadata: Optional[Dict[str, str]] = None
 
     def simpleString(self) -> str:
         return "%s:%s" % (self.name, self.dataType.simpleString())
 
     def __repr__(self) -> str:
         return "StructField('%s', %s, %s)" % (self.name, self.dataType, 
str(self.nullable))
 
+    def __eq__(self, other: Any) -> bool:
+        # since collationMetadata is lazy evaluated we should not use it in 
equality check

Review Comment:
   Isn't this dangerous?



##########
sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala:
##########
@@ -274,6 +288,39 @@ object DataType {
       messageParameters = Map("other" -> compact(render(other))))
   }
 
+  /**
+   * Checks if the current field is in the collation map, and if it is it 
returns
+   * a StringType with the given collation. Otherwise, it further parses its 
type.
+   */
+  private def resolveType(
+      json: JValue,
+      path: String,
+      collationsMap: Map[String, String]): DataType = {
+    collationsMap.get(path) match {
+      case Some(collation) => stringTypeWithCollation(collation)

Review Comment:
   What will be happen if someone adds a collation for a field that is not 
string? Should we have a check that it is actually a string type?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to