This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ef363b636b2f [SPARK-50869][ML][CONNECT][PYTHON] Support evaluators on 
ML Connet
ef363b636b2f is described below

commit ef363b636b2f2d66e7342a82f2a924e1857beb63
Author: Bobby Wang <[email protected]>
AuthorDate: Sat Jan 18 10:23:05 2025 +0800

    [SPARK-50869][ML][CONNECT][PYTHON] Support evaluators on ML Connet
    
    ### What changes were proposed in this pull request?
    
    This PR adds support Evaluator on ML Connect:
    
    - org.apache.spark.ml.evaluation.RegressionEvaluator
    - org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    - org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    - org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
    - org.apache.spark.ml.evaluation.ClusteringEvaluator
    - org.apache.spark.ml.evaluation.RankingEvaluator
    
    ### Why are the changes needed?
    for parity with spark classic
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, new evaluators supported on ML connect
    
    ### How was this patch tested?
    The newly added tests can pass
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49547 from wbo4958/evaluator.ml.connect.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 .../org.apache.spark.ml.evaluation.Evaluator       |  26 ++
 python/pyspark/ml/connect/readwrite.py             |  17 ++
 python/pyspark/ml/evaluation.py                    |   3 +-
 .../ml/tests/connect/test_parity_evaluation.py     |  49 ++++
 python/pyspark/ml/tests/test_evaluation.py         | 307 +++++++++++++++++++--
 python/pyspark/ml/util.py                          |  32 +++
 python/pyspark/ml/wrapper.py                       |   2 +-
 python/pyspark/sql/connect/proto/ml_pb2.py         |  34 +--
 python/pyspark/sql/connect/proto/ml_pb2.pyi        |  49 +++-
 .../src/main/protobuf/spark/connect/ml.proto       |  11 +
 .../apache/spark/sql/connect/ml/MLHandler.scala    |  62 +++--
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  25 ++
 .../org.apache.spark.ml.evaluation.Evaluator       |  21 ++
 .../spark/sql/connect/ml/MLBackendSuite.scala      |  86 +++++-
 .../org/apache/spark/sql/connect/ml/MLHelper.scala |  73 ++++-
 .../org/apache/spark/sql/connect/ml/MLSuite.scala  |  67 ++++-
 17 files changed, 778 insertions(+), 87 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f41767c75324..cb5cf0a85a4b 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1117,6 +1117,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_connect_pipeline",
         "pyspark.ml.tests.connect.test_connect_tuning",
         "pyspark.ml.tests.connect.test_parity_classification",
+        "pyspark.ml.tests.connect.test_parity_evaluation",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
new file mode 100644
index 000000000000..1f347edb23ed
--- /dev/null
+++ 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
@@ -0,0 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
evaluators.
+# So register the supported evaluator here if you're trying to add a new one.
+
+org.apache.spark.ml.evaluation.RegressionEvaluator
+org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
+org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
+org.apache.spark.ml.evaluation.ClusteringEvaluator
+org.apache.spark.ml.evaluation.RankingEvaluator
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index feb92e0a36f8..1f514c653aa0 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -38,6 +38,7 @@ class RemoteMLWriter(MLWriter):
 
     def save(self, path: str) -> None:
         from pyspark.ml.wrapper import JavaModel, JavaEstimator
+        from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.sql.connect.session import SparkSession
 
         session = SparkSession.getActiveSession()
@@ -69,6 +70,19 @@ class RemoteMLWriter(MLWriter):
                 should_overwrite=self.shouldOverwrite,
                 options=self.optionMap,
             )
