HyukjinKwon commented on a change in pull request #23977: [SPARK-26923][SQL][R] 
Refactor ArrowRRunner and RRunner to share one BaseRRunner
URL: https://github.com/apache/spark/pull/23977#discussion_r263334756
 
 

 ##########
 File path: sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
 ##########
 @@ -546,14 +548,23 @@ case class FlatMapGroupsInRWithArrowExec(
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitionsInternal { iter =>
-      val grouped = GroupedIterator(iter, groupingAttributes, child.output)
+      val grouped = GroupedIterator(iter, groupingAttributes, 
child.output).filter(_._2.hasNext)
       val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, 
groupingAttributes)
-      val runner = new ArrowRRunner(func, packageNames, broadcastVars, 
inputSchema,
-        SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY)
 
-      val groupedByRKey = grouped.map { case (key, rowIter) =>
-        val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
-        (newKey, rowIter)
+      // Iterating over keys is relatively cheap.
+      val keys: Iterator[Array[Byte]] =
+        grouped.map { case (key, rowIter) => 
rowToRBytes(getKey(key).asInstanceOf[Row]) }
+      val groupedByRKey: Iterator[Iterator[InternalRow]] =
+        grouped.map { case (key, rowIter) => rowIter }
+
+      val runner = new ArrowRRunner(func, packageNames, broadcastVars, 
inputSchema,
+        SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) {
+        protected override def bufferedWrite(
+            dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => 
Unit): Unit = {
+          super.bufferedWrite(dataOut)(writeFunc)
+          // Don't forget we're sending keys additionally.
+          keys.foreach(dataOut.write)
 
 Review comment:
   Yea, keys are expected to be small. I am also going to try to explain from 
scratch to make sure we're synced.
   
   Input iterator:
   
   ```
   [a, 1]
   [b, 2]
   [b, 3]
   ```
   
   Output iterator:
   
   ```
   ([a], Iterator([a, 1])
   ([b], Iterator([b, 2], [b, 3])
   ```
   
   Here, I am doing:
   
   ```
   next() call => ([a], Iterator([a, 1])
     keeps the key 'a'
     converts Iterator([a, 1]) to Arrow batch
     buffers the Arrow batch
   next() call => ([b], Iterator([b, 2], [b, 3])
     keeps the key 'a'
     converts Iterator([b, 2], [b, 3]) to Arrow batch
     buffers the Arrow batch
   ...
   
   Write buffered all Arrow batches
   Write kept all the keys.
   ```
   
   I tried to only leave Arrow conversion logics in `ArrowRRunner` (which is 
shared for both `dapply` and `gapply`), and keeping/writing keys logic is here 
separate.
   
   (BTW, I tried to avoid those refactorings cus I know it's difficult to 
review such PRs).
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to