MaxGekk commented on code in PR #47154:
URL: https://github.com/apache/spark/pull/47154#discussion_r1766314751


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala:
##########
@@ -86,6 +91,49 @@ case class Mode(
     buffer
   }
 
+  private def getCollationAwareBuffer(
+         childDataType: DataType,
+         buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
+    def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, 
Long)] = {
+      buffer.groupMapReduce(t =>
+        groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
+    }
+    def determineBufferingFunction(
+                                    childDataType: DataType): Option[AnyRef => 
_] = {
+      childDataType match {
+        case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
+        case _ => Some(collationAwareTransform(_, childDataType))
+      }
+    }
+    
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
+  }
+
+  private def collationAwareTransform(data: AnyRef, dataType: DataType): 
AnyRef = {
+    dataType match {
+      case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
+      case st: StructType =>
+        
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
+      case at: ArrayType => processArrayTypeWithBuffer(at, 
data.asInstanceOf[ArrayData])
+      case st: StringType =>
+        CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], 
st.collationId)
+      case _ =>
+        throw new SparkUnsupportedOperationException(
+          s"Unsupported data type for collation-aware mode: $dataType")
+    }
+  }
+
+  private def processStructTypeWithBuffer(
+                                           tuples: Seq[(Any, StructField)]): 
Seq[Any] = {
+    tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], 
t._2.dataType))
+  }
+
+  private def processArrayTypeWithBuffer(
+                                          a: ArrayType,
+                                          data: ArrayData): Seq[Any] = {

Review Comment:
   Fix indentations



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala:
##########
@@ -86,6 +91,49 @@ case class Mode(
     buffer
   }
 
+  private def getCollationAwareBuffer(
+         childDataType: DataType,
+         buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {

Review Comment:
   Wrong indentations



##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala:
##########
@@ -1736,6 +1737,40 @@ class CollationSQLExpressionsSuite
     })
   }
 
+  test("Support Mode.eval(buffer) with complex types") {
+    case class UTF8StringModeTestCase[R](
+      collationId: String,
+      bufferValues: Map[InternalRow, Long],
+      result: R)

Review Comment:
   4 spaces, see 
https://github.com/databricks/scala-style-guide?tab=readme-ov-file#spacing-and-indentation



##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -1005,6 +1005,11 @@
           "The input of <functionName> can't be <dataType> type data."
         ]
       },
+      "UNSUPPORTED_MODE_DATA_TYPE" : {
+        "message" : [
+          "The <mode> does not support the <child> data type, because there is 
MapType with collated fields."

Review Comment:
   PySpark and SQL users might be not aware of Scala/Java types like `MapType`. 
Let's follow common convention, and SQL types.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala:
##########
@@ -86,6 +91,49 @@ case class Mode(
     buffer
   }
 
+  private def getCollationAwareBuffer(
+         childDataType: DataType,
+         buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
+    def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, 
Long)] = {
+      buffer.groupMapReduce(t =>
+        groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
+    }
+    def determineBufferingFunction(
+                                    childDataType: DataType): Option[AnyRef => 
_] = {

Review Comment:
   Please, fix indentation.



##########
sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala:
##########
@@ -1711,9 +1712,9 @@ class CollationSQLExpressionsSuite
 
   test("Support Mode.eval(buffer)") {
     case class UTF8StringModeTestCase[R](
-        collationId: String,
-        bufferValues: Map[UTF8String, Long],
-        result: R)
+      collationId: String,
+      bufferValues: Map[UTF8String, Long],
+      result: R)

Review Comment:
   Revert it back, see 
https://github.com/databricks/scala-style-guide?tab=readme-ov-file#spacing-and-indentation



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala:
##########
@@ -86,6 +91,49 @@ case class Mode(
     buffer
   }
 
+  private def getCollationAwareBuffer(
+         childDataType: DataType,
+         buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
+    def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, 
Long)] = {
+      buffer.groupMapReduce(t =>
+        groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
+    }
+    def determineBufferingFunction(
+                                    childDataType: DataType): Option[AnyRef => 
_] = {
+      childDataType match {
+        case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
+        case _ => Some(collationAwareTransform(_, childDataType))
+      }
+    }
+    
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
+  }
+
+  private def collationAwareTransform(data: AnyRef, dataType: DataType): 
AnyRef = {
+    dataType match {
+      case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
+      case st: StructType =>
+        
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
+      case at: ArrayType => processArrayTypeWithBuffer(at, 
data.asInstanceOf[ArrayData])
+      case st: StringType =>
+        CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], 
st.collationId)
+      case _ =>
+        throw new SparkUnsupportedOperationException(
+          s"Unsupported data type for collation-aware mode: $dataType")
+    }
+  }
+
+  private def processStructTypeWithBuffer(
+                                           tuples: Seq[(Any, StructField)]): 
Seq[Any] = {

Review Comment:
   Fix indentation



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to