+        elif isinstance(self._instance, JavaEvaluator):
+            evaluator = cast("JavaEvaluator", self._instance)
+            params = serialize_ml_params(evaluator, session.client)
+            assert isinstance(evaluator._java_obj, str)
+            writer = pb2.MlCommand.Write(
+                operator=pb2.MlOperator(
+                    name=evaluator._java_obj, uid=evaluator.uid, 
type=pb2.MlOperator.EVALUATOR
+                ),
+                params=params,
+                path=path,
+                should_overwrite=self.shouldOverwrite,
+                options=self.optionMap,
+            )
         else:
             raise NotImplementedError(f"Unsupported writing for 
{self._instance}")
 
@@ -85,6 +99,7 @@ class RemoteMLReader(MLReader[RL]):
     def load(self, path: str) -> RL:
         from pyspark.sql.connect.session import SparkSession
         from pyspark.ml.wrapper import JavaModel, JavaEstimator
+        from pyspark.ml.evaluation import JavaEvaluator
 
         session = SparkSession.getActiveSession()
         assert session is not None
@@ -99,6 +114,8 @@ class RemoteMLReader(MLReader[RL]):
             ml_type = pb2.MlOperator.MODEL
         elif issubclass(self._clazz, JavaEstimator):
             ml_type = pb2.MlOperator.ESTIMATOR
+        elif issubclass(self._clazz, JavaEvaluator):
+            ml_type = pb2.MlOperator.EVALUATOR
         else:
             raise ValueError(f"Unsupported reading for 
{java_qualified_class_name}")
 
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c6445c7f0241..b2b2d32c31f0 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -31,7 +31,7 @@ from pyspark.ml.param.shared import (
     HasWeightCol,
 )
 from pyspark.ml.common import inherit_doc
-from pyspark.ml.util import JavaMLReadable, JavaMLWritable
+from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_evaluate
 from pyspark.sql.dataframe import DataFrame
 
 if TYPE_CHECKING:
@@ -128,6 +128,7 @@ class JavaEvaluator(JavaParams, Evaluator, 
metaclass=ABCMeta):
     implementations.
     """
 
+    @try_remote_evaluate
     def _evaluate(self, dataset: DataFrame) -> float:
         """
         Evaluates the output.
diff --git a/python/pyspark/ml/tests/connect/test_parity_evaluation.py 
b/python/pyspark/ml/tests/connect/test_parity_evaluation.py
new file mode 100644
index 000000000000..9f78313a318e
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_evaluation.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin
+from pyspark.sql import SparkSession
+
+
+class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote(
+            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+        ).getOrCreate()
+
+    def test_assert_remote_mode(self):
+        from pyspark.sql import is_remote
+
+        self.assertTrue(is_remote())
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_evaluation import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_evaluation.py 
b/python/pyspark/ml/tests/test_evaluation.py
index 3c5ae3fbe7d1..39e88fa47057 100644
--- a/python/pyspark/ml/tests/test_evaluation.py
+++ b/python/pyspark/ml/tests/test_evaluation.py
@@ -14,18 +14,298 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import tempfile
 import unittest
 
 import numpy as np
 
-from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator
+from pyspark.ml.evaluation import (
+    ClusteringEvaluator,
+    RegressionEvaluator,
+    BinaryClassificationEvaluator,
+    MulticlassClassificationEvaluator,
+    MultilabelClassificationEvaluator,
+    RankingEvaluator,
+)
 from pyspark.ml.linalg import Vectors
-from pyspark.sql import Row
-from pyspark.testing.mlutils import SparkSessionTestCase
+from pyspark.sql import Row, SparkSession
+
+
+class EvaluatorTestsMixin:
+    def test_ranking_evaluator(self):
+        scoreAndLabels = [
+            ([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0], [1.0, 2.0, 
3.0, 4.0, 5.0]),
+            ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 
3.0]),
+            ([1.0, 2.0, 3.0, 4.0, 5.0], []),
+        ]
+        dataset = self.spark.createDataFrame(scoreAndLabels, ["prediction", 
"label"])
+
+        # Initialize RankingEvaluator
+        evaluator = RankingEvaluator().setPredictionCol("prediction")
+
+        # Evaluate the dataset using the default metric (mean average 
precision)
+        mean_average_precision = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(mean_average_precision, 0.3550, atol=1e-4))
+
+        # Evaluate the dataset using precisionAtK for k=2
+        precision_at_k = evaluator.evaluate(
+            dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2}
+        )
+        self.assertTrue(np.allclose(precision_at_k, 0.3333, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            evaluator.write().overwrite().save(tmp_dir)
+            # Load the saved evaluator
+            evaluator2 = RankingEvaluator.load(tmp_dir)
+            self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+            self.assertEqual(str(evaluator), str(evaluator2))
+
+    def test_multilabel_classification_evaluator(self):
+        dataset = self.spark.createDataFrame(
+            [
+                ([0.0, 1.0], [0.0, 2.0]),
+                ([0.0, 2.0], [0.0, 1.0]),
+                ([], [0.0]),
+                ([2.0], [2.0]),
+                ([2.0, 0.0], [2.0, 0.0]),
+                ([0.0, 1.0, 2.0], [0.0, 1.0]),
+                ([1.0], [1.0, 2.0]),
+            ],
+            ["prediction", "label"],
+        )
+
+        evaluator = 
MultilabelClassificationEvaluator().setPredictionCol("prediction")
+
+        # Evaluate the dataset using the default metric (f1 measure by default)
+        f1_score = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(f1_score, 0.6380, atol=1e-4))
+        # Evaluate the dataset using accuracy
+        accuracy = evaluator.evaluate(dataset, {evaluator.metricName: 
"accuracy"})
+        self.assertTrue(np.allclose(accuracy, 0.5476, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            evaluator.write().overwrite().save(tmp_dir)
+            # Load the saved evaluator
+            evaluator2 = MultilabelClassificationEvaluator.load(tmp_dir)
+            self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+            self.assertEqual(str(evaluator), str(evaluator2))
+
+    def test_multiclass_classification_evaluator(self):
+        dataset = self.spark.createDataFrame(
+            [
+                (0.0, 0.0, 1.0, [0.1, 0.8, 0.1]),
+                (0.0, 1.0, 1.0, [0.3, 0.4, 0.3]),
+                (0.0, 0.0, 1.0, [0.9, 0.05, 0.05]),
+                (1.0, 0.0, 1.0, [0.5, 0.2, 0.3]),
+                (1.0, 1.0, 1.0, [0.2, 0.7, 0.1]),
+                (1.0, 1.0, 1.0, [0.1, 0.3, 0.6]),
+                (1.0, 1.0, 1.0, [0.2, 0.1, 0.7]),
+                (2.0, 2.0, 1.0, [0.3, 0.2, 0.5]),
+                (2.0, 0.0, 1.0, [0.6, 0.2, 0.2]),
+            ],
+            ["prediction", "label", "weight", "probability"],
+        )
+
+        evaluator = 
MulticlassClassificationEvaluator().setPredictionCol("prediction")
+
+        f1_score = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(f1_score, 0.6613, atol=1e-4))
+
+        # Evaluate the dataset using accuracy
+        accuracy = evaluator.evaluate(dataset, {evaluator.metricName: 
"accuracy"})
+        self.assertTrue(np.allclose(accuracy, 0.6666, atol=1e-4))
+
+        # Evaluate the true positive rate for label 1.0
+        true_positive_rate_label_1 = evaluator.evaluate(
+            dataset, {evaluator.metricName: "truePositiveRateByLabel", 
evaluator.metricLabel: 1.0}
+        )
+        self.assertEqual(true_positive_rate_label_1, 0.75)
+
+        # Set the metric to Hamming loss
+        evaluator.setMetricName("hammingLoss")
+
+        # Evaluate the dataset using Hamming loss
+        hamming_loss = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(hamming_loss, 0.3333, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            evaluator.write().overwrite().save(tmp_dir)
+            # Load the saved evaluator
+            evaluator2 = MulticlassClassificationEvaluator.load(tmp_dir)
+            self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+            self.assertEqual(str(evaluator), str(evaluator2))
+
+        # Initialize MulticlassClassificationEvaluator with weight column
+        evaluator = MulticlassClassificationEvaluator(
+            predictionCol="prediction", weightCol="weight"
+        )
+
+        # Evaluate the dataset with weights using default metric (f1 score)
+        weighted_f1_score = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(weighted_f1_score, 0.6613, atol=1e-4))
+
+        # Evaluate the dataset with weights using accuracy
+        weighted_accuracy = evaluator.evaluate(dataset, {evaluator.metricName: 
"accuracy"})
+        self.assertTrue(np.allclose(weighted_accuracy, 0.6666, atol=1e-4))
+
+        evaluator = MulticlassClassificationEvaluator(
+            predictionCol="prediction", probabilityCol="probability"
+        )
+        # Set the metric to log loss
+        evaluator.setMetricName("logLoss")
+        # Evaluate the dataset using log loss
+        log_loss = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(log_loss, 1.0093, atol=1e-4))
+
+    def test_binary_classification_evaluator(self):
+        # Define score and labels data
+        data = map(
+            lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
+            [
+                (0.1, 0.0, 1.0),
+                (0.1, 1.0, 0.9),
+                (0.4, 0.0, 0.7),
+                (0.6, 0.0, 0.9),
+                (0.6, 1.0, 1.0),
+                (0.6, 1.0, 0.3),
+                (0.8, 1.0, 1.0),
+            ],
+        )
+        dataset = self.spark.createDataFrame(data, ["raw", "label", "weight"])
+
+        evaluator = BinaryClassificationEvaluator().setRawPredictionCol("raw")
+        auc_roc = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(auc_roc, 0.7083, atol=1e-4))
+
+        # Evaluate the dataset using the areaUnderPR metric
+        auc_pr = evaluator.evaluate(dataset, {evaluator.metricName: 
"areaUnderPR"})
+        self.assertTrue(np.allclose(auc_pr, 0.8339, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            evaluator.write().overwrite().save(tmp_dir)
+            # Load the saved evaluator
+            evaluator2 = BinaryClassificationEvaluator.load(tmp_dir)
+            self.assertEqual(evaluator2.getRawPredictionCol(), "raw")
+            self.assertEqual(str(evaluator), str(evaluator2))
+
+        evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw", 
weightCol="weight")
+
+        # Evaluate the dataset with weights using the default metric 
(areaUnderROC)
+        auc_roc_weighted = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(auc_roc_weighted, 0.7025, atol=1e-4))
+
+        # Evaluate the dataset with weights using the areaUnderPR metric
+        auc_pr_weighted = evaluator.evaluate(dataset, {evaluator.metricName: 
"areaUnderPR"})
+        self.assertTrue(np.allclose(auc_pr_weighted, 0.8221, atol=1e-4))
+
+        # Get the number of bins used to compute areaUnderROC
+        num_bins = evaluator.getNumBins()
+        self.assertEqual(num_bins, 1000)
+
+    def test_clustering_evaluator(self):
+        # Define feature and predictions data
+        data = map(
+            lambda x: (Vectors.dense(x[0]), x[1], x[2]),
+            [
+                ([0.0, 0.5], 0.0, 2.5),
+                ([0.5, 0.0], 0.0, 2.5),
+                ([10.0, 11.0], 1.0, 2.5),
+                ([10.5, 11.5], 1.0, 2.5),
+                ([1.0, 1.0], 0.0, 2.5),
+                ([8.0, 6.0], 1.0, 2.5),
+            ],
+        )
+        dataset = self.spark.createDataFrame(data, ["features", "prediction", 
"weight"])
 
