This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 4c00d639b913 [SPARK-51004][ML][PYTHON][CONNECT] Add supports for 
IndexString
4c00d639b913 is described below

commit 4c00d639b913ad709b8f826f298ced2ffdc28800
Author: Bobby Wang <[email protected]>
AuthorDate: Mon Jan 27 21:44:28 2025 +0800

    [SPARK-51004][ML][PYTHON][CONNECT] Add supports for IndexString
    
    ### What changes were proposed in this pull request?
    This PR add supports for IndexString and add labels/labelsArray to 
ALLOWED_LIST.
    
    ### Why are the changes needed?
    
    new feature parity and bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    CI passes
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49690 from wbo4958/index-str.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit b5deb8dc25b8434f07649887da63f7172e9abdeb)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       |  1 +
 python/pyspark/ml/tests/test_feature.py            | 48 ++++++++++++++++++++++
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  1 +
 3 files changed, 50 insertions(+)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 06375a701010..ec3c5abe6278 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -32,6 +32,7 @@ org.apache.spark.ml.feature.SQLTransformer
 org.apache.spark.ml.feature.StopWordsRemover
 org.apache.spark.ml.feature.FeatureHasher
 org.apache.spark.ml.feature.HashingTF
+org.apache.spark.ml.feature.IndexToString
 
 ########### Model for loading
 # classification
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index ee6f8a78cc4c..3a6d3adf2cbf 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -72,6 +72,7 @@ from pyspark.ml.feature import (
     BucketedRandomProjectionLSHModel,
     MinHashLSH,
     MinHashLSHModel,
+    IndexToString,
 )
 from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
 from pyspark.sql import Row
@@ -80,6 +81,51 @@ from pyspark.testing.mlutils import SparkSessionTestCase
 
 
 class FeatureTestsMixin:
+    def test_index_string(self):
+        dataset = self.spark.createDataFrame(
+            [
+                (0, "a"),
+                (1, "b"),
+                (2, "c"),
+                (3, "a"),
+                (4, "a"),
+                (5, "c"),
+            ],
+            ["id", "label"],
+        )
+
+        indexer = StringIndexer(inputCol="label", 
outputCol="labelIndex").fit(dataset)
+        transformed = indexer.transform(dataset)
+        idx2str = (
+            IndexToString()
+            .setInputCol("labelIndex")
+            .setOutputCol("sameLabel")
+            .setLabels(indexer.labels)
+        )
+
+        def check(t: IndexToString) -> None:
+            self.assertEqual(t.getInputCol(), "labelIndex")
+            self.assertEqual(t.getOutputCol(), "sameLabel")
+            self.assertEqual(t.getLabels(), indexer.labels)
+
+        check(idx2str)
+
+        ret = idx2str.transform(transformed)
+        self.assertEqual(
+            sorted(ret.schema.names), sorted(["id", "label", "labelIndex", 
"sameLabel"])
+        )
+
+        rows = ret.select("label", "sameLabel").collect()
+        for r in rows:
+            self.assertEqual(r.label, r.sameLabel)
+
+        # save & load
+        with tempfile.TemporaryDirectory(prefix="index_string") as d:
+            idx2str.write().overwrite().save(d)
+            idx2str2 = IndexToString.load(d)
+            self.assertEqual(str(idx2str), str(idx2str2))
+            check(idx2str2)
+
     def test_dct(self):
         df = self.spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], 
["vec"])
         dct = DCT()
@@ -128,6 +174,7 @@ class FeatureTestsMixin:
         si = StringIndexer(inputCol="label1", outputCol="index1")
         model = si.fit(df.select("label1"))
         self.assertEqual(si.uid, model.uid)
+        self.assertEqual(model.labels, list(model.labelsArray[0]))
 
         # read/write
         with tempfile.TemporaryDirectory(prefix="string_indexer") as tmp_dir:
@@ -188,6 +235,7 @@ class FeatureTestsMixin:
         pca = PCA(k=2, inputCol="features", outputCol="pca_features")
 
         model = pca.fit(df)
+        self.assertTrue(np.allclose(model.pc.toArray()[0], [-0.44859172, 
-0.28423808], atol=1e-4))
         self.assertEqual(pca.uid, model.uid)
         self.assertEqual(model.getK(), 2)
         self.assertTrue(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index be067ed8972e..0dd1353abbd2 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -646,6 +646,7 @@ private[ml] object MLUtils {
     (classOf[Word2VecModel], Set("getVectors", "findSynonyms", 
"findSynonymsArray")),
     (classOf[CountVectorizerModel], Set("vocabulary")),
     (classOf[OneHotEncoderModel], Set("categorySizes")),
+    (classOf[StringIndexerModel], Set("labels", "labelsArray")),
     (classOf[IDFModel], Set("idf", "docFreq", "numDocs")))
 
   private def validate(obj: Any, method: String): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to