Repository: spark
Updated Branches:
  refs/heads/master 9afcf127d -> 1816eb3be


[SPARK-20631][FOLLOW-UP] Fix incorrect tests.

## What changes were proposed in this pull request?

- Fix incorrect tests for `_check_thresholds`.
- Move test to `ParamTests`.

## How was this patch tested?

Unit tests.

Author: zero323 <zero...@users.noreply.github.com>

Closes #18085 from zero323/SPARK-20631-FOLLOW-UP.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1816eb3b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1816eb3b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1816eb3b

Branch: refs/heads/master
Commit: 1816eb3bef930407dc9e083de08f5105725c55d1
Parents: 9afcf12
Author: zero323 <zero...@users.noreply.github.com>
Authored: Wed May 24 19:57:44 2017 +0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed May 24 19:57:44 2017 +0800

----------------------------------------------------------------------
 python/pyspark/ml/tests.py | 24 ++++++++++++------------
 1 file changed, 12 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1816eb3b/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index a3393c6..0daf29d 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -404,6 +404,18 @@ class ParamTests(PySparkTestCase):
         self.assertEqual(tp._paramMap, copied_no_extra)
         self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
 
+    def test_logistic_regression_check_thresholds(self):
+        self.assertIsInstance(
+            LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
+            LogisticRegression
+        )
+
+        self.assertRaisesRegexp(
+            ValueError,
+            "Logistic Regression getThreshold found inconsistent.*$",
+            LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
+        )
+
 
 class EvaluatorTests(SparkSessionTestCase):
 
@@ -807,18 +819,6 @@ class PersistenceTest(SparkSessionTestCase):
         except OSError:
             pass
 
-    def logistic_regression_check_thresholds(self):
-        self.assertIsInstance(
-            LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
-            LogisticRegressionModel
-        )
-
-        self.assertRaisesRegexp(
-            ValueError,
-            "Logistic Regression getThreshold found inconsistent.*$",
-            LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
-        )
-
     def _compare_params(self, m1, m2, param):
         """
         Compare 2 ML Params instances for the given param, and assert both 
have the same param value


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to