HyukjinKwon commented on a change in pull request #23977: [SPARK-26923][SQL][R]
Refactor ArrowRRunner and RRunner to share one BaseRRunner
URL: https://github.com/apache/spark/pull/23977#discussion_r263334756
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
##########
@@ -546,14 +548,23 @@ case class FlatMapGroupsInRWithArrowExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+ val grouped = GroupedIterator(iter, groupingAttributes,
child.output).filter(_._2.hasNext)
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer,
groupingAttributes)
- val runner = new ArrowRRunner(func, packageNames, broadcastVars,
inputSchema,
- SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY)
- val groupedByRKey = grouped.map { case (key, rowIter) =>
- val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
- (newKey, rowIter)
+ // Iterating over keys is relatively cheap.
+ val keys: Iterator[Array[Byte]] =
+ grouped.map { case (key, rowIter) =>
rowToRBytes(getKey(key).asInstanceOf[Row]) }
+ val groupedByRKey: Iterator[Iterator[InternalRow]] =
+ grouped.map { case (key, rowIter) => rowIter }
+
+ val runner = new ArrowRRunner(func, packageNames, broadcastVars,
inputSchema,
+ SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) {
+ protected override def bufferedWrite(
+ dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream =>
Unit): Unit = {
+ super.bufferedWrite(dataOut)(writeFunc)
+ // Don't forget we're sending keys additionally.
+ keys.foreach(dataOut.write)
Review comment:
Yea, keys are expected to be small. I am also going to try to explain from
scratch to make sure we're synced.
Input iterator:
```
[a, 1]
[b, 2]
[b, 3]
```
Output iterator:
```
([a], Iterator([a, 1])
([b], Iterator([b, 2], [b, 3])
```
Here, I am doing:
```
next() call => ([a], Iterator([a, 1])
converts the key 'a' into R-readable bytes
keeps the converted key in `keys`.
converts Iterator([a, 1]) to Arrow batch
buffers the Arrow batch
next() call => ([b], Iterator([b, 2], [b, 3])
converts the key 'b' into R-readable bytes
keeps the converted key in `keys`.
converts Iterator([b, 2], [b, 3]) to Arrow batch
buffers the Arrow batch
...
Write buffered all Arrow batches
Write all the keys.
```
I tried to only leave Arrow conversion logics in `ArrowRRunner` (which is
shared for both `dapply` and `gapply`), and keeping/writing keys logic is here
separate.
(BTW, I tried to avoid those refactorings cus I know it's difficult to
review such PRs).
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]