This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c4eb99  [SPARK-27870][SQL][PYSPARK] Flush batch timely for pandas UDF 
(for improving pandas UDFs pipeline)
9c4eb99 is described below

commit 9c4eb99c52803f2488ac3787672aa8d3e4d1544e
Author: WeichenXu <weichen...@databricks.com>
AuthorDate: Fri Jun 7 14:02:43 2019 -0700

    [SPARK-27870][SQL][PYSPARK] Flush batch timely for pandas UDF (for 
improving pandas UDFs pipeline)
    
    ## What changes were proposed in this pull request?
    
    Flush batch timely for pandas UDF.
    
    This could improve performance when multiple pandas UDF plans are pipelined.
    
    When batch being flushed in time, downstream pandas UDFs will get pipelined 
as soon as possible, and pipeline will help hide the donwstream UDFs 
computation time. For example:
    
    When the first UDF start computing on batch-3, the second pipelined UDF can 
start computing on batch-2, and the third pipelined UDF can start computing on 
batch-1.
    
    If we do not flush each batch in time, the donwstream UDF's pipeline will 
lag behind too much, which may increase the total processing time.
    
    I add flush at two places:
    * JVM process feed data into python worker. In jvm side, when write one 
batch, flush it
    * VM process read data from python worker output, In python worker side, 
when write one batch, flush it
    
    If no flush, the default buffer size for them are both 65536. Especially in 
the ML case, in order to make realtime prediction, we will make batch size very 
small. The buffer size is too large for the case, which cause downstream pandas 
UDF pipeline lag behind too much.
    
    ### Note
    * This is only applied to pandas scalar UDF.
    * Do not flush for each batch. The minimum interval between two flush is 
0.1 second. This avoid too frequent flushing when batch size is small. It works 
like:
    ```
            last_flush_time = time.time()
            for batch in iterator:
                    writer.write_batch(batch)
                    flush_time = time.time()
                    if self.flush_timely and (flush_time - last_flush_time > 
0.1):
                          stream.flush()
                          last_flush_time = flush_time
    ```
    
    ## How was this patch tested?
    
    ### Benchmark to make sure the flush do not cause performance regression
    #### Test code:
    ```
    numRows = ...
    batchSize = ...
    
    spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', 
str(batchSize))
    df = spark.range(1, numRows + 1, 
numPartitions=1).select(col('id').alias('a'))
    
    pandas_udf("int", PandasUDFType.SCALAR)
    def fp1(x):
        return x + 10
    
    beg_time = time.time()
    result = df.select(sum(fp1('a'))).head()
    print("result: " + str(result[0]))
    print("consume time: " + str(time.time() - beg_time))
    ```
    #### Test Result:
    
     params        | Consume time (Before) | Consume time (After)
    ------------ | ----------------------- | ----------------------
    numRows=100000000, batchSize=10000 | 23.43s | 24.64s
    numRows=100000000, batchSize=1000 | 36.73s | 34.50s
    numRows=10000000, batchSize=100 | 35.67s | 32.64s
    numRows=1000000, batchSize=10 | 33.60s | 32.11s
    numRows=100000, batchSize=1 | 33.36s | 31.82s
    
    ### Benchmark pipelined pandas UDF
    #### Test code:
    ```
    spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1')
    df = spark.range(1, 31, numPartitions=1).select(col('id').alias('a'))
    
    pandas_udf("int", PandasUDFType.SCALAR)
    def fp1(x):
        print("run fp1")
        time.sleep(1)
        return x + 100
    
    pandas_udf("int", PandasUDFType.SCALAR)
    def fp2(x, y):
        print("run fp2")
        time.sleep(1)
        return x + y
    
    beg_time = time.time()
    result = df.select(sum(fp2(fp1('a'), col('a')))).head()
    print("result: " + str(result[0]))
    print("consume time: " + str(time.time() - beg_time))
    
    ```
    #### Test Result:
    
    **Before**: consume time: 63.57s
    **After**: consume time: 32.43s
    **So the PR improve performance by make downstream UDF get pipelined 
early.**
    
    Please review https://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #24734 from WeichenXu123/improve_pandas_udf_pipeline.
    
    Lead-authored-by: WeichenXu <weichen...@databricks.com>
    Co-authored-by: Xiangrui Meng <m...@databricks.com>
    Signed-off-by: gatorsmile <gatorsm...@gmail.com>
