Repository: spark
Updated Branches:
  refs/heads/master 508573958 -> 7706eea6a


[SPARK-18630][PYTHON][ML] Move del method from JavaParams to JavaWrapper; add 
tests

The `__del__` method that explicitly detaches the object was moved from 
`JavaParams` to `JavaWrapper` class, this way model summaries could also be 
garbage collected in Java. A test case was added to make sure that relevant 
error messages are thrown after the objects are deleted.

I ran pyspark tests  agains `pyspark-ml` module
`./python/run-tests --python-executables=$(which python) --modules=pyspark-ml`

Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>

Closes #20724 from yogeshg/java_wrapper_memory.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7706eea6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7706eea6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7706eea6

Branch: refs/heads/master
Commit: 7706eea6a8bdcd73e9dde5212368f8825e2f1801
Parents: 5085739
Author: Yogesh Garg <yogesh(dot)garg()databricks(dot)com>
Authored: Mon Mar 5 15:53:10 2018 -0800
Committer: Joseph K. Bradley <[email protected]>
Committed: Mon Mar 5 15:53:10 2018 -0800

----------------------------------------------------------------------
 python/pyspark/ml/tests.py   | 39 +++++++++++++++++++++++++++++++++++++++
 python/pyspark/ml/wrapper.py |  8 ++++----
 2 files changed, 43 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7706eea6/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 1168859..6dee693 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -173,6 +173,45 @@ class MockModel(MockTransformer, Model, HasFake):
     pass
 
 
+class JavaWrapperMemoryTests(SparkSessionTestCase):
+
+    def test_java_object_gets_detached(self):
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+                                         (0.0, 2.0, Vectors.sparse(1, [], 
[]))],
+                                        ["label", "weight", "features"])
+        lr = LinearRegression(maxIter=1, regParam=0.0, solver="normal", 
weightCol="weight",
+                              fitIntercept=False)
+
+        model = lr.fit(df)
+        summary = model.summary
+
+        self.assertIsInstance(model, JavaWrapper)
+        self.assertIsInstance(summary, JavaWrapper)
+        self.assertIsInstance(model, JavaParams)
+        self.assertNotIsInstance(summary, JavaParams)
+
+        error_no_object = 'Target Object ID does not exist for this gateway'
+
+        self.assertIn("LinearRegression_", model._java_obj.toString())
+        self.assertIn("LinearRegressionTrainingSummary", 
summary._java_obj.toString())
+
+        model.__del__()
+
+        with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+            model._java_obj.toString()
+        self.assertIn("LinearRegressionTrainingSummary", 
summary._java_obj.toString())
+
+        try:
+            summary.__del__()
+        except:
+            pass
+
+        with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+            model._java_obj.toString()
+        with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
+            summary._java_obj.toString()
+
+
 class ParamTypeConversionTests(PySparkTestCase):
     """
     Test that param type conversion happens.

http://git-wip-us.apache.org/repos/asf/spark/blob/7706eea6/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 0f846fb..5061f64 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -36,6 +36,10 @@ class JavaWrapper(object):
         super(JavaWrapper, self).__init__()
         self._java_obj = java_obj
 
+    def __del__(self):
+        if SparkContext._active_spark_context and self._java_obj is not None:
+            SparkContext._active_spark_context._gateway.detach(self._java_obj)
+
     @classmethod
     def _create_from_java_class(cls, java_class, *args):
         """
@@ -100,10 +104,6 @@ class JavaParams(JavaWrapper, Params):
 
     __metaclass__ = ABCMeta
 
-    def __del__(self):
-        if SparkContext._active_spark_context:
-            SparkContext._active_spark_context._gateway.detach(self._java_obj)
-
     def _make_java_param_pair(self, param, value):
         """
         Makes a Java param pair.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to