This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c4e200895ecc [SPARK-50949][ML][PYTHON][CONNECT] Introduce a helper
model to support `StringIndexModel.from_labels_xxx`
c4e200895ecc is described below
commit c4e200895ecc01fd176e56cb6fbadee139d9e5b2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 10:24:50 2025 +0800
[SPARK-50949][ML][PYTHON][CONNECT] Introduce a helper model to support
`StringIndexModel.from_labels_xxx`
### What changes were proposed in this pull request?
1, introduce a helper model, with a specified id;
2, use it to support `StringIndexModel.from_labels_xxx`
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
yes, new APIs supported
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49799 from zhengruifeng/mlc_dummy_model.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Transformer | 3 +
.../org/apache/spark/ml/util/ConnectHelper.scala | 43 +++++++++++++
python/pyspark/ml/feature.py | 73 ++++++++++++++++------
python/pyspark/ml/tests/test_feature.py | 20 +++++-
python/pyspark/ml/util.py | 3 +
.../org/apache/spark/sql/connect/ml/MLCache.scala | 10 ++-
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 9 ++-
.../apache/spark/sql/connect/ml/Serializer.scala | 37 +++++++++++
8 files changed, 174 insertions(+), 24 deletions(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 7f694d7a5b7d..84f3631e5475 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,6 +18,9 @@
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
+########### Helper Model
+org.apache.spark.ml.util.ConnectHelper
+
########### Transformers
org.apache.spark.ml.feature.DCT
org.apache.spark.ml.feature.NGram
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
new file mode 100644
index 000000000000..dd0781e2752e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.feature.StringIndexerModel
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+private[spark] class ConnectHelper(override val uid: String) extends
Model[ConnectHelper] {
+ def this() = this(Identifiable.randomUID("ConnectHelper"))
+
+ def stringIndexerModelFromLabels(labels: Array[String]): StringIndexerModel
= {
+ new StringIndexerModel(labels)
+ }
+
+ def stringIndexerModelFromLabelsArray(labelsArray: Array[Array[String]]):
StringIndexerModel = {
+ new StringIndexerModel(labelsArray)
+ }
+
+ override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
+
+ override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
+
+ override def transformSchema(schema: StructType): StructType = schema
+
+ override def hasParent: Boolean = false
+}
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 20eadb326b63..173d890a12e5 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -50,8 +50,20 @@ from pyspark.ml.param.shared import (
Param,
Params,
)
-from pyspark.ml.util import JavaMLReadable, JavaMLWritable,
try_remote_attribute_relation
-from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams,
JavaTransformer, _jvm
+from pyspark.ml.util import (
+ JavaMLReadable,
+ JavaMLWritable,
+ try_remote_attribute_relation,
+ ML_CONNECT_HELPER_ID,
+)
+from pyspark.ml.wrapper import (
+ JavaWrapper,
+ JavaEstimator,
+ JavaModel,
+ JavaParams,
+ JavaTransformer,
+ _jvm,
+)
from pyspark.ml.common import inherit_doc
from pyspark.sql.utils import is_remote
@@ -4816,15 +4828,26 @@ class StringIndexerModel(
Construct the model directly from an array of label strings,
requires an active SparkContext.
"""
- from pyspark.core.context import SparkContext
+ if is_remote():
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ model = StringIndexerModel(
+ helper._call_java(
+ "stringIndexerModelFromLabels",
+ list(labels),
+ )
+ )
+
+ else:
+ from pyspark.core.context import SparkContext
+
+ sc = SparkContext._active_spark_context
+ assert sc is not None and sc._gateway is not None
+ java_class = getattr(sc._gateway.jvm, "java.lang.String")
+ jlabels = StringIndexerModel._new_java_array(labels, java_class)
+ model = StringIndexerModel._create_from_java_class(
+ "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+ )
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._gateway is not None
- java_class = getattr(sc._gateway.jvm, "java.lang.String")
- jlabels = StringIndexerModel._new_java_array(labels, java_class)
- model = StringIndexerModel._create_from_java_class(
- "org.apache.spark.ml.feature.StringIndexerModel", jlabels
- )
model.setInputCol(inputCol)
if outputCol is not None:
model.setOutputCol(outputCol)
@@ -4845,15 +4868,25 @@ class StringIndexerModel(
Construct the model directly from an array of array of label strings,
requires an active SparkContext.
"""
- from pyspark.core.context import SparkContext
+ if is_remote():
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ model = StringIndexerModel(
+ helper._call_java(
+ "stringIndexerModelFromLabelsArray",
+ [list(labels) for labels in arrayOfLabels],
+ )
+ )
- sc = SparkContext._active_spark_context
- assert sc is not None and sc._gateway is not None
- java_class = getattr(sc._gateway.jvm, "java.lang.String")
- jlabels = StringIndexerModel._new_java_array(arrayOfLabels, java_class)
- model = StringIndexerModel._create_from_java_class(
- "org.apache.spark.ml.feature.StringIndexerModel", jlabels
- )
+ else:
+ from pyspark.core.context import SparkContext
+
+ sc = SparkContext._active_spark_context
+ assert sc is not None and sc._gateway is not None
+ java_class = getattr(sc._gateway.jvm, "java.lang.String")
+ jlabels = StringIndexerModel._new_java_array(arrayOfLabels,
java_class)
+ model = StringIndexerModel._create_from_java_class(
+ "org.apache.spark.ml.feature.StringIndexerModel", jlabels
+ )
model.setInputCols(inputCols)
if outputCols is not None:
model.setOutputCols(outputCols)
@@ -4874,12 +4907,12 @@ class StringIndexerModel(
@property
@since("3.0.2")
- def labelsArray(self) -> List[str]:
+ def labelsArray(self) -> List[List[str]]:
"""
Array of ordered list of labels, corresponding to indices to be
assigned
for each input column.
"""
- return self._call_java("labelsArray")
+ return [list(labels) for labels in self._call_java("labelsArray")]
@inherit_doc
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index f42f89b3014f..4298905e452b 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -1439,7 +1439,11 @@ class FeatureTestsMixin:
["a", "b", "c"], inputCol="label", outputCol="indexed",
handleInvalid="keep"
)
self.assertEqual(model.labels, ["a", "b", "c"])
- self.assertEqual(model.labelsArray, [("a", "b", "c")])
+ self.assertEqual(model.labelsArray, [["a", "b", "c"]])
+
+ self.assertEqual(model.getInputCol(), "label")
+ self.assertEqual(model.getOutputCol(), "indexed")
+ self.assertEqual(model.getHandleInvalid(), "keep")
df1 = self.spark.createDataFrame(
[(0, "a"), (1, "c"), (2, None), (3, "b"), (4, "b")], ["id",
"label"]
@@ -1481,6 +1485,20 @@ class FeatureTestsMixin:
)
self.assertEqual(len(transformed_list), 5)
+ def test_string_indexer_from_arrays_of_labels(self):
+ model = StringIndexerModel.from_arrays_of_labels(
+ [["a", "b", "c"], ["x", "y", "z"]],
+ inputCols=["label1", "label2"],
+ outputCols=["indexed1", "indexed2"],
+ handleInvalid="keep",
+ )
+
+ self.assertEqual(model.labelsArray, [["a", "b", "c"], ["x", "y", "z"]])
+
+ self.assertEqual(model.getInputCols(), ["label1", "label2"])
+ self.assertEqual(model.getOutputCols(), ["indexed1", "indexed2"])
+ self.assertEqual(model.getHandleInvalid(), "keep")
+
def test_target_encoder_binary(self):
df = self.spark.createDataFrame(
[
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 7d4c1b9460f8..666ebb0071c7 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -61,6 +61,9 @@ JR = TypeVar("JR", bound="JavaMLReader")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
+ML_CONNECT_HELPER_ID = "______ML_CONNECT_HELPER______"
+
+
def try_remote_intermediate_result(f: FuncT) -> FuncT:
"""Mark the function/property that returns the intermediate result of the
remote call.
Eg, model.summary"""
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index a036f8b67350..beb06065d04a 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -20,11 +20,15 @@ import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.util.ConnectHelper
/**
* MLCache is for caching ML objects, typically for models and summaries
evaluated by a model.
*/
private[connect] class MLCache extends Logging {
+ private val helper = new ConnectHelper()
+ private val helperID = "______ML_CONNECT_HELPER______"
+
private val cachedModel: ConcurrentHashMap[String, Object] =
new ConcurrentHashMap[String, Object]()
@@ -49,7 +53,11 @@ private[connect] class MLCache extends Logging {
* the cached object
*/
def get(refId: String): Object = {
- cachedModel.get(refId)
+ if (refId == helperID) {
+ helper
+ } else {
+ cachedModel.get(refId)
+ }
}
/**
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 86f5879aa99d..4418dee68ed1 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -36,7 +36,7 @@ import org.apache.spark.ml.param.Params
import org.apache.spark.ml.recommendation._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
-import org.apache.spark.ml.util.{HasTrainingSummary, Identifiable, MLWritable}
+import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary,
Identifiable, MLWritable}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
@@ -650,7 +650,12 @@ private[ml] object MLUtils {
(classOf[OneHotEncoderModel], Set("categorySizes")),
(classOf[StringIndexerModel], Set("labels", "labelsArray")),
(classOf[RFormulaModel], Set("resolvedFormulaString")),
- (classOf[IDFModel], Set("idf", "docFreq", "numDocs")))
+ (classOf[IDFModel], Set("idf", "docFreq", "numDocs")),
+
+ // Utils
+ (
+ classOf[ConnectHelper],
+ Set("stringIndexerModelFromLabels",
"stringIndexerModelFromLabelsArray")))
private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
index df3e97398012..2bbc0b258cad 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
@@ -164,6 +164,22 @@ private[ml] object Serializer {
(literal.getDouble.asInstanceOf[Object], classOf[Double])
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
(literal.getBoolean.asInstanceOf[Object], classOf[Boolean])
+ case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
+ val array = literal.getArray
+ array.getElementType.getKindCase match {
+ case proto.DataType.KindCase.STRING =>
+ (parseStringArray(array), classOf[Array[String]])
+ case proto.DataType.KindCase.ARRAY =>
+ array.getElementType.getArray.getElementType.getKindCase match
{
+ case proto.DataType.KindCase.STRING =>
+ (parseStringArrayArray(array),
classOf[Array[Array[String]]])
+ case _ =>
+ throw MlUnsupportedException(s"Unsupported inner array
$array")
+ }
+ case _ =>
+ throw MlUnsupportedException(s"Unsupported array $literal")
+ }
+
case other =>
throw MlUnsupportedException(s"$other not supported")
}
@@ -175,6 +191,27 @@ private[ml] object Serializer {
}
}
+ private def parseStringArray(array: proto.Expression.Literal.Array):
Array[String] = {
+ val values = new Array[String](array.getElementsCount)
+ var i = 0
+ while (i < array.getElementsCount) {
+ values(i) = array.getElements(i).getString
+ i += 1
+ }
+ values
+ }
+
+ private def parseStringArrayArray(
+ array: proto.Expression.Literal.Array): Array[Array[String]] = {
+ val values = new Array[Array[String]](array.getElementsCount)
+ var i = 0
+ while (i < array.getElementsCount) {
+ values(i) = parseStringArray(array.getElements(i).getArray)
+ i += 1
+ }
+ values
+ }
+
/**
* Serialize an instance of "Params" which could be
estimator/model/evaluator ...
* @param instance
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]