Repository: spark
Updated Branches:
refs/heads/master ab76900fe -> ecaa495b1
[SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record
batches to improve performance
## What changes were proposed in this pull request?
When executing `toPandas` with Arrow enabled, partitions that arrive in the JVM
out-of-order must be buffered before they can be send to Python. This causes an
excess of memory to be used in the driver JVM and increases the time it takes
to complete because data must sit in the JVM waiting for preceding partitions
to come in.
This change sends un-ordered partitions to Python as soon as they arrive in the
JVM, followed by a list of partition indices so that Python can assemble the
data in the correct order. This way, data is not buffered at the JVM and there
is no waiting on particular partitions so performance will be increased.
Followup to #21546
## How was this patch tested?
Added new test with a large number of batches per partition, and test that
forces a small delay in the first partition. These test that partitions are
collected out-of-order and then are are put in the correct order in Python.
## Performance Tests - toPandas
Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu
and OpenJDK 8
measured wall clock time to execute `toPandas()` and took the average best time
of 5 runs/5 loops each.
Test code
```python
df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1",
rand()).withColumn("x2", rand()).withColumn("x3", rand()).withColumn("x4",
rand())
for i in range(5):
start = time.time()
_ = df.toPandas()
elapsed = time.time() - start
```
Spark config
```
spark.driver.memory 5g
spark.executor.memory 5g
spark.driver.maxResultSize 2g
spark.sql.execution.arrow.enabled true
```
Current Master w/ Arrow stream | This PR
---------------------|------------
5.16207 | 4.342533
5.133671 | 4.399408
5.147513 | 4.468471
5.105243 | 4.36524
5.018685 | 4.373791
Avg Master | Avg This PR
------------------|--------------
5.1134364 | 4.3898886
Speedup of **1.164821449**
Closes #22275 from BryanCutler/arrow-toPandas-oo-batches-SPARK-25274.
Authored-by: Bryan Cutler <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ecaa495b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ecaa495b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ecaa495b
Branch: refs/heads/master
Commit: ecaa495b1fe532c36e952ccac42f4715809476af
Parents: ab76900
Author: Bryan Cutler <[email protected]>
Authored: Thu Dec 6 10:07:28 2018 -0800
Committer: Bryan Cutler <[email protected]>
Committed: Thu Dec 6 10:07:28 2018 -0800
----------------------------------------------------------------------
python/pyspark/serializers.py | 33 ++++++++++++++
python/pyspark/sql/dataframe.py | 11 ++++-
python/pyspark/sql/tests/test_arrow.py | 28 ++++++++++++
.../scala/org/apache/spark/sql/Dataset.scala | 45 +++++++++++---------
4 files changed, 95 insertions(+), 22 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ff9a612..f3ebd37 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -185,6 +185,39 @@ class FramedSerializer(Serializer):
raise NotImplementedError
+class ArrowCollectSerializer(Serializer):
+ """
+ Deserialize a stream of batches followed by batch order information. Used
in
+ DataFrame._collectAsArrow() after invoking
Dataset.collectAsArrowToPython() in the JVM.
+ """
+
+ def __init__(self):
+ self.serializer = ArrowStreamSerializer()
+
+ def dump_stream(self, iterator, stream):
+ return self.serializer.dump_stream(iterator, stream)
+
+ def load_stream(self, stream):
+ """
+ Load a stream of un-ordered Arrow RecordBatches, where the last
iteration yields
+ a list of indices that can be used to put the RecordBatches in the
correct order.
+ """
+ # load the batches
+ for batch in self.serializer.load_stream(stream):
+ yield batch
+
+ # load the batch order indices
+ num = read_int(stream)
+ batch_order = []
+ for i in xrange(num):
+ index = read_int(stream)
+ batch_order.append(index)
+ yield batch_order
+
+ def __repr__(self):
+ return "ArrowCollectSerializer(%s)" % self.serializer
+
+
class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1b1092c..a1056d0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,7 +29,7 @@ import warnings
from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
-from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer,
PickleSerializer, \
+from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer,
PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -2168,7 +2168,14 @@ class DataFrame(object):
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
- return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
+
+ # Collect list of un-ordered batches where last element is a list of
correct order indices
+ results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
+ batches = results[:-1]
+ batch_order = results[-1]
+
+ # Re-order the batch list using the correct order
+ return [batches[i] for i in batch_order]
##########################################################################################
# Pandas compatibility
http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/sql/tests/test_arrow.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 6e75e82..21fe500 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -381,6 +381,34 @@ class ArrowTests(ReusedSQLTestCase):
self.assertPandasEqual(pdf, df_from_python.toPandas())
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
+ def test_toPandas_batch_order(self):
+
+ def delay_first_part(partition_index, iterator):
+ if partition_index == 0:
+ time.sleep(0.1)
+ return iterator
+
+ # Collects Arrow RecordBatches out of order in driver JVM then
re-orders in Python
+ def run_test(num_records, num_parts, max_records, use_delay=False):
+ df = self.spark.range(num_records,
numPartitions=num_parts).toDF("a")
+ if use_delay:
+ df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
+ with
self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
+ pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+ self.assertPandasEqual(pdf, pdf_arrow)
+
+ cases = [
+ (1024, 512, 2), # Use large num partitions for more likely
collecting out of order
+ (64, 8, 2, True), # Use delay in first partition to force
collecting out of order
+ (64, 64, 1), # Test single batch per partition
+ (64, 1, 64), # Test single partition, single batch
+ (64, 1, 8), # Test single partition, multiple batches
+ (30, 7, 2), # Test different sized partitions
+ ]
+
+ for case in cases:
+ run_test(*case)
+
class EncryptionArrowTests(ArrowTests):
http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b10d66d..a664c73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql
-import java.io.CharArrayWriter
+import java.io.{CharArrayWriter, DataOutputStream}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.util.control.NonFatal
@@ -3200,34 +3201,38 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
withAction("collectAsArrowToPython", queryExecution) { plan =>
- PythonRDD.serveToStream("serve-Arrow") { out =>
+ PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+ val out = new DataOutputStream(outputStream)
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
- // Store collection results for worst case of 1 to N-1 partitions
- val results = new Array[Array[Array[Byte]]](numPartitions - 1)
- var lastIndex = -1 // index of last partition written
+ // Batches ordered by (index of partition, batch index in that
partition) tuple
+ val batchOrder = new ArrayBuffer[(Int, Int)]()
+ var partitionCount = 0
- // Handler to eagerly write partitions to Python in order
+ // Handler to eagerly write batches to Python as they arrive,
un-ordered
def handlePartitionBatches(index: Int, arrowBatches:
Array[Array[Byte]]): Unit = {
- // If result is from next partition in order
- if (index - 1 == lastIndex) {
+ if (arrowBatches.nonEmpty) {
+ // Write all batches (can be more than 1) in the partition, store
the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
- lastIndex += 1
- // Write stored partitions that come next in order
- while (lastIndex < results.length && results(lastIndex) != null) {
- batchWriter.writeBatches(results(lastIndex).iterator)
- results(lastIndex) = null
- lastIndex += 1
+ arrowBatches.indices.foreach {
+ partition_batch_index => batchOrder.append((index,
partition_batch_index))
}
- // After last batch, end the stream
- if (lastIndex == results.length) {
- batchWriter.end()
+ }
+ partitionCount += 1
+
+ // After last batch, end the stream and write batch order indices
+ if (partitionCount == numPartitions) {
+ batchWriter.end()
+ out.writeInt(batchOrder.length)
+ // Sort by (index of partition, batch index in that partition)
tuple to get the
+ // overall_batch_index from 0 to N-1 batches, which can be used to
put the
+ // transferred batches in the correct order
+ batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_,
overall_batch_index) =>
+ out.writeInt(overall_batch_index)
}
- } else {
- // Store partitions received out of order
- results(index - 1) = arrowBatches
+ out.flush()
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]