cloud-fan commented on code in PR #46697:
URL: https://github.com/apache/spark/pull/46697#discussion_r1609739643


##########
core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala:
##########
@@ -104,12 +104,37 @@ private[spark] object SerDeUtil extends Logging {
     }
   }
 
+  /**
+   * Use a fixed batch size
+   */
+  private[spark] class FixedBatchedPickler(iter: Iterator[Any], batch: 
Integer) extends Iterator[Array[Byte]] {
+    private val pickle = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
+    private val buffer = new mutable.ArrayBuffer[Any]
+
+    override def hasNext: Boolean = iter.hasNext
+
+    override def next(): Array[Byte] = {
+      while (iter.hasNext && buffer.length < batch) {
+        buffer += iter.next()
+      }
+      val bytes = pickle.dumps(buffer.toArray)
+      buffer.clear()
+      bytes
+    }
+  }
+
   /**
    * Convert an RDD of Java objects to an RDD of serialized Python objects, 
that is usable by
    * PySpark.
+   * @param batchSize if non-zero, the fixed batch size to use when pickling 
objects
    */
-  def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
-    jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+  def javaToPython(jRDD: JavaRDD[_], batchSize: Integer = 0): 
JavaRDD[Array[Byte]] = {
+    if (batchSize == 0) {

Review Comment:
   shall we fail if `batchSize < 0`?



##########
core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala:
##########
@@ -104,12 +104,37 @@ private[spark] object SerDeUtil extends Logging {
     }
   }
 
+  /**
+   * Use a fixed batch size
+   */
+  private[spark] class FixedBatchedPickler(iter: Iterator[Any], batch: 
Integer) extends Iterator[Array[Byte]] {
+    private val pickle = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
+    private val buffer = new mutable.ArrayBuffer[Any]
+
+    override def hasNext: Boolean = iter.hasNext
+
+    override def next(): Array[Byte] = {
+      while (iter.hasNext && buffer.length < batch) {
+        buffer += iter.next()
+      }
+      val bytes = pickle.dumps(buffer.toArray)
+      buffer.clear()
+      bytes
+    }
+  }
+
   /**
    * Convert an RDD of Java objects to an RDD of serialized Python objects, 
that is usable by
    * PySpark.
+   * @param batchSize if non-zero, the fixed batch size to use when pickling 
objects
    */
-  def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
-    jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+  def javaToPython(jRDD: JavaRDD[_], batchSize: Integer = 0): 
JavaRDD[Array[Byte]] = {

Review Comment:
   ```suggestion
     def javaToPython(jRDD: JavaRDD[_], batchSize: int = 0): 
JavaRDD[Array[Byte]] = {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to