This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 97e9bb3ac4b6 [SPARK-48700][SQL] Mode expression for complex types (all
collations)
97e9bb3ac4b6 is described below
commit 97e9bb3ac4b66711ced640ea466eeea5da6d1fd2
Author: Gideon P <[email protected]>
AuthorDate: Tue Oct 1 15:09:35 2024 +0200
[SPARK-48700][SQL] Mode expression for complex types (all collations)
### What changes were proposed in this pull request?
Add support for complex types with subfields that are collated strings, for
the mode operator.
### Why are the changes needed?
Full support for collations as per SPARK-48700
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Unit tests only, so far.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47154 from GideonPotok/collationmodecomplex.
Lead-authored-by: Gideon P <[email protected]>
Co-authored-by: Gideon Potok <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 10 +
.../sql/catalyst/expressions/aggregate/Mode.scala | 85 +++++--
.../spark/sql/CollationSQLExpressionsSuite.scala | 257 ++++++++++++++-------
3 files changed, 250 insertions(+), 102 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index fcaf2b1d9d30..3786643125a9 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -631,6 +631,11 @@
"Cannot process input data types for the expression: <expression>."
],
"subClass" : {
+ "BAD_INPUTS" : {
+ "message" : [
+ "The input data types to <functionName> must be valid, but found the
input types <dataType>."
+ ]
+ },
"MISMATCHED_TYPES" : {
"message" : [
"All input types must be the same except nullable, containsNull,
valueContainsNull flags, but found the input types <inputTypes>."
@@ -1011,6 +1016,11 @@
"The input of <functionName> can't be <dataType> type data."
]
},
+ "UNSUPPORTED_MODE_DATA_TYPE" : {
+ "message" : [
+ "The <mode> does not support the <child> data type, because there is
a \"MAP\" type with keys and/or values that have collated sub-fields."
+ ]
+ },
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index e254a670991a..8998348f0571 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -17,14 +17,17 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending,
Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
-import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData,
UnsafeRowUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
GenericArrayData, UnsafeRowUtils}
+import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType,
BooleanType, DataType, StringType}
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType,
BooleanType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap
@@ -50,17 +53,20 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
override def checkInputDataTypes(): TypeCheckResult = {
- if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
child.dataType.isInstanceOf[StringType]) {
+ // TODO: SPARK-49358: Mode expression for map type with collated fields
+ if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
+ !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
+ !UnsafeRowUtils.isBinaryStable(f))) {
/*
* The Mode class uses collation awareness logic to handle string data.
- * Complex types with collated fields are not yet supported.
+ * All complex types except MapType with collated fields are supported.
*/
- // TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
- TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
- " a type of binary-unstable type that is " +
- s"not currently supported by ${prettyName}.")
+ TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
+ messageParameters =
+ Map("child" -> toSQLType(child.dataType),
+ "mode" -> toSQLId(prettyName)))
}
}
@@ -86,6 +92,54 @@ case class Mode(
buffer
}
+ private def getCollationAwareBuffer(
+ childDataType: DataType,
+ buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
+ def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef,
Long)] = {
+ buffer.groupMapReduce(t =>
+ groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
+ }
+ def determineBufferingFunction(
+ childDataType: DataType): Option[AnyRef => _] = {
+ childDataType match {
+ case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
+ case _ => Some(collationAwareTransform(_, childDataType))
+ }
+ }
+
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
+ }
+
+ protected[sql] def collationAwareTransform(data: AnyRef, dataType:
DataType): AnyRef = {
+ dataType match {
+ case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
+ case st: StructType =>
+
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
+ case at: ArrayType => processArrayTypeWithBuffer(at,
data.asInstanceOf[ArrayData])
+ case st: StringType =>
+ CollationFactory.getCollationKey(data.asInstanceOf[UTF8String],
st.collationId)
+ case _ =>
+ throw new SparkIllegalArgumentException(
+ errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
+ messageParameters = Map(
+ "expression" -> toSQLExpr(this),
+ "functionName" -> toSQLType(prettyName),
+ "dataType" -> toSQLType(child.dataType))
+ )
+ }
+ }
+
+ private def processStructTypeWithBuffer(
+ tuples: Seq[(Any, StructField)]): Seq[Any] = {
+ tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef],
t._2.dataType))
+ }
+
+ private def processArrayTypeWithBuffer(
+ a: ArrayType,
+ data: ArrayData): Seq[Any] = {
+ (0 until data.numElements()).map(i =>
+ collationAwareTransform(data.get(i, a.elementType), a.elementType))
+ }
+
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
@@ -102,17 +156,12 @@ case class Mode(
* to a single value (the sum of the counts), and finally reduces the
groups to a single map.
*
* The new map is then used in the rest of the Mode evaluation logic.
+ *
+ * It is expected to work for all simple and complex types with
+ * collated fields, except for MapType (temporarily).
*/
- val collationAwareBuffer = child.dataType match {
- case c: StringType if
- !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality
=>
- val collationId = c.collationId
- val modeMap = buffer.toSeq.groupMapReduce {
- case (k, _) =>
CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
- }(x => x)((x, y) => (x._1, x._2 + y._2)).values
- modeMap
- case _ => buffer
- }
+ val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)
+
reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 941d5cd31db4..9930709cd8bf 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
+import java.util.Locale
import scala.collection.immutable.Seq
-import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException}
-import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable}
+import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
@@ -1752,7 +1753,7 @@ class CollationSQLExpressionsSuite
UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"),
UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"))
- testCasesUTF8String.foreach(t => {
+ testCasesUTF8String.foreach ( t => {
val buffer = new OpenHashMap[AnyRef, Long](5)
val myMode = Mode(child = Literal.create("some_column_name",
StringType(t.collationId)))
t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
@@ -1760,6 +1761,40 @@ class CollationSQLExpressionsSuite
})
}
+ test("Support Mode.eval(buffer) with complex types") {
+ case class UTF8StringModeTestCase[R](
+ collationId: String,
+ bufferValues: Map[InternalRow, Long],
+ result: R)
+
+ val bufferValuesUTF8String: Map[Any, Long] = Map(
+ UTF8String.fromString("a") -> 5L,
+ UTF8String.fromString("b") -> 4L,
+ UTF8String.fromString("B") -> 3L,
+ UTF8String.fromString("d") -> 2L,
+ UTF8String.fromString("e") -> 1L)
+
+ val bufferValuesComplex = bufferValuesUTF8String.map{
+ case (k, v) => (InternalRow.fromSeq(Seq(k, k, k)), v)
+ }
+ val testCasesUTF8String = Seq(
+ UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"),
+ UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"),
+ UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"),
+ UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]"))
+
+ testCasesUTF8String.foreach { t =>
+ val buffer = new OpenHashMap[AnyRef, Long](5)
+ val myMode = Mode(child = Literal.create(null, StructType(Seq(
+ StructField("f1", StringType(t.collationId), true),
+ StructField("f2", StringType(t.collationId), true),
+ StructField("f3", StringType(t.collationId), true)
+ ))))
+ t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
+ assert(myMode.eval(buffer).toString.toLowerCase() ==
t.result.toLowerCase())
+ }
+ }
+
test("Support mode for string expression with collated strings in struct") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
val testCases = Seq(
@@ -1780,33 +1815,7 @@ class CollationSQLExpressionsSuite
t.collationId + ", f2: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
- if(t.collationId == "UTF8_LCASE" ||
- t.collationId == "unicode_ci" ||
- t.collationId == "unicode") {
- // Cannot resolve "mode(i)" due to data type mismatch:
- // Input to function mode was a complex type with strings collated
on non-binary
- // collations, which is not yet supported.. SQLSTATE: 42K09; line 1
pos 13;
- val params = Seq(("sqlExpr", "\"mode(i)\""),
- ("msg", "The input to the function 'mode'" +
- " was a type of binary-unstable type that is not currently
supported by mode."),
- ("hint", "")).toMap
- checkError(
- exception = intercept[AnalysisException] {
- sql(query)
- },
- condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
- parameters = params,
- queryContext = Array(
- ExpectedContext(objectType = "",
- objectName = "",
- startIndex = 13,
- stopIndex = 19,
- fragment = "mode(i)")
- )
- )
- } else {
- checkAnswer(sql(query), Row(t.result))
- }
+ checkAnswer(sql(query), Row(t.result))
}
})
}
@@ -1819,47 +1828,21 @@ class CollationSQLExpressionsSuite
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
- testCases.foreach(t => {
+ testCases.foreach { t =>
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"named_struct('f1', " +
s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3',
1)").mkString(",")
}.mkString(",")
- val tableName = s"t_${t.collationId}_mode_nested_struct"
+ val tableName = s"t_${t.collationId}_mode_nested_struct1"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE
" +
t.collationId + ">, f3: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
- if(t.collationId == "UTF8_LCASE" ||
- t.collationId == "unicode_ci" ||
- t.collationId == "unicode") {
- // Cannot resolve "mode(i)" due to data type mismatch:
- // Input to function mode was a complex type with strings collated
on non-binary
- // collations, which is not yet supported.. SQLSTATE: 42K09; line 1
pos 13;
- val params = Seq(("sqlExpr", "\"mode(i)\""),
- ("msg", "The input to the function 'mode' " +
- "was a type of binary-unstable type that is not currently
supported by mode."),
- ("hint", "")).toMap
- checkError(
- exception = intercept[AnalysisException] {
- sql(query)
- },
- condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
- parameters = params,
- queryContext = Array(
- ExpectedContext(objectType = "",
- objectName = "",
- startIndex = 13,
- stopIndex = 19,
- fragment = "mode(i)")
- )
- )
- } else {
- checkAnswer(sql(query), Row(t.result))
- }
+ checkAnswer(sql(query), Row(t.result))
}
- })
+ }
}
test("Support mode for string expression with collated strings in array
complex type") {
@@ -1870,44 +1853,150 @@ class CollationSQLExpressionsSuite
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
- testCases.foreach(t => {
+ testCases.foreach { t =>
+ val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+ (0L to numRepeats).map(_ => s"array(named_struct('f2', " +
+ s"collate('$elt', '${t.collationId}'), 'f3', 1))").mkString(",")
+ }.mkString(",")
+
+ val tableName = s"t_${t.collationId}_mode_nested_struct2"
+ withTable(tableName) {
+ sql(s"CREATE TABLE ${tableName}(" +
+ s"i ARRAY< STRUCT<f2: STRING COLLATE ${t.collationId}, f3: INT>>)" +
+ s" USING parquet")
+ sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+ val query = s"SELECT lower(element_at(mode(i).f2, 1)) FROM
${tableName}"
+ checkAnswer(sql(query), Row(t.result))
+ }
+ }
+ }
+
+ test("Support mode for string expression with collated strings in 3D array
type") {
+ case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
+ val testCases = Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+ )
+ testCases.foreach { t =>
+ val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
+ (0L to numRepeats).map(_ =>
+ s"array(array(array(collate('$elt',
'${t.collationId}'))))").mkString(",")
+ }.mkString(",")
+
+ val tableName = s"t_${t.collationId}_mode_nested_3d_array"
+ withTable(tableName) {
+ sql(s"CREATE TABLE ${tableName}(i ARRAY<ARRAY<ARRAY" +
+ s"<STRING COLLATE ${t.collationId}>>>) USING parquet")
+ sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
+ val query = s"SELECT lower(" +
+ s"element_at(element_at(element_at(mode(i),1),1),1)) FROM
${tableName}"
+ checkAnswer(sql(query), Row(t.result))
+ }
+ }
+ }
+
+ test("Support mode for string expression with collated complex type - Highly
nested") {
+ case class ModeTestCase[R](collationId: String, bufferValues: Map[String,
Long], result: R)
+ val testCases = Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+ )
+ testCases.foreach { t =>
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"array(named_struct('s1',
named_struct('a2', " +
s"array(collate('$elt', '${t.collationId}'))), 'f3',
1))").mkString(",")
}.mkString(",")
- val tableName = s"t_${t.collationId}_mode_nested_struct"
+ val tableName = s"t_${t.collationId}_mode_highly_nested_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(" +
s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE
${t.collationId}>>, f3: INT>>)" +
s" USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2,
1)) FROM ${tableName}"
- if(t.collationId == "UTF8_LCASE" ||
- t.collationId == "unicode_ci" || t.collationId == "unicode") {
- val params = Seq(("sqlExpr", "\"mode(i)\""),
- ("msg", "The input to the function 'mode' was a type" +
- " of binary-unstable type that is not currently supported by
mode."),
- ("hint", "")).toMap
- checkError(
- exception = intercept[AnalysisException] {
- sql(query)
- },
- condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
- parameters = params,
- queryContext = Array(
- ExpectedContext(objectType = "",
- objectName = "",
- startIndex = 35,
- stopIndex = 41,
- fragment = "mode(i)")
- )
- )
- } else {
+
checkAnswer(sql(query), Row(t.result))
+ }
+ }
+ }
+
+ test("Support mode expression with collated in recursively nested struct
with map with keys") {
+ case class ModeTestCase(collationId: String, bufferValues: Map[String,
Long], result: String)
+ Seq(
+ ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a ->
1}"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a ->
1}"),
+ ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b ->
1}"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b ->
1}")
+ ).foreach { t1 =>
+ def checkThisError(t: ModeTestCase, query: String): Any = {
+ val c = s"STRUCT<m1: MAP<STRING COLLATE
${t.collationId.toUpperCase(Locale.ROOT)}, INT>>"
+ val c1 = s"\"${c}\""
+ checkError(
+ exception = intercept[SparkThrowable] {
+ sql(query).collect()
+ },
+ condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE",
+ parameters = Map(
+ ("sqlExpr", "\"mode(i)\""),
+ ("child", c1),
+ ("mode", "`mode`")),
+ queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray
+ )
+ }
+
+ def getValuesToAdd(t: ModeTestCase): String = {
+ val valuesToAdd = t.bufferValues.map {
+ case (elt, numRepeats) =>
+ (0L to numRepeats).map(i =>
+ s"named_struct('m1', map(collate('$elt', '${t.collationId}'),
1))"
+ ).mkString(",")
+ }.mkString(",")
+ valuesToAdd
+ }
+ val tableName = s"t_${t1.collationId}_mode_nested_map_struct1"
+ withTable(tableName) {
+ sql(s"CREATE TABLE ${tableName}(" +
+ s"i STRUCT<m1: MAP<STRING COLLATE ${t1.collationId}, INT>>) USING
parquet")
+ sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
+ val query = "SELECT lower(cast(mode(i).m1 as string))" +
+ s" FROM ${tableName}"
+ if (t1.collationId == "utf8_binary") {
+ checkAnswer(sql(query), Row(t1.result))
+ } else {
+ checkThisError(t1, query)
}
}
- })
+ }
+ }
+
+ test("UDT with collation - Mode (throw exception)") {
+ case class ModeTestCase(collationId: String, bufferValues: Map[String,
Long], result: String)
+ Seq(
+ ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
+ ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
+ ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
+ ).foreach { t1 =>
+ checkError(
+ exception = intercept[SparkIllegalArgumentException] {
+ Mode(
+ child = Literal.create(null,
+ MapType(StringType(t1.collationId), IntegerType))
+ ).collationAwareTransform(
+ data = Map.empty[String, Any],
+ dataType = MapType(StringType(t1.collationId), IntegerType)
+ )
+ },
+ condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
+ parameters = Map(
+ "expression" -> "\"mode(NULL)\"",
+ "functionName" -> "\"MODE\"",
+ "dataType" -> s"\"MAP<STRING COLLATE
${t1.collationId.toUpperCase()}, INT>\"")
+ )
+ }
}
test("SPARK-48430: Map value extraction with collations") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]