This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new dd771ccd2845 [SPARK-50901][ML][PYTHON][CONNECT] Support Transformer 
`VectorAssembler`
dd771ccd2845 is described below

commit dd771ccd28450c25159d5b4d391cd7acbe3e32da
Author: Bobby Wang <[email protected]>
AuthorDate: Wed Jan 22 09:28:55 2025 +0800

    [SPARK-50901][ML][PYTHON][CONNECT] Support Transformer `VectorAssembler`
    
    ### What changes were proposed in this pull request?
    
    This PR adds support transformer on ml connect. Currently, VectorAssembler 
is fully supported.
    
    ### Why are the changes needed?
    
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, new algorithms supported on connect
    
    ### How was this patch tested?
    The newly added test can pass
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49588 from wbo4958/transformer.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 3450184d332c3ff6203a200df0dddeced7ec9fd4)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       |   1 +
 python/pyspark/ml/connect/readwrite.py             |  48 +++---
 python/pyspark/ml/tests/test_feature.py            |  44 ++++++
 .../apache/spark/sql/connect/ml/MLHandler.scala    |  16 +-
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  24 +++
 .../services/org.apache.spark.ml.Transformer       |   1 +
 .../spark/sql/connect/ml/MLBackendSuite.scala      | 159 +++++++-------------
 .../org/apache/spark/sql/connect/ml/MLHelper.scala | 160 +++++++++++++++++++-
 .../org/apache/spark/sql/connect/ml/MLSuite.scala  | 161 +++++++--------------
 9 files changed, 374 insertions(+), 240 deletions(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 4b029ae610d7..a25c03ed2b8e 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -17,6 +17,7 @@
 
 # Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
non-model transformer.
 # So register the supported transformer here if you're trying to add a new one.
+########### Transformers
 org.apache.spark.ml.feature.VectorAssembler
 
 ########### Model for loading
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 1f514c653aa0..41ae66d32108 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-from typing import cast, Type, TYPE_CHECKING
+from typing import cast, Type, TYPE_CHECKING, Union
 
 import pyspark.sql.connect.proto as pb2
 from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, 
deserialize_param
@@ -37,7 +37,7 @@ class RemoteMLWriter(MLWriter):
         raise RuntimeError("Accessing SparkContext is not supported on 
Connect")
 
     def save(self, path: str) -> None:
-        from pyspark.ml.wrapper import JavaModel, JavaEstimator
+        from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.sql.connect.session import SparkSession
 
@@ -57,35 +57,29 @@ class RemoteMLWriter(MLWriter):
                 should_overwrite=self.shouldOverwrite,
                 options=self.optionMap,
             )
-        elif isinstance(self._instance, JavaEstimator):
-            estimator = cast("JavaEstimator", self._instance)
-            params = serialize_ml_params(estimator, session.client)
-            assert isinstance(estimator._java_obj, str)
-            writer = pb2.MlCommand.Write(
-                operator=pb2.MlOperator(
-                    name=estimator._java_obj, uid=estimator.uid, 
type=pb2.MlOperator.ESTIMATOR
-                ),
-                params=params,
-                path=path,
-                should_overwrite=self.shouldOverwrite,
-                options=self.optionMap,
-            )
-        elif isinstance(self._instance, JavaEvaluator):
-            evaluator = cast("JavaEvaluator", self._instance)
-            params = serialize_ml_params(evaluator, session.client)
-            assert isinstance(evaluator._java_obj, str)
+        else:
+            operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
+            if isinstance(self._instance, JavaEstimator):
+                ml_type = pb2.MlOperator.ESTIMATOR
+                operator = cast("JavaEstimator", self._instance)
+            elif isinstance(self._instance, JavaEvaluator):
+                ml_type = pb2.MlOperator.EVALUATOR
+                operator = cast("JavaEvaluator", self._instance)
+            elif isinstance(self._instance, JavaTransformer):
+                ml_type = pb2.MlOperator.TRANSFORMER
+                operator = cast("JavaTransformer", self._instance)
+            else:
+                raise NotImplementedError(f"Unsupported writing for 
{self._instance}")
+
+            params = serialize_ml_params(operator, session.client)
+            assert isinstance(operator._java_obj, str)
             writer = pb2.MlCommand.Write(
-                operator=pb2.MlOperator(
-                    name=evaluator._java_obj, uid=evaluator.uid, 
type=pb2.MlOperator.EVALUATOR
-                ),
+                operator=pb2.MlOperator(name=operator._java_obj, 
uid=operator.uid, type=ml_type),
                 params=params,
                 path=path,
                 should_overwrite=self.shouldOverwrite,
                 options=self.optionMap,
             )
-        else:
-            raise NotImplementedError(f"Unsupported writing for 
{self._instance}")
-
         command = pb2.Command()
         command.ml_command.write.CopyFrom(writer)
         session.client.execute_command(command)
@@ -98,7 +92,7 @@ class RemoteMLReader(MLReader[RL]):
 
     def load(self, path: str) -> RL:
         from pyspark.sql.connect.session import SparkSession
-        from pyspark.ml.wrapper import JavaModel, JavaEstimator
+        from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
 
         session = SparkSession.getActiveSession()
@@ -116,6 +110,8 @@ class RemoteMLReader(MLReader[RL]):
             ml_type = pb2.MlOperator.ESTIMATOR
         elif issubclass(self._clazz, JavaEvaluator):
             ml_type = pb2.MlOperator.EVALUATOR
+        elif issubclass(self._clazz, JavaTransformer):
+            ml_type = pb2.MlOperator.TRANSFORMER
         else:
             raise ValueError(f"Unsupported reading for 
{java_qualified_class_name}")
 
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index a46fdd22e2bc..51c7a3631e1b 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -40,6 +40,7 @@ from pyspark.ml.feature import (
     StringIndexerModel,
     TargetEncoder,
     VectorSizeHint,
+    VectorAssembler,
 )
 from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
 from pyspark.sql import Row
@@ -48,6 +49,49 @@ from pyspark.testing.mlutils import check_params, 
SparkSessionTestCase
 
 
 class FeatureTestsMixin:
+    def test_vector_assembler(self):
+        # Create a DataFrame
+        df = (
+            self.spark.createDataFrame(
+                [
+                    (1, 5.0, 6.0, 7.0),
+                    (2, 1.0, 2.0, None),
+                    (3, 3.0, float("nan"), 4.0),
+                ],
+                ["index", "a", "b", "c"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("index")
+        )
+
+        # Initialize VectorAssembler
+        vec_assembler = 
VectorAssembler(outputCol="features").setInputCols(["a", "b", "c"])
+        output = vec_assembler.transform(df)
+        self.assertEqual(output.columns, ["index", "a", "b", "c", "features"])
+        self.assertEqual(output.head().features, Vectors.dense([5.0, 6.0, 
7.0]))
+
+        # Set custom parameters and transform the DataFrame
+        params = {vec_assembler.inputCols: ["b", "a"], 
vec_assembler.outputCol: "vector"}
+        self.assertEqual(
+            vec_assembler.transform(df, params).head().vector, 
Vectors.dense([6.0, 5.0])
+        )
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="read_write") as tmp_dir:
+            vec_assembler.write().overwrite().save(tmp_dir)
+            vec_assembler2 = VectorAssembler.load(tmp_dir)
+            self.assertEqual(str(vec_assembler), str(vec_assembler2))
+
+        # Initialize a new VectorAssembler with handleInvalid="keep"
+        vec_assembler3 = VectorAssembler(
+            inputCols=["a", "b", "c"], outputCol="features", 
handleInvalid="keep"
+        )
+        self.assertEqual(vec_assembler3.transform(df).count(), 3)
+
+        # Update handleInvalid to "skip" and transform the DataFrame
+        vec_assembler3.setParams(handleInvalid="skip")
+        self.assertEqual(vec_assembler3.transform(df).count(), 1)
+
     def test_standard_scaler(self):
         df = (
             self.spark.createDataFrame(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index c66a2e7004b9..ea6303937bc3 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -191,6 +191,15 @@ private[connect] object MLHandler extends Logging {
                   case other => throw MlUnsupportedException(s"Evaluator 
$other is not writable")
                 }
 
+              case proto.MlOperator.OperatorType.TRANSFORMER =>
+                val transformer =
+                  MLUtils.getTransformer(sessionHolder, writer.getOperator, 
params)
+                transformer match {
+                  case writable: MLWritable => MLUtils.write(writable, 
mlCommand.getWrite)
+                  case other =>
+                    throw MlUnsupportedException(s"Transformer $other is not 
writable")
+                }
+
               case _ =>
                 throw MlUnsupportedException(s"Operator $operatorName is not 
supported")
             }
@@ -217,12 +226,15 @@ private[connect] object MLHandler extends Logging {
             .build()
 
         } else if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR 
||
-          operator.getType == proto.MlOperator.OperatorType.EVALUATOR) {
+          operator.getType == proto.MlOperator.OperatorType.EVALUATOR ||
+          operator.getType == proto.MlOperator.OperatorType.TRANSFORMER) {
           val mlOperator = {
             if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR) {
               MLUtils.loadEstimator(sessionHolder, name, 
path).asInstanceOf[Params]
-            } else {
+            } else if (operator.getType == 
proto.MlOperator.OperatorType.EVALUATOR) {
               MLUtils.loadEvaluator(sessionHolder, name, 
path).asInstanceOf[Params]
+            } else {
+              MLUtils.loadTransformer(sessionHolder, name, 
path).asInstanceOf[Params]
             }
           }
           proto.MlCommandResult
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 04dbb60cb1ed..34a0317f55af 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -338,6 +338,30 @@ private[ml] object MLUtils {
     getInstance[Transformer](name, uid, transformers, Some(params))
   }
 
+  /**
+   * Get the Transformer instance according to the proto information
+   *
+   * @param sessionHolder
+   *   session holder to hold the Spark Connect session state
+   * @param operator
+   *   MlOperator information
+   * @param params
+   *   The optional parameters of the transformer
+   * @return
+   *   the transformer
+   */
+  def getTransformer(
+      sessionHolder: SessionHolder,
+      operator: proto.MlOperator,
+      params: Option[proto.MlParams]): Transformer = {
+    val name = replaceOperator(sessionHolder, operator.getName)
+    val uid = operator.getUid
+
+    // Load the transformers by ServiceLoader everytime
+    val transformers = loadOperators(classOf[Transformer])
+    getInstance[Transformer](name, uid, transformers, params)
+  }
+
   /**
    * Get the Evaluator instance according to the proto information
    *
diff --git 
a/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
 
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
index 92d3a7018054..e74b087fa8da 100644
--- 
a/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ 
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,3 +18,4 @@
 # Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
estimators.
 # So register the supported estimator here if you're trying to add a new one.
 org.apache.spark.sql.connect.ml.MyLogisticRegressionModel
+org.apache.spark.sql.connect.ml.MyVectorAssembler
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
index 7cd95f9f657d..5b2b5e6dd793 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql.connect.ml
 
-import java.io.File
+import scala.jdk.CollectionConverters.ListHasAsScala
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.connect.proto
-import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.connect.SparkConnectTestUtils
 import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.util.Utils
@@ -79,43 +78,12 @@ class MLBackendSuite extends MLHelper {
       assert(model.intercept == 3.5f)
       assert(model.coefficients == 4.6f)
 
-      // read/write
-      val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
-      try {
-        val path = new File(tempDir, 
Identifiable.randomUID("LogisticRegression")).getPath
-        val writeCmd = proto.MlCommand
-          .newBuilder()
-          .setWrite(
-            proto.MlCommand.Write
-              .newBuilder()
-              .setOperator(getLogisticRegressionBuilder)
-              .setParams(getMaxIterBuilder)
-              .setPath(path)
-              .setShouldOverwrite(true))
-          .build()
-        MLHandler.handleMlCommand(sessionHolder, writeCmd)
+      val ret = readWrite(sessionHolder, getLogisticRegressionBuilder, 
getMaxIterBuilder)
 
-        val readCmd = proto.MlCommand
-          .newBuilder()
-          .setRead(
-            proto.MlCommand.Read
-              .newBuilder()
-              .setOperator(getLogisticRegressionBuilder)
-              .setPath(path))
-          .build()
-
-        val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
-        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("maxIter"))
-        assert(
-          ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger
-            == 2)
-        assert(
-          
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
-            == 101010)
-      } finally {
-        Utils.deleteRecursively(tempDir)
-      }
+      
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
+      assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("maxIter"))
+      
assert(ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger == 
2)
+      
assert(ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger 
== 101010)
     }
   }
 
@@ -138,37 +106,11 @@ class MLBackendSuite extends MLHelper {
               .setParams(getMaxIterBuilder))
           .build()
         val fitRet = MLHandler.handleMlCommand(sessionHolder, fitCommand)
-        val modelId = fitRet.getOperatorInfo.getObjRef.getId
-
-        // Write a model
-        val path = new File(tempDir, 
Identifiable.randomUID("LogisticRegression")).getPath
-        val writeCmd = proto.MlCommand
-          .newBuilder()
-          .setWrite(
-            proto.MlCommand.Write
-              .newBuilder()
-              .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
-              .setPath(path)
-              .setShouldOverwrite(true))
-          .build()
-        MLHandler.handleMlCommand(sessionHolder, writeCmd)
-
-        // read a model
-        val readCmd = proto.MlCommand
-          .newBuilder()
-          .setRead(
-            proto.MlCommand.Read
-              .newBuilder()
-              .setOperator(proto.MlOperator
-                .newBuilder()
-                
.setName("org.apache.spark.ml.classification.LogisticRegressionModel")
-                .setType(proto.MlOperator.OperatorType.MODEL))
-              .setPath(path))
-          .build()
 
-        val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
-        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("maxIter"))
+        val ret = readWrite(
+          sessionHolder,
+          fitRet.getOperatorInfo.getObjRef.getId,
+          "org.apache.spark.ml.classification.LogisticRegressionModel")
         assert(
           ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger
             == 2)
@@ -203,43 +145,54 @@ class MLBackendSuite extends MLHelper {
       val evalResult = MLHandler.handleMlCommand(sessionHolder, evalCmd)
       assert(evalResult.getParam.getDouble == 1.11)
 
-      // read/write
-      val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
-      try {
-        val path = new File(tempDir, 
Identifiable.randomUID("Evaluator")).getPath
-        val writeCmd = proto.MlCommand
-          .newBuilder()
-          .setWrite(
-            proto.MlCommand.Write
-              .newBuilder()
-              .setOperator(getRegressorEvaluator)
-              .setParams(getMetricName)
-              .setPath(path)
-              .setShouldOverwrite(true))
-          .build()
-        MLHandler.handleMlCommand(sessionHolder, writeCmd)
+      val ret = readWrite(sessionHolder, getRegressorEvaluator, getMetricName)
 
-        val readCmd = proto.MlCommand
-          .newBuilder()
-          .setRead(
-            proto.MlCommand.Read
-              .newBuilder()
-              .setOperator(getRegressorEvaluator)
-              .setPath(path))
-          .build()
+      
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
+      
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("metricName"))
+      assert(
+        ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
+          == "mae")
+      assert(
+        ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
+          == 101010)
+    }
+  }
 
-        val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
-        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("metricName"))
-        assert(
-          
ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
-            == "mae")
-        assert(
-          
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
-            == 101010)
-      } finally {
-        Utils.deleteRecursively(tempDir)
-      }
+  test("ML backend: transformer works") {
+    withSparkConf(
+      Connect.CONNECT_ML_BACKEND_CLASSES.key ->
+        "org.apache.spark.sql.connect.ml.MyMlBackend") {
+      val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+      val transformerRelation = proto.MlRelation
+        .newBuilder()
+        .setTransform(
+          proto.MlRelation.Transform
+            .newBuilder()
+            .setTransformer(getVectorAssembler)
+            .setParams(getVectorAssemblerParams)
+            .setInput(createMultiColumnLocalRelationProto))
+        .build()
+
+      val transRet = MLHandler.transformMLRelation(transformerRelation, 
sessionHolder)
+      // MyVectorAssembler has hacked the transform function
+      Seq("a", "b", "c", "new").foreach(n => 
assert(transRet.schema.names.contains(n)))
+
+      val ret = readWrite(sessionHolder, getVectorAssembler, 
getVectorAssemblerParams)
+      assert(
+        
ret.getOperatorInfo.getParams.getParamsMap.get("handleInvalid").getString
+          == "skip")
+      assert(
+        ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
+          == 101010)
+      assert(
+        ret.getOperatorInfo.getParams.getParamsMap
+          .get("inputCols")
+          .getArray
+          .getElementsList
+          .asScala
+          .map(_.getString)
+          .toArray sameElements Array("a", "b", "c"))
     }
   }
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
index ef5b8a59a58b..9383794b38dc 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
@@ -17,15 +17,16 @@
 
 package org.apache.spark.sql.connect.ml
 
+import java.io.File
 import java.util.Optional
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.connect.proto
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.{Estimator, Model, Transformer}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
 import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
-import org.apache.spark.ml.param.shared.HasMaxIter
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, 
HasMaxIter, HasOutputCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable, MLReadable, MLReader}
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -33,7 +34,10 @@ import 
org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanTest
 import org.apache.spark.sql.connect.plugin.MLBackendPlugin
-import org.apache.spark.sql.types.{DoubleType, FloatType, Metadata, 
StructField, StructType}
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, 
Metadata, StructField, StructType}
+import org.apache.spark.util.Utils
 
 trait MLHelper extends SparkFunSuite with SparkConnectPlanTest {
 
@@ -74,6 +78,32 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
     createLocalRelationProto(schema, inputRows)
   }
 
+  def createMultiColumnLocalRelationProto: proto.Relation = {
+    val rows = Seq(InternalRow(1, 0, 3))
+    val schema = StructType(
+      Seq(
+        StructField("a", IntegerType),
+        StructField("b", IntegerType),
+        StructField("c", IntegerType)))
+    val inputRows = rows.map { row =>
+      val proj = UnsafeProjection.create(schema)
+      proj(row).copy()
+    }
+    createLocalRelationProto(schema, inputRows)
+  }
+
+  def getLogisticRegression: proto.MlOperator.Builder =
+    proto.MlOperator
+      .newBuilder()
+      .setName("org.apache.spark.ml.classification.LogisticRegression")
+      .setUid("LogisticRegression")
+      .setType(proto.MlOperator.OperatorType.ESTIMATOR)
+
+  def getMaxIter: proto.MlParams.Builder =
+    proto.MlParams
+      .newBuilder()
+      .putParams("maxIter", 
proto.Expression.Literal.newBuilder().setInteger(2).build())
+
   def getRegressorEvaluator: proto.MlOperator.Builder =
     proto.MlOperator
       .newBuilder()
@@ -96,6 +126,109 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
           .addMethods(proto.Fetch.Method.newBuilder().setMethod(method)))
       .build()
   }
+
+  def getArrayStrings: proto.Expression.Literal =
+    proto.Expression.Literal
+      .newBuilder()
+      .setArray(
+        proto.Expression.Literal.Array
+          .newBuilder()
+          .setElementType(proto.DataType
+            .newBuilder()
+            .setString(proto.DataType.String.getDefaultInstance)
+            .build())
+          .addElements(proto.Expression.Literal.newBuilder().setString("a"))
+          .addElements(proto.Expression.Literal.newBuilder().setString("b"))
+          .addElements(proto.Expression.Literal.newBuilder().setString("c"))
+          .build())
+      .build()
+
+  def getVectorAssembler: proto.MlOperator.Builder =
+    proto.MlOperator
+      .newBuilder()
+      .setUid("vec")
+      .setName("org.apache.spark.ml.feature.VectorAssembler")
+      .setType(proto.MlOperator.OperatorType.TRANSFORMER)
+
+  def getVectorAssemblerParams: proto.MlParams.Builder =
+    proto.MlParams
+      .newBuilder()
+      .putParams("handleInvalid", 
proto.Expression.Literal.newBuilder().setString("skip").build())
+      .putParams("outputCol", 
proto.Expression.Literal.newBuilder().setString("features").build())
+      .putParams("inputCols", getArrayStrings)
+
+  def readWrite(
+      sessionHolder: SessionHolder,
+      operator: proto.MlOperator.Builder,
+      params: proto.MlParams.Builder): proto.MlCommandResult = {
+    // read/write
+    val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+    try {
+      val path = new File(tempDir, Identifiable.randomUID("test")).getPath
+      val writeCmd = proto.MlCommand
+        .newBuilder()
+        .setWrite(
+          proto.MlCommand.Write
+            .newBuilder()
+            .setOperator(operator)
+            .setParams(params)
+            .setPath(path)
+            .setShouldOverwrite(true))
+        .build()
+      MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+      val readCmd = proto.MlCommand
+        .newBuilder()
+        .setRead(
+          proto.MlCommand.Read
+            .newBuilder()
+            .setOperator(operator)
+            .setPath(path))
+        .build()
+
+      MLHandler.handleMlCommand(sessionHolder, readCmd)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
+  def readWrite(
+      sessionHolder: SessionHolder,
+      modelId: String,
+      clsName: String): proto.MlCommandResult = {
+    val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+    try {
+      val path = new File(tempDir, Identifiable.randomUID("test")).getPath
+      val writeCmd = proto.MlCommand
+        .newBuilder()
+        .setWrite(
+          proto.MlCommand.Write
+            .newBuilder()
+            .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
+            .setPath(path)
+            .setShouldOverwrite(true))
+        .build()
+      MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+      val readCmd = proto.MlCommand
+        .newBuilder()
+        .setRead(
+          proto.MlCommand.Read
+            .newBuilder()
+            .setOperator(
+              proto.MlOperator
+                .newBuilder()
+                .setName(clsName)
+                .setType(proto.MlOperator.OperatorType.MODEL))
+            .setPath(path))
+        .build()
+
+      MLHandler.handleMlCommand(sessionHolder, readCmd)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+
 }
 
 class MyMlBackend extends MLBackendPlugin {
@@ -108,6 +241,8 @@ class MyMlBackend extends MLBackendPlugin {
         
Optional.of("org.apache.spark.sql.connect.ml.MyLogisticRegressionModel")
       case "org.apache.spark.ml.evaluation.RegressionEvaluator" =>
         Optional.of("org.apache.spark.sql.connect.ml.MyRegressionEvaluator")
+      case "org.apache.spark.ml.feature.VectorAssembler" =>
+        Optional.of("org.apache.spark.sql.connect.ml.MyVectorAssembler")
       case _ => Optional.empty()
     }
   }
@@ -117,6 +252,25 @@ trait HasFakedParam extends Params {
   final val fakeParam: IntParam = new IntParam(this, "fakeParam", "faked 
parameter")
 }
 
+class MyVectorAssembler(override val uid: String)
+    extends Transformer
+    with HasInputCols
+    with HasOutputCol
+    with HasHandleInvalid
+    with HasFakedParam
+    with DefaultParamsWritable {
+  set(fakeParam, 101010)
+  private[spark] def this() = this(Identifiable.randomUID("MyVectorAssembler"))
+  override def transform(dataset: Dataset[_]): DataFrame =
+    dataset.withColumn("new", lit(1))
+  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
+  override def transformSchema(schema: StructType): StructType = schema
+}
+
+object MyVectorAssembler extends DefaultParamsReadable[MyVectorAssembler] {
+  override def load(path: String): MyVectorAssembler = super.load(path)
+}
+
 class MyRegressionEvaluator(override val uid: String)
     extends Evaluator
     with DefaultParamsWritable
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index aee0759d0d3a..c3ab6248be8f 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -17,16 +17,15 @@
 
 package org.apache.spark.sql.connect.ml
 
-import java.io.File
+import scala.jdk.CollectionConverters.ListHasAsScala
 
 import org.apache.spark.connect.proto
 import org.apache.spark.ml.classification.LogisticRegressionModel
-import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.connect.SparkConnectTestUtils
 import org.apache.spark.sql.connect.service.SessionHolder
-import org.apache.spark.util.Utils
 
 trait FakeArrayParams extends Params {
   final val arrayString: StringArrayParam =
@@ -76,21 +75,7 @@ class MLSuite extends MLHelper {
       .putParams("double", 
proto.Expression.Literal.newBuilder().setDouble(1.0).build())
       .putParams("int", 
proto.Expression.Literal.newBuilder().setInteger(10).build())
       .putParams("float", 
proto.Expression.Literal.newBuilder().setFloat(10.0f).build())
-      .putParams(
-        "arrayString",
-        proto.Expression.Literal
-          .newBuilder()
-          .setArray(
-            proto.Expression.Literal.Array
-              .newBuilder()
-              .setElementType(proto.DataType
-                .newBuilder()
-                .setString(proto.DataType.String.getDefaultInstance)
-                .build())
-              
.addElements(proto.Expression.Literal.newBuilder().setString("hello"))
-              
.addElements(proto.Expression.Literal.newBuilder().setString("world"))
-              .build())
-          .build())
+      .putParams("arrayString", getArrayStrings)
       .putParams(
         "arrayInt",
         proto.Expression.Literal
@@ -127,7 +112,7 @@ class MLSuite extends MLHelper {
     assert(fakedML.getFloat === 10.0)
     assert(fakedML.getArrayInt === Array(1, 2))
     assert(fakedML.getArrayDouble === Array(11.0, 12.0))
-    assert(fakedML.getArrayString === Array("hello", "world"))
+    assert(fakedML.getArrayString === Array("a", "b", "c"))
     assert(fakedML.getBoolean === true)
     assert(fakedML.getDouble === 1.0)
   }
@@ -139,29 +124,21 @@ class MLSuite extends MLHelper {
         proto.MlCommand.Fit
           .newBuilder()
           .setDataset(createLocalRelationProto)
-          .setEstimator(
-            proto.MlOperator
-              .newBuilder()
-              .setName("org.apache.spark.ml.classification.LogisticRegression")
-              .setUid("LogisticRegression")
-              .setType(proto.MlOperator.OperatorType.ESTIMATOR))
-          .setParams(
-            proto.MlParams
-              .newBuilder()
-              .putParams(
-                "maxIter",
-                proto.Expression.Literal
-                  .newBuilder()
-                  .setInteger(2)
-                  .build())))
+          .setEstimator(getLogisticRegression)
+          .setParams(getMaxIter))
       .build()
     val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand)
     fitResult.getOperatorInfo.getObjRef.getId
   }
 
+  // Estimator/Model works
   test("LogisticRegression works") {
     val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
 
+    // estimator read/write
+    val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter)
+    
assert(ret.getOperatorInfo.getParams.getParamsMap.get("maxIter").getInteger == 
2)
+
     def verifyModel(modelId: String, hasSummary: Boolean = false): Unit = {
       val model = sessionHolder.mlCache.get(modelId)
       // Model is cached
@@ -248,48 +225,15 @@ class MLSuite extends MLHelper {
       }
     }
 
-    try {
-      val modelId = trainLogisticRegressionModel(sessionHolder)
-
-      verifyModel(modelId, true)
-
-      // read/write
-      val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
-      try {
-        val path = new File(tempDir, 
Identifiable.randomUID("LogisticRegression")).getPath
-        val writeCmd = proto.MlCommand
-          .newBuilder()
-          .setWrite(
-            proto.MlCommand.Write
-              .newBuilder()
-              .setPath(path)
-              .setObjRef(proto.ObjectRef.newBuilder().setId(modelId)))
-          .build()
-        MLHandler.handleMlCommand(sessionHolder, writeCmd)
-
-        val readCmd = proto.MlCommand
-          .newBuilder()
-          .setRead(
-            proto.MlCommand.Read
-              .newBuilder()
-              .setOperator(
-                proto.MlOperator
-                  .newBuilder()
-                  
.setName("org.apache.spark.ml.classification.LogisticRegressionModel")
-                  .setType(proto.MlOperator.OperatorType.MODEL))
-              .setPath(path))
-          .build()
-
-        val readResult = MLHandler.handleMlCommand(sessionHolder, readCmd)
-        verifyModel(readResult.getOperatorInfo.getObjRef.getId)
-
-      } finally {
-        Utils.deleteRecursively(tempDir)
-      }
-
-    } finally {
-      sessionHolder.mlCache.clear()
-    }
+    val modelId = trainLogisticRegressionModel(sessionHolder)
+    verifyModel(modelId, hasSummary = true)
+
+    // model read/write
+    val ret1 = readWrite(
+      sessionHolder,
+      modelId,
+      "org.apache.spark.ml.classification.LogisticRegressionModel")
+    verifyModel(ret1.getOperatorInfo.getObjRef.getId)
   }
 
   test("Exception: Unsupported ML operator") {
@@ -365,37 +309,42 @@ class MLSuite extends MLHelper {
       evalResult.getParam.getDouble > 2.841 &&
         evalResult.getParam.getDouble < 2.843)
 
-    // read/write
-    val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
-    try {
-      val path = new File(tempDir, 
Identifiable.randomUID("RegressionEvaluator")).getPath
-      val writeCmd = proto.MlCommand
-        .newBuilder()
-        .setWrite(
-          proto.MlCommand.Write
-            .newBuilder()
-            .setOperator(getRegressorEvaluator)
-            .setParams(getMetricName)
-            .setPath(path)
-            .setShouldOverwrite(true))
-        .build()
-      MLHandler.handleMlCommand(sessionHolder, writeCmd)
+    val ret = readWrite(sessionHolder, getRegressorEvaluator, getMetricName)
+    assert(
+      ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString ==
+        "mae")
+  }
 
-      val readCmd = proto.MlCommand
-        .newBuilder()
-        .setRead(
-          proto.MlCommand.Read
-            .newBuilder()
-            .setOperator(getRegressorEvaluator)
-            .setPath(path))
-        .build()
+  // Transformer works
+  test("VectorAssembler works") {
+    val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
 
-      val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
-      assert(
-        ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString 
==
-          "mae")
-    } finally {
-      Utils.deleteRecursively(tempDir)
-    }
+    val transformerRelation = proto.MlRelation
+      .newBuilder()
+      .setTransform(
+        proto.MlRelation.Transform
+          .newBuilder()
+          .setTransformer(getVectorAssembler)
+          .setParams(getVectorAssemblerParams)
+          .setInput(createMultiColumnLocalRelationProto))
+      .build()
+
+    val transRet = MLHandler.transformMLRelation(transformerRelation, 
sessionHolder)
+    Seq("a", "b", "c", "features").foreach(n => 
assert(transRet.schema.names.contains(n)))
+    assert(transRet.schema("features").dataType.isInstanceOf[VectorUDT])
+    val rows = transRet.collect()
+    assert(rows.mkString(",") === "[1,0,3,[1.0,0.0,3.0]]")
+
+    val ret = readWrite(sessionHolder, getVectorAssembler, 
getVectorAssemblerParams)
+    
assert(ret.getOperatorInfo.getParams.getParamsMap.get("outputCol").getString == 
"features")
+    
assert(ret.getOperatorInfo.getParams.getParamsMap.get("handleInvalid").getString
 == "skip")
+    assert(
+      ret.getOperatorInfo.getParams.getParamsMap
+        .get("inputCols")
+        .getArray
+        .getElementsList
+        .asScala
+        .map(_.getString)
+        .toArray sameElements Array("a", "b", "c"))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to