This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new c97ce86226de [SPARK-51005][ML][PYTHON][CONNECT] Support VectorIndexer 
and ElementwiseProduct on Connect
c97ce86226de is described below

commit c97ce86226deaaa6c353d26459b59b4e046829ea
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 28 09:33:37 2025 +0900

    [SPARK-51005][ML][PYTHON][CONNECT] Support VectorIndexer and 
ElementwiseProduct on Connect
    
    ### What changes were proposed in this pull request?
    Support VectorIndexer and ElementwiseProduct on Connect
    
    ### Why are the changes needed?
    For feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49691 from zhengruifeng/ml_connect_index.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 9d8f39b2007b3db30da0da87d64bf31b19fb2e6d)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../services/org.apache.spark.ml.Estimator         |  1 +
 .../services/org.apache.spark.ml.Transformer       |  2 +
 .../apache/spark/ml/feature/VectorIndexer.scala    |  2 +
 python/pyspark/ml/tests/test_feature.py            | 64 ++++++++++++++++++++++
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  1 +
 5 files changed, 70 insertions(+)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index 26ea0fe5a00e..44f59ef3b07e 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -60,6 +60,7 @@ org.apache.spark.ml.feature.ChiSqSelector
 org.apache.spark.ml.feature.UnivariateFeatureSelector
 org.apache.spark.ml.feature.VarianceThresholdSelector
 org.apache.spark.ml.feature.StringIndexer
+org.apache.spark.ml.feature.VectorIndexer
 org.apache.spark.ml.feature.PCA
 org.apache.spark.ml.feature.IDF
 org.apache.spark.ml.feature.Word2Vec
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index ec3c5abe6278..c64f93866caf 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -31,6 +31,7 @@ org.apache.spark.ml.feature.RegexTokenizer
 org.apache.spark.ml.feature.SQLTransformer
 org.apache.spark.ml.feature.StopWordsRemover
 org.apache.spark.ml.feature.FeatureHasher
+org.apache.spark.ml.feature.ElementwiseProduct
 org.apache.spark.ml.feature.HashingTF
 org.apache.spark.ml.feature.IndexToString
 
@@ -78,6 +79,7 @@ org.apache.spark.ml.feature.ChiSqSelectorModel
 org.apache.spark.ml.feature.UnivariateFeatureSelectorModel
 org.apache.spark.ml.feature.VarianceThresholdSelectorModel
 org.apache.spark.ml.feature.StringIndexerModel
+org.apache.spark.ml.feature.VectorIndexerModel
 org.apache.spark.ml.feature.PCAModel
 org.apache.spark.ml.feature.IDFModel
 org.apache.spark.ml.feature.Word2VecModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index b2323d2b706f..1adc8a1bdae7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -297,6 +297,8 @@ class VectorIndexerModel private[ml] (
 
   import VectorIndexerModel._
 
+  private[ml] def this() = this(Identifiable.randomUID("vecIdx"), -1, 
Map.empty)
+
   /** Java-friendly version of [[categoryMaps]] */
   @Since("1.4.0")
   def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index 3a6d3adf2cbf..1c1c703bd221 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -31,6 +31,7 @@ from pyspark.ml.feature import (
     OneHotEncoder,
     OneHotEncoderModel,
     FeatureHasher,
+    ElementwiseProduct,
     HashingTF,
     IDF,
     IDFModel,
@@ -60,6 +61,8 @@ from pyspark.ml.feature import (
     StopWordsRemover,
     StringIndexer,
     StringIndexerModel,
+    VectorIndexer,
+    VectorIndexerModel,
     TargetEncoder,
     TargetEncoderModel,
     VectorSizeHint,
@@ -223,6 +226,67 @@ class FeatureTestsMixin:
         sorted_value = sorted([v for _, v in result])
         self.assertEqual(sorted_value, [0.0, 1.0])
 
+    def test_vector_indexer(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [
+                (Vectors.dense([-1.0, 0.0]),),
+                (Vectors.dense([0.0, 1.0]),),
+                (Vectors.dense([0.0, 2.0]),),
+            ],
+            ["a"],
+        )
+
+        indexer = VectorIndexer(maxCategories=2, inputCol="a")
+        indexer.setOutputCol("indexed")
+        self.assertEqual(indexer.getMaxCategories(), 2)
+        self.assertEqual(indexer.getInputCol(), "a")
+        self.assertEqual(indexer.getOutputCol(), "indexed")
+
+        model = indexer.fit(df)
+        self.assertEqual(indexer.uid, model.uid)
+        self.assertEqual(model.numFeatures, 2)
+
+        output = model.transform(df)
+        self.assertEqual(output.columns, ["a", "indexed"])
+        self.assertEqual(output.count(), 3)
+
+        # save & load
+        with tempfile.TemporaryDirectory(prefix="vector_indexer") as d:
+            indexer.write().overwrite().save(d)
+            indexer2 = VectorIndexer.load(d)
+            self.assertEqual(str(indexer), str(indexer2))
+            self.assertEqual(indexer2.getOutputCol(), "indexed")
+
+            model.write().overwrite().save(d)
+            model2 = VectorIndexerModel.load(d)
+            self.assertEqual(str(model), str(model2))
+            self.assertEqual(model2.getOutputCol(), "indexed")
+
+    def test_elementwise_product(self):
+        spark = self.spark
+        df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], 
["values"])
+
+        ep = ElementwiseProduct()
+        ep.setScalingVec(Vectors.dense([1.0, 2.0, 3.0]))
+        ep.setInputCol("values")
+        ep.setOutputCol("eprod")
+
+        self.assertEqual(ep.getScalingVec(), Vectors.dense([1.0, 2.0, 3.0]))
+        self.assertEqual(ep.getInputCol(), "values")
+        self.assertEqual(ep.getOutputCol(), "eprod")
+
+        output = ep.transform(df)
+        self.assertEqual(output.columns, ["values", "eprod"])
+        self.assertEqual(output.count(), 1)
+
+        # save & load
+        with tempfile.TemporaryDirectory(prefix="elementwise_product") as d:
+            ep.write().overwrite().save(d)
+            ep2 = ElementwiseProduct.load(d)
+            self.assertEqual(str(ep), str(ep2))
+            self.assertEqual(ep2.getOutputCol(), "eprod")
+
     def test_pca(self):
         df = self.spark.createDataFrame(
             [
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 06b382707848..12aae8947507 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -638,6 +638,7 @@ private[ml] object MLUtils {
     (classOf[MaxAbsScalerModel], Set("maxAbs")),
     (classOf[MinMaxScalerModel], Set("originalMax", "originalMin")),
     (classOf[RobustScalerModel], Set("range", "median")),
+    (classOf[VectorIndexerModel], Set("numFeatures")),
     (classOf[ChiSqSelectorModel], Set("selectedFeatures")),
     (classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
     (classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to