Github user davies commented on a diff in the pull request:
https://github.com/apache/spark/pull/12087#discussion_r58410362
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala ---
@@ -68,6 +71,69 @@ case class MapPartitions(
}
/**
+ * Applies the given function to each input row and encodes the result.
+ *
+ * TODO: Each serializer expression needs the result object which is
returned by the given function,
+ * as input. This operator uses some tricks to make sure we only calculate
the result object once,
+ * we can use [[Project]] to replace this operator after we make
subexpression elimination work in
+ * whole stage codegen.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator with
CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row:
ExprCode): String = {
+ val (funcClass, methodName) = func match {
+ case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
+ case _ => classOf[Any => Any] -> "apply"
+ }
+ val funcObj = Literal.create(func, ObjectType(funcClass))
+ val resultObjType = serializer.head.collect { case b: BoundReference
=> b }.head.dataType
+ val callFunc = Invoke(funcObj, methodName, resultObjType,
Seq(deserializer))
+
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(callFunc, child.output))
+ ctx.currentVars = input
+ val evaluated = bound.gen(ctx)
+
+ val resultObj = LambdaVariable(evaluated.value, evaluated.isNull,
resultObjType)
+ val outputFields = serializer.map(_ transform {
+ case _: BoundReference => resultObj
+ })
+ val resultVars = outputFields.map(_.gen(ctx))
+ s"""
+ ${evaluated.code}
+ ${consume(ctx, resultVars)}
+ """
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val callFunc: Any => Any = func match {
+ case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any,
Any]].call(i)
+ case _ => func.asInstanceOf[Any => Any]
+ }
+ child.execute().mapPartitionsInternal { iter =>
+ val getObject = generateToObject(deserializer, child.output)
+ val outputObject = generateToRow(serializer)
+ iter.map(getObject).map(callFunc).map(outputObject)
--- End diff --
It's better to combine them into a single `map` to avoid unnecessary
iterators.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]