This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new caf358b46718 [SPARK-50949][ML][PYTHON][CONNECT][FOLLOW-UP] Fix uid 
issue and empty array args
caf358b46718 is described below

commit caf358b46718cdad436164e5113803d826e3641f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 13:19:42 2025 +0800

    [SPARK-50949][ML][PYTHON][CONNECT][FOLLOW-UP] Fix uid issue and empty array 
args
    
    ### What changes were proposed in this pull request?
    1, pass the uid
    2, fix empty array args
    
    ### Why are the changes needed?
    to enable parity test
    
    ### Does this PR introduce _any_ user-facing change?
    yes, bug fix
    
    ### How was this patch tested?
    enabled tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49802 from zhengruifeng/ml_connect_fix_labels.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 4b08932a4fbbb72e9244b3b831bcb1fe4cc48df3)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/ml/util/ConnectHelper.scala   | 10 ++++++----
 python/pyspark/ml/connect/serialize.py             |  9 +++++++++
 python/pyspark/ml/feature.py                       | 22 +++++++++++++---------
 .../ml/tests/connect/test_parity_feature.py        |  4 ----
 4 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index dd0781e2752e..3fed85e4a00a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -25,12 +25,14 @@ import org.apache.spark.sql.types.StructType
 private[spark] class ConnectHelper(override val uid: String) extends 
Model[ConnectHelper] {
   def this() = this(Identifiable.randomUID("ConnectHelper"))
 
-  def stringIndexerModelFromLabels(labels: Array[String]): StringIndexerModel 
= {
-    new StringIndexerModel(labels)
+  def stringIndexerModelFromLabels(
+      uid: String, labels: Array[String]): StringIndexerModel = {
+    new StringIndexerModel(uid, labels)
   }
 
-  def stringIndexerModelFromLabelsArray(labelsArray: Array[Array[String]]): 
StringIndexerModel = {
-    new StringIndexerModel(labelsArray)
+  def stringIndexerModelFromLabelsArray(
+      uid: String, labelsArray: Array[Array[String]]): StringIndexerModel = {
+    new StringIndexerModel(uid, labelsArray)
   }
 
   override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
diff --git a/python/pyspark/ml/connect/serialize.py 
b/python/pyspark/ml/connect/serialize.py
index 417f57cc9a71..42bedfb330b1 100644
--- a/python/pyspark/ml/connect/serialize.py
+++ b/python/pyspark/ml/connect/serialize.py
@@ -17,6 +17,7 @@
 from typing import Any, List, TYPE_CHECKING, Mapping, Dict
 
 import pyspark.sql.connect.proto as pb2
+from pyspark.sql.types import DataType
 from pyspark.ml.linalg import (
     DenseVector,
     SparseVector,
@@ -131,11 +132,19 @@ def serialize_param(value: Any, client: 
"SparkConnectClient") -> pb2.Expression.
 
 def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
     from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+    from pyspark.sql.connect.expressions import LiteralExpression
 
     result = []
     for arg in args:
         if isinstance(arg, ConnectDataFrame):
             result.append(pb2.Fetch.Method.Args(input=arg._plan.plan(client)))
+        elif isinstance(arg, tuple) and len(arg) == 2 and isinstance(arg[1], 
DataType):
+            # explicitly specify the data type, for cases like empty list[str]
+            result.append(
+                pb2.Fetch.Method.Args(
+                    param=LiteralExpression(value=arg[0], 
dataType=arg[1]).to_plan(client).literal
+                )
+            )
         else:
             result.append(pb2.Fetch.Method.Args(param=serialize_param(arg, 
client)))
     return result
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 173d890a12e5..fb9c96bd6114 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -65,6 +65,7 @@ from pyspark.ml.wrapper import (
     _jvm,
 )
 from pyspark.ml.common import inherit_doc
+from pyspark.sql.types import ArrayType, StringType
 from pyspark.sql.utils import is_remote
 
 if TYPE_CHECKING:
@@ -4829,12 +4830,12 @@ class StringIndexerModel(
         requires an active SparkContext.
         """
         if is_remote():
+            model = StringIndexerModel()
             helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            model = StringIndexerModel(
-                helper._call_java(
-                    "stringIndexerModelFromLabels",
-                    list(labels),
-                )
+            model._java_obj = helper._call_java(
+                "stringIndexerModelFromLabels",
+                model.uid,
+                (list(labels), ArrayType(StringType(), False)),
             )
 
         else:
@@ -4869,12 +4870,15 @@ class StringIndexerModel(
         requires an active SparkContext.
         """
         if is_remote():
+            model = StringIndexerModel()
             helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
-            model = StringIndexerModel(
-                helper._call_java(
-                    "stringIndexerModelFromLabelsArray",
+            model._java_obj = helper._call_java(
+                "stringIndexerModelFromLabelsArray",
+                model.uid,
+                (
                     [list(labels) for labels in arrayOfLabels],
-                )
+                    ArrayType(ArrayType(StringType(), False)),
+                ),
             )
 
         else:
diff --git a/python/pyspark/ml/tests/connect/test_parity_feature.py 
b/python/pyspark/ml/tests/connect/test_parity_feature.py
index baa3e6e7e0df..2c19ef24465b 100644
--- a/python/pyspark/ml/tests/connect/test_parity_feature.py
+++ b/python/pyspark/ml/tests/connect/test_parity_feature.py
@@ -26,10 +26,6 @@ class FeatureParityTests(FeatureTestsMixin, 
ReusedConnectTestCase):
     def test_count_vectorizer_from_vocab(self):
         super().test_count_vectorizer_from_vocab()
 
-    @unittest.skip("Need to support.")
-    def test_string_indexer_from_labels(self):
-        super().test_string_indexer_from_labels()
-
     @unittest.skip("Need to support.")
     def test_stop_words_lengague_selection(self):
         super().test_stop_words_lengague_selection()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to