+        evaluator = ClusteringEvaluator().setPredictionCol("prediction")
+        score = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(score, 0.9079, atol=1e-4))
+
+        evaluator.setWeightCol("weight")
+
+        # Evaluate the dataset with weights
+        score_with_weight = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(score_with_weight, 0.9079, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            evaluator.write().overwrite().save(tmp_dir)
+            # Load the saved evaluator
+            evaluator2 = ClusteringEvaluator.load(tmp_dir)
+            self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+
+    def test_clustering_evaluator_with_cosine_distance(self):
+        featureAndPredictions = map(
+            lambda x: (Vectors.dense(x[0]), x[1]),
+            [
+                ([1.0, 1.0], 1.0),
+                ([10.0, 10.0], 1.0),
+                ([1.0, 0.5], 2.0),
+                ([10.0, 4.4], 2.0),
+                ([-1.0, 1.0], 3.0),
+                ([-100.0, 90.0], 3.0),
+            ],
+        )
+        dataset = self.spark.createDataFrame(featureAndPredictions, 
["features", "prediction"])
+        evaluator = ClusteringEvaluator(predictionCol="prediction", 
distanceMeasure="cosine")
+        self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
+        self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, 
atol=1e-5))
+
+    def test_regression_evaluator(self):
+        dataset = self.spark.createDataFrame(
+            [
+                (-28.98343821, -27.0, 1.0),
+                (20.21491975, 21.5, 0.8),
+                (-25.98418959, -22.0, 1.0),
+                (30.69731842, 33.0, 0.6),
+                (74.69283752, 71.0, 0.2),
+            ],
+            ["raw", "label", "weight"],
+        )
+
+        evaluator = RegressionEvaluator()
+        evaluator.setPredictionCol("raw")
+
+        # Evaluate dataset with default metric (RMSE)
+        rmse = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(rmse, 2.8424, atol=1e-4))
+        # Evaluate dataset with R2 metric
+        r2 = evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
+        self.assertTrue(np.allclose(r2, 0.9939, atol=1e-4))
+        # Evaluate dataset with MAE metric
+        mae = evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
+        self.assertTrue(np.allclose(mae, 2.6496, atol=1e-4))
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            evaluator.write().overwrite().save(tmp_dir)
+            # Load the saved evaluator
+            evaluator2 = RegressionEvaluator.load(tmp_dir)
+            self.assertEqual(evaluator2.getPredictionCol(), "raw")
+
+        evaluator_with_weights = RegressionEvaluator(predictionCol="raw", 
weightCol="weight")
+        weighted_rmse = evaluator_with_weights.evaluate(dataset)
+        self.assertTrue(np.allclose(weighted_rmse, 2.7405, atol=1e-4))
+        through_origin = evaluator_with_weights.getThroughOrigin()
+        self.assertEqual(through_origin, False)
+
+
+class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
 
