This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new caf358b46718 [SPARK-50949][ML][PYTHON][CONNECT][FOLLOW-UP] Fix uid
issue and empty array args
caf358b46718 is described below
commit caf358b46718cdad436164e5113803d826e3641f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 13:19:42 2025 +0800
[SPARK-50949][ML][PYTHON][CONNECT][FOLLOW-UP] Fix uid issue and empty array
args
### What changes were proposed in this pull request?
1, pass the uid
2, fix empty array args
### Why are the changes needed?
to enable parity test
### Does this PR introduce _any_ user-facing change?
yes, bug fix
### How was this patch tested?
enabled tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49802 from zhengruifeng/ml_connect_fix_labels.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 4b08932a4fbbb72e9244b3b831bcb1fe4cc48df3)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../org/apache/spark/ml/util/ConnectHelper.scala | 10 ++++++----
python/pyspark/ml/connect/serialize.py | 9 +++++++++
python/pyspark/ml/feature.py | 22 +++++++++++++---------
.../ml/tests/connect/test_parity_feature.py | 4 ----
4 files changed, 28 insertions(+), 17 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index dd0781e2752e..3fed85e4a00a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -25,12 +25,14 @@ import org.apache.spark.sql.types.StructType
private[spark] class ConnectHelper(override val uid: String) extends
Model[ConnectHelper] {
def this() = this(Identifiable.randomUID("ConnectHelper"))
- def stringIndexerModelFromLabels(labels: Array[String]): StringIndexerModel
= {
- new StringIndexerModel(labels)
+ def stringIndexerModelFromLabels(
+ uid: String, labels: Array[String]): StringIndexerModel = {
+ new StringIndexerModel(uid, labels)
}
- def stringIndexerModelFromLabelsArray(labelsArray: Array[Array[String]]):
StringIndexerModel = {
- new StringIndexerModel(labelsArray)
+ def stringIndexerModelFromLabelsArray(
+ uid: String, labelsArray: Array[Array[String]]): StringIndexerModel = {
+ new StringIndexerModel(uid, labelsArray)
}
override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
diff --git a/python/pyspark/ml/connect/serialize.py
b/python/pyspark/ml/connect/serialize.py
index 417f57cc9a71..42bedfb330b1 100644
--- a/python/pyspark/ml/connect/serialize.py
+++ b/python/pyspark/ml/connect/serialize.py
@@ -17,6 +17,7 @@
from typing import Any, List, TYPE_CHECKING, Mapping, Dict
import pyspark.sql.connect.proto as pb2
+from pyspark.sql.types import DataType
from pyspark.ml.linalg import (
DenseVector,
SparseVector,
@@ -131,11 +132,19 @@ def serialize_param(value: Any, client:
"SparkConnectClient") -> pb2.Expression.
def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+ from pyspark.sql.connect.expressions import LiteralExpression
result = []
for arg in args:
if isinstance(arg, ConnectDataFrame):
result.append(pb2.Fetch.Method.Args(input=arg._plan.plan(client)))
+ elif isinstance(arg, tuple) and len(arg) == 2 and isinstance(arg[1],
DataType):
+ # explicitly specify the data type, for cases like empty list[str]
+ result.append(
+ pb2.Fetch.Method.Args(
+ param=LiteralExpression(value=arg[0],
dataType=arg[1]).to_plan(client).literal
+ )
+ )
else:
result.append(pb2.Fetch.Method.Args(param=serialize_param(arg,
client)))
return result
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 173d890a12e5..fb9c96bd6114 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -65,6 +65,7 @@ from pyspark.ml.wrapper import (
_jvm,
)
from pyspark.ml.common import inherit_doc
+from pyspark.sql.types import ArrayType, StringType
from pyspark.sql.utils import is_remote
if TYPE_CHECKING:
@@ -4829,12 +4830,12 @@ class StringIndexerModel(
requires an active SparkContext.
"""
if is_remote():
+ model = StringIndexerModel()
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
- model = StringIndexerModel(
- helper._call_java(
- "stringIndexerModelFromLabels",
- list(labels),
- )
+ model._java_obj = helper._call_java(
+ "stringIndexerModelFromLabels",
+ model.uid,
+ (list(labels), ArrayType(StringType(), False)),
)
else:
@@ -4869,12 +4870,15 @@ class StringIndexerModel(
requires an active SparkContext.
"""
if is_remote():
+ model = StringIndexerModel()
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
- model = StringIndexerModel(
- helper._call_java(
- "stringIndexerModelFromLabelsArray",
+ model._java_obj = helper._call_java(
+ "stringIndexerModelFromLabelsArray",
+ model.uid,
+ (
[list(labels) for labels in arrayOfLabels],
- )
+ ArrayType(ArrayType(StringType(), False)),
+ ),
)
else:
diff --git a/python/pyspark/ml/tests/connect/test_parity_feature.py
b/python/pyspark/ml/tests/connect/test_parity_feature.py
index baa3e6e7e0df..2c19ef24465b 100644
--- a/python/pyspark/ml/tests/connect/test_parity_feature.py
+++ b/python/pyspark/ml/tests/connect/test_parity_feature.py
@@ -26,10 +26,6 @@ class FeatureParityTests(FeatureTestsMixin,
ReusedConnectTestCase):
def test_count_vectorizer_from_vocab(self):
super().test_count_vectorizer_from_vocab()
- @unittest.skip("Need to support.")
- def test_string_indexer_from_labels(self):
- super().test_string_indexer_from_labels()
-
@unittest.skip("Need to support.")
def test_stop_words_lengague_selection(self):
super().test_stop_words_lengague_selection()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]