This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 4c00d639b913 [SPARK-51004][ML][PYTHON][CONNECT] Add supports for
IndexString
4c00d639b913 is described below
commit 4c00d639b913ad709b8f826f298ced2ffdc28800
Author: Bobby Wang <[email protected]>
AuthorDate: Mon Jan 27 21:44:28 2025 +0800
[SPARK-51004][ML][PYTHON][CONNECT] Add supports for IndexString
### What changes were proposed in this pull request?
This PR add supports for IndexString and add labels/labelsArray to
ALLOWED_LIST.
### Why are the changes needed?
new feature parity and bug fix
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
CI passes
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49690 from wbo4958/index-str.
Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit b5deb8dc25b8434f07649887da63f7172e9abdeb)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Transformer | 1 +
python/pyspark/ml/tests/test_feature.py | 48 ++++++++++++++++++++++
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 1 +
3 files changed, 50 insertions(+)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 06375a701010..ec3c5abe6278 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -32,6 +32,7 @@ org.apache.spark.ml.feature.SQLTransformer
org.apache.spark.ml.feature.StopWordsRemover
org.apache.spark.ml.feature.FeatureHasher
org.apache.spark.ml.feature.HashingTF
+org.apache.spark.ml.feature.IndexToString
########### Model for loading
# classification
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index ee6f8a78cc4c..3a6d3adf2cbf 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -72,6 +72,7 @@ from pyspark.ml.feature import (
BucketedRandomProjectionLSHModel,
MinHashLSH,
MinHashLSHModel,
+ IndexToString,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
@@ -80,6 +81,51 @@ from pyspark.testing.mlutils import SparkSessionTestCase
class FeatureTestsMixin:
+ def test_index_string(self):
+ dataset = self.spark.createDataFrame(
+ [
+ (0, "a"),
+ (1, "b"),
+ (2, "c"),
+ (3, "a"),
+ (4, "a"),
+ (5, "c"),
+ ],
+ ["id", "label"],
+ )
+
+ indexer = StringIndexer(inputCol="label",
outputCol="labelIndex").fit(dataset)
+ transformed = indexer.transform(dataset)
+ idx2str = (
+ IndexToString()
+ .setInputCol("labelIndex")
+ .setOutputCol("sameLabel")
+ .setLabels(indexer.labels)
+ )
+
+ def check(t: IndexToString) -> None:
+ self.assertEqual(t.getInputCol(), "labelIndex")
+ self.assertEqual(t.getOutputCol(), "sameLabel")
+ self.assertEqual(t.getLabels(), indexer.labels)
+
+ check(idx2str)
+
+ ret = idx2str.transform(transformed)
+ self.assertEqual(
+ sorted(ret.schema.names), sorted(["id", "label", "labelIndex",
"sameLabel"])
+ )
+
+ rows = ret.select("label", "sameLabel").collect()
+ for r in rows:
+ self.assertEqual(r.label, r.sameLabel)
+
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="index_string") as d:
+ idx2str.write().overwrite().save(d)
+ idx2str2 = IndexToString.load(d)
+ self.assertEqual(str(idx2str), str(idx2str2))
+ check(idx2str2)
+
def test_dct(self):
df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)],
["vec"])
dct = DCT()
@@ -128,6 +174,7 @@ class FeatureTestsMixin:
si = StringIndexer(inputCol="label1", outputCol="index1")
model = si.fit(df.select("label1"))
self.assertEqual(si.uid, model.uid)
+ self.assertEqual(model.labels, list(model.labelsArray[0]))
# read/write
with tempfile.TemporaryDirectory(prefix="string_indexer") as tmp_dir:
@@ -188,6 +235,7 @@ class FeatureTestsMixin:
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
model = pca.fit(df)
+ self.assertTrue(np.allclose(model.pc.toArray()[0], [-0.44859172,
-0.28423808], atol=1e-4))
self.assertEqual(pca.uid, model.uid)
self.assertEqual(model.getK(), 2)
self.assertTrue(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index be067ed8972e..0dd1353abbd2 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -646,6 +646,7 @@ private[ml] object MLUtils {
(classOf[Word2VecModel], Set("getVectors", "findSynonyms",
"findSynonymsArray")),
(classOf[CountVectorizerModel], Set("vocabulary")),
(classOf[OneHotEncoderModel], Set("categorySizes")),
+ (classOf[StringIndexerModel], Set("labels", "labelsArray")),
(classOf[IDFModel], Set("idf", "docFreq", "numDocs")))
private def validate(obj: Any, method: String): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]