-class EvaluatorTests(SparkSessionTestCase):
     def test_evaluate_invalid_type(self):
         evaluator = RegressionEvaluator(metricName="r2")
         df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
@@ -47,23 +327,6 @@ class EvaluatorTests(SparkSessionTestCase):
         self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
         self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae")
 
-    def test_clustering_evaluator_with_cosine_distance(self):
-        featureAndPredictions = map(
-            lambda x: (Vectors.dense(x[0]), x[1]),
-            [
-                ([1.0, 1.0], 1.0),
-                ([10.0, 10.0], 1.0),
-                ([1.0, 0.5], 2.0),
-                ([10.0, 4.4], 2.0),
-                ([-1.0, 1.0], 3.0),
-                ([-100.0, 90.0], 3.0),
-            ],
-        )
-        dataset = self.spark.createDataFrame(featureAndPredictions, 
["features", "prediction"])
-        evaluator = ClusteringEvaluator(predictionCol="prediction", 
distanceMeasure="cosine")
-        self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
-        self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213, 
atol=1e-5))
-
 
 if __name__ == "__main__":
     from pyspark.ml.tests.test_evaluation import *  # noqa: F401
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d161db04fad2..74a07ec365b3 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -49,6 +49,7 @@ if TYPE_CHECKING:
     from pyspark.core.context import SparkContext
     from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
     from pyspark.ml.wrapper import JavaWrapper, JavaEstimator
+    from pyspark.ml.evaluation import JavaEvaluator
 
 T = TypeVar("T")
 RW = TypeVar("RW", bound="BaseReadWrite")
