cloud-fan commented on code in PR #47331:
URL: https://github.com/apache/spark/pull/47331#discussion_r1694560083
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -892,132 +892,108 @@ case class MapFromEntries(child: Expression)
copy(child = newChild)
}
+// Sorts all MapType expressions based on the ordering of their keys.
+// This is used when GROUP BY is done with a MapType (possibly nested) column.
case class MapSort(base: Expression)
- extends UnaryExpression with NullIntolerant with QueryErrorsBase {
+ extends UnaryExpression with NullIntolerant with QueryErrorsBase with
CodegenFallback {
- val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType
- val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType
+ override lazy val canonicalized: Expression = base.canonicalized
+
+ override lazy val deterministic: Boolean = base.deterministic
override def child: Expression = base
override def dataType: DataType = base.dataType
- override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
- case m: MapType if RowOrdering.isOrderable(m.keyType) =>
- TypeCheckResult.TypeCheckSuccess
+ def recursiveCheckDataTypes(dataType: DataType): TypeCheckResult = dataType
match {
+ case a: ArrayType => recursiveCheckDataTypes(a.elementType)
+ case StructType(fields) =>
+ fields.collect(sf =>
recursiveCheckDataTypes(sf.dataType)).filter(_.isFailure).headOption
+ .getOrElse(TypeCheckResult.TypeCheckSuccess)
+ case m: MapType if RowOrdering.isOrderable(m.keyType) =>
TypeCheckResult.TypeCheckSuccess
case _: MapType =>
DataTypeMismatch(
errorSubClass = "INVALID_ORDERING_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
- "dataType" -> toSQLType(base.dataType)
+ "dataType" -> toSQLType(dataType)
)
)
- case _ =>
- DataTypeMismatch(
+ case _ => TypeCheckResult.TypeCheckSuccess
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!dataType.existsRecursively(_.isInstanceOf[MapType])) {
+ return DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(0),
"requiredType" -> toSQLType(MapType),
"inputSql" -> toSQLExpr(base),
"inputType" -> toSQLType(base.dataType))
)
- }
-
- override def nullSafeEval(array: Any): Any = {
- // put keys and their respective values inside a tuple and sort them
- // according to the key ordering. Extract the new sorted k/v pairs to form
a sorted map
-
- val mapData = array.asInstanceOf[MapData]
- val numElements = mapData.numElements()
- val keys = mapData.keyArray()
- val values = mapData.valueArray()
-
- val ordering = PhysicalDataType.ordering(keyType)
-
- val sortedMap = Array
- .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any],
- values.get(i, valueType).asInstanceOf[Any]))
- .sortBy(_._1)(ordering)
-
- new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)),
- new GenericArrayData(sortedMap.map(_._2)))
- }
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- nullSafeCodeGen(ctx, ev, b => sortCodegen(ctx, ev, b))
- }
-
- private def sortCodegen(ctx: CodegenContext, ev: ExprCode,
- base: String): String = {
-
- val arrayBasedMapData = classOf[ArrayBasedMapData].getName
- val genericArrayData = classOf[GenericArrayData].getName
-
- val numElements = ctx.freshName("numElements")
- val keys = ctx.freshName("keys")
- val values = ctx.freshName("values")
- val sortArray = ctx.freshName("sortArray")
- val i = ctx.freshName("i")
- val o1 = ctx.freshName("o1")
- val o1entry = ctx.freshName("o1entry")
- val o2 = ctx.freshName("o2")
- val o2entry = ctx.freshName("o2entry")
- val c = ctx.freshName("c")
- val newKeys = ctx.freshName("newKeys")
- val newValues = ctx.freshName("newValues")
-
- val boxedKeyType = CodeGenerator.boxedType(keyType)
- val boxedValueType = CodeGenerator.boxedType(valueType)
- val javaKeyType = CodeGenerator.javaType(keyType)
+ }
- val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType,
$boxedValueType>"
+ if (dataType.existsRecursively(dt =>
+ dt.isInstanceOf[MapType] &&
!RowOrdering.isOrderable(dt.asInstanceOf[MapType].keyType))) {
+ DataTypeMismatch(
+ errorSubClass = "INVALID_ORDERING_TYPE",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "dataType" -> toSQLType(dataType)
+ )
+ )
+ }
- val comp = if (CodeGenerator.isPrimitiveType(keyType)) {
- val v1 = ctx.freshName("v1")
- val v2 = ctx.freshName("v2")
- s"""
- |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value();
- |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value();
- |int $c = ${ctx.genComp(keyType, v1, v2)};
- """.stripMargin
- } else {
- s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)",
s"(($javaKeyType) $o2)")};"
+ TypeCheckResult.TypeCheckSuccess
+ }
+
+ // Evaluates the expression recursively by taking into
+ // account complex types and nesting
+ def nullSafeEvalRecursive(input: Any, dataType: DataType): Any = {
+
+ dataType match {
+ // For ArrayType recursively call evaluate for
+ // all its children since MapType can be nested
+ // as array element
+ case ArrayType(elementType, _) =>
Review Comment:
We know the type information at query plan compile time, we can always build
the correct expression to handle the actual nested types.
`CharVarcharUtils.stringLengthCheck` handles struct type, isn't it?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]