This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 492df8926cbf [SPARK-50949][ML][PYTHON][CONNECT] Introduce a helper 
model to support `StringIndexModel.from_labels_xxx`
492df8926cbf is described below

commit 492df8926cbfc65d883cc16a1c8d88f346a86ee2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 10:24:50 2025 +0800

    [SPARK-50949][ML][PYTHON][CONNECT] Introduce a helper model to support 
`StringIndexModel.from_labels_xxx`
    
    ### What changes were proposed in this pull request?
    1, introduce a helper model, with a specified id;
    2, use it to support `StringIndexModel.from_labels_xxx`
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new APIs supported
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49799 from zhengruifeng/mlc_dummy_model.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit c4e200895ecc01fd176e56cb6fbadee139d9e5b2)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       |  3 +
 .../org/apache/spark/ml/util/ConnectHelper.scala   | 43 +++++++++++++
 python/pyspark/ml/feature.py                       | 73 ++++++++++++++++------
 python/pyspark/ml/tests/test_feature.py            | 20 +++++-
 python/pyspark/ml/util.py                          |  3 +
 .../org/apache/spark/sql/connect/ml/MLCache.scala  | 10 ++-
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  9 ++-
 .../apache/spark/sql/connect/ml/Serializer.scala   | 37 +++++++++++
 8 files changed, 174 insertions(+), 24 deletions(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 7f694d7a5b7d..84f3631e5475 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,6 +18,9 @@
 # Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
non-model transformer.
 # So register the supported transformer here if you're trying to add a new one.
 
+########### Helper Model
+org.apache.spark.ml.util.ConnectHelper
+
 ########### Transformers
 org.apache.spark.ml.feature.DCT
 org.apache.spark.ml.feature.NGram
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
new file mode 100644
index 000000000000..dd0781e2752e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.feature.StringIndexerModel
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+private[spark] class ConnectHelper(override val uid: String) extends 
Model[ConnectHelper] {
+  def this() = this(Identifiable.randomUID("ConnectHelper"))
+
+  def stringIndexerModelFromLabels(labels: Array[String]): StringIndexerModel 
= {
+    new StringIndexerModel(labels)
+  }
+
+  def stringIndexerModelFromLabelsArray(labelsArray: Array[Array[String]]): 
StringIndexerModel = {
+    new StringIndexerModel(labelsArray)
+  }
+
+  override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
+
+  override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
+
+  override def transformSchema(schema: StructType): StructType = schema
+
+  override def hasParent: Boolean = false
+}
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 20eadb326b63..173d890a12e5 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -50,8 +50,20 @@ from pyspark.ml.param.shared import (
     Param,
     Params,
 )
-from pyspark.ml.util import JavaMLReadable, JavaMLWritable, 
try_remote_attribute_relation
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, 
JavaTransformer, _jvm
+from pyspark.ml.util import (
+    JavaMLReadable,
+    JavaMLWritable,
+    try_remote_attribute_relation,
+    ML_CONNECT_HELPER_ID,
+)
+from pyspark.ml.wrapper import (
+    JavaWrapper,
+    JavaEstimator,
+    JavaModel,
+    JavaParams,
+    JavaTransformer,
+    _jvm,
+)
 from pyspark.ml.common import inherit_doc
 from pyspark.sql.utils import is_remote
 
@@ -4816,15 +4828,26 @@ class StringIndexerModel(
         Construct the model directly from an array of label strings,
         requires an active SparkContext.
         """
-        from pyspark.core.context import SparkContext
+        if is_remote():
+            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+            model = StringIndexerModel(
+                helper._call_java(
+                    "stringIndexerModelFromLabels",
+                    list(labels),
+                )
+            )
+
+        else:
+            from pyspark.core.context import SparkContext
+
+            sc = SparkContext._active_spark_context
+            assert sc is not None and sc._gateway is not None
+            java_class = getattr(sc._gateway.jvm, "java.lang.String")
+            jlabels = StringIndexerModel._new_java_array(labels, java_class)
+            model = StringIndexerModel._create_from_java_class(
+                "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+            )
 
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._gateway is not None
-        java_class = getattr(sc._gateway.jvm, "java.lang.String")
-        jlabels = StringIndexerModel._new_java_array(labels, java_class)
-        model = StringIndexerModel._create_from_java_class(
-            "org.apache.spark.ml.feature.StringIndexerModel", jlabels
-        )
         model.setInputCol(inputCol)
         if outputCol is not None:
             model.setOutputCol(outputCol)
@@ -4845,15 +4868,25 @@ class StringIndexerModel(
         Construct the model directly from an array of array of label strings,
         requires an active SparkContext.
         """
-        from pyspark.core.context import SparkContext
+        if is_remote():
+            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+            model = StringIndexerModel(
+                helper._call_java(
+                    "stringIndexerModelFromLabelsArray",
+                    [list(labels) for labels in arrayOfLabels],
+                )
+            )
 
-        sc = SparkContext._active_spark_context
-        assert sc is not None and sc._gateway is not None
-        java_class = getattr(sc._gateway.jvm, "java.lang.String")
-        jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
-        model = StringIndexerModel._create_from_java_class(
-            "org.apache.spark.ml.feature.StringIndexerModel", jlabels
-        )
+        else:
+            from pyspark.core.context import SparkContext
+
+            sc = SparkContext._active_spark_context
+            assert sc is not None and sc._gateway is not None
+            java_class = getattr(sc._gateway.jvm, "java.lang.String")
+            jlabels = StringIndexerModel._new_java_array(arrayOfLabels, 
java_class)
+            model = StringIndexerModel._create_from_java_class(
+                "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+            )
         model.setInputCols(inputCols)
         if outputCols is not None:
             model.setOutputCols(outputCols)
@@ -4874,12 +4907,12 @@ class StringIndexerModel(
 
     @property
     @since("3.0.2")
-    def labelsArray(self) -> List[str]:
+    def labelsArray(self) -> List[List[str]]:
         """
         Array of ordered list of labels, corresponding to indices to be 
assigned
         for each input column.
         """
-        return self._call_java("labelsArray")
+        return [list(labels) for labels in self._call_java("labelsArray")]
 
 
 @inherit_doc
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index f42f89b3014f..4298905e452b 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -1439,7 +1439,11 @@ class FeatureTestsMixin:
             ["a", "b", "c"], inputCol="label", outputCol="indexed", 
handleInvalid="keep"
         )
         self.assertEqual(model.labels, ["a", "b", "c"])
-        self.assertEqual(model.labelsArray, [("a", "b", "c")])
+        self.assertEqual(model.labelsArray, [["a", "b", "c"]])
+
+        self.assertEqual(model.getInputCol(), "label")
+        self.assertEqual(model.getOutputCol(), "indexed")
+        self.assertEqual(model.getHandleInvalid(), "keep")
 
         df1 = self.spark.createDataFrame(
             [(0, "a"), (1, "c"), (2, None), (3, "b"), (4, "b")], ["id", 
"label"]
@@ -1481,6 +1485,20 @@ class FeatureTestsMixin:
         )
         self.assertEqual(len(transformed_list), 5)
 
+    def test_string_indexer_from_arrays_of_labels(self):
+        model = StringIndexerModel.from_arrays_of_labels(
+            [["a", "b", "c"], ["x", "y", "z"]],
+            inputCols=["label1", "label2"],
+            outputCols=["indexed1", "indexed2"],
+            handleInvalid="keep",
+        )
+
+        self.assertEqual(model.labelsArray, [["a", "b", "c"], ["x", "y", "z"]])
+
+        self.assertEqual(model.getInputCols(), ["label1", "label2"])
+        self.assertEqual(model.getOutputCols(), ["indexed1", "indexed2"])
+        self.assertEqual(model.getHandleInvalid(), "keep")
+
     def test_target_encoder_binary(self):
         df = self.spark.createDataFrame(
             [
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 7d4c1b9460f8..666ebb0071c7 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -61,6 +61,9 @@ JR = TypeVar("JR", bound="JavaMLReader")
 FuncT = TypeVar("FuncT", bound=Callable[..., Any])
 
 
+ML_CONNECT_HELPER_ID = "______ML_CONNECT_HELPER______"
+
+
 def try_remote_intermediate_result(f: FuncT) -> FuncT:
     """Mark the function/property that returns the intermediate result of the 
remote call.
     Eg, model.summary"""
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index a036f8b67350..beb06065d04a 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -20,11 +20,15 @@ import java.util.UUID
 import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml.util.ConnectHelper
 
 /**
  * MLCache is for caching ML objects, typically for models and summaries 
evaluated by a model.
  */
 private[connect] class MLCache extends Logging {
+  private val helper = new ConnectHelper()
+  private val helperID = "______ML_CONNECT_HELPER______"
+
   private val cachedModel: ConcurrentHashMap[String, Object] =
     new ConcurrentHashMap[String, Object]()
 
@@ -49,7 +53,11 @@ private[connect] class MLCache extends Logging {
    *   the cached object
    */
   def get(refId: String): Object = {
-    cachedModel.get(refId)
+    if (refId == helperID) {
+      helper
+    } else {
+      cachedModel.get(refId)
+    }
   }
 
   /**
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 86f5879aa99d..4418dee68ed1 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -36,7 +36,7 @@ import org.apache.spark.ml.param.Params
 import org.apache.spark.ml.recommendation._
 import org.apache.spark.ml.regression._
 import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.{HasTrainingSummary, Identifiable, MLWritable}
+import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary, 
Identifiable, MLWritable}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.classic.Dataset
 import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
@@ -650,7 +650,12 @@ private[ml] object MLUtils {
     (classOf[OneHotEncoderModel], Set("categorySizes")),
     (classOf[StringIndexerModel], Set("labels", "labelsArray")),
     (classOf[RFormulaModel], Set("resolvedFormulaString")),
-    (classOf[IDFModel], Set("idf", "docFreq", "numDocs")))
+    (classOf[IDFModel], Set("idf", "docFreq", "numDocs")),
+
+    // Utils
+    (
+      classOf[ConnectHelper],
+      Set("stringIndexerModelFromLabels", 
"stringIndexerModelFromLabelsArray")))
 
   private def validate(obj: Any, method: String): Unit = {
     assert(obj != null)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
index df3e97398012..2bbc0b258cad 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
@@ -164,6 +164,22 @@ private[ml] object Serializer {
             (literal.getDouble.asInstanceOf[Object], classOf[Double])
           case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
             (literal.getBoolean.asInstanceOf[Object], classOf[Boolean])
+          case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+            val array = literal.getArray
+            array.getElementType.getKindCase match {
+              case proto.DataType.KindCase.STRING =>
+                (parseStringArray(array), classOf[Array[String]])
+              case proto.DataType.KindCase.ARRAY =>
+                array.getElementType.getArray.getElementType.getKindCase match 
{
+                  case proto.DataType.KindCase.STRING =>
+                    (parseStringArrayArray(array), 
classOf[Array[Array[String]]])
+                  case _ =>
+                    throw MlUnsupportedException(s"Unsupported inner array 
$array")
+                }
+              case _ =>
+                throw MlUnsupportedException(s"Unsupported array $literal")
+            }
+
           case other =>
             throw MlUnsupportedException(s"$other not supported")
         }
@@ -175,6 +191,27 @@ private[ml] object Serializer {
     }
   }
 
+  private def parseStringArray(array: proto.Expression.Literal.Array): 
Array[String] = {
+    val values = new Array[String](array.getElementsCount)
+    var i = 0
+    while (i < array.getElementsCount) {
+      values(i) = array.getElements(i).getString
+      i += 1
+    }
+    values
+  }
+
+  private def parseStringArrayArray(
+      array: proto.Expression.Literal.Array): Array[Array[String]] = {
+    val values = new Array[Array[String]](array.getElementsCount)
+    var i = 0
+    while (i < array.getElementsCount) {
+      values(i) = parseStringArray(array.getElements(i).getArray)
+      i += 1
+    }
+    values
+  }
+
   /**
    * Serialize an instance of "Params" which could be 
estimator/model/evaluator ...
    * @param instance


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to