This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6f6b4860268d [SPARK-48175][SQL][PYTHON] Store collation information in metadata and not in type for SER/DE 6f6b4860268d is described below commit 6f6b4860268dc250d8e31a251d740733798aa512 Author: Stefan Kandic <stefan.kan...@databricks.com> AuthorDate: Sat May 18 15:17:56 2024 +0800 [SPARK-48175][SQL][PYTHON] Store collation information in metadata and not in type for SER/DE ### What changes were proposed in this pull request? Changing serialization and deserialization of collated strings so that the collation information is put in the metadata of the enclosing struct field - and then read back from there during parsing. Format of serialization will look something like this: ```json { "type": "struct", "fields": [ "name": "colName", "type": "string", "nullable": true, "metadata": { "__COLLATIONS": { "colName": "UNICODE" } } ] } ``` If we have a map we will add suffixes `.key` and `.value` in the metadata: ```json { "type": "struct", "fields": [ { "name": "mapField", "type": { "type": "map", "keyType": "string", "valueType": "string", "valueContainsNull": true }, "nullable": true, "metadata": { "__COLLATIONS": { "mapField.key": "UNICODE", "mapField.value": "UNICODE" } } } ] } ``` It will be a similar story for arrays (we will add `.element` suffix). We could have multiple suffixes when working with deeply nested data types (Map[String, Array[Array[String]]] - see tests for this example) ### Why are the changes needed? Putting collation info in field metadata is the only way to not break old clients reading new tables with collations. `CharVarcharUtils` does a similar thing but this is much less hacky, and more friendly for all 3p clients - which is especially important since delta also uses spark for schema ser/de. It will also remove the need for additional logic introduced in #46083 to remove collations before writing to HMS as this way the tables will be fully HMS compatible. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? With unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #46280 from stefankandic/newDeltaSchema. Lead-authored-by: Stefan Kandic <stefan.kan...@databricks.com> Co-authored-by: Stefan Kandic <154237371+stefankan...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/util/CollationFactory.java | 99 +++++++- .../src/main/resources/error/error-conditions.json | 12 + python/pyspark/errors/error-conditions.json | 10 + .../pyspark/sql/tests/connect/test_parity_types.py | 4 + python/pyspark/sql/tests/test_types.py | 249 +++++++++++++++++++-- python/pyspark/sql/types.py | 178 +++++++++++++-- .../org/apache/spark/sql/types/DataType.scala | 74 +++++- .../org/apache/spark/sql/types/StringType.scala | 7 + .../org/apache/spark/sql/types/StructField.scala | 62 ++++- .../org/apache/spark/sql/types/DataTypeSuite.scala | 181 ++++++++++++++- .../apache/spark/sql/types/StructTypeSuite.scala | 183 +++++++++++++++ .../streaming/StreamingDeduplicationSuite.scala | 2 +- .../spark/sql/streaming/StreamingQuerySuite.scala | 2 +- 13 files changed, 1004 insertions(+), 59 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 8ffff63445b6..0133c3feb611 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -36,11 +36,62 @@ import org.apache.spark.unsafe.types.UTF8String; * Provides functionality to the UTF8String object which respects defined collation settings. */ public final class CollationFactory { + + /** + * Identifier for single a collation. + */ + public static class CollationIdentifier { + private final String provider; + private final String name; + private final String version; + + public CollationIdentifier(String provider, String collationName, String version) { + this.provider = provider; + this.name = collationName; + this.version = version; + } + + public static CollationIdentifier fromString(String identifier) { + long numDots = identifier.chars().filter(ch -> ch == '.').count(); + assert(numDots > 0); + + if (numDots == 1) { + String[] parts = identifier.split("\\.", 2); + return new CollationIdentifier(parts[0], parts[1], null); + } + + String[] parts = identifier.split("\\.", 3); + return new CollationIdentifier(parts[0], parts[1], parts[2]); + } + + /** + * Returns the identifier's string value without the version. + * This is used for the table schema as the schema doesn't care about the version, + * only the statistics do. + */ + public String toStringWithoutVersion() { + return String.format("%s.%s", provider, name); + } + + public String getProvider() { + return provider; + } + + public String getName() { + return name; + } + + public Optional<String> getVersion() { + return Optional.ofNullable(version); + } + } + /** * Entry encapsulating all information about a collation. */ public static class Collation { public final String collationName; + public final String provider; public final Collator collator; public final Comparator<UTF8String> comparator; @@ -89,6 +140,7 @@ public final class CollationFactory { public Collation( String collationName, + String provider, Collator collator, Comparator<UTF8String> comparator, String version, @@ -97,6 +149,7 @@ public final class CollationFactory { boolean supportsBinaryOrdering, boolean supportsLowercaseEquality) { this.collationName = collationName; + this.provider = provider; this.collator = collator; this.comparator = comparator; this.version = version; @@ -110,6 +163,8 @@ public final class CollationFactory { // No Collation can simultaneously support binary equality and lowercase equality assert(!supportsBinaryEquality || !supportsLowercaseEquality); + assert(SUPPORTED_PROVIDERS.contains(provider)); + if (supportsBinaryEquality) { this.equalsFunction = UTF8String::equals; } else { @@ -122,6 +177,7 @@ public final class CollationFactory { */ public Collation( String collationName, + String provider, Collator collator, String version, boolean supportsBinaryEquality, @@ -129,6 +185,7 @@ public final class CollationFactory { boolean supportsLowercaseEquality) { this( collationName, + provider, collator, (s1, s2) -> collator.compare(s1.toString(), s2.toString()), version, @@ -137,6 +194,11 @@ public final class CollationFactory { supportsBinaryOrdering, supportsLowercaseEquality); } + + /** Returns the collation identifier. */ + public CollationIdentifier identifier() { + return new CollationIdentifier(provider, collationName, version); + } } private static final Collation[] collationTable = new Collation[4]; @@ -145,12 +207,17 @@ public final class CollationFactory { public static final int UTF8_BINARY_COLLATION_ID = 0; public static final int UTF8_BINARY_LCASE_COLLATION_ID = 1; + public static final String PROVIDER_SPARK = "spark"; + public static final String PROVIDER_ICU = "icu"; + public static final List<String> SUPPORTED_PROVIDERS = List.of(PROVIDER_SPARK, PROVIDER_ICU); + static { // Binary comparison. This is the default collation. // No custom comparators will be used for this collation. // Instead, we rely on byte for byte comparison. collationTable[0] = new Collation( "UTF8_BINARY", + PROVIDER_SPARK, null, UTF8String::binaryCompare, "1.0", @@ -163,6 +230,7 @@ public final class CollationFactory { // TODO: Do in place comparisons instead of creating new strings. collationTable[1] = new Collation( "UTF8_BINARY_LCASE", + PROVIDER_SPARK, null, UTF8String::compareLowerCase, "1.0", @@ -173,13 +241,28 @@ public final class CollationFactory { // UNICODE case sensitive comparison (ROOT locale, in ICU). collationTable[2] = new Collation( - "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true, false, false); + "UNICODE", + PROVIDER_ICU, + Collator.getInstance(ULocale.ROOT), + "153.120.0.0", + true, + false, + false + ); + collationTable[2].collator.setStrength(Collator.TERTIARY); collationTable[2].collator.freeze(); // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary strength). collationTable[3] = new Collation( - "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false, false, false); + "UNICODE_CI", + PROVIDER_ICU, + Collator.getInstance(ULocale.ROOT), + "153.120.0.0", + false, + false, + false + ); collationTable[3].collator.setStrength(Collator.SECONDARY); collationTable[3].collator.freeze(); @@ -263,6 +346,18 @@ public final class CollationFactory { } } + public static void assertValidProvider(String provider) throws SparkException { + if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { + Map<String, String> params = Map.of( + "provider", provider, + "supportedProviders", String.join(", ", SUPPORTED_PROVIDERS) + ); + + throw new SparkException( + "COLLATION_INVALID_PROVIDER", SparkException.constructMessageParams(params), null); + } + } + public static Collation fetchCollation(int collationId) { return collationTable[collationId]; } diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 69889435b02e..c1c0cd6bfb39 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -473,6 +473,12 @@ ], "sqlState" : "42704" }, + "COLLATION_INVALID_PROVIDER" : { + "message" : [ + "The value <provider> does not represent a correct collation provider. Supported providers are: [<supportedProviders>]." + ], + "sqlState" : "42704" + }, "COLLATION_MISMATCH" : { "message" : [ "Could not determine which collation to use for string functions and operators." @@ -2342,6 +2348,12 @@ ], "sqlState" : "2203G" }, + "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS" : { + "message" : [ + "Collations can only be applied to string types, but the JSON data type is <jsonType>." + ], + "sqlState" : "2203G" + }, "INVALID_JSON_ROOT_FIELD" : { "message" : [ "Cannot convert JSON root field to target Spark type." diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 906bf781e1bb..30db37387249 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -149,6 +149,11 @@ "Cannot <condition1> without <condition2>." ] }, + "COLLATION_INVALID_PROVIDER" : { + "message" : [ + "The value <provider> does not represent a correct collation provider. Supported providers are: [<supportedProviders>]." + ] + }, "COLUMN_IN_LIST": { "message": [ "`<func_name>` does not allow a Column in a list." @@ -357,6 +362,11 @@ "All items in `<arg_name>` should be in <allowed_types>, got <item_type>." ] }, + "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS" : { + "message" : [ + "Collations can only be applied to string types, but the JSON data type is <jsonType>." + ] + }, "INVALID_MULTIPLE_ARGUMENT_CONDITIONS": { "message": [ "[{arg_names}] cannot be <condition>." diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py index 9e81af47ceb0..4f736ac1215c 100644 --- a/python/pyspark/sql/tests/connect/test_parity_types.py +++ b/python/pyspark/sql/tests/connect/test_parity_types.py @@ -90,6 +90,10 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase): def test_udt(self): super().test_udt() + @unittest.skip("Requires JVM access.") + def test_schema_with_collations_json_ser_de(self): + super().test_schema_with_collations_json_ser_de() + @unittest.skip("Does not test anything related to Spark Connect") def test_parse_datatype_string(self): super().test_parse_datatype_string() diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 5942ae2abdb3..4d6fc499b70b 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -600,6 +600,234 @@ class TypesTestsMixin: self.assertEqual(df.count(), 1) self.assertEqual(df.head(), Row(name="[123]", income=120)) + def test_schema_with_collations_json_ser_de(self): + from pyspark.sql.types import _parse_datatype_json_string + + unicode_collation = "UNICODE" + + simple_struct = StructType([StructField("c1", StringType(unicode_collation))]) + + nested_struct = StructType([StructField("nested", simple_struct)]) + + array_in_schema = StructType( + [StructField("array", ArrayType(StringType(unicode_collation)))] + ) + + map_in_schema = StructType( + [ + StructField( + "map", MapType(StringType(unicode_collation), StringType(unicode_collation)) + ) + ] + ) + + nested_map = StructType( + [ + StructField( + "nested", + StructType( + [ + StructField( + "mapField", + MapType( + StringType(unicode_collation), StringType(unicode_collation) + ), + ) + ] + ), + ) + ] + ) + + array_in_map = StructType( + [ + StructField( + "arrInMap", + MapType( + StringType(unicode_collation), ArrayType(StringType(unicode_collation)) + ), + ) + ] + ) + + nested_array_in_map_value = StructType( + [ + StructField( + "nestedArrayInMap", + ArrayType( + MapType( + StringType(unicode_collation), + ArrayType(ArrayType(StringType(unicode_collation))), + ) + ), + ) + ] + ) + + schema_with_multiple_fields = StructType( + simple_struct.fields + + nested_struct.fields + + array_in_schema.fields + + map_in_schema.fields + + nested_map.fields + + array_in_map.fields + + nested_array_in_map_value.fields + ) + + schemas = [ + simple_struct, + nested_struct, + array_in_schema, + map_in_schema, + nested_map, + nested_array_in_map_value, + array_in_map, + schema_with_multiple_fields, + ] + + for schema in schemas: + scala_datatype = self.spark._jsparkSession.parseDataType(schema.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert schema == python_datatype + assert schema == _parse_datatype_json_string(schema.json()) + + def test_schema_with_collations_on_non_string_types(self): + from pyspark.sql.types import _parse_datatype_json_string, _COLLATIONS_METADATA_KEY + + collations_on_int_col_json = f""" + {{ + "type": "struct", + "fields": [ + {{ + "name": "c1", + "type": "integer", + "nullable": true, + "metadata": {{ + "{_COLLATIONS_METADATA_KEY}": {{ + "c1": "icu.UNICODE" + }} + }} + }} + ] + }} + """ + + collations_in_array_element_json = f""" + {{ + "type": "struct", + "fields": [ + {{ + "name": "arrayField", + "type": {{ + "type": "array", + "elementType": "integer", + "containsNull": true + }}, + "nullable": true, + "metadata": {{ + "{_COLLATIONS_METADATA_KEY}": {{ + "arrayField.element": "icu.UNICODE" + }} + }} + }} + ] + }} + """ + + collations_on_array_json = f""" + {{ + "type": "struct", + "fields": [ + {{ + "name": "arrayField", + "type": {{ + "type": "array", + "elementType": "integer", + "containsNull": true + }}, + "nullable": true, + "metadata": {{ + "{_COLLATIONS_METADATA_KEY}": {{ + "arrayField": "icu.UNICODE" + }} + }} + }} + ] + }} + """ + + collations_in_nested_map_json = f""" + {{ + "type": "struct", + "fields": [ + {{ + "name": "nested", + "type": {{ + "type": "struct", + "fields": [ + {{ + "name": "mapField", + "type": {{ + "type": "map", + "keyType": "string", + "valueType": "integer", + "valueContainsNull": true + }}, + "nullable": true, + "metadata": {{ + "{_COLLATIONS_METADATA_KEY}": {{ + "mapField.value": "icu.UNICODE" + }} + }} + }} + ] + }}, + "nullable": true, + "metadata": {{}} + }} + ] + }} + """ + + self.assertRaises( + PySparkTypeError, lambda: _parse_datatype_json_string(collations_on_int_col_json) + ) + + self.assertRaises( + PySparkTypeError, lambda: _parse_datatype_json_string(collations_in_array_element_json) + ) + + self.assertRaises( + PySparkTypeError, lambda: _parse_datatype_json_string(collations_on_array_json) + ) + + self.assertRaises( + PySparkTypeError, lambda: _parse_datatype_json_string(collations_in_nested_map_json) + ) + + def test_schema_with_bad_collations_provider(self): + from pyspark.sql.types import _parse_datatype_json_string, _COLLATIONS_METADATA_KEY + + schema_json = f""" + {{ + "type": "struct", + "fields": [ + {{ + "name": "c1", + "type": "string", + "nullable": "true", + "metadata": {{ + "{_COLLATIONS_METADATA_KEY}": {{ + "c1": "badProvider.UNICODE" + }} + }} + }} + ] + }} + """ + + self.assertRaises(PySparkValueError, lambda: _parse_datatype_json_string(schema_json)) + def test_udt(self): from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier @@ -915,27 +1143,6 @@ class TypesTestsMixin: self.assertEqual(t(), _parse_datatype_string(k)) self.assertEqual(IntegerType(), _parse_datatype_string("int")) self.assertEqual(StringType(), _parse_datatype_string("string")) - self.assertEqual(StringType(), _parse_datatype_string("string collate UTF8_BINARY")) - self.assertEqual(StringType(), _parse_datatype_string("string COLLATE UTF8_BINARY")) - self.assertEqual( - StringType.fromCollationId(0), _parse_datatype_string("string COLLATE UTF8_BINARY") - ) - self.assertEqual( - StringType.fromCollationId(1), - _parse_datatype_string("string COLLATE UTF8_BINARY_LCASE"), - ) - self.assertEqual( - StringType.fromCollationId(2), _parse_datatype_string("string COLLATE UNICODE") - ) - self.assertEqual( - StringType.fromCollationId(2), _parse_datatype_string("string COLLATE `UNICODE`") - ) - self.assertEqual( - StringType.fromCollationId(3), _parse_datatype_string("string COLLATE UNICODE_CI") - ) - self.assertEqual( - StringType.fromCollationId(3), _parse_datatype_string("string COLLATE `UNICODE_CI`") - ) self.assertEqual(CharType(1), _parse_datatype_string("char(1)")) self.assertEqual(CharType(10), _parse_datatype_string("char( 10 )")) self.assertEqual(CharType(11), _parse_datatype_string("char( 11)")) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ed3535e7d4aa..d692fd6f3681 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -255,6 +255,9 @@ class StringType(AtomicType): """ collationNames = ["UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"] + providerSpark = "spark" + providerICU = "icu" + providers = [providerSpark, providerICU] def __init__(self, collation: Optional[str] = None): self.collationId = 0 if collation is None else self.collationNameToId(collation) @@ -263,21 +266,32 @@ class StringType(AtomicType): def fromCollationId(self, collationId: int) -> "StringType": return StringType(StringType.collationNames[collationId]) - def collationIdToName(self) -> str: - if self.collationId == 0: - return "" - else: - return " collate %s" % StringType.collationNames[self.collationId] + @classmethod + def collationIdToName(cls, collationId: int) -> str: + return StringType.collationNames[collationId] @classmethod def collationNameToId(cls, collationName: str) -> int: return StringType.collationNames.index(collationName) + @classmethod + def collationProvider(cls, collationName: str) -> str: + # TODO: do this properly like on the scala side + if collationName.startswith("UTF8"): + return StringType.providerSpark + return StringType.providerICU + def simpleString(self) -> str: - return "string" + self.collationIdToName() + if self.isUTF8BinaryCollation(): + return "string" + return f"string collate ${self.collationIdToName(self.collationId)}" + + # For backwards compatibility and compatibility with other readers all string types + # are serialized in json as regular strings and the collation info is written to + # struct field metadata def jsonValue(self) -> str: - return "string" + self.collationIdToName() + return "string" def __repr__(self) -> str: return ( @@ -286,6 +300,9 @@ class StringType(AtomicType): else "StringType()" ) + def isUTF8BinaryCollation(self) -> bool: + return self.collationId == 0 + class CharType(AtomicType): """Char data type @@ -693,8 +710,16 @@ class ArrayType(DataType): } @classmethod - def fromJson(cls, json: Dict[str, Any]) -> "ArrayType": - return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) + def fromJson( + cls, + json: Dict[str, Any], + fieldPath: str, + collationsMap: Optional[Dict[str, str]], + ) -> "ArrayType": + elementType = _parse_datatype_json_value( + json["elementType"], fieldPath + ".element", collationsMap + ) + return ArrayType(elementType, json["containsNull"]) def needConversion(self) -> bool: return self.elementType.needConversion() @@ -810,10 +835,19 @@ class MapType(DataType): } @classmethod - def fromJson(cls, json: Dict[str, Any]) -> "MapType": + def fromJson( + cls, + json: Dict[str, Any], + fieldPath: str, + collationsMap: Optional[Dict[str, str]], + ) -> "MapType": + keyType = _parse_datatype_json_value(json["keyType"], fieldPath + ".key", collationsMap) + valueType = _parse_datatype_json_value( + json["valueType"], fieldPath + ".value", collationsMap + ) return MapType( - _parse_datatype_json_value(json["keyType"]), - _parse_datatype_json_value(json["valueType"]), + keyType, + valueType, json["valueContainsNull"], ) @@ -884,22 +918,89 @@ class StructField(DataType): return "StructField('%s', %s, %s)" % (self.name, self.dataType, str(self.nullable)) def jsonValue(self) -> Dict[str, Any]: + collationMetadata = self.getCollationMetadata() + metadata = ( + self.metadata + if not collationMetadata + else {**self.metadata, _COLLATIONS_METADATA_KEY: collationMetadata} + ) + return { "name": self.name, "type": self.dataType.jsonValue(), "nullable": self.nullable, - "metadata": self.metadata, + "metadata": metadata, } @classmethod def fromJson(cls, json: Dict[str, Any]) -> "StructField": + metadata = json.get("metadata") + collationsMap = {} + if metadata and _COLLATIONS_METADATA_KEY in metadata: + collationsMap = metadata[_COLLATIONS_METADATA_KEY] + for key, value in collationsMap.items(): + nameParts = value.split(".") + assert len(nameParts) == 2 + provider, name = nameParts[0], nameParts[1] + _assert_valid_collation_provider(provider) + collationsMap[key] = name + + metadata = { + key: value for key, value in metadata.items() if key != _COLLATIONS_METADATA_KEY + } + return StructField( json["name"], - _parse_datatype_json_value(json["type"]), + _parse_datatype_json_value(json["type"], json["name"], collationsMap), json.get("nullable", True), - json.get("metadata"), + metadata, ) + def getCollationsMap(self, metadata: Dict[str, Any]) -> Dict[str, str]: + if not metadata or _COLLATIONS_METADATA_KEY not in metadata: + return {} + + collationMetadata: Dict[str, str] = metadata[_COLLATIONS_METADATA_KEY] + collationsMap: Dict[str, str] = {} + + for key, value in collationMetadata.items(): + nameParts = value.split(".") + assert len(nameParts) == 2 + provider, name = nameParts[0], nameParts[1] + _assert_valid_collation_provider(provider) + collationsMap[key] = name + + return collationsMap + + def getCollationMetadata(self) -> Dict[str, str]: + def visitRecursively(dt: DataType, fieldPath: str) -> None: + if isinstance(dt, ArrayType): + processDataType(dt.elementType, fieldPath + ".element") + elif isinstance(dt, MapType): + processDataType(dt.keyType, fieldPath + ".key") + processDataType(dt.valueType, fieldPath + ".value") + elif isinstance(dt, StringType) and self._isCollatedString(dt): + collationMetadata[fieldPath] = self.schemaCollationValue(dt) + + def processDataType(dt: DataType, fieldPath: str) -> None: + if self._isCollatedString(dt): + collationMetadata[fieldPath] = self.schemaCollationValue(dt) + else: + visitRecursively(dt, fieldPath) + + collationMetadata: Dict[str, str] = {} + visitRecursively(self.dataType, self.name) + return collationMetadata + + def _isCollatedString(self, dt: DataType) -> bool: + return isinstance(dt, StringType) and not dt.isUTF8BinaryCollation() + + def schemaCollationValue(self, dt: DataType) -> str: + assert isinstance(dt, StringType) + collationName = StringType.collationIdToName(dt.collationId) + provider = StringType.collationProvider(collationName) + return f"{provider}.{collationName}" + def needConversion(self) -> bool: return self.dataType.needConversion() @@ -1561,13 +1662,14 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dic (v.typeName(), v) for v in _complex_types ) -_COLLATED_STRING = re.compile(r"string\s+collate\s+([\w_]+|`[\w_]`)") _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)") _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") +_COLLATIONS_METADATA_KEY = "__COLLATIONS" + def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]: assert isinstance(d, (DataType, StructField)) @@ -1715,9 +1817,17 @@ def _parse_datatype_json_string(json_string: str) -> DataType: return _parse_datatype_json_value(json.loads(json_string)) -def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: +def _parse_datatype_json_value( + json_value: Union[dict, str], + fieldPath: str = "", + collationsMap: Optional[Dict[str, str]] = None, +) -> DataType: if not isinstance(json_value, dict): if json_value in _all_atomic_types.keys(): + if collationsMap is not None and fieldPath in collationsMap: + _assert_valid_type_for_collation(fieldPath, json_value, collationsMap) + collation_name = collationsMap[fieldPath] + return StringType(collation_name) return _all_atomic_types[json_value]() elif json_value == "decimal": return DecimalType() @@ -1742,9 +1852,6 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: return YearMonthIntervalType(first_field, second_field) elif json_value == "interval": return CalendarIntervalType() - elif _COLLATED_STRING.match(json_value): - m = _COLLATED_STRING.match(json_value) - return StringType(m.group(1)) # type: ignore[union-attr] elif _LENGTH_CHAR.match(json_value): m = _LENGTH_CHAR.match(json_value) return CharType(int(m.group(1))) # type: ignore[union-attr] @@ -1759,7 +1866,15 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: else: tpe = json_value["type"] if tpe in _all_complex_types: - return _all_complex_types[tpe].fromJson(json_value) + if collationsMap is not None and fieldPath in collationsMap: + _assert_valid_type_for_collation(fieldPath, tpe, collationsMap) + + complex_type = _all_complex_types[tpe] + if complex_type is ArrayType: + return ArrayType.fromJson(json_value, fieldPath, collationsMap) + elif complex_type is MapType: + return MapType.fromJson(json_value, fieldPath, collationsMap) + return StructType.fromJson(json_value) elif tpe == "udt": return UserDefinedType.fromJson(json_value) else: @@ -1769,6 +1884,27 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: ) +def _assert_valid_type_for_collation( + fieldPath: str, fieldType: Any, collationMap: Dict[str, str] +) -> None: + if fieldPath in collationMap and fieldType != "string": + raise PySparkTypeError( + error_class="INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + message_parameters={"jsonType": fieldType}, + ) + + +def _assert_valid_collation_provider(provider: str) -> None: + if provider.lower() not in StringType.providers: + raise PySparkValueError( + error_class="COLLATION_INVALID_PROVIDER", + message_parameters={ + "provider": provider, + "supportedProviders": ", ".join(StringType.providers), + }, + ) + + # Mapping Python types to Spark SQL DataType _type_mappings = { type(None): NullType, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 16cf6224ce27..0d53f5ae7902 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -117,7 +117,8 @@ object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r - private val COLLATED_STRING_TYPE = """string\s+collate\s+([\w_]+|`[\w_]`)""".r + + val COLLATIONS_METADATA_KEY = "__COLLATIONS" def fromDDL(ddl: String): DataType = { parseTypeWithFallback( @@ -182,9 +183,6 @@ object DataType { /** Given the string representation of a type, return its DataType */ private def nameToType(name: String): DataType = { name match { - case COLLATED_STRING_TYPE(collation) => - val collationId = CollationFactory.collationNameToId(collation) - StringType(collationId) case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) case CHAR_TYPE(length) => CharType(length.toInt) @@ -208,26 +206,40 @@ object DataType { } // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private[sql] def parseDataType(json: JValue): DataType = json match { + private[sql] def parseDataType( + json: JValue, + fieldPath: String = "", + collationsMap: Map[String, String] = Map.empty): DataType = json match { case JString(name) => - nameToType(name) + collationsMap.get(fieldPath) match { + case Some(collation) => + assertValidTypeForCollations(fieldPath, name, collationsMap) + stringTypeWithCollation(collation) + case _ => nameToType(name) + } case JSortedObject( ("containsNull", JBool(n)), ("elementType", t: JValue), ("type", JString("array"))) => - ArrayType(parseDataType(t), n) + assertValidTypeForCollations(fieldPath, "array", collationsMap) + val elementType = parseDataType(t, fieldPath + ".element", collationsMap) + ArrayType(elementType, n) case JSortedObject( ("keyType", k: JValue), ("type", JString("map")), ("valueContainsNull", JBool(n)), ("valueType", v: JValue)) => - MapType(parseDataType(k), parseDataType(v), n) + assertValidTypeForCollations(fieldPath, "map", collationsMap) + val keyType = parseDataType(k, fieldPath + ".key", collationsMap) + val valueType = parseDataType(v, fieldPath + ".value", collationsMap) + MapType(keyType, valueType, n) case JSortedObject( ("fields", JArray(fields)), ("type", JString("struct"))) => + assertValidTypeForCollations(fieldPath, "struct", collationsMap) StructType(fields.map(parseStructField)) // Scala/Java UDT @@ -253,11 +265,18 @@ object DataType { private def parseStructField(json: JValue): StructField = json match { case JSortedObject( - ("metadata", metadata: JObject), + ("metadata", JObject(metadataFields)), ("name", JString(name)), ("nullable", JBool(nullable)), ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) + val collationsMap = getCollationsMap(metadataFields) + val metadataWithoutCollations = + JObject(metadataFields.filterNot(_._1 == COLLATIONS_METADATA_KEY)) + StructField( + name, + parseDataType(dataType, name, collationsMap), + nullable, + Metadata.fromJObject(metadataWithoutCollations)) // Support reading schema when 'metadata' is missing. case JSortedObject( ("name", JString(name)), @@ -274,6 +293,41 @@ object DataType { messageParameters = Map("other" -> compact(render(other)))) } + private def assertValidTypeForCollations( + fieldPath: String, + fieldType: String, + collationMap: Map[String, String]): Unit = { + if (collationMap.contains(fieldPath) && fieldType != "string") { + throw new SparkIllegalArgumentException( + errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + messageParameters = Map("jsonType" -> fieldType)) + } + } + + /** + * Returns a map of field path to collation name. + */ + private def getCollationsMap(metadataFields: List[JField]): Map[String, String] = { + val collationsJsonOpt = metadataFields.find(_._1 == COLLATIONS_METADATA_KEY).map(_._2) + collationsJsonOpt match { + case Some(JObject(fields)) => + fields.collect { + case (fieldPath, JString(collation)) => + collation.split("\\.", 2) match { + case Array(provider: String, collationName: String) => + CollationFactory.assertValidProvider(provider) + fieldPath -> collationName + } + }.toMap + + case _ => Map.empty + } + } + + private def stringTypeWithCollation(collationName: String): StringType = { + StringType(CollationFactory.collationNameToId(collationName)) + } + protected[types] def buildFormattedString( dataType: DataType, prefix: String, diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 74c714ff63f4..b8dadbc9e1dc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import org.json4s.JsonAST.{JString, JValue} + import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.CollationFactory @@ -61,6 +63,11 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa if (isUTF8BinaryCollation) "string" else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}" + // Due to backwards compatibility and compatibility with other readers + // all string types are serialized in json as regular strings and + // the collation information is written to struct field metadata + override def jsonValue: JValue = JString("string") + override def equals(obj: Any): Boolean = obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId == collationId diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala index 66f9557db213..3ff96fea9ee0 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.types +import scala.collection.mutable + +import org.json4s.{JObject, JString} import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.apache.spark.SparkException import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalyst.util.{QuotingUtils, StringConcat} +import org.apache.spark.sql.catalyst.util.{CollationFactory, QuotingUtils, StringConcat} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY} import org.apache.spark.util.SparkSchemaUtils @@ -63,7 +67,61 @@ case class StructField( ("name" -> name) ~ ("type" -> dataType.jsonValue) ~ ("nullable" -> nullable) ~ - ("metadata" -> metadata.jsonValue) + ("metadata" -> metadataJson) + } + + private def metadataJson: JValue = { + val metadataJsonValue = metadata.jsonValue + metadataJsonValue match { + case JObject(fields) if collationMetadata.nonEmpty => + val collationFields = collationMetadata.map(kv => kv._1 -> JString(kv._2)).toList + JObject(fields :+ (DataType.COLLATIONS_METADATA_KEY -> JObject(collationFields))) + + case _ => metadataJsonValue + } + } + + /** Map of field path to collation name. */ + private lazy val collationMetadata: Map[String, String] = { + val fieldToCollationMap = mutable.Map[String, String]() + + def visitRecursively(dt: DataType, path: String): Unit = dt match { + case at: ArrayType => + processDataType(at.elementType, path + ".element") + + case mt: MapType => + processDataType(mt.keyType, path + ".key") + processDataType(mt.valueType, path + ".value") + + case st: StringType if isCollatedString(st) => + fieldToCollationMap(path) = schemaCollationValue(st) + + case _ => + } + + def processDataType(dt: DataType, path: String): Unit = { + if (isCollatedString(dt)) { + fieldToCollationMap(path) = schemaCollationValue(dt) + } else { + visitRecursively(dt, path) + } + } + + visitRecursively(dataType, name) + fieldToCollationMap.toMap + } + + private def isCollatedString(dt: DataType): Boolean = dt match { + case st: StringType => !st.isUTF8BinaryCollation + case _ => false + } + + private def schemaCollationValue(dt: DataType): String = dt match { + case st: StringType => + val collation = CollationFactory.fetchCollation(st.collationId) + collation.identifier().toStringWithoutVersion() + case _ => + throw SparkException.internalError(s"Unexpected data type $dt") } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 3293957282e2..721d7c25d17b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -23,11 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite, SparkIllegalArgumentExce import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.catalyst.util.StringConcat +import org.apache.spark.sql.catalyst.util.{CollationFactory, StringConcat} import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} class DataTypeSuite extends SparkFunSuite { + private val UNICODE_COLLATION_ID = CollationFactory.collationNameToId("UNICODE") + test("construct an ArrayType") { val array = ArrayType(StringType) @@ -712,4 +714,181 @@ class DataTypeSuite extends SparkFunSuite { assert(result === expected) } + + test("schema with collation should not change during ser/de") { + val simpleStruct = StructType( + StructField("c1", StringType(UNICODE_COLLATION_ID)) :: Nil) + + val nestedStruct = StructType( + StructField("nested", simpleStruct) :: Nil) + + val caseInsensitiveNames = StructType( + StructField("c1", StringType(UNICODE_COLLATION_ID)) :: + StructField("C1", StringType(UNICODE_COLLATION_ID)) :: Nil) + + val specialCharsInName = StructType( + StructField("c1.*23?", StringType(UNICODE_COLLATION_ID)) :: Nil) + + val arrayInSchema = StructType( + StructField("arrayField", ArrayType(StringType(UNICODE_COLLATION_ID))) :: Nil) + + val mapInSchema = StructType( + StructField("mapField", + MapType(StringType(UNICODE_COLLATION_ID), StringType(UNICODE_COLLATION_ID))) :: Nil) + + val mapWithKeyInNameInSchema = StructType( + StructField("name.key", StringType) :: + StructField("name", + MapType(StringType(UNICODE_COLLATION_ID), StringType(UNICODE_COLLATION_ID))) :: Nil) + + val arrayInMapInNestedSchema = StructType( + StructField("arrInMap", + MapType(StringType(UNICODE_COLLATION_ID), + ArrayType(StringType(UNICODE_COLLATION_ID)))) :: Nil) + + val nestedArrayInMap = StructType( + StructField("nestedArrayInMap", + ArrayType(MapType(StringType(UNICODE_COLLATION_ID), + ArrayType(ArrayType(StringType(UNICODE_COLLATION_ID)))))) :: Nil) + + val schemaWithMultipleFields = StructType( + simpleStruct.fields ++ nestedStruct.fields ++ arrayInSchema.fields ++ mapInSchema.fields ++ + mapWithKeyInNameInSchema ++ arrayInMapInNestedSchema.fields ++ nestedArrayInMap.fields) + + Seq( + simpleStruct, caseInsensitiveNames, specialCharsInName, nestedStruct, arrayInSchema, + mapInSchema, mapWithKeyInNameInSchema, nestedArrayInMap, arrayInMapInNestedSchema, + schemaWithMultipleFields) + .foreach { schema => + val json = schema.json + val parsed = DataType.fromJson(json) + assert(parsed === schema) + } + } + + test("non string field has collation metadata") { + val json = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "c1", + | "type": "integer", + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "c1": "icu.UNICODE" + | } + | } + | } + | ] + |} + |""".stripMargin + + checkError( + exception = intercept[SparkIllegalArgumentException] { + DataType.fromJson(json) + }, + errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + parameters = Map("jsonType" -> "integer") + ) + } + + test("non string field in map key has collation metadata") { + val json = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "mapField", + | "type": { + | "type": "map", + | "keyType": "string", + | "valueType": "integer", + | "valueContainsNull": true + | }, + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "mapField.value": "icu.UNICODE" + | } + | } + | } + | ] + |} + |""".stripMargin + + checkError( + exception = intercept[SparkIllegalArgumentException] { + DataType.fromJson(json) + }, + errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + parameters = Map("jsonType" -> "integer") + ) + } + + test("map field has collation metadata") { + val json = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "mapField", + | "type": { + | "type": "map", + | "keyType": "string", + | "valueType": "integer", + | "valueContainsNull": true + | }, + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "mapField": "icu.UNICODE" + | } + | } + | } + | ] + |} + |""".stripMargin + + checkError( + exception = intercept[SparkIllegalArgumentException] { + DataType.fromJson(json) + }, + errorClass = "INVALID_JSON_DATA_TYPE_FOR_COLLATIONS", + parameters = Map("jsonType" -> "map") + ) + } + + test("non existing collation provider") { + val json = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "c1", + | "type": "string", + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "c1": "badProvider.UNICODE" + | } + | } + | } + | ] + |} + |""".stripMargin + + checkError( + exception = intercept[SparkException] { + DataType.fromJson(json) + }, + errorClass = "COLLATION_INVALID_PROVIDER", + parameters = Map("provider" -> "badProvider", "supportedProviders" -> "spark, icu") + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index c165ab1bf61b..bd0685e10832 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.{SparkException, SparkFunSuite, SparkIllegalArgumentException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} @@ -36,6 +38,10 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { private val s = StructType.fromDDL("a INT, b STRING") + private val UNICODE_COLLATION = "UNICODE" + private val UTF8_BINARY_LCASE_COLLATION = "UTF8_BINARY_LCASE" + private val mapper = new ObjectMapper() + test("lookup a single missing field should output existing fields") { checkError( exception = intercept[SparkIllegalArgumentException](s("c")), @@ -606,4 +612,181 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { "b STRING NOT NULL,c STRING COMMENT 'nullable comment'") assert(fromDDL(struct.toDDL) === struct) } + + test("simple struct with collations to json") { + val simpleStruct = StructType( + StructField("c1", StringType(UNICODE_COLLATION)) :: Nil) + + val expectedJson = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "c1", + | "type": "string", + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "c1": "icu.$UNICODE_COLLATION" + | } + | } + | } + | ] + |} + |""".stripMargin + + assert(mapper.readTree(simpleStruct.json) == mapper.readTree(expectedJson)) + } + + test("nested struct with collations to json") { + val nestedStruct = StructType( + StructField("nested", StructType( + StructField("c1", StringType(UTF8_BINARY_LCASE_COLLATION)) :: Nil)) :: Nil) + + val expectedJson = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "nested", + | "type": { + | "type": "struct", + | "fields": [ + | { + | "name": "c1", + | "type": "string", + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "c1": "spark.$UTF8_BINARY_LCASE_COLLATION" + | } + | } + | } + | ] + | }, + | "nullable": true, + | "metadata": {} + | } + | ] + |} + |""".stripMargin + + assert(mapper.readTree(nestedStruct.json) == mapper.readTree(expectedJson)) + } + + test("array with collations in schema to json") { + val arrayInSchema = StructType( + StructField("arrayField", ArrayType(StringType(UNICODE_COLLATION))) :: Nil) + + val expectedJson = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "arrayField", + | "type": { + | "type": "array", + | "elementType": "string", + | "containsNull": true + | }, + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "arrayField.element": "icu.$UNICODE_COLLATION" + | } + | } + | } + | ] + |} + |""".stripMargin + + assert(mapper.readTree(arrayInSchema.json) == mapper.readTree(expectedJson)) + } + + test("map with collations in schema to json") { + val arrayInSchema = StructType( + StructField("mapField", + MapType(StringType(UNICODE_COLLATION), StringType(UNICODE_COLLATION))) :: Nil) + + val expectedJson = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "mapField", + | "type": { + | "type": "map", + | "keyType": "string", + | "valueType": "string", + | "valueContainsNull": true + | }, + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "mapField.key": "icu.$UNICODE_COLLATION", + | "mapField.value": "icu.$UNICODE_COLLATION" + | } + | } + | } + | ] + |} + |""".stripMargin + + assert(mapper.readTree(arrayInSchema.json) == mapper.readTree(expectedJson)) + } + + test("nested array with collations in map to json" ) { + val mapWithNestedArray = StructType( + StructField("column", ArrayType(MapType( + StringType(UNICODE_COLLATION), + ArrayType(ArrayType(ArrayType(StringType(UNICODE_COLLATION))))))) :: Nil) + + val expectedJson = + s""" + |{ + | "type": "struct", + | "fields": [ + | { + | "name": "column", + | "type": { + | "type": "array", + | "elementType": { + | "type": "map", + | "keyType": "string", + | "valueType": { + | "type": "array", + | "elementType": { + | "type": "array", + | "elementType": { + | "type": "array", + | "elementType": "string", + | "containsNull": true + | }, + | "containsNull": true + | }, + | "containsNull": true + | }, + | "valueContainsNull": true + | }, + | "containsNull": true + | }, + | "nullable": true, + | "metadata": { + | "${DataType.COLLATIONS_METADATA_KEY}": { + | "column.element.key": "icu.$UNICODE_COLLATION", + | "column.element.value.element.element.element": "icu.$UNICODE_COLLATION" + | } + | } + | } + | ] + |} + |""".stripMargin + + assert( + mapper.readTree(mapWithNestedArray.json) == mapper.readTree(expectedJson)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index 5c816c5cddc7..7c84e3e2d018 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -527,7 +527,7 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { ex.getCause.asInstanceOf[SparkUnsupportedOperationException], errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY", parameters = Map( - "schema" -> ".+\"type\":\"string collate UTF8_BINARY_LCASE\".+" + "schema" -> ".+\"str\":\"spark.UTF8_BINARY_LCASE\".+" ), matchPVals = true ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 227b50509afe..8b761c24b604 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -1425,7 +1425,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi ex.getCause.asInstanceOf[SparkUnsupportedOperationException], errorClass = "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY", parameters = Map( - "schema" -> ".+\"type\":\"string collate UTF8_BINARY_LCASE\".+" + "schema" -> ".+\"c1\":\"spark.UTF8_BINARY_LCASE\".+" ), matchPVals = true ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org