cstavr commented on code in PR #46280:
URL: https://github.com/apache/spark/pull/46280#discussion_r1597674391
##########
python/pyspark/sql/types.py:
##########
@@ -263,21 +263,23 @@ def __init__(self, collation: Optional[str] = None):
def fromCollationId(self, collationId: int) -> "StringType":
return StringType(StringType.collationNames[collationId])
- def collationIdToName(self) -> str:
- if self.collationId == 0:
- return ""
- else:
- return " collate %s" % StringType.collationNames[self.collationId]
+ @classmethod
+ def collationIdToName(cls, collationId: int) -> str:
+ return StringType.collationNames[collationId]
@classmethod
def collationNameToId(cls, collationName: str) -> int:
return StringType.collationNames.index(collationName)
def simpleString(self) -> str:
- return "string" + self.collationIdToName()
+ return (
+ "string"
+ if self.isUTF8BinaryCollation()
+ else "string collate" + self.collationIdToName(self.collationId)
Review Comment:
```suggestion
else "string collate " + self.collationIdToName(self.collationId)
```
##########
python/pyspark/sql/tests/test_types.py:
##########
@@ -549,6 +549,76 @@ def test_convert_list_to_str(self):
self.assertEqual(df.count(), 1)
self.assertEqual(df.head(), Row(name="[123]", income=120))
+ def test_schema_with_collations_json_ser_de(self):
+ from pyspark.sql.types import _parse_datatype_json_string
+
+ unicode_collation = "UNICODE"
+
+ simple_struct = StructType([StructField("c1",
StringType(unicode_collation))])
+
+ nested_struct = StructType([StructField("nested", simple_struct)])
+
+ array_in_schema = StructType(
+ [StructField("array", ArrayType(StringType(unicode_collation)))]
+ )
+
+ map_in_schema = StructType(
+ [
+ StructField(
+ "map", MapType(StringType(unicode_collation),
StringType(unicode_collation))
+ )
+ ]
+ )
+
+ array_in_map_in_nested_schema = StructType(
Review Comment:
What does in "in nested schema" mean here?
```suggestion
array_in_map = StructType(
```
##########
python/pyspark/sql/tests/test_types.py:
##########
@@ -549,6 +549,76 @@ def test_convert_list_to_str(self):
self.assertEqual(df.count(), 1)
self.assertEqual(df.head(), Row(name="[123]", income=120))
+ def test_schema_with_collations_json_ser_de(self):
+ from pyspark.sql.types import _parse_datatype_json_string
+
+ unicode_collation = "UNICODE"
+
+ simple_struct = StructType([StructField("c1",
StringType(unicode_collation))])
+
+ nested_struct = StructType([StructField("nested", simple_struct)])
+
+ array_in_schema = StructType(
+ [StructField("array", ArrayType(StringType(unicode_collation)))]
+ )
+
+ map_in_schema = StructType(
+ [
+ StructField(
+ "map", MapType(StringType(unicode_collation),
StringType(unicode_collation))
+ )
+ ]
+ )
+
+ array_in_map_in_nested_schema = StructType(
+ [
+ StructField(
+ "arrInMap",
+ MapType(
+ StringType(unicode_collation),
ArrayType(StringType(unicode_collation))
+ ),
+ )
+ ]
+ )
+
+ nested_array_in_map = StructType(
+ [
+ StructField(
+ "nestedArrayInMap",
+ ArrayType(
+ MapType(
+ StringType(unicode_collation),
+
ArrayType(ArrayType(StringType(unicode_collation))),
+ )
+ ),
+ )
+ ]
+ )
+
+ schema_with_multiple_fields = StructType(
+ simple_struct.fields
+ + nested_struct.fields
+ + array_in_schema.fields
+ + map_in_schema.fields
+ + array_in_map_in_nested_schema.fields
+ + nested_array_in_map.fields
+ )
+
+ schemas = [
+ simple_struct,
+ nested_struct,
+ array_in_schema,
+ map_in_schema,
+ nested_array_in_map,
+ array_in_map_in_nested_schema,
+ schema_with_multiple_fields,
+ ]
+
+ for schema in schemas:
+ scala_datatype =
self.spark._jsparkSession.parseDataType(schema.json())
+ python_datatype =
_parse_datatype_json_string(scala_datatype.json())
+ assert schema == python_datatype
Review Comment:
```suggestion
scala_datatype =
self.spark._jsparkSession.parseDataType(schema.json())
assert schema == scala_datatype
python_datatype = _parse_datatype_json_string(schema.json())
assert schema == python_datatype
```
##########
sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala:
##########
@@ -63,7 +66,60 @@ case class StructField(
("name" -> name) ~
("type" -> dataType.jsonValue) ~
("nullable" -> nullable) ~
- ("metadata" -> metadata.jsonValue)
+ ("metadata" -> metadataJson)
+ }
+
+ private def metadataJson: JValue = {
+ val metadataJsonValue = metadata.jsonValue
+ metadataJsonValue match {
+ case JObject(fields) if collationMetadata.nonEmpty =>
+ val collationFields = collationMetadata.map(kv => kv._1 ->
JString(kv._2)).toList
+ JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY ->
JObject(collationFields)))
+
+ case _ => metadataJsonValue
+ }
+ }
+
+ /** Map of field path to collation name. */
+ private lazy val collationMetadata: mutable.Map[String, String] = {
Review Comment:
```suggestion
private lazy val collationMetadata: .Map[String, String] = {
```
Is this modified somewhere?
##########
sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala:
##########
@@ -274,6 +288,39 @@ object DataType {
messageParameters = Map("other" -> compact(render(other))))
}
+ /**
+ * Checks if the current field is in the collation map, and if it is it
returns
+ * a StringType with the given collation. Otherwise, it further parses its
type.
+ */
+ private def resolveType(
Review Comment:
```suggestion
private def parseDataTypeWithCollation(
```
##########
sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala:
##########
@@ -61,6 +63,8 @@ class StringType private(val collationId: Int) extends
AtomicType with Serializa
if (isUTF8BinaryCollation) "string"
else s"string collate
${CollationFactory.fetchCollation(collationId).collationName}"
+ override def jsonValue: JValue = JString("string")
Review Comment:
Please add a comment here and explain that collations are not included.
##########
sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala:
##########
@@ -63,7 +66,60 @@ case class StructField(
("name" -> name) ~
("type" -> dataType.jsonValue) ~
("nullable" -> nullable) ~
- ("metadata" -> metadata.jsonValue)
+ ("metadata" -> metadataJson)
+ }
+
+ private def metadataJson: JValue = {
+ val metadataJsonValue = metadata.jsonValue
+ metadataJsonValue match {
+ case JObject(fields) if collationMetadata.nonEmpty =>
+ val collationFields = collationMetadata.map(kv => kv._1 ->
JString(kv._2)).toList
+ JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY ->
JObject(collationFields)))
+
+ case _ => metadataJsonValue
+ }
+ }
+
+ /** Map of field path to collation name. */
+ private lazy val collationMetadata: mutable.Map[String, String] = {
+ val fieldToCollationMap = mutable.Map[String, String]()
+
+ def visitRecursively(dt: DataType, path: String): Unit = dt match {
+ case at: ArrayType =>
+ processDataType(at.elementType, path + ".element")
+
+ case mt: MapType =>
+ processDataType(mt.keyType, path + ".key")
+ processDataType(mt.valueType, path + ".value")
+
+ case st: StringType if isCollatedString(st) =>
+ fieldToCollationMap(path) = collationName(st)
+
+ case _ =>
+ }
+
+ def processDataType(dt: DataType, path: String): Unit = {
+ if (isCollatedString(dt)) {
+ fieldToCollationMap(path) = collationName(dt)
+ } else {
+ visitRecursively(dt, path)
+ }
+ }
+
+ visitRecursively(dataType, name)
+ fieldToCollationMap
+ }
+
+ private def isCollatedString(dt: DataType): Boolean = dt match {
+ case st: StringType => !st.isUTF8BinaryCollation
+ case _ => false
+ }
+
+ private def collationName(dt: DataType): String = dt match {
+ case st: StringType =>
Review Comment:
This can a result in weird non-match errors. Can we make this method take a
StringType as argument?
##########
python/pyspark/sql/types.py:
##########
@@ -876,30 +894,86 @@ def __init__(
self.dataType = dataType
self.nullable = nullable
self.metadata = metadata or {}
+ self._collationMetadata: Optional[Dict[str, str]] = None
def simpleString(self) -> str:
return "%s:%s" % (self.name, self.dataType.simpleString())
def __repr__(self) -> str:
return "StructField('%s', %s, %s)" % (self.name, self.dataType,
str(self.nullable))
+ def __eq__(self, other: Any) -> bool:
+ # since collationMetadata is lazy evaluated we should not use it in
equality check
Review Comment:
Isn't this dangerous?
##########
sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala:
##########
@@ -274,6 +288,39 @@ object DataType {
messageParameters = Map("other" -> compact(render(other))))
}
+ /**
+ * Checks if the current field is in the collation map, and if it is it
returns
+ * a StringType with the given collation. Otherwise, it further parses its
type.
+ */
+ private def resolveType(
+ json: JValue,
+ path: String,
+ collationsMap: Map[String, String]): DataType = {
+ collationsMap.get(path) match {
+ case Some(collation) => stringTypeWithCollation(collation)
Review Comment:
What will be happen if someone adds a collation for a field that is not
string? Should we have a check that it is actually a string type?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]