Repository: spark
Updated Branches:
  refs/heads/master 20ea54cc7 -> 4f01265f7


[SPARK-3786] [PySpark] speedup tests

This patch try to speed up tests of PySpark, re-use the SparkContext in 
tests.py and mllib/tests.py to reduce the overhead of create SparkContext, 
remove some test cases, which did not make sense. It also improve the 
performance of some cases, such as MergerTests and SortTests.

before this patch:

real    21m27.320s
user    4m42.967s
sys     0m17.343s

after this patch:

real    9m47.541s
user    2m12.947s
sys     0m14.543s

It almost cut the time by half.

Author: Davies Liu <[email protected]>

Closes #2646 from davies/tests and squashes the following commits:

c54de60 [Davies Liu] revert change about memory limit
6a2a4b0 [Davies Liu] refactor of tests, speedup 100%


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f01265f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f01265f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f01265f

Branch: refs/heads/master
Commit: 4f01265f7d62e070ba42c251255e385644c1b16c
Parents: 20ea54c
Author: Davies Liu <[email protected]>
Authored: Mon Oct 6 14:07:53 2014 -0700
Committer: Josh Rosen <[email protected]>
Committed: Mon Oct 6 14:07:53 2014 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/tests.py |  2 +-
 python/pyspark/shuffle.py     |  5 +--
 python/pyspark/tests.py       | 92 +++++++++++++++++---------------------
 python/run-tests              | 74 +++++++++++++++---------------
 4 files changed, 82 insertions(+), 91 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f01265f/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f72e88b..5c20e10 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -32,7 +32,7 @@ else:
 from pyspark.serializers import PickleSerializer
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, 
_convert_to_vector
 from pyspark.mllib.regression import LabeledPoint
-from pyspark.tests import PySparkTestCase
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
 
 
 _have_scipy = False

http://git-wip-us.apache.org/repos/asf/spark/blob/4f01265f/python/pyspark/shuffle.py
----------------------------------------------------------------------
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index ce597cb..d57a802 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -396,7 +396,6 @@ class ExternalMerger(Merger):
                 for v in self.data.iteritems():
                     yield v
                 self.data.clear()
-                gc.collect()
 
                 # remove the merged partition
                 for j in range(self.spills):
@@ -428,7 +427,7 @@ class ExternalMerger(Merger):
             subdirs = [os.path.join(d, "parts", str(i))
                        for d in self.localdirs]
             m = ExternalMerger(self.agg, self.memory_limit, self.serializer,
-                               subdirs, self.scale * self.partitions)
+                               subdirs, self.scale * self.partitions, 
self.partitions)
             m.pdata = [{} for _ in range(self.partitions)]
             limit = self._next_limit()
 