---
 python/pyspark/serializers.py                         | 18 ++++++++++++++++--
 python/pyspark/testing/utils.py                       |  3 +++
 python/pyspark/tests/test_serializers.py              | 10 ++++++++++
 .../sql/execution/python/ArrowPythonRunner.scala      | 19 ++++++++++++-------
 4 files changed, 41 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 516ee7e..1b17e60 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -58,6 +58,7 @@ import types
 import collections
 import zlib
 import itertools
+import time
 
 if sys.version < '3':
     import cPickle as pickle
@@ -230,11 +231,19 @@ class ArrowStreamSerializer(Serializer):
     def dump_stream(self, iterator, stream):
         import pyarrow as pa
         writer = None
+        last_flush_time = time.time()
         try:
             for batch in iterator:
                 if writer is None:
                     writer = pa.RecordBatchStreamWriter(stream, batch.schema)
                 writer.write_batch(batch)
+                current_time = time.time()
+                # If it takes time to compute each input batch but per-batch 
data is very small,
+                # the data might stay in the buffer for long and downstream 
reader cannot read it.
+                # We want to flush timely in this case.
+                if current_time - last_flush_time > 0.1:
+                    stream.flush()
+                    last_flush_time = current_time
         finally:
             if writer is not None:
                 writer.close()
@@ -872,11 +881,16 @@ class ChunkedStream(object):
                 byte_pos = new_byte_pos
                 self.current_pos = 0
 
-    def close(self):
-        # if there is anything left in the buffer, write it out first
+    def flush(self):
         if self.current_pos > 0:
             write_int(self.current_pos, self.wrapped)
             self.wrapped.write(self.buffer[:self.current_pos])
+            self.current_pos = 0
+        self.wrapped.flush()
+
+    def close(self):
+        # If there is anything left in the buffer, write it out first.
+        self.flush()
         # -1 length indicates to the receiving end that we're done.
         write_int(-1, self.wrapped)
         self.wrapped.close()
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 2b42b89..61c342b 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -99,6 +99,9 @@ class ByteArrayOutput(object):
     def write(self, b):
         self.buffer += b
 
+    def flush(self):
+        pass
+
     def close(self):
         pass
 
diff --git a/python/pyspark/tests/test_serializers.py 
b/python/pyspark/tests/test_serializers.py
index bce9406..498076d 100644
--- a/python/pyspark/tests/test_serializers.py
+++ b/python/pyspark/tests/test_serializers.py
@@ -225,6 +225,16 @@ class SerializersTest(unittest.TestCase):
                 # ends with a -1
                 self.assertEqual(dest.buffer[-4:], write_int(-1))
 
+    def test_chunked_stream_flush(self):
+        wrapped = ByteArrayOutput()
+        stream = serializers.ChunkedStream(wrapped, 10)
+        stream.write(bytearray([0]))
+        self.assertEqual(len(wrapped.buffer), 0, "small write should be 
buffered")
+        stream.flush()
+        # Expect buffer size 4 bytes + buffer data 1 byte.
+        self.assertEqual(len(wrapped.buffer), 5, "flush should work")
+        stream.close()
+
 
 if __name__ == "__main__":
     from pyspark.tests.test_serializers import *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 3710218..ddb65a5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -78,16 +78,21 @@ class ArrowPythonRunner(
           val arrowWriter = ArrowWriter.create(root)
           val writer = new ArrowStreamWriter(root, null, dataOut)
           writer.start()
-
-          while (inputIterator.hasNext) {
-            val nextBatch = inputIterator.next()
-
-            while (nextBatch.hasNext) {
-              arrowWriter.write(nextBatch.next())
+          var lastFlushTime = System.currentTimeMillis()
+          inputIterator.foreach { batch =>
+            batch.foreach { row =>
+              arrowWriter.write(row)
             }
-
             arrowWriter.finish()
             writer.writeBatch()
+            val currentTime = System.currentTimeMillis()
+            // If it takes time to compute each input batch but per-batch data 
is very small,
+            // the data might stay in the buffer for long and downstream 
reader cannot read it.
+            // We want to flush timely in this case.
+            if (currentTime - lastFlushTime > 100) {
+              dataOut.flush()
+              lastFlushTime = currentTime
+            }
             arrowWriter.reset()
           }
           // end writes footer to the output stream and doesn't clean any 
resources.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to