Repository: spark
Updated Branches:
  refs/heads/master 187bb7d00 -> 518a3d10c


[SPARK-26033][SPARK-26034][PYTHON][FOLLOW-UP] Small cleanup and deduplication 
in ml/mllib tests

## What changes were proposed in this pull request?

This PR is a small follow up that puts some logic and functions into smaller 
scope and make it localized, and deduplicate.

## How was this patch tested?

Manually tested. Jenkins tests as well.

Closes #23200 from HyukjinKwon/followup-SPARK-26034-SPARK-26033.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/518a3d10
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/518a3d10
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/518a3d10

Branch: refs/heads/master
Commit: 518a3d10c87bb6d7d442eba7265fc026aa54473e
Parents: 187bb7d
Author: Hyukjin Kwon <[email protected]>
Authored: Mon Dec 3 14:03:10 2018 -0800
Committer: Bryan Cutler <[email protected]>
Committed: Mon Dec 3 14:03:10 2018 -0800

----------------------------------------------------------------------
 python/pyspark/ml/tests/test_linalg.py        | 44 ++++++++-------
 python/pyspark/mllib/tests/test_algorithms.py |  8 +--
 python/pyspark/mllib/tests/test_linalg.py     | 62 +++++++++-------------
 python/pyspark/testing/mllibutils.py          |  5 --
 4 files changed, 51 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/518a3d10/python/pyspark/ml/tests/test_linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests/test_linalg.py 
b/python/pyspark/ml/tests/test_linalg.py
index 71cad5d..995bc35 100644
--- a/python/pyspark/ml/tests/test_linalg.py
+++ b/python/pyspark/ml/tests/test_linalg.py
@@ -20,25 +20,17 @@ import array as pyarray
 
 from numpy import arange, array, array_equal, inf, ones, tile, zeros
 
+from pyspark.serializers import PickleSerializer
 from pyspark.ml.linalg import DenseMatrix, DenseVector, MatrixUDT, 
SparseMatrix, SparseVector, \
     Vector, VectorUDT, Vectors
-from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
+from pyspark.testing.mllibutils import MLlibTestCase
 from pyspark.sql import Row
 
 
-ser = make_serializer()
-
-
-def _squared_distance(a, b):
-    if isinstance(a, Vector):
-        return a.squared_distance(b)
-    else:
-        return b.squared_distance(a)
-
-
 class VectorTests(MLlibTestCase):
 
     def _test_serialize(self, v):
+        ser = PickleSerializer()
         self.assertEqual(v, ser.loads(ser.dumps(v)))
         jvec = 
self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
         nv = 
ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
@@ -77,24 +69,30 @@ class VectorTests(MLlibTestCase):
         self.assertEqual(7.0, sv.dot(arr))
 
     def test_squared_distance(self):
+        def squared_distance(a, b):
+            if isinstance(a, Vector):
+                return a.squared_distance(b)
+            else:
+                return b.squared_distance(a)
+
         sv = SparseVector(4, {1: 1, 3: 2})
         dv = DenseVector(array([1., 2., 3., 4.]))
         lst = DenseVector([4, 3, 2, 1])
         lst1 = [4, 3, 2, 1]
         arr = pyarray.array('d', [0, 2, 1, 3])
         narr = array([0, 2, 1, 3])
-        self.assertEqual(15.0, _squared_distance(sv, dv))
-        self.assertEqual(25.0, _squared_distance(sv, lst))
-        self.assertEqual(20.0, _squared_distance(dv, lst))
-        self.assertEqual(15.0, _squared_distance(dv, sv))
-        self.assertEqual(25.0, _squared_distance(lst, sv))
-        self.assertEqual(20.0, _squared_distance(lst, dv))
-        self.assertEqual(0.0, _squared_distance(sv, sv))
-        self.assertEqual(0.0, _squared_distance(dv, dv))
-        self.assertEqual(0.0, _squared_distance(lst, lst))
-        self.assertEqual(25.0, _squared_distance(sv, lst1))
-        self.assertEqual(3.0, _squared_distance(sv, arr))
-        self.assertEqual(3.0, _squared_distance(sv, narr))
+        self.assertEqual(15.0, squared_distance(sv, dv))
+        self.assertEqual(25.0, squared_distance(sv, lst))
+        self.assertEqual(20.0, squared_distance(dv, lst))
+        self.assertEqual(15.0, squared_distance(dv, sv))
+        self.assertEqual(25.0, squared_distance(lst, sv))
+        self.assertEqual(20.0, squared_distance(lst, dv))
+        self.assertEqual(0.0, squared_distance(sv, sv))
+        self.assertEqual(0.0, squared_distance(dv, dv))
+        self.assertEqual(0.0, squared_distance(lst, lst))
+        self.assertEqual(25.0, squared_distance(sv, lst1))
+        self.assertEqual(3.0, squared_distance(sv, arr))
+        self.assertEqual(3.0, squared_distance(sv, narr))
 
     def test_hash(self):
         v1 = DenseVector([0.0, 1.0, 0.0, 5.5])

