hvanhovell commented on code in PR #40729:
URL: https://github.com/apache/spark/pull/40729#discussion_r1177045727
##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -520,54 +515,205 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformTypedMapPartitions(
fun: proto.CommonInlineUserDefinedFunction,
child: LogicalPlan): LogicalPlan = {
- val udf = fun.getScalarScalaUdf
- val udfPacket =
- Utils.deserialize[UdfPacket](
- udf.getPayload.toByteArray,
- SparkConnectArtifactManager.classLoaderWithArtifacts)
- assert(udfPacket.inputEncoders.size == 1)
- val iEnc = ExpressionEncoder(udfPacket.inputEncoders.head)
- val rEnc = ExpressionEncoder(udfPacket.outputEncoder)
-
- val deserializer = UnresolvedDeserializer(iEnc.deserializer)
- val deserialized = DeserializeToObject(deserializer,
generateObjAttr(iEnc), child)
+ val udf = ScalaUdf(fun)
+ val deserialized = DeserializeToObject(udf.inputDeserializer(),
udf.inputObjAttr, child)
val mapped = MapPartitions(
- udfPacket.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
- generateObjAttr(rEnc),
+ udf.function.asInstanceOf[Iterator[Any] => Iterator[Any]],
+ udf.outputObjAttr,
deserialized)
- SerializeFromObject(rEnc.namedExpressions, mapped)
+ SerializeFromObject(udf.outputNamedExpression, mapped)
}
private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
- val pythonUdf = transformPythonUDF(rel.getFunc)
- val cols =
- rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
Column(transformExpression(expr)))
+ val commonUdf = rel.getFunc
+ commonUdf.getFunctionCase match {
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF
=>
+ transformTypedGroupMap(rel, commonUdf)
- Dataset
- .ofRows(session, transformRelation(rel.getInput))
- .groupBy(cols: _*)
- .flatMapGroupsInPandas(pythonUdf)
- .logicalPlan
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+ val pythonUdf = transformPythonUDF(commonUdf)
+ val cols =
+ rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
+ Column(transformExpression(expr)))
+
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .groupBy(cols: _*)
+ .flatMapGroupsInPandas(pythonUdf)
+ .logicalPlan
+
+ case _ =>
+ throw InvalidPlanInput(
+ s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not
supported")
+ }
+ }
+
+ private def transformTypedGroupMap(
+ rel: proto.GroupMap,
+ commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+ val udf = ScalaUdf(commonUdf)
+ val ds = UntypedKeyValueGroupedDataset(
+ rel.getInput,
+ rel.getGroupingExpressionsList,
+ rel.getSortingExpressionsList)
+
+ val mapped = new MapGroups(
+ udf.function.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
+ udf.inputDeserializer(ds.groupingAttributes),
+ ds.valueDeserializer,
+ ds.groupingAttributes,
+ ds.dataAttributes,
+ ds.sortOrder,
+ udf.outputObjAttr,
+ ds.analyzed)
+ SerializeFromObject(udf.outputNamedExpression, mapped)
}
private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
- val pythonUdf = transformPythonUDF(rel.getFunc)
+ val commonUdf = rel.getFunc
+ commonUdf.getFunctionCase match {
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF
=>
+ transformTypedCoGroupMap(rel, commonUdf)
- val inputCols =
- rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
- Column(transformExpression(expr)))
- val otherCols =
- rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
- Column(transformExpression(expr)))
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+ val pythonUdf = transformPythonUDF(commonUdf)
- val input = Dataset
- .ofRows(session, transformRelation(rel.getInput))
- .groupBy(inputCols: _*)
- val other = Dataset
- .ofRows(session, transformRelation(rel.getOther))
- .groupBy(otherCols: _*)
+ val inputCols =
+ rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
+ Column(transformExpression(expr)))
+ val otherCols =
+ rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
+ Column(transformExpression(expr)))
- input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+ val input = Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .groupBy(inputCols: _*)
+ val other = Dataset
+ .ofRows(session, transformRelation(rel.getOther))
+ .groupBy(otherCols: _*)
+
+ input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+
+ case _ =>
+ throw InvalidPlanInput(
+ s"Function with ID: ${commonUdf.getFunctionCase.getNumber} is not
supported")
+ }
+ }
+
+ private def transformTypedCoGroupMap(
+ rel: proto.CoGroupMap,
+ commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
+ val udf = ScalaUdf(commonUdf)
+ val left = UntypedKeyValueGroupedDataset(
+ rel.getInput,
+ rel.getInputGroupingExpressionsList,
+ rel.getInputSortingExpressionsList)
+ val right = UntypedKeyValueGroupedDataset(
+ rel.getOther,
+ rel.getOtherGroupingExpressionsList,
+ rel.getOtherSortingExpressionsList)
+
+ val mapped = CoGroup(
+ udf.function.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) =>
TraversableOnce[Any]],
+ // The `leftGroup` and `rightGroup` are guaranteed te be of same schema,
so it's safe to
+ // resolve the `keyDeserializer` based on either of them, here we pick
the left one.
+ udf.inputDeserializer(left.groupingAttributes),
+ left.valueDeserializer,
+ right.valueDeserializer,
+ left.groupingAttributes,
+ right.groupingAttributes,
+ left.dataAttributes,
+ right.dataAttributes,
+ left.sortOrder,
+ right.sortOrder,
+ udf.outputObjAttr,
+ left.analyzed,
+ right.analyzed)
+ SerializeFromObject(udf.outputNamedExpression, mapped)
+ }
+
+ /**
+ * This is the untyped version of [[KeyValueGroupedDataset]].
+ */
+ private case class UntypedKeyValueGroupedDataset(
+ kEncoder: ExpressionEncoder[_],
+ vEncoder: ExpressionEncoder[_],
+ valueDeserializer: Expression,
+ analyzed: LogicalPlan,
+ dataAttributes: Seq[Attribute],
+ groupingAttributes: Seq[Attribute],
+ sortOrder: Seq[SortOrder])
+ private object UntypedKeyValueGroupedDataset {
+ def apply(
+ input: proto.Relation,
+ groupingExprs: java.util.List[proto.Expression],
+ sortingExprs: java.util.List[proto.Expression]):
UntypedKeyValueGroupedDataset = {
+ val logicalPlan = transformRelation(input)
+ assert(groupingExprs.size() == 1)
+ val groupFunc = groupingExprs.asScala.toSeq
+ .map(expr => unpackUdf(expr.getCommonInlineUserDefinedFunction))
+ .head
+
+ assert(groupFunc.inputEncoders.size == 1)
+ val vEnc = ExpressionEncoder(groupFunc.inputEncoders.head)
+ val kEnc = ExpressionEncoder(groupFunc.outputEncoder)
+
+ val withGroupingKey = new AppendColumns(
+ groupFunc.function.asInstanceOf[Any => Any],
+ vEnc.clsTag.runtimeClass,
+ vEnc.schema,
+ UnresolvedDeserializer(vEnc.deserializer),
+ kEnc.namedExpressions,
+ logicalPlan)
+
+ // The input logical plan of KeyValueGroupedDataset need to be executed
and analyzed
+ val analyzed = session.sessionState.executePlan(withGroupingKey).analyzed
+ val dataAttributes = logicalPlan.output
+ val groupingAttributes = withGroupingKey.newColumns
+ val valueDeserializer = UnresolvedDeserializer(vEnc.deserializer,
dataAttributes)
+
+ // Compute sort order
+ val sortExprs =
+ sortingExprs.asScala.toSeq.map(expr => transformExpression(expr))
+ val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs)
+
+ UntypedKeyValueGroupedDataset(
+ kEnc,
+ vEnc,
+ valueDeserializer,
+ analyzed,
+ dataAttributes,
+ groupingAttributes,
+ sortOrder)
+ }
+ }
+
+ /**
+ * The UDF used in typed APIs, where the input column is absent.
+ */
+ private case class ScalaUdf(
Review Comment:
Maybe name this `TypedScalaUdf`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]