This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 906c646a9473 [SPARK-54741][ML][CONNECT][TESTS] Restore 
ClusteringParityTests
906c646a9473 is described below

commit 906c646a9473e69119e03c55e08b751fa7aba6d4
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Dec 18 09:39:53 2025 +0800

    [SPARK-54741][ML][CONNECT][TESTS] Restore ClusteringParityTests
    
    ### What changes were proposed in this pull request?
    Restore ClusteringParityTests
    
    ### Why are the changes needed?
    to recover test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53514 from zhengruifeng/restore_clu.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/ml/clustering.py                        |  7 +++++--
 .../pyspark/ml/tests/connect/test_parity_clustering.py |  2 --
 python/pyspark/ml/tests/test_clustering.py             | 18 ------------------
 3 files changed, 5 insertions(+), 22 deletions(-)

diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 0fc2b34d1748..f6543e707680 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -1542,9 +1542,12 @@ class DistributedLDAModel(LDAModel, 
JavaMLReadable["DistributedLDAModel"], JavaM
 
         .. warning:: This involves collecting a large :py:func:`topicsMatrix` 
to the driver.
         """
-        model = LocalLDAModel(self._call_java("toLocal"))
         if is_remote():
-            return model
+            from pyspark.ml.util import RemoteModelRef
+
+            return LocalLDAModel(RemoteModelRef(self._call_java("toLocal")))
+
+        model = LocalLDAModel(self._call_java("toLocal"))
 
         # SPARK-10931: Temporary fix to be removed once LDAModel defines Params
         model._create_params_from_java()
diff --git a/python/pyspark/ml/tests/connect/test_parity_clustering.py 
b/python/pyspark/ml/tests/connect/test_parity_clustering.py
index bbfd2a2aea80..99714b0d6962 100644
--- a/python/pyspark/ml/tests/connect/test_parity_clustering.py
+++ b/python/pyspark/ml/tests/connect/test_parity_clustering.py
@@ -21,8 +21,6 @@ from pyspark.ml.tests.test_clustering import 
ClusteringTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-# TODO(SPARK-52764): Re-enable this test after fixing the flakiness.
[email protected]("Disabled due to flakiness, should be enabled after fixing the 
issue")
 class ClusteringParityTests(ClusteringTestsMixin, ReusedConnectTestCase):
     pass
 
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index d624b6398881..c1ec03b5ecc2 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -37,7 +37,6 @@ from pyspark.ml.clustering import (
     DistributedLDAModel,
     PowerIterationClustering,
 )
-from pyspark.sql import is_remote
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -107,18 +106,6 @@ class ClusteringTestsMixin:
         # check summary before model offloading occurs
         check_summary()
 
-        if is_remote():
-            self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
-            # check summary "try_remote_call" path after model offloading 
occurs
-            self.assertEqual(model.summary.numIter, 2)
-
-            self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
-            # check summary "invoke_remote_attribute_relation" path after 
model offloading occurs
-            self.assertEqual(model.summary.cluster.count(), 6)
-
-            self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
-            check_summary()
-
         # save & load
         with tempfile.TemporaryDirectory(prefix="kmeans_model") as d:
             km.write().overwrite().save(d)
@@ -323,11 +310,6 @@ class ClusteringTestsMixin:
             self.assertEqual(summary.probability.columns, ["probability"])
             self.assertEqual(summary.predictions.count(), 6)
 
-        check_summary()
-        if is_remote():
-            self.spark.client._delete_ml_cache([model._java_obj._ref_id], 
evict_only=True)
-            check_summary()
-
         # save & load
         with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d:
             gmm.write().overwrite().save(d)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to