Repository: spark Updated Branches: refs/heads/master 9afcf127d -> 1816eb3be
[SPARK-20631][FOLLOW-UP] Fix incorrect tests. ## What changes were proposed in this pull request? - Fix incorrect tests for `_check_thresholds`. - Move test to `ParamTests`. ## How was this patch tested? Unit tests. Author: zero323 <zero...@users.noreply.github.com> Closes #18085 from zero323/SPARK-20631-FOLLOW-UP. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1816eb3b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1816eb3b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1816eb3b Branch: refs/heads/master Commit: 1816eb3bef930407dc9e083de08f5105725c55d1 Parents: 9afcf12 Author: zero323 <zero...@users.noreply.github.com> Authored: Wed May 24 19:57:44 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Wed May 24 19:57:44 2017 +0800 ---------------------------------------------------------------------- python/pyspark/ml/tests.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1816eb3b/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a3393c6..0daf29d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -404,6 +404,18 @@ class ParamTests(PySparkTestCase): self.assertEqual(tp._paramMap, copied_no_extra) self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap) + def test_logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegression + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + class EvaluatorTests(SparkSessionTestCase): @@ -807,18 +819,6 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass - def logistic_regression_check_thresholds(self): - self.assertIsInstance( - LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), - LogisticRegressionModel - ) - - self.assertRaisesRegexp( - ValueError, - "Logistic Regression getThreshold found inconsistent.*$", - LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] - ) - def _compare_params(self, m1, m2, param): """ Compare 2 ML Params instances for the given param, and assert both have the same param value --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org