http://git-wip-us.apache.org/repos/asf/spark/blob/518a3d10/python/pyspark/mllib/tests/test_algorithms.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests/test_algorithms.py 
b/python/pyspark/mllib/tests/test_algorithms.py
index cc3b64b..21a2d64 100644
--- a/python/pyspark/mllib/tests/test_algorithms.py
+++ b/python/pyspark/mllib/tests/test_algorithms.py
@@ -26,10 +26,8 @@ from py4j.protocol import Py4JJavaError
 from pyspark.mllib.fpm import FPGrowth
 from pyspark.mllib.recommendation import Rating
 from pyspark.mllib.regression import LabeledPoint
-from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
-
-
-ser = make_serializer()
+from pyspark.serializers import PickleSerializer
+from pyspark.testing.mllibutils import MLlibTestCase
 
 
 class ListTests(MLlibTestCase):
@@ -265,6 +263,7 @@ class ListTests(MLlibTestCase):
 class ALSTests(MLlibTestCase):
 
     def test_als_ratings_serialize(self):
+        ser = PickleSerializer()
         r = Rating(7, 1123, 3.14)
         jr = 
self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
         nr = 
ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
@@ -273,6 +272,7 @@ class ALSTests(MLlibTestCase):
         self.assertAlmostEqual(r.rating, nr.rating, 2)
 
     def test_als_ratings_id_long_error(self):
