Repository: spark
Updated Branches:
  refs/heads/master ab76900fe -> ecaa495b1


[SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record 
batches to improve performance

## What changes were proposed in this pull request?

When executing `toPandas` with Arrow enabled, partitions that arrive in the JVM 
out-of-order must be buffered before they can be send to Python. This causes an 
excess of memory to be used in the driver JVM and increases the time it takes 
to complete because data must sit in the JVM waiting for preceding partitions 
to come in.

This change sends un-ordered partitions to Python as soon as they arrive in the 
JVM, followed by a list of partition indices so that Python can assemble the 
data in the correct order. This way, data is not buffered at the JVM and there 
is no waiting on particular partitions so performance will be increased.

Followup to #21546

## How was this patch tested?

Added new test with a large number of batches per partition, and test that 
forces a small delay in the first partition. These test that partitions are 
collected out-of-order and then are are put in the correct order in Python.

## Performance Tests - toPandas

Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu 
and OpenJDK 8
measured wall clock time to execute `toPandas()` and took the average best time 
of 5 runs/5 loops each.

Test code
```python
df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", 
rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4", 
rand())
for i in range(5):
        start = time.time()
        _ = df.toPandas()
        elapsed = time.time() - start
```

Spark config
```
spark.driver.memory 5g
spark.executor.memory 5g
spark.driver.maxResultSize 2g
spark.sql.execution.arrow.enabled true
```

Current Master w/ Arrow stream | This PR
---------------------|------------
5.16207 | 4.342533
5.133671 | 4.399408
5.147513 | 4.468471
5.105243 | 4.36524
5.018685 | 4.373791

Avg Master | Avg This PR
------------------|--------------
5.1134364 | 4.3898886

Speedup of **1.164821449**

Closes #22275 from BryanCutler/arrow-toPandas-oo-batches-SPARK-25274.

Authored-by: Bryan Cutler <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ecaa495b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ecaa495b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ecaa495b

Branch: refs/heads/master
Commit: ecaa495b1fe532c36e952ccac42f4715809476af
Parents: ab76900
Author: Bryan Cutler <[email protected]>
Authored: Thu Dec 6 10:07:28 2018 -0800
Committer: Bryan Cutler <[email protected]>
Committed: Thu Dec 6 10:07:28 2018 -0800

----------------------------------------------------------------------
 python/pyspark/serializers.py                   | 33 ++++++++++++++
 python/pyspark/sql/dataframe.py                 | 11 ++++-
 python/pyspark/sql/tests/test_arrow.py          | 28 ++++++++++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 45 +++++++++++---------
 4 files changed, 95 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ff9a612..f3ebd37 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -185,6 +185,39 @@ class FramedSerializer(Serializer):
         raise NotImplementedError
 
 
+class ArrowCollectSerializer(Serializer):
+    """
+    Deserialize a stream of batches followed by batch order information. Used 
in
+    DataFrame._collectAsArrow() after invoking 
Dataset.collectAsArrowToPython() in the JVM.
+    """
+
+    def __init__(self):
+        self.serializer = ArrowStreamSerializer()
+
+    def dump_stream(self, iterator, stream):
+        return self.serializer.dump_stream(iterator, stream)
+
+    def load_stream(self, stream):
+        """
+        Load a stream of un-ordered Arrow RecordBatches, where the last 
iteration yields
+        a list of indices that can be used to put the RecordBatches in the 
correct order.
+        """
+        # load the batches
+        for batch in self.serializer.load_stream(stream):
+            yield batch
+
+        # load the batch order indices
+        num = read_int(stream)
+        batch_order = []
+        for i in xrange(num):
+            index = read_int(stream)
+            batch_order.append(index)
+        yield batch_order
+
+    def __repr__(self):
+        return "ArrowCollectSerializer(%s)" % self.serializer
+
+
 class ArrowStreamSerializer(Serializer):
     """
     Serializes Arrow record batches as a stream.

http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1b1092c..a1056d0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,7 +29,7 @@ import warnings
 
 from pyspark import copy_func, since, _NoValue
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
-from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, 
PickleSerializer, \
+from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, 
PickleSerializer, \
     UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
@@ -2168,7 +2168,14 @@ class DataFrame(object):
         """
         with SCCallSiteSync(self._sc) as css:
             sock_info = self._jdf.collectAsArrowToPython()
-        return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
+
+        # Collect list of un-ordered batches where last element is a list of 
correct order indices
+        results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
+        batches = results[:-1]
+        batch_order = results[-1]
+
+        # Re-order the batch list using the correct order
+        return [batches[i] for i in batch_order]
 
     
##########################################################################################
     # Pandas compatibility

http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/sql/tests/test_arrow.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 6e75e82..21fe500 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -381,6 +381,34 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertPandasEqual(pdf, df_from_python.toPandas())
         self.assertPandasEqual(pdf, df_from_pandas.toPandas())
 
+    def test_toPandas_batch_order(self):
+
+        def delay_first_part(partition_index, iterator):
+            if partition_index == 0:
+                time.sleep(0.1)
+            return iterator
+
+        # Collects Arrow RecordBatches out of order in driver JVM then 
re-orders in Python
+        def run_test(num_records, num_parts, max_records, use_delay=False):
+            df = self.spark.range(num_records, 
numPartitions=num_parts).toDF("a")
+            if use_delay:
+                df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
+            with 
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
+                pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+                self.assertPandasEqual(pdf, pdf_arrow)
+
+        cases = [
+            (1024, 512, 2),    # Use large num partitions for more likely 
collecting out of order
+            (64, 8, 2, True),  # Use delay in first partition to force 
collecting out of order
+            (64, 64, 1),       # Test single batch per partition
+            (64, 1, 64),       # Test single partition, single batch
+            (64, 1, 8),        # Test single partition, multiple batches
+            (30, 7, 2),        # Test different sized partitions
+        ]
+
+        for case in cases:
+            run_test(*case)
+
 
 class EncryptionArrowTests(ArrowTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b10d66d..a664c73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql
 
-import java.io.CharArrayWriter
+import java.io.{CharArrayWriter, DataOutputStream}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.language.implicitConversions
 import scala.util.control.NonFatal
 
@@ -3200,34 +3201,38 @@ class Dataset[T] private[sql](
     val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
 
     withAction("collectAsArrowToPython", queryExecution) { plan =>
-      PythonRDD.serveToStream("serve-Arrow") { out =>
+      PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+        val out = new DataOutputStream(outputStream)
         val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
         val arrowBatchRdd = toArrowBatchRdd(plan)
         val numPartitions = arrowBatchRdd.partitions.length
 
-        // Store collection results for worst case of 1 to N-1 partitions
-        val results = new Array[Array[Array[Byte]]](numPartitions - 1)
-        var lastIndex = -1  // index of last partition written
+        // Batches ordered by (index of partition, batch index in that 
partition) tuple
+        val batchOrder = new ArrayBuffer[(Int, Int)]()
+        var partitionCount = 0
 
-        // Handler to eagerly write partitions to Python in order
+        // Handler to eagerly write batches to Python as they arrive, 
un-ordered
         def handlePartitionBatches(index: Int, arrowBatches: 
Array[Array[Byte]]): Unit = {
-          // If result is from next partition in order
-          if (index - 1 == lastIndex) {
+          if (arrowBatches.nonEmpty) {
+            // Write all batches (can be more than 1) in the partition, store 
the batch order tuple
             batchWriter.writeBatches(arrowBatches.iterator)
-            lastIndex += 1
-            // Write stored partitions that come next in order
-            while (lastIndex < results.length && results(lastIndex) != null) {
-              batchWriter.writeBatches(results(lastIndex).iterator)
-              results(lastIndex) = null
-              lastIndex += 1
+            arrowBatches.indices.foreach {
+              partition_batch_index => batchOrder.append((index, 
partition_batch_index))
             }
-            // After last batch, end the stream
-            if (lastIndex == results.length) {
-              batchWriter.end()
+          }
+          partitionCount += 1
+
+          // After last batch, end the stream and write batch order indices
+          if (partitionCount == numPartitions) {
+            batchWriter.end()
+            out.writeInt(batchOrder.length)
+            // Sort by (index of partition, batch index in that partition) 
tuple to get the
+            // overall_batch_index from 0 to N-1 batches, which can be used to 
put the
+            // transferred batches in the correct order
+            batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, 
overall_batch_index) =>
+              out.writeInt(overall_batch_index)
             }
-          } else {
-            // Store partitions received out of order
-            results(index - 1) = arrowBatches
+            out.flush()
           }
         }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to