This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 38f067dfcef9 [SPARK-49358][SQL] Mode expression for map types with
collated strings
38f067dfcef9 is described below
commit 38f067dfcef9ae53330fdd73ea89ebba614c965b
Author: Uros Bojanic <[email protected]>
AuthorDate: Thu Oct 3 11:36:26 2024 +0200
[SPARK-49358][SQL] Mode expression for map types with collated strings
### What changes were proposed in this pull request?
Introduce support for collated string in map types for `mode` expression.
### Why are the changes needed?
Complete complex type handling for `mode` expression.
### Does this PR introduce _any_ user-facing change?
Yes, `mode` expression can now handle map types with collated strings.
### How was this patch tested?
New tests in `CollationSQLExpressionsSuite`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #48326 from uros-db/mode-map.
Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 5 --
.../sql/catalyst/expressions/aggregate/Mode.scala | 38 ++++++---------
.../spark/sql/CollationSQLExpressionsSuite.scala | 54 ++--------------------
3 files changed, 19 insertions(+), 78 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 3786643125a9..12666fe4ff62 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1016,11 +1016,6 @@
"The input of <functionName> can't be <dataType> type data."
]
},
- "UNSUPPORTED_MODE_DATA_TYPE" : {
- "message" : [
- "The <mode> does not support the <child> data type, because there is
a \"MAP\" type with keys and/or values that have collated sub-fields."
- ]
- },
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index 8998348f0571..97add0b8e45b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
TypeCheckResult, UnresolvedWithinGroup}
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder,
UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending,
Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
-import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
GenericArrayData, UnsafeRowUtils}
-import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory,
GenericArrayData, MapData, UnsafeRowUtils}
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType,
BooleanType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
@@ -52,24 +52,6 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- override def checkInputDataTypes(): TypeCheckResult = {
- // TODO: SPARK-49358: Mode expression for map type with collated fields
- if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
- !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
- !UnsafeRowUtils.isBinaryStable(f))) {
- /*
- * The Mode class uses collation awareness logic to handle string data.
- * All complex types except MapType with collated fields are supported.
- */
- super.checkInputDataTypes()
- } else {
- TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
- messageParameters =
- Map("child" -> toSQLType(child.dataType),
- "mode" -> toSQLId(prettyName)))
- }
- }
-
override def prettyName: String = "mode"
override def update(
@@ -115,6 +97,7 @@ case class Mode(
case st: StructType =>
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
case at: ArrayType => processArrayTypeWithBuffer(at,
data.asInstanceOf[ArrayData])
+ case mt: MapType => processMapTypeWithBuffer(mt,
data.asInstanceOf[MapData])
case st: StringType =>
CollationFactory.getCollationKey(data.asInstanceOf[UTF8String],
st.collationId)
case _ =>
@@ -140,6 +123,16 @@ case class Mode(
collationAwareTransform(data.get(i, a.elementType), a.elementType))
}
+ private def processMapTypeWithBuffer(mt: MapType, data: MapData): Map[Any,
Any] = {
+ val transformedKeys = (0 until data.numElements()).map { i =>
+ collationAwareTransform(data.keyArray().get(i, mt.keyType), mt.keyType)
+ }
+ val transformedValues = (0 until data.numElements()).map { i =>
+ collationAwareTransform(data.valueArray().get(i, mt.valueType),
mt.valueType)
+ }
+ transformedKeys.zip(transformedValues).toMap
+ }
+
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
@@ -157,8 +150,7 @@ case class Mode(
*
* The new map is then used in the rest of the Mode evaluation logic.
*
- * It is expected to work for all simple and complex types with
- * collated fields, except for MapType (temporarily).
+ * It is expected to work for all simple and complex types with collated
fields.
*/
val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 9930709cd8bf..851160d2fbb9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -19,11 +19,10 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
-import java.util.Locale
import scala.collection.immutable.Seq
-import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable}
+import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
@@ -1924,7 +1923,7 @@ class CollationSQLExpressionsSuite
}
}
- test("Support mode expression with collated in recursively nested struct
with map with keys") {
+ test("Support mode for string expression with collated complex type - nested
map") {
case class ModeTestCase(collationId: String, bufferValues: Map[String,
Long], result: String)
Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a ->
1}"),
@@ -1932,22 +1931,6 @@ class CollationSQLExpressionsSuite
ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b ->
1}"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b ->
1}")
).foreach { t1 =>
- def checkThisError(t: ModeTestCase, query: String): Any = {
- val c = s"STRUCT<m1: MAP<STRING COLLATE
${t.collationId.toUpperCase(Locale.ROOT)}, INT>>"
- val c1 = s"\"${c}\""
- checkError(
- exception = intercept[SparkThrowable] {
- sql(query).collect()
- },
- condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE",
- parameters = Map(
- ("sqlExpr", "\"mode(i)\""),
- ("child", c1),
- ("mode", "`mode`")),
- queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray
- )
- }
-
def getValuesToAdd(t: ModeTestCase): String = {
val valuesToAdd = t.bufferValues.map {
case (elt, numRepeats) =>
@@ -1964,41 +1947,12 @@ class CollationSQLExpressionsSuite
sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}")
val query = "SELECT lower(cast(mode(i).m1 as string))" +
s" FROM ${tableName}"
- if (t1.collationId == "utf8_binary") {
- checkAnswer(sql(query), Row(t1.result))
- } else {
- checkThisError(t1, query)
- }
+ val queryResult = sql(query)
+ checkAnswer(queryResult, Row(t1.result))
}
}
}
- test("UDT with collation - Mode (throw exception)") {
- case class ModeTestCase(collationId: String, bufferValues: Map[String,
Long], result: String)
- Seq(
- ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
- ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
- ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
- ).foreach { t1 =>
- checkError(
- exception = intercept[SparkIllegalArgumentException] {
- Mode(
- child = Literal.create(null,
- MapType(StringType(t1.collationId), IntegerType))
- ).collationAwareTransform(
- data = Map.empty[String, Any],
- dataType = MapType(StringType(t1.collationId), IntegerType)
- )
- },
- condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS",
- parameters = Map(
- "expression" -> "\"mode(NULL)\"",
- "functionName" -> "\"MODE\"",
- "dataType" -> s"\"MAP<STRING COLLATE
${t1.collationId.toUpperCase()}, INT>\"")
- )
- }
- }
-
test("SPARK-48430: Map value extraction with collations") {
for {
collateKey <- Seq(true, false)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]