+        ser = PickleSerializer()
         r = Rating(1205640308657491975, 50233468418, 1.0)
         # rating user id exceeds max int value, should fail when pickled
         self.assertRaises(Py4JJavaError, 
self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,

http://git-wip-us.apache.org/repos/asf/spark/blob/518a3d10/python/pyspark/mllib/tests/test_linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests/test_linalg.py 
b/python/pyspark/mllib/tests/test_linalg.py
index d0ebd9b..f26e28d 100644
--- a/python/pyspark/mllib/tests/test_linalg.py
+++ b/python/pyspark/mllib/tests/test_linalg.py
@@ -22,33 +22,18 @@ import unittest
 from numpy import array, array_equal, zeros, arange, tile, ones, inf
 
 import pyspark.ml.linalg as newlinalg
+from pyspark.serializers import PickleSerializer
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, 
_convert_to_vector, \
     DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
 from pyspark.mllib.regression import LabeledPoint
-from pyspark.testing.mllibutils import make_serializer, MLlibTestCase
-
-_have_scipy = False
-try:
-    import scipy.sparse
-    _have_scipy = True
-except:
-    # No SciPy, but that's okay, we'll skip those tests
-    pass
-
-
-ser = make_serializer()
-
-
-def _squared_distance(a, b):
-    if isinstance(a, Vector):
-        return a.squared_distance(b)
-    else:
-        return b.squared_distance(a)
+from pyspark.testing.mllibutils import MLlibTestCase
+from pyspark.testing.utils import have_scipy
 
 
 class VectorTests(MLlibTestCase):
 
     def _test_serialize(self, v):
+        ser = PickleSerializer()
         self.assertEqual(v, ser.loads(ser.dumps(v)))
         jvec = 
self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
         nv = 
ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
@@ -87,24 +72,30 @@ class VectorTests(MLlibTestCase):
         self.assertEqual(7.0, sv.dot(arr))
 
     def test_squared_distance(self):
+        def squared_distance(a, b):
+            if isinstance(a, Vector):
+                return a.squared_distance(b)
+            else:
+                return b.squared_distance(a)
+
         sv = SparseVector(4, {1: 1, 3: 2})
         dv = DenseVector(array([1., 2., 3., 4.]))
         lst = DenseVector([4, 3, 2, 1])
         lst1 = [4, 3, 2, 1]
         arr = pyarray.array('d', [0, 2, 1, 3])
         narr = array([0, 2, 1, 3])
-        self.assertEqual(15.0, _squared_distance(sv, dv))
-        self.assertEqual(25.0, _squared_distance(sv, lst))
-        self.assertEqual(20.0, _squared_distance(dv, lst))
-        self.assertEqual(15.0, _squared_distance(dv, sv))
-        self.assertEqual(25.0, _squared_distance(lst, sv))
-        self.assertEqual(20.0, _squared_distance(lst, dv))
-        self.assertEqual(0.0, _squared_distance(sv, sv))
-        self.assertEqual(0.0, _squared_distance(dv, dv))
-        self.assertEqual(0.0, _squared_distance(lst, lst))
-        self.assertEqual(25.0, _squared_distance(sv, lst1))
-        self.assertEqual(3.0, _squared_distance(sv, arr))
-        self.assertEqual(3.0, _squared_distance(sv, narr))
+        self.assertEqual(15.0, squared_distance(sv, dv))
+        self.assertEqual(25.0, squared_distance(sv, lst))
+        self.assertEqual(20.0, squared_distance(dv, lst))
+        self.assertEqual(15.0, squared_distance(dv, sv))
+        self.assertEqual(25.0, squared_distance(lst, sv))
+        self.assertEqual(20.0, squared_distance(lst, dv))
+        self.assertEqual(0.0, squared_distance(sv, sv))
+        self.assertEqual(0.0, squared_distance(dv, dv))
+        self.assertEqual(0.0, squared_distance(lst, lst))
+        self.assertEqual(25.0, squared_distance(sv, lst1))
+        self.assertEqual(3.0, squared_distance(sv, arr))
+        self.assertEqual(3.0, squared_distance(sv, narr))
 
     def test_hash(self):
         v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
@@ -466,7 +457,7 @@ class MatrixUDTTests(MLlibTestCase):
                 raise ValueError("Expected a matrix but got type %r" % type(m))
 
 
[email protected](not _have_scipy, "SciPy not installed")
[email protected](not have_scipy, "SciPy not installed")
 class SciPyTests(MLlibTestCase):
 
     """
@@ -476,6 +467,8 @@ class SciPyTests(MLlibTestCase):
 
     def test_serialize(self):
         from scipy.sparse import lil_matrix
+
+        ser = PickleSerializer()
         lil = lil_matrix((4, 1))
         lil[1, 0] = 1
         lil[3, 0] = 2
@@ -621,13 +614,10 @@ class SciPyTests(MLlibTestCase):
 
 if __name__ == "__main__":
     from pyspark.mllib.tests.test_linalg import *
-    if not _have_scipy:
-        print("NOTE: Skipping SciPy tests as it does not seem to be installed")
+
     try:
         import xmlrunner
         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
     except ImportError:
         testRunner = None
     unittest.main(testRunner=testRunner, verbosity=2)
-    if not _have_scipy:
-        print("NOTE: SciPy tests were skipped as it does not seem to be 
installed")

http://git-wip-us.apache.org/repos/asf/spark/blob/518a3d10/python/pyspark/testing/mllibutils.py
----------------------------------------------------------------------
diff --git a/python/pyspark/testing/mllibutils.py 
b/python/pyspark/testing/mllibutils.py
index 25f1bba..c09fb50 100644
--- a/python/pyspark/testing/mllibutils.py
+++ b/python/pyspark/testing/mllibutils.py
@@ -18,14 +18,9 @@
 import unittest
 
 from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer
 from pyspark.sql import SparkSession
 
 
-def make_serializer():
-    return PickleSerializer()
-
-
 class MLlibTestCase(unittest.TestCase):
     def setUp(self):
         self.sc = SparkContext('local[4]', "MLlib tests")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to