@@ -324,6 +325,37 @@ def try_remote_not_supporting(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def try_remote_evaluate(f: FuncT) -> FuncT:
+    """Mark the evaluate function in Evaluator."""
+
+    @functools.wraps(f)
+    def wrapped(self: "JavaEvaluator", dataset: "ConnectDataFrame") -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            import pyspark.sql.connect.proto as pb2
+            from pyspark.ml.connect.serialize import serialize_ml_params, 
deserialize
+
+            client = dataset.sparkSession.client
+            input = dataset._plan.plan(client)
+            assert isinstance(self._java_obj, str)
+            evaluator = pb2.MlOperator(
+                name=self._java_obj, uid=self.uid, 
type=pb2.MlOperator.EVALUATOR
+            )
+            command = pb2.Command()
+            command.ml_command.evaluate.CopyFrom(
+                pb2.MlCommand.Evaluate(
+                    evaluator=evaluator,
+                    params=serialize_ml_params(self, client),
+                    dataset=input,
+                )
+            )
+            (_, properties, _) = client.execute_command(command)
+            return deserialize(properties)
+        else:
+            return f(self, dataset)
+
+    return cast(FuncT, wrapped)
+
+
 def _jvm() -> "JavaGateway":
     """
     Returns the JVM view associated with SparkContext. Must be called
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7b5c4f783a65..8fb38fc7ce18 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -353,7 +353,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
         if extra is None:
             extra = dict()
         that = super(JavaParams, self).copy(extra)
-        if self._java_obj is not None:
+        if self._java_obj is not None and not isinstance(self._java_obj, str):
             that._java_obj = self._java_obj.copy(self._empty_java_param_map())
             that._transfer_params_to_java()
         return that
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py 
b/python/pyspark/sql/connect/proto/ml_pb2.py
index c29d33db547c..8e8bc34a7a97 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as 
spark_dot_connect_dot_ml_
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xc6\x07\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
 [...]
+    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
 [...]
 )
 
 _globals = globals()
@@ -54,19 +54,21 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
     _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
     _globals["_MLCOMMAND"]._serialized_start = 137
-    _globals["_MLCOMMAND"]._serialized_end = 1103
-    _globals["_MLCOMMAND_FIT"]._serialized_start = 415
-    _globals["_MLCOMMAND_FIT"]._serialized_end = 577
-    _globals["_MLCOMMAND_DELETE"]._serialized_start = 579
-    _globals["_MLCOMMAND_DELETE"]._serialized_end = 638
-    _globals["_MLCOMMAND_WRITE"]._serialized_start = 641
-    _globals["_MLCOMMAND_WRITE"]._serialized_end = 1009
-    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 943
-    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1001
-    _globals["_MLCOMMAND_READ"]._serialized_start = 1011
-    _globals["_MLCOMMAND_READ"]._serialized_end = 1092
-    _globals["_MLCOMMANDRESULT"]._serialized_start = 1106
-    _globals["_MLCOMMANDRESULT"]._serialized_end = 1480
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1299
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1465
+    _globals["_MLCOMMAND"]._serialized_end = 1338
+    _globals["_MLCOMMAND_FIT"]._serialized_start = 480
+    _globals["_MLCOMMAND_FIT"]._serialized_end = 642
+    _globals["_MLCOMMAND_DELETE"]._serialized_start = 644
+    _globals["_MLCOMMAND_DELETE"]._serialized_end = 703
+    _globals["_MLCOMMAND_WRITE"]._serialized_start = 706
+    _globals["_MLCOMMAND_WRITE"]._serialized_end = 1074
+    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1008
+    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1066
+    _globals["_MLCOMMAND_READ"]._serialized_start = 1076
+    _globals["_MLCOMMAND_READ"]._serialized_end = 1157
+    _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1160
+    _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1327
+    _globals["_MLCOMMANDRESULT"]._serialized_start = 1341
+    _globals["_MLCOMMANDRESULT"]._serialized_end = 1715
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1534
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1700
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi 
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 6b950e4c67bb..e8ae0be8dded 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -224,11 +224,49 @@ class MlCommand(google.protobuf.message.Message):
             self, field_name: typing_extensions.Literal["operator", 
b"operator", "path", b"path"]
         ) -> None: ...
 
+    class Evaluate(google.protobuf.message.Message):
+        """Command for evaluator.evaluate(dataset)"""
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        EVALUATOR_FIELD_NUMBER: builtins.int
+        PARAMS_FIELD_NUMBER: builtins.int
+        DATASET_FIELD_NUMBER: builtins.int
+        @property
+        def evaluator(self) -> 
pyspark.sql.connect.proto.ml_common_pb2.MlOperator:
+            """Evaluator information"""
+        @property
+        def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
+            """parameters of the Evaluator"""
+        @property
+        def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
+            """the evaluating dataset"""
+        def __init__(
+            self,
+            *,
+            evaluator: pyspark.sql.connect.proto.ml_common_pb2.MlOperator | 
None = ...,
+            params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None = 
...,
+            dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None = 
...,
+        ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "dataset", b"dataset", "evaluator", b"evaluator", "params", 
b"params"
+            ],
+        ) -> builtins.bool: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "dataset", b"dataset", "evaluator", b"evaluator", "params", 
b"params"
+            ],
+        ) -> None: ...
+
     FIT_FIELD_NUMBER: builtins.int
     FETCH_FIELD_NUMBER: builtins.int
     DELETE_FIELD_NUMBER: builtins.int
     WRITE_FIELD_NUMBER: builtins.int
     READ_FIELD_NUMBER: builtins.int
+    EVALUATE_FIELD_NUMBER: builtins.int
     @property
     def fit(self) -> global___MlCommand.Fit: ...
     @property
@@ -239,6 +277,8 @@ class MlCommand(google.protobuf.message.Message):
     def write(self) -> global___MlCommand.Write: ...
     @property
     def read(self) -> global___MlCommand.Read: ...
+    @property
+    def evaluate(self) -> global___MlCommand.Evaluate: ...
     def __init__(
         self,
         *,
@@ -247,6 +287,7 @@ class MlCommand(google.protobuf.message.Message):
         delete: global___MlCommand.Delete | None = ...,
         write: global___MlCommand.Write | None = ...,
         read: global___MlCommand.Read | None = ...,
+        evaluate: global___MlCommand.Evaluate | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -255,6 +296,8 @@ class MlCommand(google.protobuf.message.Message):
             b"command",
             "delete",
             b"delete",
+            "evaluate",
+            b"evaluate",
             "fetch",
             b"fetch",
             "fit",
@@ -272,6 +315,8 @@ class MlCommand(google.protobuf.message.Message):
             b"command",
             "delete",
             b"delete",
+            "evaluate",
+            b"evaluate",
             "fetch",
             b"fetch",
             "fit",
@@ -284,7 +329,9 @@ class MlCommand(google.protobuf.message.Message):
     ) -> None: ...
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["command", b"command"]
-    ) -> typing_extensions.Literal["fit", "fetch", "delete", "write", "read"] 
| None: ...
+    ) -> (
+        typing_extensions.Literal["fit", "fetch", "delete", "write", "read", 
"evaluate"] | None
+    ): ...
 
 global___MlCommand = MlCommand
 
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 3198bbbe9c11..20a5cafebb36 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -35,6 +35,7 @@ message MlCommand {
     Delete delete = 3;
     Write write = 4;
     Read read = 5;
+    Evaluate evaluate = 6;
   }
 
   // Command for estimator.fit(dataset)
@@ -79,6 +80,16 @@ message MlCommand {
     // Load the ML instance from the input path
     string path = 2;
   }
+
+  // Command for evaluator.evaluate(dataset)
+  message Evaluate {
+    // Evaluator information
+    MlOperator evaluator = 1;
+    // parameters of the Evaluator
+    MlParams params = 2;
+    // the evaluating dataset
+    Relation dataset = 3;
+  }
 }
 
 // The result of MlCommand
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index e53276dcf7df..cc89079aeca3 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -21,8 +21,8 @@ import scala.jdk.CollectionConverters.CollectionHasAsScala
 
 import org.apache.spark.connect.proto
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.{ParamMap, Params}
 import org.apache.spark.ml.util.{MLWritable, Summary}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
@@ -172,18 +172,29 @@ private[connect] object MLHandler extends Logging {
           // save an estimator/evaluator/transformer
           case proto.MlCommand.Write.TypeCase.OPERATOR =>
             val writer = mlCommand.getWrite
-            if (writer.getOperator.getType == 
proto.MlOperator.OperatorType.ESTIMATOR) {
-              val estimator =
-                MLUtils.getEstimator(sessionHolder, writer.getOperator, 
Some(writer.getParams))
-              estimator match {
-                case m: MLWritable => MLUtils.write(m, mlCommand.getWrite)
-                case other => throw MlUnsupportedException(s"Estimator $other 
is not writable")
-              }
-            } else {
-              throw MlUnsupportedException(s"${writer.getOperator.getName} not 
supported")
-            }
+            val operatorType = writer.getOperator.getType
+            val operatorName = writer.getOperator.getName
+            val params = Some(writer.getParams)
+
+            operatorType match {
+              case proto.MlOperator.OperatorType.ESTIMATOR =>
+                val estimator = MLUtils.getEstimator(sessionHolder, 
writer.getOperator, params)
+                estimator match {
+                  case writable: MLWritable => MLUtils.write(writable, 
mlCommand.getWrite)
+                  case other => throw MlUnsupportedException(s"Estimator 
$other is not writable")
+                }
+
+              case proto.MlOperator.OperatorType.EVALUATOR =>
+                val evaluator = MLUtils.getEvaluator(sessionHolder, 
writer.getOperator, params)
+                evaluator match {
+                  case writable: MLWritable => MLUtils.write(writable, 
mlCommand.getWrite)
+                  case other => throw MlUnsupportedException(s"Evaluator 
$other is not writable")
+                }
 
-          case other => throw MlUnsupportedException(s"$other not supported")
+              case _ =>
+                throw MlUnsupportedException(s"Operator $operatorName is not 
supported")
+            }
+          case other => throw MlUnsupportedException(s"$other write not 
supported")
         }
         proto.MlCommandResult.newBuilder().build()
 
@@ -205,21 +216,36 @@ private[connect] object MLHandler extends Logging {
                 .setParams(Serializer.serializeParams(model)))
             .build()
 
-        } else if (operator.getType == 
proto.MlOperator.OperatorType.ESTIMATOR) {
-          val estimator = MLUtils.load(sessionHolder, name, 
path).asInstanceOf[Estimator[_]]
+        } else if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR 
||
+          operator.getType == proto.MlOperator.OperatorType.EVALUATOR) {
+          val operator = MLUtils.load(sessionHolder, name, 
path).asInstanceOf[Params]
           proto.MlCommandResult
             .newBuilder()
             .setOperatorInfo(
               proto.MlCommandResult.MlOperatorInfo
                 .newBuilder()
                 .setName(name)
-                .setUid(estimator.uid)
-                .setParams(Serializer.serializeParams(estimator)))
+                .setUid(operator.uid)
+                .setParams(Serializer.serializeParams(operator)))
             .build()
         } else {
-          throw MlUnsupportedException(s"${operator.getType} not supported")
+          throw MlUnsupportedException(s"${operator.getType} read not 
supported")
         }
 
+      case proto.MlCommand.CommandCase.EVALUATE =>
+        val evalCmd = mlCommand.getEvaluate
+        val evalProto = evalCmd.getEvaluator
+        assert(evalProto.getType == proto.MlOperator.OperatorType.EVALUATOR)
+
+        val dataset = MLUtils.parseRelationProto(evalCmd.getDataset, 
sessionHolder)
+        val evaluator =
+          MLUtils.getEvaluator(sessionHolder, evalProto, 
Some(evalCmd.getParams))
+        val metric = evaluator.evaluate(dataset)
+        proto.MlCommandResult
+          .newBuilder()
+          .setParam(LiteralValueProtoConverter.toLiteralProto(metric))
+          .build()
+
       case other => throw MlUnsupportedException(s"$other not supported")
     }
   }
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 4dbfb063eabb..e6e78f15b61f 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -26,6 +26,7 @@ import 
org.apache.commons.lang3.reflect.MethodUtils.invokeMethod
 
 import org.apache.spark.connect.proto
 import org.apache.spark.ml.{Estimator, Transformer}
+import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.Params
 import org.apache.spark.ml.util.{MLReadable, MLWritable}
@@ -326,6 +327,30 @@ private[ml] object MLUtils {
     getInstance[Transformer](name, uid, transformers, Some(params))
   }
 
+  /**
+   * Get the Evaluator instance according to the proto information
+   *
+   * @param sessionHolder
+   *   session holder to hold the Spark Connect session state
+   * @param operator
+   *   MlOperator information
+   * @param params
+   *   The optional parameters of the evaluator
+   * @return
+   *   the evaluator
+   */
+  def getEvaluator(
+      sessionHolder: SessionHolder,
+      operator: proto.MlOperator,
+      params: Option[proto.MlParams]): Evaluator = {
+    val name = replaceOperator(sessionHolder, operator.getName)
+    val uid = operator.getUid
+
+    // Load the evaluators by ServiceLoader everytime
+    val evaluators = loadOperators(classOf[Evaluator])
+    getInstance[Evaluator](name, uid, evaluators, params)
+  }
+
   /**
    * Call "load" function on the ML operator given the operator name
    *
diff --git 
a/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
 
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
new file mode 100644
index 000000000000..502464a1424c
--- /dev/null
+++ 
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
evaluators.
+# So register the supported evaluator here if you're trying to add a new one.
+
+org.apache.spark.sql.connect.ml.MyRegressionEvaluator
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
index f4db4077d1f3..7cd95f9f657d 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
@@ -57,12 +57,28 @@ class MLBackendSuite extends MLHelper {
           .build())
   }
 
-  test("ML backend: estimator read/write") {
+  test("ML backend: estimator works") {
     withSparkConf(
       Connect.CONNECT_ML_BACKEND_CLASSES.key ->
         "org.apache.spark.sql.connect.ml.MyMlBackend") {
       val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
 
+      val fitCommand = proto.MlCommand
+        .newBuilder()
+        .setFit(
+          proto.MlCommand.Fit
+            .newBuilder()
+            .setDataset(createLocalRelationProto)
+            .setEstimator(getLogisticRegressionBuilder)
+            .setParams(getMaxIterBuilder))
+        .build()
+      val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand)
+      val modelId = fitResult.getOperatorInfo.getObjRef.getId
+      
assert(sessionHolder.mlCache.get(modelId).isInstanceOf[MyLogisticRegressionModel])
+      val model = 
sessionHolder.mlCache.get(modelId).asInstanceOf[MyLogisticRegressionModel]
+      assert(model.intercept == 3.5f)
+      assert(model.coefficients == 4.6f)
+
       // read/write
       val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
       try {
@@ -103,7 +119,7 @@ class MLBackendSuite extends MLHelper {
     }
   }
 
-  test("ML backend: model read/write") {
+  test("ML backend: model works") {
     withSparkConf(
       Connect.CONNECT_ML_BACKEND_CLASSES.key ->
         "org.apache.spark.sql.connect.ml.MyMlBackend") {
@@ -159,32 +175,72 @@ class MLBackendSuite extends MLHelper {
         assert(
           
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
             == 101010)
+
+        // Fetch double attribute
+        val fakeAttrCmd = fetchCommand(ret.getOperatorInfo.getObjRef.getId, 
"predictRaw")
+        val fakeAttrRet = MLHandler.handleMlCommand(sessionHolder, fakeAttrCmd)
+        assert(fakeAttrRet.getParam.getDouble === 1.11)
       } finally {
         Utils.deleteRecursively(tempDir)
       }
     }
   }
 
-  test("ML backend") {
+  test("ML backend: evaluator works") {
     withSparkConf(
       Connect.CONNECT_ML_BACKEND_CLASSES.key ->
         "org.apache.spark.sql.connect.ml.MyMlBackend") {
       val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
-      val fitCommand = proto.MlCommand
+
+      val evalCmd = proto.MlCommand
         .newBuilder()
-        .setFit(
-          proto.MlCommand.Fit
+        .setEvaluate(
+          proto.MlCommand.Evaluate
             .newBuilder()
-            .setDataset(createLocalRelationProto)
-            .setEstimator(getLogisticRegressionBuilder)
-            .setParams(getMaxIterBuilder))
+            .setDataset(createRegressionEvaluationLocalRelationProto)
+            .setEvaluator(getRegressorEvaluator))
         .build()
-      val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand)
-      val modelId = fitResult.getOperatorInfo.getObjRef.getId
-      
assert(sessionHolder.mlCache.get(modelId).isInstanceOf[MyLogisticRegressionModel])
-      val model = 
sessionHolder.mlCache.get(modelId).asInstanceOf[MyLogisticRegressionModel]
-      assert(model.intercept == 3.5f)
-      assert(model.coefficients == 4.6f)
+      val evalResult = MLHandler.handleMlCommand(sessionHolder, evalCmd)
+      assert(evalResult.getParam.getDouble == 1.11)
+
+      // read/write
+      val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+      try {
+        val path = new File(tempDir, 
Identifiable.randomUID("Evaluator")).getPath
+        val writeCmd = proto.MlCommand
+          .newBuilder()
+          .setWrite(
+            proto.MlCommand.Write
+              .newBuilder()
+              .setOperator(getRegressorEvaluator)
+              .setParams(getMetricName)
+              .setPath(path)
+              .setShouldOverwrite(true))
+          .build()
+        MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+        val readCmd = proto.MlCommand
+          .newBuilder()
+          .setRead(
+            proto.MlCommand.Read
+              .newBuilder()
+              .setOperator(getRegressorEvaluator)
+              .setPath(path))
+          .build()
+
+        val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
+        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
+        
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("metricName"))
+        assert(
+          
ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
+            == "mae")
+        assert(
+          
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
+            == 101010)
+      } finally {
+        Utils.deleteRecursively(tempDir)
+      }
     }
   }
+
 }
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
index c778603eeece..844e85fa03b6 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
@@ -22,8 +22,9 @@ import java.util.Optional
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.connect.proto
 import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
-import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
 import org.apache.spark.ml.param.shared.HasMaxIter
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable, MLReadable, MLReader}
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -32,7 +33,7 @@ import 
org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.connect.planner.SparkConnectPlanTest
 import org.apache.spark.sql.connect.plugin.MLBackendPlugin
-import org.apache.spark.sql.types.{FloatType, Metadata, StructField, 
StructType}
+import org.apache.spark.sql.types.{DoubleType, FloatType, Metadata, 
StructField, StructType}
 
 trait MLHelper extends SparkFunSuite with SparkConnectPlanTest {
 
@@ -55,6 +56,46 @@ trait MLHelper extends SparkFunSuite with 
SparkConnectPlanTest {
     }
     createLocalRelationProto(DataTypeUtils.toAttributes(schema), inputRows, 
"UTC", Some(schema))
   }
+
+  def createRegressionEvaluationLocalRelationProto: proto.Relation = {
+    // The test refers to
+    // 
https://github.com/apache/spark/blob/master/python/pyspark/ml/evaluation.py#L331
+    val rows = Seq(
+      InternalRow(-28.98343821, -27.0),
+      InternalRow(20.21491975, 21.5),
+      InternalRow(-25.98418959, -22.0),
+      InternalRow(30.69731842, 33.0),
+      InternalRow(74.69283752, 71.0))
+    val schema = StructType(Seq(StructField("raw", DoubleType), 
StructField("label", DoubleType)))
+    val inputRows = rows.map { row =>
+      val proj = UnsafeProjection.create(schema)
+      proj(row).copy()
+    }
+    createLocalRelationProto(schema, inputRows)
+  }
+
+  def getRegressorEvaluator: proto.MlOperator.Builder =
+    proto.MlOperator
+      .newBuilder()
+      .setName("org.apache.spark.ml.evaluation.RegressionEvaluator")
+      .setUid("RegressionEvaluator")
+      .setType(proto.MlOperator.OperatorType.EVALUATOR)
+
+  def getMetricName: proto.MlParams.Builder =
+    proto.MlParams
+      .newBuilder()
+      .putParams("metricName", 
proto.Expression.Literal.newBuilder().setString("mae").build())
+
+  def fetchCommand(modelId: String, method: String): proto.MlCommand = {
+    proto.MlCommand
+      .newBuilder()
+      .setFetch(
+        proto.Fetch
+          .newBuilder()
+          .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
+          .addMethods(proto.Fetch.Method.newBuilder().setMethod(method)))
+      .build()
+  }
 }
 
 class MyMlBackend extends MLBackendPlugin {
@@ -65,6 +106,8 @@ class MyMlBackend extends MLBackendPlugin {
         Optional.of("org.apache.spark.sql.connect.ml.MyLogisticRegression")
       case "org.apache.spark.ml.classification.LogisticRegressionModel" =>
         
Optional.of("org.apache.spark.sql.connect.ml.MyLogisticRegressionModel")
+      case "org.apache.spark.ml.evaluation.RegressionEvaluator" =>
+        Optional.of("org.apache.spark.sql.connect.ml.MyRegressionEvaluator")
       case _ => Optional.empty()
     }
   }
@@ -74,6 +117,29 @@ trait HasFakedParam extends Params {
   final val fakeParam: IntParam = new IntParam(this, "fakeParam", "faked 
parameter")
 }
 
+class MyRegressionEvaluator(override val uid: String)
+    extends Evaluator
+    with DefaultParamsWritable
+    with HasFakedParam {
+
+  def this() = this(Identifiable.randomUID("MyRegressionEvaluator"))
+
+  // keep same as RegressionEvaluator
+  val metricName: Param[String] = {
+    new Param(this, "metricName", "metric name in evaluation 
(mse|rmse|r2|mae|var)")
+  }
+
+  set(fakeParam, 101010)
+
+  override def evaluate(dataset: Dataset[_]): Double = 1.11
+
+  override def copy(extra: ParamMap): Evaluator = defaultCopy(extra)
+}
+
+object MyRegressionEvaluator extends 
DefaultParamsReadable[MyRegressionEvaluator] {
+  override def load(path: String): MyRegressionEvaluator = super.load(path)
+}
+
 class MyLogisticRegressionModel(
     override val uid: String,
     val intercept: Float,
@@ -95,6 +161,9 @@ class MyLogisticRegressionModel(
     dataset.toDF()
   }
 
+  // fake a function
+  def predictRaw: Double = 1.11
+
   override def transformSchema(schema: StructType): StructType = schema
 }
 
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index b537f1c702bd..49dcd7dbe9ad 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -132,17 +132,6 @@ class MLSuite extends MLHelper {
     assert(fakedML.getDouble === 1.0)
   }
 
-  def fetchCommand(modelId: String, method: String): proto.MlCommand = {
-    proto.MlCommand
-      .newBuilder()
-      .setFetch(
-        proto.Fetch
-          .newBuilder()
-          .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
-          .addMethods(proto.Fetch.Method.newBuilder().setMethod(method)))
-      .build()
-  }
-
   def trainLogisticRegressionModel(sessionHolder: SessionHolder): String = {
     val fitCommand = proto.MlCommand
       .newBuilder()
@@ -353,4 +342,60 @@ class MLSuite extends MLHelper {
       
thrown.message.contains("org.apache.spark.sql.connect.ml.NotImplementingMLReadble
 " +
         "must implement MLReadable"))
   }
+
+  test("RegressionEvaluator works") {
+    val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+    val evalCmd = proto.MlCommand
+      .newBuilder()
+      .setEvaluate(
+        proto.MlCommand.Evaluate
+          .newBuilder()
+          .setDataset(createRegressionEvaluationLocalRelationProto)
+          .setEvaluator(getRegressorEvaluator)
+          .setParams(
+            proto.MlParams
+              .newBuilder()
+              .putParams(
+                "predictionCol",
+                
proto.Expression.Literal.newBuilder().setString("raw").build())))
+      .build()
+    val evalResult = MLHandler.handleMlCommand(sessionHolder, evalCmd)
+    assert(
+      evalResult.getParam.getDouble > 2.841 &&
+        evalResult.getParam.getDouble < 2.843)
+
+    // read/write
+    val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+    try {
+      val path = new File(tempDir, 
Identifiable.randomUID("RegressionEvaluator")).getPath
+      val writeCmd = proto.MlCommand
+        .newBuilder()
+        .setWrite(
+          proto.MlCommand.Write
+            .newBuilder()
+            .setOperator(getRegressorEvaluator)
+            .setParams(getMetricName)
+            .setPath(path)
+            .setShouldOverwrite(true))
+        .build()
+      MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+      val readCmd = proto.MlCommand
+        .newBuilder()
+        .setRead(
+          proto.MlCommand.Read
+            .newBuilder()
+            .setOperator(getRegressorEvaluator)
+            .setPath(path))
+        .build()
+
+      val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
+      assert(
+        ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString 
==
+          "mae")
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to