This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new f8c471344c36 [SPARK-51089][ML][PYTHON][CONNECT] Support
`VectorIndexerModel.categoryMaps` on connect
f8c471344c36 is described below
commit f8c471344c3688928b4e91ed6852400d80e6c22c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 21:45:42 2025 +0800
[SPARK-51089][ML][PYTHON][CONNECT] Support
`VectorIndexerModel.categoryMaps` on connect
### What changes were proposed in this pull request?
Fix `VectorIndexerModel.categoryMaps` on connect
### Why are the changes needed?
feature parity
### Does this PR introduce _any_ user-facing change?
yes, new api supported
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49803 from zhengruifeng/ml_connect_fix_cate_map.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit c7f6e37e519270e7e43987534445a5130dfdd56d)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../org/apache/spark/ml/feature/VectorIndexer.scala | 11 ++++++++++-
python/pyspark/ml/feature.py | 17 +++++++++++++++--
python/pyspark/ml/tests/test_feature.py | 3 +++
.../scala/org/apache/spark/sql/connect/ml/MLUtils.scala | 2 +-
4 files changed, 29 insertions(+), 4 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 1adc8a1bdae7..5063f15302a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -32,9 +32,10 @@ import org.apache.spark.ml.linalg.{DenseVector,
SparseVector, Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.collection.{OpenHashSet, Utils}
/** Private trait for params for VectorIndexer and VectorIndexerModel */
@@ -306,6 +307,14 @@ class VectorIndexerModel private[ml] (
.asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]]
}
+ private[spark] def categoryMapsDF: DataFrame = {
+ val data = categoryMaps.iterator.flatMap {
+ case (idx, map) => map.iterator.map(t => (idx, t._1, t._2))
+ }.toArray.toImmutableArraySeq
+ SparkSession.builder().getOrCreate().createDataFrame(data)
+ .toDF("featureIndex", "originalValue", "categoryIndex")
+ }
+
/**
* Pre-computed feature attributes, with some missing info.
* In transform(), set attribute name and other info, if available.
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 43cb00bf5305..6d1ddf5e51c4 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -6016,13 +6016,26 @@ class VectorIndexerModel(
@property
@since("1.4.0")
- def categoryMaps(self) -> Dict[int, Tuple[float, int]]:
+ def categoryMaps(self) -> Dict[int, Dict[float, int]]:
"""
Feature value index. Keys are categorical feature indices (column
indices).
Values are maps from original features values to 0-based category
indices.
If a feature is not in this map, it is treated as continuous.
"""
- return self._call_java("javaCategoryMaps")
+
+ @try_remote_attribute_relation
+ def categoryMapsDF(m: VectorIndexerModel) -> DataFrame:
+ return m._call_java("categoryMapsDF")
+
+ res: Dict[int, Dict[float, int]] = {}
+ for row in categoryMapsDF(self).collect():
+ featureIndex = int(row.featureIndex)
+ originalValue = float(row.originalValue)
+ categoryIndex = int(row.categoryIndex)
+ if featureIndex not in res:
+ res[featureIndex] = {}
+ res[featureIndex][originalValue] = categoryIndex
+ return res
@inherit_doc
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index bbcb54b8a442..a3e580ec7220 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -275,6 +275,9 @@ class FeatureTestsMixin:
self.assertEqual(indexer.uid, model.uid)
self.assertEqual(model.numFeatures, 2)
+ categoryMaps = model.categoryMaps
+ self.assertEqual(categoryMaps, {0: {0.0: 0, -1.0: 1}}, categoryMaps)
+
output = model.transform(df)
self.assertEqual(output.columns, ["a", "indexed"])
self.assertEqual(output.count(), 3)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 8138a8560ddb..e69f670226e4 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -640,7 +640,7 @@ private[ml] object MLUtils {
(classOf[MaxAbsScalerModel], Set("maxAbs")),
(classOf[MinMaxScalerModel], Set("originalMax", "originalMin")),
(classOf[RobustScalerModel], Set("range", "median")),
- (classOf[VectorIndexerModel], Set("numFeatures")),
+ (classOf[VectorIndexerModel], Set("numFeatures", "categoryMapsDF")),
(classOf[ChiSqSelectorModel], Set("selectedFeatures")),
(classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
(classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]