@@ -486,7 +485,7 @@ class ExternalSorter(object):
         goes above the limit.
         """
         global MemoryBytesSpilled, DiskBytesSpilled
-        batch = 10
+        batch = 100
         chunks, current_chunk = [], []
         iterator = iter(iterator)
         while True:

http://git-wip-us.apache.org/repos/asf/spark/blob/4f01265f/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6fb6bc9..7f05d48 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -67,10 +67,10 @@ except:
 SPARK_HOME = os.environ["SPARK_HOME"]
 
 
-class TestMerger(unittest.TestCase):
+class MergerTests(unittest.TestCase):
 
     def setUp(self):
-        self.N = 1 << 16
+        self.N = 1 << 14
         self.l = [i for i in xrange(self.N)]
         self.data = zip(self.l, self.l)
         self.agg = Aggregator(lambda x: [x],
@@ -115,7 +115,7 @@ class TestMerger(unittest.TestCase):
                          sum(xrange(self.N)) * 3)
 
     def test_huge_dataset(self):
-        m = ExternalMerger(self.agg, 10)
+        m = ExternalMerger(self.agg, 10, partitions=3)
         m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10))
         self.assertTrue(m.spills >= 1)
         self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)),
@@ -123,7 +123,7 @@ class TestMerger(unittest.TestCase):
         m._cleanup()
 
 
-class TestSorter(unittest.TestCase):
+class SorterTests(unittest.TestCase):
     def test_in_memory_sort(self):
         l = range(1024)
         random.shuffle(l)
@@ -244,16 +244,25 @@ class PySparkTestCase(unittest.TestCase):
         sys.path = self._old_sys_path
 
 
-class TestCheckpoint(PySparkTestCase):
+class ReusedPySparkTestCase(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.sc.stop()
+
+
+class CheckpointTests(ReusedPySparkTestCase):
 
     def setUp(self):
-        PySparkTestCase.setUp(self)
         self.checkpointDir = tempfile.NamedTemporaryFile(delete=False)
         os.unlink(self.checkpointDir.name)
         self.sc.setCheckpointDir(self.checkpointDir.name)
 
     def tearDown(self):
-        PySparkTestCase.tearDown(self)
         shutil.rmtree(self.checkpointDir.name)
 
     def test_basic_checkpointing(self):
@@ -288,7 +297,7 @@ class TestCheckpoint(PySparkTestCase):
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 
-class TestAddFile(PySparkTestCase):
+class AddFileTests(PySparkTestCase):
 
     def test_add_py_file(self):
         # To ensure that we're actually testing addPyFile's effects, check that
@@ -354,7 +363,7 @@ class TestAddFile(PySparkTestCase):
         self.assertEqual(["My Server"], 
self.sc.parallelize(range(1)).map(func).collect())
 
 
-class TestRDDFunctions(PySparkTestCase):
+class RDDTests(ReusedPySparkTestCase):
 
     def test_id(self):
         rdd = self.sc.parallelize(range(10))
@@ -365,12 +374,6 @@ class TestRDDFunctions(PySparkTestCase):
         self.assertEqual(id + 1, id2)
         self.assertEqual(id2, rdd2.id())
 
-    def test_failed_sparkcontext_creation(self):
-        # Regression test for SPARK-1550
-        self.sc.stop()
-        self.assertRaises(Exception, lambda: 
SparkContext("an-invalid-master-name"))
-        self.sc = SparkContext("local")
-
     def test_save_as_textfile_with_unicode(self):
         # Regression test for SPARK-970
         x = u"\u00A1Hola, mundo!"
@@ -636,7 +639,7 @@ class TestRDDFunctions(PySparkTestCase):
         self.assertEquals(result.count(), 3)
 
 
-class TestProfiler(PySparkTestCase):
+class ProfilerTests(PySparkTestCase):
 
     def setUp(self):
         self._old_sys_path = list(sys.path)
@@ -666,10 +669,9 @@ class TestProfiler(PySparkTestCase):
         self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
 
 
-class TestSQL(PySparkTestCase):
+class SQLTests(ReusedPySparkTestCase):
 
     def setUp(self):
-        PySparkTestCase.setUp(self)
         self.sqlCtx = SQLContext(self.sc)
 
     def test_udf(self):
@@ -754,27 +756,19 @@ class TestSQL(PySparkTestCase):
         self.assertEqual("2", row.d)
 
 
-class TestIO(PySparkTestCase):
-
-    def test_stdout_redirection(self):
-        import subprocess
-
-        def func(x):
-            subprocess.check_call('ls', shell=True)
-        self.sc.parallelize([1]).foreach(func)
+class InputFormatTests(ReusedPySparkTestCase):
 
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+        
cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, 
cls.sc._jsc)
 
-class TestInputFormat(PySparkTestCase):
-
-    def setUp(self):
-        PySparkTestCase.setUp(self)
-        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
-        os.unlink(self.tempdir.name)
-        
self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, 
self.sc._jsc)
-
-    def tearDown(self):
-        PySparkTestCase.tearDown(self)
-        shutil.rmtree(self.tempdir.name)
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        shutil.rmtree(cls.tempdir.name)
 
     def test_sequencefiles(self):
         basepath = self.tempdir.name
@@ -954,15 +948,13 @@ class TestInputFormat(PySparkTestCase):
         self.assertEqual(maps, em)
 
 
-class TestOutputFormat(PySparkTestCase):
+class OutputFormatTests(ReusedPySparkTestCase):
 
     def setUp(self):
-        PySparkTestCase.setUp(self)
         self.tempdir = tempfile.NamedTemporaryFile(delete=False)
         os.unlink(self.tempdir.name)
 
     def tearDown(self):
-        PySparkTestCase.tearDown(self)
         shutil.rmtree(self.tempdir.name, ignore_errors=True)
 
     def test_sequencefiles(self):
@@ -1243,8 +1235,7 @@ class TestOutputFormat(PySparkTestCase):
             basepath + "/malformed/sequence"))
 
 
-class TestDaemon(unittest.TestCase):
-
+class DaemonTests(unittest.TestCase):
     def connect(self, port):
         from socket import socket, AF_INET, SOCK_STREAM
         sock = socket(AF_INET, SOCK_STREAM)
@@ -1290,7 +1281,7 @@ class TestDaemon(unittest.TestCase):
         self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
 
 
-class TestWorker(PySparkTestCase):
+class WorkerTests(PySparkTestCase):
 
     def test_cancel_task(self):
         temp = tempfile.NamedTemporaryFile(delete=True)
@@ -1342,11 +1333,6 @@ class TestWorker(PySparkTestCase):
         rdd = self.sc.parallelize(range(100), 1)
         self.assertEqual(100, rdd.map(str).count())
 
-    def test_fd_leak(self):
-        N = 1100  # fd limit is 1024 by default
-        rdd = self.sc.parallelize(range(N), N)
-        self.assertEquals(N, rdd.count())
-
     def test_after_exception(self):
         def raise_exception(_):
             raise Exception()
@@ -1379,7 +1365,7 @@ class TestWorker(PySparkTestCase):
         self.assertEqual(sum(range(100)), acc1.value)
 
 
-class TestSparkSubmit(unittest.TestCase):
+class SparkSubmitTests(unittest.TestCase):
 
     def setUp(self):
         self.programDir = tempfile.mkdtemp()
@@ -1492,6 +1478,8 @@ class TestSparkSubmit(unittest.TestCase):
             |sc = SparkContext()
             |print sc.parallelize([1, 2, 3]).map(foo).collect()
             """)
