amaliujia commented on code in PR #40796:
URL: https://github.com/apache/spark/pull/40796#discussion_r1174198988


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -545,34 +540,94 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
+    val udf = unpackUdf(fun)
+    assert(udf.inputEncoders.size == 1)
+    val iEnc = ExpressionEncoder(udf.inputEncoders.head)
+    val rEnc = ExpressionEncoder(udf.outputEncoder)
 
     val deserializer = UnresolvedDeserializer(iEnc.deserializer)
     val deserialized = DeserializeToObject(deserializer, 
generateObjAttr(iEnc), child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
       generateObjAttr(rEnc),
       deserialized)
     SerializeFromObject(rEnc.namedExpressions, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => 
Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: GroupMap,
+      commonUdf: CommonInlineUserDefinedFunction): LogicalPlan = {
+    // Compute grouping key
+    val logicalPlan = transformRelation(rel.getInput)
+    val udf = unpackUdf(commonUdf)
+    assert(rel.getGroupingExpressionsCount == 1)
+    val groupFunc = rel.getGroupingExpressionsList.asScala.toSeq
+      .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
+      .head
+
+    assert(groupFunc.inputEncoders.size == 1)
+    val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
+    val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
+    val uEnc = ExpressionEncoder(udf.outputEncoder)
+    assert(udf.inputEncoders.nonEmpty)
+    // ukEnc != kEnc if user has called kvDS.keyAs
+    val ukEnc = ExpressionEncoder(udf.inputEncoders.head)
+
+    val withGroupingKey = new AppendColumns(
+      groupFunc.function.asInstanceOf[Any => Any],
+      vEnc.clsTag.runtimeClass,
+      vEnc.schema,
+      UnresolvedDeserializer(vEnc.deserializer),
+      kEnc.namedExpressions,
+      logicalPlan)
+
+    // Compute sort order
+    val sortExprs =
+      rel.getSortingExpressionsList.asScala.toSeq.map(expr => 
transformExpression(expr))
+    val sortOrder: Seq[SortOrder] = sortExprs.map {
+      case expr: SortOrder => expr
+      case expr: Expression => SortOrder(expr, Ascending)

Review Comment:
   Who will be in charge of checking non-supported expr? 



##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -545,34 +540,94 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformTypedMapPartitions(
       fun: proto.CommonInlineUserDefinedFunction,
       child: LogicalPlan): LogicalPlan = {
-    val udf = fun.getScalarScalaUdf
-    val udfPacket =
-      Utils.deserialize[UdfPacket](
-        udf.getPayload.toByteArray,
-        SparkConnectArtifactManager.classLoaderWithArtifacts)
-    assert(udfPacket.inputEncoders.size == 1)
-    val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
-    val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
+    val udf = unpackUdf(fun)
+    assert(udf.inputEncoders.size == 1)
+    val iEnc = ExpressionEncoder(udf.inputEncoders.head)
+    val rEnc = ExpressionEncoder(udf.outputEncoder)
 
     val deserializer = UnresolvedDeserializer(iEnc.deserializer)
     val deserialized = DeserializeToObject(deserializer, 
generateObjAttr(iEnc), child)
     val mapped = MapPartitions(
-      udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+      udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
       generateObjAttr(rEnc),
       deserialized)
     SerializeFromObject(rEnc.namedExpressions, mapped)
   }
 
   private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
-    val pythonUdf = transformPythonUDF(rel.getFunc)
-    val cols =
-      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => 
Column(transformExpression(expr)))
+    val commonUdf = rel.getFunc
+    commonUdf.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
+        transformTypedGroupMap(rel, commonUdf)
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .groupBy(cols: _*)
-      .flatMapGroupsInPandas(pythonUdf)
-      .logicalPlan
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        val pythonUdf = transformPythonUDF(commonUdf)
+        val cols =
+          rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+            Column(transformExpression(expr)))
+
+        Dataset
+          .ofRows(session, transformRelation(rel.getInput))
+          .groupBy(cols: _*)
+          .flatMapGroupsInPandas(pythonUdf)
+          .logicalPlan
+
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  private def transformTypedGroupMap(
+      rel: GroupMap,
+      commonUdf: CommonInlineUserDefinedFunction): LogicalPlan = {
+    // Compute grouping key
+    val logicalPlan = transformRelation(rel.getInput)
+    val udf = unpackUdf(commonUdf)
+    assert(rel.getGroupingExpressionsCount == 1)
+    val groupFunc = rel.getGroupingExpressionsList.asScala.toSeq
+      .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
+      .head
+
+    assert(groupFunc.inputEncoders.size == 1)
+    val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
+    val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
+    val uEnc = ExpressionEncoder(udf.outputEncoder)
+    assert(udf.inputEncoders.nonEmpty)
+    // ukEnc != kEnc if user has called kvDS.keyAs
+    val ukEnc = ExpressionEncoder(udf.inputEncoders.head)
+
+    val withGroupingKey = new AppendColumns(
+      groupFunc.function.asInstanceOf[Any => Any],
+      vEnc.clsTag.runtimeClass,
+      vEnc.schema,
+      UnresolvedDeserializer(vEnc.deserializer),
+      kEnc.namedExpressions,
+      logicalPlan)
+
+    // Compute sort order
+    val sortExprs =
+      rel.getSortingExpressionsList.asScala.toSeq.map(expr => 
transformExpression(expr))
+    val sortOrder: Seq[SortOrder] = sortExprs.map {
+      case expr: SortOrder => expr
+      case expr: Expression => SortOrder(expr, Ascending)

Review Comment:
   Who will in charge of checking non-supported expr? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to