cstavr commented on code in PR #46280: URL: https://github.com/apache/spark/pull/46280#discussion_r1597674391
########## python/pyspark/sql/types.py: ########## @@ -263,21 +263,23 @@ def __init__(self, collation: Optional[str] = None): def fromCollationId(self, collationId: int) -> "StringType": return StringType(StringType.collationNames[collationId]) - def collationIdToName(self) -> str: - if self.collationId == 0: - return "" - else: - return " collate %s" % StringType.collationNames[self.collationId] + @classmethod + def collationIdToName(cls, collationId: int) -> str: + return StringType.collationNames[collationId] @classmethod def collationNameToId(cls, collationName: str) -> int: return StringType.collationNames.index(collationName) def simpleString(self) -> str: - return "string" + self.collationIdToName() + return ( + "string" + if self.isUTF8BinaryCollation() + else "string collate" + self.collationIdToName(self.collationId) Review Comment: ```suggestion else "string collate " + self.collationIdToName(self.collationId) ``` ########## python/pyspark/sql/tests/test_types.py: ########## @@ -549,6 +549,76 @@ def test_convert_list_to_str(self): self.assertEqual(df.count(), 1) self.assertEqual(df.head(), Row(name="[123]", income=120)) + def test_schema_with_collations_json_ser_de(self): + from pyspark.sql.types import _parse_datatype_json_string + + unicode_collation = "UNICODE" + + simple_struct = StructType([StructField("c1", StringType(unicode_collation))]) + + nested_struct = StructType([StructField("nested", simple_struct)]) + + array_in_schema = StructType( + [StructField("array", ArrayType(StringType(unicode_collation)))] + ) + + map_in_schema = StructType( + [ + StructField( + "map", MapType(StringType(unicode_collation), StringType(unicode_collation)) + ) + ] + ) + + array_in_map_in_nested_schema = StructType( Review Comment: What does in "in nested schema" mean here? ```suggestion array_in_map = StructType( ``` ########## python/pyspark/sql/tests/test_types.py: ########## @@ -549,6 +549,76 @@ def test_convert_list_to_str(self): self.assertEqual(df.count(), 1) self.assertEqual(df.head(), Row(name="[123]", income=120)) + def test_schema_with_collations_json_ser_de(self): + from pyspark.sql.types import _parse_datatype_json_string + + unicode_collation = "UNICODE" + + simple_struct = StructType([StructField("c1", StringType(unicode_collation))]) + + nested_struct = StructType([StructField("nested", simple_struct)]) + + array_in_schema = StructType( + [StructField("array", ArrayType(StringType(unicode_collation)))] + ) + + map_in_schema = StructType( + [ + StructField( + "map", MapType(StringType(unicode_collation), StringType(unicode_collation)) + ) + ] + ) + + array_in_map_in_nested_schema = StructType( + [ + StructField( + "arrInMap", + MapType( + StringType(unicode_collation), ArrayType(StringType(unicode_collation)) + ), + ) + ] + ) + + nested_array_in_map = StructType( + [ + StructField( + "nestedArrayInMap", + ArrayType( + MapType( + StringType(unicode_collation), + ArrayType(ArrayType(StringType(unicode_collation))), + ) + ), + ) + ] + ) + + schema_with_multiple_fields = StructType( + simple_struct.fields + + nested_struct.fields + + array_in_schema.fields + + map_in_schema.fields + + array_in_map_in_nested_schema.fields + + nested_array_in_map.fields + ) + + schemas = [ + simple_struct, + nested_struct, + array_in_schema, + map_in_schema, + nested_array_in_map, + array_in_map_in_nested_schema, + schema_with_multiple_fields, + ] + + for schema in schemas: + scala_datatype = self.spark._jsparkSession.parseDataType(schema.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert schema == python_datatype Review Comment: ```suggestion scala_datatype = self.spark._jsparkSession.parseDataType(schema.json()) assert schema == scala_datatype python_datatype = _parse_datatype_json_string(schema.json()) assert schema == python_datatype ``` ########## sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala: ########## @@ -63,7 +66,60 @@ case class StructField( ("name" -> name) ~ ("type" -> dataType.jsonValue) ~ ("nullable" -> nullable) ~ - ("metadata" -> metadata.jsonValue) + ("metadata" -> metadataJson) + } + + private def metadataJson: JValue = { + val metadataJsonValue = metadata.jsonValue + metadataJsonValue match { + case JObject(fields) if collationMetadata.nonEmpty => + val collationFields = collationMetadata.map(kv => kv._1 -> JString(kv._2)).toList + JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY -> JObject(collationFields))) + + case _ => metadataJsonValue + } + } + + /** Map of field path to collation name. */ + private lazy val collationMetadata: mutable.Map[String, String] = { Review Comment: ```suggestion private lazy val collationMetadata: .Map[String, String] = { ``` Is this modified somewhere? ########## sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala: ########## @@ -274,6 +288,39 @@ object DataType { messageParameters = Map("other" -> compact(render(other)))) } + /** + * Checks if the current field is in the collation map, and if it is it returns + * a StringType with the given collation. Otherwise, it further parses its type. + */ + private def resolveType( Review Comment: ```suggestion private def parseDataTypeWithCollation( ``` ########## sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala: ########## @@ -61,6 +63,8 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa if (isUTF8BinaryCollation) "string" else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + override def jsonValue: JValue = JString("string") Review Comment: Please add a comment here and explain that collations are not included. ########## sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala: ########## @@ -63,7 +66,60 @@ case class StructField( ("name" -> name) ~ ("type" -> dataType.jsonValue) ~ ("nullable" -> nullable) ~ - ("metadata" -> metadata.jsonValue) + ("metadata" -> metadataJson) + } + + private def metadataJson: JValue = { + val metadataJsonValue = metadata.jsonValue + metadataJsonValue match { + case JObject(fields) if collationMetadata.nonEmpty => + val collationFields = collationMetadata.map(kv => kv._1 -> JString(kv._2)).toList + JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY -> JObject(collationFields))) + + case _ => metadataJsonValue + } + } + + /** Map of field path to collation name. */ + private lazy val collationMetadata: mutable.Map[String, String] = { + val fieldToCollationMap = mutable.Map[String, String]() + + def visitRecursively(dt: DataType, path: String): Unit = dt match { + case at: ArrayType => + processDataType(at.elementType, path + ".element") + + case mt: MapType => + processDataType(mt.keyType, path + ".key") + processDataType(mt.valueType, path + ".value") + + case st: StringType if isCollatedString(st) => + fieldToCollationMap(path) = collationName(st) + + case _ => + } + + def processDataType(dt: DataType, path: String): Unit = { + if (isCollatedString(dt)) { + fieldToCollationMap(path) = collationName(dt) + } else { + visitRecursively(dt, path) + } + } + + visitRecursively(dataType, name) + fieldToCollationMap + } + + private def isCollatedString(dt: DataType): Boolean = dt match { + case st: StringType => !st.isUTF8BinaryCollation + case _ => false + } + + private def collationName(dt: DataType): String = dt match { + case st: StringType => Review Comment: This can a result in weird non-match errors. Can we make this method take a StringType as argument? ########## python/pyspark/sql/types.py: ########## @@ -876,30 +894,86 @@ def __init__( self.dataType = dataType self.nullable = nullable self.metadata = metadata or {} + self._collationMetadata: Optional[Dict[str, str]] = None def simpleString(self) -> str: return "%s:%s" % (self.name, self.dataType.simpleString()) def __repr__(self) -> str: return "StructField('%s', %s, %s)" % (self.name, self.dataType, str(self.nullable)) + def __eq__(self, other: Any) -> bool: + # since collationMetadata is lazy evaluated we should not use it in equality check Review Comment: Isn't this dangerous? ########## sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala: ########## @@ -274,6 +288,39 @@ object DataType { messageParameters = Map("other" -> compact(render(other)))) } + /** + * Checks if the current field is in the collation map, and if it is it returns + * a StringType with the given collation. Otherwise, it further parses its type. + */ + private def resolveType( + json: JValue, + path: String, + collationsMap: Map[String, String]): DataType = { + collationsMap.get(path) match { + case Some(collation) => stringTypeWithCollation(collation) Review Comment: What will be happen if someone adds a collation for a field that is not string? Should we have a check that it is actually a string type? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org