HyukjinKwon commented on a change in pull request #23977: [SPARK-26923][SQL][R]
Refactor ArrowRRunner and RRunner to share one BaseRRunner
URL: https://github.com/apache/spark/pull/23977#discussion_r264072521
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
##########
@@ -546,14 +548,23 @@ case class FlatMapGroupsInRWithArrowExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+ val grouped = GroupedIterator(iter, groupingAttributes,
child.output).filter(_._2.hasNext)
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer,
groupingAttributes)
- val runner = new ArrowRRunner(func, packageNames, broadcastVars,
inputSchema,
- SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY)
- val groupedByRKey = grouped.map { case (key, rowIter) =>
- val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
- (newKey, rowIter)
+ // Iterating over keys is relatively cheap.
+ val keys: Iterator[Array[Byte]] =
+ grouped.map { case (key, rowIter) =>
rowToRBytes(getKey(key).asInstanceOf[Row]) }
+ val groupedByRKey: Iterator[Iterator[InternalRow]] =
+ grouped.map { case (key, rowIter) => rowIter }
+
+ val runner = new ArrowRRunner(func, packageNames, broadcastVars,
inputSchema,
+ SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) {
+ protected override def bufferedWrite(
+ dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream =>
Unit): Unit = {
+ super.bufferedWrite(dataOut)(writeFunc)
+ // Don't forget we're sending keys additionally.
+ keys.foreach(dataOut.write)
Review comment:
Oh, right. The types look confusing.
Input iterator:
`Iterator[InternalRow]`
Output Iterator
`Iterator[(InternalRow, Iterator[InternalRow])]`
```bash
next() call => ([a], Iterator([a, 1]) # (InternalRow,
Iterator[InternalRow])
1. converts the key 'a' into R-readable bytes # 'a' becomes InternalRow
-> Array[Byte] (see [1])
2. keeps the converted key in `keys` # 'keys` is
mutable.ArrayBuffer[Array[Byte] (see [1])
3. converts Iterator([a, 1]) to Arrow batch # Iterator[InternalRow] ->
Array[Byte]
4. buffers the Arrow batch # buffer: Array[Byte] +
Arrow batch: Array[Byte]
...
Write buffered all Arrow batches # See [2]
Write all the keys. # See [3]
```
[1]
https://github.com/apache/spark/pull/23977/files#diff-a1240a7ba20d0e027d23a690a770bf44R557
[2]
https://github.com/apache/spark/pull/23977/files#diff-a0b6a11cc2e2299455c795fe3c96b823R71
[3]
https://github.com/apache/spark/pull/23977/files#diff-a1240a7ba20d0e027d23a690a770bf44R567
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]