Repository: spark Updated Branches: refs/heads/master 508573958 -> 7706eea6a
[SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add tests The `__del__` method that explicitly detaches the object was moved from `JavaParams` to `JavaWrapper` class, this way model summaries could also be garbage collected in Java. A test case was added to make sure that relevant error messages are thrown after the objects are deleted. I ran pyspark tests agains `pyspark-ml` module `./python/run-tests --python-executables=$(which python) --modules=pyspark-ml` Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Closes #20724 from yogeshg/java_wrapper_memory. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7706eea6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7706eea6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7706eea6 Branch: refs/heads/master Commit: 7706eea6a8bdcd73e9dde5212368f8825e2f1801 Parents: 5085739 Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com> Authored: Mon Mar 5 15:53:10 2018 -0800 Committer: Joseph K. Bradley <[email protected]> Committed: Mon Mar 5 15:53:10 2018 -0800 ---------------------------------------------------------------------- python/pyspark/ml/tests.py | 39 +++++++++++++++++++++++++++++++++++++++ python/pyspark/ml/wrapper.py | 8 ++++---- 2 files changed, 43 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7706eea6/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1168859..6dee693 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake): pass +class JavaWrapperMemoryTests(SparkSessionTestCase): + + def test_java_object_gets_detached(self): + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + + model = lr.fit(df) + summary = model.summary + + self.assertIsInstance(model, JavaWrapper) + self.assertIsInstance(summary, JavaWrapper) + self.assertIsInstance(model, JavaParams) + self.assertNotIsInstance(summary, JavaParams) + + error_no_object = 'Target Object ID does not exist for this gateway' + + self.assertIn("LinearRegression_", model._java_obj.toString()) + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + model.__del__() + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + + try: + summary.__del__() + except: + pass + + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + + class ParamTypeConversionTests(PySparkTestCase): """ Test that param type conversion happens. http://git-wip-us.apache.org/repos/asf/spark/blob/7706eea6/python/pyspark/ml/wrapper.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0f846fb..5061f64 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -36,6 +36,10 @@ class JavaWrapper(object): super(JavaWrapper, self).__init__() self._java_obj = java_obj + def __del__(self): + if SparkContext._active_spark_context and self._java_obj is not None: + SparkContext._active_spark_context._gateway.detach(self._java_obj) + @classmethod def _create_from_java_class(cls, java_class, *args): """ @@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params): __metaclass__ = ABCMeta - def __del__(self): - if SparkContext._active_spark_context: - SparkContext._active_spark_context._gateway.detach(self._java_obj) - def _make_java_param_pair(self, param, value): """ Makes a Java param pair. --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
