This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new f8c471344c36 [SPARK-51089][ML][PYTHON][CONNECT] Support 
`VectorIndexerModel.categoryMaps` on connect
f8c471344c36 is described below

commit f8c471344c3688928b4e91ed6852400d80e6c22c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 21:45:42 2025 +0800

    [SPARK-51089][ML][PYTHON][CONNECT] Support 
`VectorIndexerModel.categoryMaps` on connect
    
    ### What changes were proposed in this pull request?
    Fix `VectorIndexerModel.categoryMaps` on connect
    
    ### Why are the changes needed?
    feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new api supported
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49803 from zhengruifeng/ml_connect_fix_cate_map.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit c7f6e37e519270e7e43987534445a5130dfdd56d)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/ml/feature/VectorIndexer.scala     | 11 ++++++++++-
 python/pyspark/ml/feature.py                            | 17 +++++++++++++++--
 python/pyspark/ml/tests/test_feature.py                 |  3 +++
 .../scala/org/apache/spark/sql/connect/ml/MLUtils.scala |  2 +-
 4 files changed, 29 insertions(+), 4 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 1adc8a1bdae7..5063f15302a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -32,9 +32,10 @@ import org.apache.spark.ml.linalg.{DenseVector, 
SparseVector, Vector, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.ArrayImplicits._
 import org.apache.spark.util.collection.{OpenHashSet, Utils}
 
 /** Private trait for params for VectorIndexer and VectorIndexerModel */
@@ -306,6 +307,14 @@ class VectorIndexerModel private[ml] (
       .asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]]
   }
 
+  private[spark] def categoryMapsDF: DataFrame = {
+    val data = categoryMaps.iterator.flatMap {
+      case (idx, map) => map.iterator.map(t => (idx, t._1, t._2))
+    }.toArray.toImmutableArraySeq
+    SparkSession.builder().getOrCreate().createDataFrame(data)
+      .toDF("featureIndex", "originalValue", "categoryIndex")
+  }
+
   /**
    * Pre-computed feature attributes, with some missing info.
    * In transform(), set attribute name and other info, if available.
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 43cb00bf5305..6d1ddf5e51c4 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -6016,13 +6016,26 @@ class VectorIndexerModel(
 
     @property
     @since("1.4.0")
-    def categoryMaps(self) -> Dict[int, Tuple[float, int]]:
+    def categoryMaps(self) -> Dict[int, Dict[float, int]]:
         """
         Feature value index.  Keys are categorical feature indices (column 
indices).
         Values are maps from original features values to 0-based category 
indices.
         If a feature is not in this map, it is treated as continuous.
         """
-        return self._call_java("javaCategoryMaps")
+
+        @try_remote_attribute_relation
+        def categoryMapsDF(m: VectorIndexerModel) -> DataFrame:
+            return m._call_java("categoryMapsDF")
+
+        res: Dict[int, Dict[float, int]] = {}
+        for row in categoryMapsDF(self).collect():
+            featureIndex = int(row.featureIndex)
+            originalValue = float(row.originalValue)
+            categoryIndex = int(row.categoryIndex)
+            if featureIndex not in res:
+                res[featureIndex] = {}
+            res[featureIndex][originalValue] = categoryIndex
+        return res
 
 
 @inherit_doc
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index bbcb54b8a442..a3e580ec7220 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -275,6 +275,9 @@ class FeatureTestsMixin:
         self.assertEqual(indexer.uid, model.uid)
         self.assertEqual(model.numFeatures, 2)
 
+        categoryMaps = model.categoryMaps
+        self.assertEqual(categoryMaps, {0: {0.0: 0, -1.0: 1}}, categoryMaps)
+
         output = model.transform(df)
         self.assertEqual(output.columns, ["a", "indexed"])
         self.assertEqual(output.count(), 3)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 8138a8560ddb..e69f670226e4 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -640,7 +640,7 @@ private[ml] object MLUtils {
     (classOf[MaxAbsScalerModel], Set("maxAbs")),
     (classOf[MinMaxScalerModel], Set("originalMax", "originalMin")),
     (classOf[RobustScalerModel], Set("range", "median")),
-    (classOf[VectorIndexerModel], Set("numFeatures")),
+    (classOf[VectorIndexerModel], Set("numFeatures", "categoryMapsDF")),
     (classOf[ChiSqSelectorModel], Set("selectedFeatures")),
     (classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
     (classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to