This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 6a7a710c4ada [SPARK-51215][ML][PYTHON][CONNECT] Add a helper function 
to invoke helper model attr
6a7a710c4ada is described below

commit 6a7a710c4ada14fd373c6578040363a42fbcc662
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Mon Feb 17 11:03:19 2025 +0800

    [SPARK-51215][ML][PYTHON][CONNECT] Add a helper function to invoke helper 
model attr
    
    ### What changes were proposed in this pull request?
    Add a helper function to invoke helper model attr
    
    ### Why are the changes needed?
    deduplicate code
    
    ### Does this PR introduce _any_ user-facing change?
    no, minor refactor
    
    ### How was this patch tested?
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49951 from zhengruifeng/ml_connect_invoke_help_attr.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
    (cherry picked from commit b3dac8814a61d094c2ed8ec2136018c82ace9fbf)
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/ml/feature.py | 18 ++++++------------
 python/pyspark/ml/util.py    |  7 +++++++
 2 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 6a4a9dc99875..d669fab27d50 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -54,10 +54,9 @@ from pyspark.ml.util import (
     JavaMLReadable,
     JavaMLWritable,
     try_remote_attribute_relation,
-    ML_CONNECT_HELPER_ID,
+    invoke_helper_attr,
 )
 from pyspark.ml.wrapper import (
-    JavaWrapper,
     JavaEstimator,
     JavaModel,
     JavaParams,
@@ -1225,8 +1224,7 @@ class CountVectorizerModel(
 
         if is_remote():
             model = CountVectorizerModel()
-            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            model._java_obj = helper._call_java(
+            model._java_obj = invoke_helper_attr(
                 "countVectorizerModelFromVocabulary",
                 model.uid,
                 list(vocabulary),
@@ -4845,8 +4843,7 @@ class StringIndexerModel(
         """
         if is_remote():
             model = StringIndexerModel()
-            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            model._java_obj = helper._call_java(
+            model._java_obj = invoke_helper_attr(
                 "stringIndexerModelFromLabels",
                 model.uid,
                 (list(labels), ArrayType(StringType())),
@@ -4885,8 +4882,7 @@ class StringIndexerModel(
         """
         if is_remote():
             model = StringIndexerModel()
-            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            model._java_obj = helper._call_java(
+            model._java_obj = invoke_helper_attr(
                 "stringIndexerModelFromLabelsArray",
                 model.uid,
                 (
@@ -5142,8 +5138,7 @@ class StopWordsRemover(
             "org.apache.spark.ml.feature.StopWordsRemover", self.uid
         )
         if is_remote():
-            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            locale = helper._call_java("stopWordsRemoverGetDefaultOrUS")
+            locale = invoke_helper_attr("stopWordsRemoverGetDefaultOrUS")
         else:
             locale = self._java_obj.getLocale()
 
@@ -5274,8 +5269,7 @@ class StopWordsRemover(
         italian, norwegian, portuguese, russian, spanish, swedish, turkish
         """
         if is_remote():
-            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            stopWords = 
helper._call_java("stopWordsRemoverLoadDefaultStopWords", language)
+            stopWords = 
invoke_helper_attr("stopWordsRemoverLoadDefaultStopWords", language)
             return list(stopWords)
 
         else:
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 67921d312d37..4919b828a35c 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -78,6 +78,13 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def invoke_helper_attr(method: str, *args: Any) -> Any:
+    from pyspark.ml.wrapper import JavaWrapper
+
+    helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+    return helper._call_java(method, *args)
+
+
 def invoke_helper_relation(method: str, *args: Any) -> "ConnectDataFrame":
     from pyspark.ml.wrapper import JavaWrapper
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to