MaxGekk commented on code in PR #45639:
URL: https://github.com/apache/spark/pull/45639#discussion_r1534294248
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -888,6 +888,158 @@ case class MapFromEntries(child: Expression)
copy(child = newChild)
}
+case class MapSort(base: Expression, ascendingOrder: Expression)
Review Comment:
What's is the internal use-cases for the expression? Do we need this
parameter at all?
Seems like you are going to pass `true` as `ascendingOrder` always at
https://github.com/apache/spark/pull/45549/files#diff-11264d807efa58054cca2d220aae8fba644ee0f0f2a4722c46d52828394846efR2488
```scala
case a @ Aggregate(groupingExpr, x, b) =>
val newGrouping = groupingExpr.map { expr =>
(expr, expr.dataType) match {
case (_: MapSort, _) => expr
case (_, _: MapType) =>
MapSort(expr, Literal.TrueLiteral)
case _ => expr
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -888,6 +888,158 @@ case class MapFromEntries(child: Expression)
copy(child = newChild)
}
+case class MapSort(base: Expression, ascendingOrder: Expression)
+ extends BinaryExpression with NullIntolerant with QueryErrorsBase {
+
+ def this(e: Expression) = this(e, Literal(true))
+
+ val keyType: DataType = base.dataType.asInstanceOf[MapType].keyType
+ val valueType: DataType = base.dataType.asInstanceOf[MapType].valueType
+
+ override def left: Expression = base
+ override def right: Expression = ascendingOrder
+ override def dataType: DataType = base.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
+ case m: MapType if RowOrdering.isOrderable(m.keyType) =>
+ ascendingOrder match {
+ case Literal(_: Boolean, BooleanType) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(1),
+ "requiredType" -> toSQLType(BooleanType),
+ "inputSql" -> toSQLExpr(ascendingOrder),
+ "inputType" -> toSQLType(ascendingOrder.dataType))
+ )
+ }
+ case _: MapType =>
+ DataTypeMismatch(
+ errorSubClass = "INVALID_ORDERING_TYPE",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "dataType" -> toSQLType(base.dataType)
+ )
+ )
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(MapType),
+ "inputSql" -> toSQLExpr(base),
+ "inputType" -> toSQLType(base.dataType))
+ )
+ }
+
+ override def nullSafeEval(array: Any, ascending: Any): Any = {
+ // put keys and their respective values inside a tuple and sort them
+ // according to the key ordering. Extract the new sorted k/v pairs to form
a sorted map
+
+ val mapData = array.asInstanceOf[MapData]
+ val numElements = mapData.numElements()
+ val keys = mapData.keyArray()
+ val values = mapData.valueArray()
+
+ val ordering = if (ascending.asInstanceOf[Boolean]) {
+ PhysicalDataType.ordering(keyType)
+ } else {
+ PhysicalDataType.ordering(keyType).reverse
+ }
+
+ val sortedMap = Array
+ .tabulate(numElements)(i => (keys.get(i, keyType).asInstanceOf[Any],
+ values.get(i, valueType).asInstanceOf[Any]))
+ .sortBy(_._1)(ordering)
+
+ new ArrayBasedMapData(new GenericArrayData(sortedMap.map(_._1)),
+ new GenericArrayData(sortedMap.map(_._2)))
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
+ }
+
+ private def sortCodegen(ctx: CodegenContext, ev: ExprCode,
+ base: String, order: String): String = {
+
+ val arrayBasedMapData = classOf[ArrayBasedMapData].getName
+ val genericArrayData = classOf[GenericArrayData].getName
+
+ val numElements = ctx.freshName("numElements")
+ val keys = ctx.freshName("keys")
+ val values = ctx.freshName("values")
+ val sortArray = ctx.freshName("sortArray")
+ val i = ctx.freshName("i")
+ val o1 = ctx.freshName("o1")
+ val o1entry = ctx.freshName("o1entry")
+ val o2 = ctx.freshName("o2")
+ val o2entry = ctx.freshName("o2entry")
+ val c = ctx.freshName("c")
+ val newKeys = ctx.freshName("newKeys")
+ val newValues = ctx.freshName("newValues")
+ val sortOrder = ctx.freshName("sortOrder")
+
+ val boxedKeyType = CodeGenerator.boxedType(keyType)
+ val boxedValueType = CodeGenerator.boxedType(valueType)
+ val javaKeyType = CodeGenerator.javaType(keyType)
+
+ val simpleEntryType = s"java.util.AbstractMap.SimpleEntry<$boxedKeyType,
$boxedValueType>"
+
+ val comp = if (CodeGenerator.isPrimitiveType(keyType)) {
+ val v1 = ctx.freshName("v1")
+ val v2 = ctx.freshName("v2")
+ s"""
+ |$javaKeyType $v1 = (($boxedKeyType) $o1).${javaKeyType}Value();
+ |$javaKeyType $v2 = (($boxedKeyType) $o2).${javaKeyType}Value();
+ |int $c = ${ctx.genComp(keyType, v1, v2)};
+ """.stripMargin
+ } else {
+ s"int $c = ${ctx.genComp(keyType, s"(($javaKeyType) $o1)",
s"(($javaKeyType) $o2)")};"
+ }
+
+ s"""
+ |final int $numElements = $base.numElements();
+ |ArrayData $keys = $base.keyArray();
+ |ArrayData $values = $base.valueArray();
+ |
+ |Object[] $sortArray = new Object[$numElements];
+ |
+ |for (int $i = 0; $i < $numElements; $i++) {
+ | $sortArray[$i] = new $simpleEntryType(
+ | ${CodeGenerator.getValue(keys, keyType, i)},
+ | ${CodeGenerator.getValue(values, valueType, i)});
+ |}
+ |final int $sortOrder = $order ? 1 : -1;
+ |java.util.Arrays.sort($sortArray, new java.util.Comparator<Object>() {
+ | @Override public int compare(Object $o1entry, Object $o2entry) {
+ | Object $o1 = (($simpleEntryType) $o1entry).getKey();
+ | Object $o2 = (($simpleEntryType) $o2entry).getKey();
+ | $comp;
+ | return $sortOrder * $c;
+ | }
+ |});
+ |
+ |Object[] $newKeys = new Object[$numElements];
+ |Object[] $newValues = new Object[$numElements];
+ |
+ |for (int $i = 0; $i < $numElements; $i++) {
+ | $newKeys[$i] = (($simpleEntryType) $sortArray[$i]).getKey();
+ | $newValues[$i] = (($simpleEntryType) $sortArray[$i]).getValue();
+ |}
+ |
+ |${ev.value} = new $arrayBasedMapData(
+ | new $genericArrayData($newKeys), new $genericArrayData($newValues));
+ |""".stripMargin
+ }
+
+ override def prettyName: String = "map_sort"
Review Comment:
Remove this since the expression hasn't been bound to the function name.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]