+        # this will fail if you have different spark.executor.memory
+        # in conf/spark-defaults.conf
         proc = subprocess.Popen(
             [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script],
             stdout=subprocess.PIPE)
@@ -1500,7 +1488,11 @@ class TestSparkSubmit(unittest.TestCase):
         self.assertIn("[2, 4, 6]", out)
 
 
-class ContextStopTests(unittest.TestCase):
+class ContextTests(unittest.TestCase):
+
+    def test_failed_sparkcontext_creation(self):
+        # Regression test for SPARK-1550
+        self.assertRaises(Exception, lambda: 
SparkContext("an-invalid-master-name"))
 
     def test_stop(self):
         sc = SparkContext()

http://git-wip-us.apache.org/repos/asf/spark/blob/4f01265f/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index a7ec270..c713861 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -34,7 +34,7 @@ rm -rf metastore warehouse
 function run_test() {
     echo "Running test: $1"
 
-    SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+    SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
 
     FAILED=$((PIPESTATUS[0]||$FAILED))
 
@@ -48,6 +48,37 @@ function run_test() {
     fi
 }
 
+function run_core_tests() {
+    echo "Run core tests ..."
+    run_test "pyspark/rdd.py"
+    run_test "pyspark/context.py"
+    run_test "pyspark/conf.py"
+    PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
+    PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
+    PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+    run_test "pyspark/shuffle.py"
+    run_test "pyspark/tests.py"
+}
+
+function run_sql_tests() {
+    echo "Run sql tests ..."
+    run_test "pyspark/sql.py"
+}
+
+function run_mllib_tests() {
+    echo "Run mllib tests ..."
+    run_test "pyspark/mllib/classification.py"
+    run_test "pyspark/mllib/clustering.py"
+    run_test "pyspark/mllib/linalg.py"
+    run_test "pyspark/mllib/random.py"
+    run_test "pyspark/mllib/recommendation.py"
+    run_test "pyspark/mllib/regression.py"
+    run_test "pyspark/mllib/stat.py"
+    run_test "pyspark/mllib/tree.py"
+    run_test "pyspark/mllib/util.py"
+    run_test "pyspark/mllib/tests.py"
+}
+
 echo "Running PySpark tests. Output is in python/unit-tests.log."
 
 export PYSPARK_PYTHON="python"
@@ -60,29 +91,9 @@ fi
 echo "Testing with Python version:"
 $PYSPARK_PYTHON --version
 
-run_test "pyspark/rdd.py"
-run_test "pyspark/context.py"
-run_test "pyspark/conf.py"
-run_test "pyspark/sql.py"
-# These tests are included in the module-level docs, and so must
-# be handled on a higher level rather than within the python file.
-export PYSPARK_DOC_TEST=1
-run_test "pyspark/broadcast.py"
-run_test "pyspark/accumulators.py"
-run_test "pyspark/serializers.py"
-unset PYSPARK_DOC_TEST
-run_test "pyspark/shuffle.py"
-run_test "pyspark/tests.py"
-run_test "pyspark/mllib/classification.py"
-run_test "pyspark/mllib/clustering.py"
-run_test "pyspark/mllib/linalg.py"
-run_test "pyspark/mllib/random.py"
-run_test "pyspark/mllib/recommendation.py"
-run_test "pyspark/mllib/regression.py"
-run_test "pyspark/mllib/stat.py"
-run_test "pyspark/mllib/tests.py"
-run_test "pyspark/mllib/tree.py"
-run_test "pyspark/mllib/util.py"
+run_core_tests
+run_sql_tests
+run_mllib_tests
 
 # Try to test with PyPy
 if [ $(which pypy) ]; then
@@ -90,19 +101,8 @@ if [ $(which pypy) ]; then
     echo "Testing with PyPy version:"
     $PYSPARK_PYTHON --version
 
-    run_test "pyspark/rdd.py"
-    run_test "pyspark/context.py"
-    run_test "pyspark/conf.py"
-    run_test "pyspark/sql.py"
-    # These tests are included in the module-level docs, and so must
-    # be handled on a higher level rather than within the python file.
-    export PYSPARK_DOC_TEST=1
-    run_test "pyspark/broadcast.py"
-    run_test "pyspark/accumulators.py"
-    run_test "pyspark/serializers.py"
-    unset PYSPARK_DOC_TEST
-    run_test "pyspark/shuffle.py"
-    run_test "pyspark/tests.py"
+    run_core_tests
+    run_sql_tests
 fi
 
 if [[ $FAILED == 0 ]]; then


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to