This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ef363b636b2f [SPARK-50869][ML][CONNECT][PYTHON] Support evaluators on
ML Connet
ef363b636b2f is described below
commit ef363b636b2f2d66e7342a82f2a924e1857beb63
Author: Bobby Wang <[email protected]>
AuthorDate: Sat Jan 18 10:23:05 2025 +0800
[SPARK-50869][ML][CONNECT][PYTHON] Support evaluators on ML Connet
### What changes were proposed in this pull request?
This PR adds support Evaluator on ML Connect:
- org.apache.spark.ml.evaluation.RegressionEvaluator
- org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
- org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
- org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
- org.apache.spark.ml.evaluation.ClusteringEvaluator
- org.apache.spark.ml.evaluation.RankingEvaluator
### Why are the changes needed?
for parity with spark classic
### Does this PR introduce _any_ user-facing change?
Yes, new evaluators supported on ML connect
### How was this patch tested?
The newly added tests can pass
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49547 from wbo4958/evaluator.ml.connect.
Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
.../org.apache.spark.ml.evaluation.Evaluator | 26 ++
python/pyspark/ml/connect/readwrite.py | 17 ++
python/pyspark/ml/evaluation.py | 3 +-
.../ml/tests/connect/test_parity_evaluation.py | 49 ++++
python/pyspark/ml/tests/test_evaluation.py | 307 +++++++++++++++++++--
python/pyspark/ml/util.py | 32 +++
python/pyspark/ml/wrapper.py | 2 +-
python/pyspark/sql/connect/proto/ml_pb2.py | 34 +--
python/pyspark/sql/connect/proto/ml_pb2.pyi | 49 +++-
.../src/main/protobuf/spark/connect/ml.proto | 11 +
.../apache/spark/sql/connect/ml/MLHandler.scala | 62 +++--
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 25 ++
.../org.apache.spark.ml.evaluation.Evaluator | 21 ++
.../spark/sql/connect/ml/MLBackendSuite.scala | 86 +++++-
.../org/apache/spark/sql/connect/ml/MLHelper.scala | 73 ++++-
.../org/apache/spark/sql/connect/ml/MLSuite.scala | 67 ++++-
17 files changed, 778 insertions(+), 87 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f41767c75324..cb5cf0a85a4b 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1117,6 +1117,7 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_parity_classification",
+ "pyspark.ml.tests.connect.test_parity_evaluation",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
new file mode 100644
index 000000000000..1f347edb23ed
--- /dev/null
+++
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
@@ -0,0 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
evaluators.
+# So register the supported evaluator here if you're trying to add a new one.
+
+org.apache.spark.ml.evaluation.RegressionEvaluator
+org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
+org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
+org.apache.spark.ml.evaluation.ClusteringEvaluator
+org.apache.spark.ml.evaluation.RankingEvaluator
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index feb92e0a36f8..1f514c653aa0 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -38,6 +38,7 @@ class RemoteMLWriter(MLWriter):
def save(self, path: str) -> None:
from pyspark.ml.wrapper import JavaModel, JavaEstimator
+ from pyspark.ml.evaluation import JavaEvaluator
from pyspark.sql.connect.session import SparkSession
session = SparkSession.getActiveSession()
@@ -69,6 +70,19 @@ class RemoteMLWriter(MLWriter):
should_overwrite=self.shouldOverwrite,
options=self.optionMap,
)
+ elif isinstance(self._instance, JavaEvaluator):
+ evaluator = cast("JavaEvaluator", self._instance)
+ params = serialize_ml_params(evaluator, session.client)
+ assert isinstance(evaluator._java_obj, str)
+ writer = pb2.MlCommand.Write(
+ operator=pb2.MlOperator(
+ name=evaluator._java_obj, uid=evaluator.uid,
type=pb2.MlOperator.EVALUATOR
+ ),
+ params=params,
+ path=path,
+ should_overwrite=self.shouldOverwrite,
+ options=self.optionMap,
+ )
else:
raise NotImplementedError(f"Unsupported writing for
{self._instance}")
@@ -85,6 +99,7 @@ class RemoteMLReader(MLReader[RL]):
def load(self, path: str) -> RL:
from pyspark.sql.connect.session import SparkSession
from pyspark.ml.wrapper import JavaModel, JavaEstimator
+ from pyspark.ml.evaluation import JavaEvaluator
session = SparkSession.getActiveSession()
assert session is not None
@@ -99,6 +114,8 @@ class RemoteMLReader(MLReader[RL]):
ml_type = pb2.MlOperator.MODEL
elif issubclass(self._clazz, JavaEstimator):
ml_type = pb2.MlOperator.ESTIMATOR
+ elif issubclass(self._clazz, JavaEvaluator):
+ ml_type = pb2.MlOperator.EVALUATOR
else:
raise ValueError(f"Unsupported reading for
{java_qualified_class_name}")
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c6445c7f0241..b2b2d32c31f0 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -31,7 +31,7 @@ from pyspark.ml.param.shared import (
HasWeightCol,
)
from pyspark.ml.common import inherit_doc
-from pyspark.ml.util import JavaMLReadable, JavaMLWritable
+from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_evaluate
from pyspark.sql.dataframe import DataFrame
if TYPE_CHECKING:
@@ -128,6 +128,7 @@ class JavaEvaluator(JavaParams, Evaluator,
metaclass=ABCMeta):
implementations.
"""
+ @try_remote_evaluate
def _evaluate(self, dataset: DataFrame) -> float:
"""
Evaluates the output.
diff --git a/python/pyspark/ml/tests/connect/test_parity_evaluation.py
b/python/pyspark/ml/tests/connect/test_parity_evaluation.py
new file mode 100644
index 000000000000..9f78313a318e
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_evaluation.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin
+from pyspark.sql import SparkSession
+
+
+class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase):
+ def setUp(self) -> None:
+ self.spark = SparkSession.builder.remote(
+ os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+ ).getOrCreate()
+
+ def test_assert_remote_mode(self):
+ from pyspark.sql import is_remote
+
+ self.assertTrue(is_remote())
+
+ def tearDown(self) -> None:
+ self.spark.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_evaluation import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_evaluation.py
b/python/pyspark/ml/tests/test_evaluation.py
index 3c5ae3fbe7d1..39e88fa47057 100644
--- a/python/pyspark/ml/tests/test_evaluation.py
+++ b/python/pyspark/ml/tests/test_evaluation.py
@@ -14,18 +14,298 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import tempfile
import unittest
import numpy as np
-from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator
+from pyspark.ml.evaluation import (
+ ClusteringEvaluator,
+ RegressionEvaluator,
+ BinaryClassificationEvaluator,
+ MulticlassClassificationEvaluator,
+ MultilabelClassificationEvaluator,
+ RankingEvaluator,
+)
from pyspark.ml.linalg import Vectors
-from pyspark.sql import Row
-from pyspark.testing.mlutils import SparkSessionTestCase
+from pyspark.sql import Row, SparkSession
+
+
+class EvaluatorTestsMixin:
+ def test_ranking_evaluator(self):
+ scoreAndLabels = [
+ ([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0], [1.0, 2.0,
3.0, 4.0, 5.0]),
+ ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0,
3.0]),
+ ([1.0, 2.0, 3.0, 4.0, 5.0], []),
+ ]
+ dataset = self.spark.createDataFrame(scoreAndLabels, ["prediction",
"label"])
+
+ # Initialize RankingEvaluator
+ evaluator = RankingEvaluator().setPredictionCol("prediction")
+
+ # Evaluate the dataset using the default metric (mean average
precision)
+ mean_average_precision = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(mean_average_precision, 0.3550, atol=1e-4))
+
+ # Evaluate the dataset using precisionAtK for k=2
+ precision_at_k = evaluator.evaluate(
+ dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2}
+ )
+ self.assertTrue(np.allclose(precision_at_k, 0.3333, atol=1e-4))
+
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+ # Save the evaluator
+ evaluator.write().overwrite().save(tmp_dir)
+ # Load the saved evaluator
+ evaluator2 = RankingEvaluator.load(tmp_dir)
+ self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+ self.assertEqual(str(evaluator), str(evaluator2))
+
+ def test_multilabel_classification_evaluator(self):
+ dataset = self.spark.createDataFrame(
+ [
+ ([0.0, 1.0], [0.0, 2.0]),
+ ([0.0, 2.0], [0.0, 1.0]),
+ ([], [0.0]),
+ ([2.0], [2.0]),
+ ([2.0, 0.0], [2.0, 0.0]),
+ ([0.0, 1.0, 2.0], [0.0, 1.0]),
+ ([1.0], [1.0, 2.0]),
+ ],
+ ["prediction", "label"],
+ )
+
+ evaluator =
MultilabelClassificationEvaluator().setPredictionCol("prediction")
+
+ # Evaluate the dataset using the default metric (f1 measure by default)
+ f1_score = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(f1_score, 0.6380, atol=1e-4))
+ # Evaluate the dataset using accuracy
+ accuracy = evaluator.evaluate(dataset, {evaluator.metricName:
"accuracy"})
+ self.assertTrue(np.allclose(accuracy, 0.5476, atol=1e-4))
+
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+ # Save the evaluator
+ evaluator.write().overwrite().save(tmp_dir)
+ # Load the saved evaluator
+ evaluator2 = MultilabelClassificationEvaluator.load(tmp_dir)
+ self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+ self.assertEqual(str(evaluator), str(evaluator2))
+
+ def test_multiclass_classification_evaluator(self):
+ dataset = self.spark.createDataFrame(
+ [
+ (0.0, 0.0, 1.0, [0.1, 0.8, 0.1]),
+ (0.0, 1.0, 1.0, [0.3, 0.4, 0.3]),
+ (0.0, 0.0, 1.0, [0.9, 0.05, 0.05]),
+ (1.0, 0.0, 1.0, [0.5, 0.2, 0.3]),
+ (1.0, 1.0, 1.0, [0.2, 0.7, 0.1]),
+ (1.0, 1.0, 1.0, [0.1, 0.3, 0.6]),
+ (1.0, 1.0, 1.0, [0.2, 0.1, 0.7]),
+ (2.0, 2.0, 1.0, [0.3, 0.2, 0.5]),
+ (2.0, 0.0, 1.0, [0.6, 0.2, 0.2]),
+ ],
+ ["prediction", "label", "weight", "probability"],
+ )
+
+ evaluator =
MulticlassClassificationEvaluator().setPredictionCol("prediction")
+
+ f1_score = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(f1_score, 0.6613, atol=1e-4))
+
+ # Evaluate the dataset using accuracy
+ accuracy = evaluator.evaluate(dataset, {evaluator.metricName:
"accuracy"})
+ self.assertTrue(np.allclose(accuracy, 0.6666, atol=1e-4))
+
+ # Evaluate the true positive rate for label 1.0
+ true_positive_rate_label_1 = evaluator.evaluate(
+ dataset, {evaluator.metricName: "truePositiveRateByLabel",
evaluator.metricLabel: 1.0}
+ )
+ self.assertEqual(true_positive_rate_label_1, 0.75)
+
+ # Set the metric to Hamming loss
+ evaluator.setMetricName("hammingLoss")
+
+ # Evaluate the dataset using Hamming loss
+ hamming_loss = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(hamming_loss, 0.3333, atol=1e-4))
+
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+ # Save the evaluator
+ evaluator.write().overwrite().save(tmp_dir)
+ # Load the saved evaluator
+ evaluator2 = MulticlassClassificationEvaluator.load(tmp_dir)
+ self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+ self.assertEqual(str(evaluator), str(evaluator2))
+
+ # Initialize MulticlassClassificationEvaluator with weight column
+ evaluator = MulticlassClassificationEvaluator(
+ predictionCol="prediction", weightCol="weight"
+ )
+
+ # Evaluate the dataset with weights using default metric (f1 score)
+ weighted_f1_score = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(weighted_f1_score, 0.6613, atol=1e-4))
+
+ # Evaluate the dataset with weights using accuracy
+ weighted_accuracy = evaluator.evaluate(dataset, {evaluator.metricName:
"accuracy"})
+ self.assertTrue(np.allclose(weighted_accuracy, 0.6666, atol=1e-4))
+
+ evaluator = MulticlassClassificationEvaluator(
+ predictionCol="prediction", probabilityCol="probability"
+ )
+ # Set the metric to log loss
+ evaluator.setMetricName("logLoss")
+ # Evaluate the dataset using log loss
+ log_loss = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(log_loss, 1.0093, atol=1e-4))
+
+ def test_binary_classification_evaluator(self):
+ # Define score and labels data
+ data = map(
+ lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1], x[2]),
+ [
+ (0.1, 0.0, 1.0),
+ (0.1, 1.0, 0.9),
+ (0.4, 0.0, 0.7),
+ (0.6, 0.0, 0.9),
+ (0.6, 1.0, 1.0),
+ (0.6, 1.0, 0.3),
+ (0.8, 1.0, 1.0),
+ ],
+ )
+ dataset = self.spark.createDataFrame(data, ["raw", "label", "weight"])
+
+ evaluator = BinaryClassificationEvaluator().setRawPredictionCol("raw")
+ auc_roc = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(auc_roc, 0.7083, atol=1e-4))
+
+ # Evaluate the dataset using the areaUnderPR metric
+ auc_pr = evaluator.evaluate(dataset, {evaluator.metricName:
"areaUnderPR"})
+ self.assertTrue(np.allclose(auc_pr, 0.8339, atol=1e-4))
+
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+ # Save the evaluator
+ evaluator.write().overwrite().save(tmp_dir)
+ # Load the saved evaluator
+ evaluator2 = BinaryClassificationEvaluator.load(tmp_dir)
+ self.assertEqual(evaluator2.getRawPredictionCol(), "raw")
+ self.assertEqual(str(evaluator), str(evaluator2))
+
+ evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw",
weightCol="weight")
+
+ # Evaluate the dataset with weights using the default metric
(areaUnderROC)
+ auc_roc_weighted = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(auc_roc_weighted, 0.7025, atol=1e-4))
+
+ # Evaluate the dataset with weights using the areaUnderPR metric
+ auc_pr_weighted = evaluator.evaluate(dataset, {evaluator.metricName:
"areaUnderPR"})
+ self.assertTrue(np.allclose(auc_pr_weighted, 0.8221, atol=1e-4))
+
+ # Get the number of bins used to compute areaUnderROC
+ num_bins = evaluator.getNumBins()
+ self.assertEqual(num_bins, 1000)
+
+ def test_clustering_evaluator(self):
+ # Define feature and predictions data
+ data = map(
+ lambda x: (Vectors.dense(x[0]), x[1], x[2]),
+ [
+ ([0.0, 0.5], 0.0, 2.5),
+ ([0.5, 0.0], 0.0, 2.5),
+ ([10.0, 11.0], 1.0, 2.5),
+ ([10.5, 11.5], 1.0, 2.5),
+ ([1.0, 1.0], 0.0, 2.5),
+ ([8.0, 6.0], 1.0, 2.5),
+ ],
+ )
+ dataset = self.spark.createDataFrame(data, ["features", "prediction",
"weight"])
+ evaluator = ClusteringEvaluator().setPredictionCol("prediction")
+ score = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(score, 0.9079, atol=1e-4))
+
+ evaluator.setWeightCol("weight")
+
+ # Evaluate the dataset with weights
+ score_with_weight = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(score_with_weight, 0.9079, atol=1e-4))
+
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+ # Save the evaluator
+ evaluator.write().overwrite().save(tmp_dir)
+ # Load the saved evaluator
+ evaluator2 = ClusteringEvaluator.load(tmp_dir)
+ self.assertEqual(evaluator2.getPredictionCol(), "prediction")
+
+ def test_clustering_evaluator_with_cosine_distance(self):
+ featureAndPredictions = map(
+ lambda x: (Vectors.dense(x[0]), x[1]),
+ [
+ ([1.0, 1.0], 1.0),
+ ([10.0, 10.0], 1.0),
+ ([1.0, 0.5], 2.0),
+ ([10.0, 4.4], 2.0),
+ ([-1.0, 1.0], 3.0),
+ ([-100.0, 90.0], 3.0),
+ ],
+ )
+ dataset = self.spark.createDataFrame(featureAndPredictions,
["features", "prediction"])
+ evaluator = ClusteringEvaluator(predictionCol="prediction",
distanceMeasure="cosine")
+ self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
+ self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213,
atol=1e-5))
+
+ def test_regression_evaluator(self):
+ dataset = self.spark.createDataFrame(
+ [
+ (-28.98343821, -27.0, 1.0),
+ (20.21491975, 21.5, 0.8),
+ (-25.98418959, -22.0, 1.0),
+ (30.69731842, 33.0, 0.6),
+ (74.69283752, 71.0, 0.2),
+ ],
+ ["raw", "label", "weight"],
+ )
+
+ evaluator = RegressionEvaluator()
+ evaluator.setPredictionCol("raw")
+
+ # Evaluate dataset with default metric (RMSE)
+ rmse = evaluator.evaluate(dataset)
+ self.assertTrue(np.allclose(rmse, 2.8424, atol=1e-4))
+ # Evaluate dataset with R2 metric
+ r2 = evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
+ self.assertTrue(np.allclose(r2, 0.9939, atol=1e-4))
+ # Evaluate dataset with MAE metric
+ mae = evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
+ self.assertTrue(np.allclose(mae, 2.6496, atol=1e-4))
+ # read/write
+ with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+ # Save the evaluator
+ evaluator.write().overwrite().save(tmp_dir)
+ # Load the saved evaluator
+ evaluator2 = RegressionEvaluator.load(tmp_dir)
+ self.assertEqual(evaluator2.getPredictionCol(), "raw")
+
+ evaluator_with_weights = RegressionEvaluator(predictionCol="raw",
weightCol="weight")
+ weighted_rmse = evaluator_with_weights.evaluate(dataset)
+ self.assertTrue(np.allclose(weighted_rmse, 2.7405, atol=1e-4))
+ through_origin = evaluator_with_weights.getThroughOrigin()
+ self.assertEqual(through_origin, False)
+
+
+class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase):
+ def setUp(self) -> None:
+ self.spark = SparkSession.builder.master("local[4]").getOrCreate()
+
+ def tearDown(self) -> None:
+ self.spark.stop()
-class EvaluatorTests(SparkSessionTestCase):
def test_evaluate_invalid_type(self):
evaluator = RegressionEvaluator(metricName="r2")
df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
@@ -47,23 +327,6 @@ class EvaluatorTests(SparkSessionTestCase):
self.assertEqual(evaluator._java_obj.getMetricName(), "r2")
self.assertEqual(evaluatorCopy._java_obj.getMetricName(), "mae")
- def test_clustering_evaluator_with_cosine_distance(self):
- featureAndPredictions = map(
- lambda x: (Vectors.dense(x[0]), x[1]),
- [
- ([1.0, 1.0], 1.0),
- ([10.0, 10.0], 1.0),
- ([1.0, 0.5], 2.0),
- ([10.0, 4.4], 2.0),
- ([-1.0, 1.0], 3.0),
- ([-100.0, 90.0], 3.0),
- ],
- )
- dataset = self.spark.createDataFrame(featureAndPredictions,
["features", "prediction"])
- evaluator = ClusteringEvaluator(predictionCol="prediction",
distanceMeasure="cosine")
- self.assertEqual(evaluator.getDistanceMeasure(), "cosine")
- self.assertTrue(np.isclose(evaluator.evaluate(dataset), 0.992671213,
atol=1e-5))
-
if __name__ == "__main__":
from pyspark.ml.tests.test_evaluation import * # noqa: F401
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index d161db04fad2..74a07ec365b3 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -49,6 +49,7 @@ if TYPE_CHECKING:
from pyspark.core.context import SparkContext
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.ml.wrapper import JavaWrapper, JavaEstimator
+ from pyspark.ml.evaluation import JavaEvaluator
T = TypeVar("T")
RW = TypeVar("RW", bound="BaseReadWrite")
@@ -324,6 +325,37 @@ def try_remote_not_supporting(f: FuncT) -> FuncT:
return cast(FuncT, wrapped)
+def try_remote_evaluate(f: FuncT) -> FuncT:
+ """Mark the evaluate function in Evaluator."""
+
+ @functools.wraps(f)
+ def wrapped(self: "JavaEvaluator", dataset: "ConnectDataFrame") -> Any:
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+ import pyspark.sql.connect.proto as pb2
+ from pyspark.ml.connect.serialize import serialize_ml_params,
deserialize
+
+ client = dataset.sparkSession.client
+ input = dataset._plan.plan(client)
+ assert isinstance(self._java_obj, str)
+ evaluator = pb2.MlOperator(
+ name=self._java_obj, uid=self.uid,
type=pb2.MlOperator.EVALUATOR
+ )
+ command = pb2.Command()
+ command.ml_command.evaluate.CopyFrom(
+ pb2.MlCommand.Evaluate(
+ evaluator=evaluator,
+ params=serialize_ml_params(self, client),
+ dataset=input,
+ )
+ )
+ (_, properties, _) = client.execute_command(command)
+ return deserialize(properties)
+ else:
+ return f(self, dataset)
+
+ return cast(FuncT, wrapped)
+
+
def _jvm() -> "JavaGateway":
"""
Returns the JVM view associated with SparkContext. Must be called
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 7b5c4f783a65..8fb38fc7ce18 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -353,7 +353,7 @@ class JavaParams(JavaWrapper, Params, metaclass=ABCMeta):
if extra is None:
extra = dict()
that = super(JavaParams, self).copy(extra)
- if self._java_obj is not None:
+ if self._java_obj is not None and not isinstance(self._java_obj, str):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py
b/python/pyspark/sql/connect/proto/ml_pb2.py
index c29d33db547c..8e8bc34a7a97 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xc6\x07\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\
[...]
+
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
[...]
)
_globals = globals()
@@ -54,19 +54,21 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_MLCOMMAND"]._serialized_start = 137
- _globals["_MLCOMMAND"]._serialized_end = 1103
- _globals["_MLCOMMAND_FIT"]._serialized_start = 415
- _globals["_MLCOMMAND_FIT"]._serialized_end = 577
- _globals["_MLCOMMAND_DELETE"]._serialized_start = 579
- _globals["_MLCOMMAND_DELETE"]._serialized_end = 638
- _globals["_MLCOMMAND_WRITE"]._serialized_start = 641
- _globals["_MLCOMMAND_WRITE"]._serialized_end = 1009
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 943
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1001
- _globals["_MLCOMMAND_READ"]._serialized_start = 1011
- _globals["_MLCOMMAND_READ"]._serialized_end = 1092
- _globals["_MLCOMMANDRESULT"]._serialized_start = 1106
- _globals["_MLCOMMANDRESULT"]._serialized_end = 1480
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1299
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1465
+ _globals["_MLCOMMAND"]._serialized_end = 1338
+ _globals["_MLCOMMAND_FIT"]._serialized_start = 480
+ _globals["_MLCOMMAND_FIT"]._serialized_end = 642
+ _globals["_MLCOMMAND_DELETE"]._serialized_start = 644
+ _globals["_MLCOMMAND_DELETE"]._serialized_end = 703
+ _globals["_MLCOMMAND_WRITE"]._serialized_start = 706
+ _globals["_MLCOMMAND_WRITE"]._serialized_end = 1074
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1008
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1066
+ _globals["_MLCOMMAND_READ"]._serialized_start = 1076
+ _globals["_MLCOMMAND_READ"]._serialized_end = 1157
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1160
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1327
+ _globals["_MLCOMMANDRESULT"]._serialized_start = 1341
+ _globals["_MLCOMMANDRESULT"]._serialized_end = 1715
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1534
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1700
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 6b950e4c67bb..e8ae0be8dded 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -224,11 +224,49 @@ class MlCommand(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["operator",
b"operator", "path", b"path"]
) -> None: ...
+ class Evaluate(google.protobuf.message.Message):
+ """Command for evaluator.evaluate(dataset)"""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ EVALUATOR_FIELD_NUMBER: builtins.int
+ PARAMS_FIELD_NUMBER: builtins.int
+ DATASET_FIELD_NUMBER: builtins.int
+ @property
+ def evaluator(self) ->
pyspark.sql.connect.proto.ml_common_pb2.MlOperator:
+ """Evaluator information"""
+ @property
+ def params(self) -> pyspark.sql.connect.proto.ml_common_pb2.MlParams:
+ """parameters of the Evaluator"""
+ @property
+ def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
+ """the evaluating dataset"""
+ def __init__(
+ self,
+ *,
+ evaluator: pyspark.sql.connect.proto.ml_common_pb2.MlOperator |
None = ...,
+ params: pyspark.sql.connect.proto.ml_common_pb2.MlParams | None =
...,
+ dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None =
...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "dataset", b"dataset", "evaluator", b"evaluator", "params",
b"params"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "dataset", b"dataset", "evaluator", b"evaluator", "params",
b"params"
+ ],
+ ) -> None: ...
+
FIT_FIELD_NUMBER: builtins.int
FETCH_FIELD_NUMBER: builtins.int
DELETE_FIELD_NUMBER: builtins.int
WRITE_FIELD_NUMBER: builtins.int
READ_FIELD_NUMBER: builtins.int
+ EVALUATE_FIELD_NUMBER: builtins.int
@property
def fit(self) -> global___MlCommand.Fit: ...
@property
@@ -239,6 +277,8 @@ class MlCommand(google.protobuf.message.Message):
def write(self) -> global___MlCommand.Write: ...
@property
def read(self) -> global___MlCommand.Read: ...
+ @property
+ def evaluate(self) -> global___MlCommand.Evaluate: ...
def __init__(
self,
*,
@@ -247,6 +287,7 @@ class MlCommand(google.protobuf.message.Message):
delete: global___MlCommand.Delete | None = ...,
write: global___MlCommand.Write | None = ...,
read: global___MlCommand.Read | None = ...,
+ evaluate: global___MlCommand.Evaluate | None = ...,
) -> None: ...
def HasField(
self,
@@ -255,6 +296,8 @@ class MlCommand(google.protobuf.message.Message):
b"command",
"delete",
b"delete",
+ "evaluate",
+ b"evaluate",
"fetch",
b"fetch",
"fit",
@@ -272,6 +315,8 @@ class MlCommand(google.protobuf.message.Message):
b"command",
"delete",
b"delete",
+ "evaluate",
+ b"evaluate",
"fetch",
b"fetch",
"fit",
@@ -284,7 +329,9 @@ class MlCommand(google.protobuf.message.Message):
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command", b"command"]
- ) -> typing_extensions.Literal["fit", "fetch", "delete", "write", "read"]
| None: ...
+ ) -> (
+ typing_extensions.Literal["fit", "fetch", "delete", "write", "read",
"evaluate"] | None
+ ): ...
global___MlCommand = MlCommand
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 3198bbbe9c11..20a5cafebb36 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -35,6 +35,7 @@ message MlCommand {
Delete delete = 3;
Write write = 4;
Read read = 5;
+ Evaluate evaluate = 6;
}
// Command for estimator.fit(dataset)
@@ -79,6 +80,16 @@ message MlCommand {
// Load the ML instance from the input path
string path = 2;
}
+
+ // Command for evaluator.evaluate(dataset)
+ message Evaluate {
+ // Evaluator information
+ MlOperator evaluator = 1;
+ // parameters of the Evaluator
+ MlParams params = 2;
+ // the evaluating dataset
+ Relation dataset = 3;
+ }
}
// The result of MlCommand
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index e53276dcf7df..cc89079aeca3 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -21,8 +21,8 @@ import scala.jdk.CollectionConverters.CollectionHasAsScala
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.util.{MLWritable, Summary}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
@@ -172,18 +172,29 @@ private[connect] object MLHandler extends Logging {
// save an estimator/evaluator/transformer
case proto.MlCommand.Write.TypeCase.OPERATOR =>
val writer = mlCommand.getWrite
- if (writer.getOperator.getType ==
proto.MlOperator.OperatorType.ESTIMATOR) {
- val estimator =
- MLUtils.getEstimator(sessionHolder, writer.getOperator,
Some(writer.getParams))
- estimator match {
- case m: MLWritable => MLUtils.write(m, mlCommand.getWrite)
- case other => throw MlUnsupportedException(s"Estimator $other
is not writable")
- }
- } else {
- throw MlUnsupportedException(s"${writer.getOperator.getName} not
supported")
- }
+ val operatorType = writer.getOperator.getType
+ val operatorName = writer.getOperator.getName
+ val params = Some(writer.getParams)
+
+ operatorType match {
+ case proto.MlOperator.OperatorType.ESTIMATOR =>
+ val estimator = MLUtils.getEstimator(sessionHolder,
writer.getOperator, params)
+ estimator match {
+ case writable: MLWritable => MLUtils.write(writable,
mlCommand.getWrite)
+ case other => throw MlUnsupportedException(s"Estimator
$other is not writable")
+ }
+
+ case proto.MlOperator.OperatorType.EVALUATOR =>
+ val evaluator = MLUtils.getEvaluator(sessionHolder,
writer.getOperator, params)
+ evaluator match {
+ case writable: MLWritable => MLUtils.write(writable,
mlCommand.getWrite)
+ case other => throw MlUnsupportedException(s"Evaluator
$other is not writable")
+ }
- case other => throw MlUnsupportedException(s"$other not supported")
+ case _ =>
+ throw MlUnsupportedException(s"Operator $operatorName is not
supported")
+ }
+ case other => throw MlUnsupportedException(s"$other write not
supported")
}
proto.MlCommandResult.newBuilder().build()
@@ -205,21 +216,36 @@ private[connect] object MLHandler extends Logging {
.setParams(Serializer.serializeParams(model)))
.build()
- } else if (operator.getType ==
proto.MlOperator.OperatorType.ESTIMATOR) {
- val estimator = MLUtils.load(sessionHolder, name,
path).asInstanceOf[Estimator[_]]
+ } else if (operator.getType == proto.MlOperator.OperatorType.ESTIMATOR
||
+ operator.getType == proto.MlOperator.OperatorType.EVALUATOR) {
+ val operator = MLUtils.load(sessionHolder, name,
path).asInstanceOf[Params]
proto.MlCommandResult
.newBuilder()
.setOperatorInfo(
proto.MlCommandResult.MlOperatorInfo
.newBuilder()
.setName(name)
- .setUid(estimator.uid)
- .setParams(Serializer.serializeParams(estimator)))
+ .setUid(operator.uid)
+ .setParams(Serializer.serializeParams(operator)))
.build()
} else {
- throw MlUnsupportedException(s"${operator.getType} not supported")
+ throw MlUnsupportedException(s"${operator.getType} read not
supported")
}
+ case proto.MlCommand.CommandCase.EVALUATE =>
+ val evalCmd = mlCommand.getEvaluate
+ val evalProto = evalCmd.getEvaluator
+ assert(evalProto.getType == proto.MlOperator.OperatorType.EVALUATOR)
+
+ val dataset = MLUtils.parseRelationProto(evalCmd.getDataset,
sessionHolder)
+ val evaluator =
+ MLUtils.getEvaluator(sessionHolder, evalProto,
Some(evalCmd.getParams))
+ val metric = evaluator.evaluate(dataset)
+ proto.MlCommandResult
+ .newBuilder()
+ .setParam(LiteralValueProtoConverter.toLiteralProto(metric))
+ .build()
+
case other => throw MlUnsupportedException(s"$other not supported")
}
}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 4dbfb063eabb..e6e78f15b61f 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -26,6 +26,7 @@ import
org.apache.commons.lang3.reflect.MethodUtils.invokeMethod
import org.apache.spark.connect.proto
import org.apache.spark.ml.{Estimator, Transformer}
+import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{MLReadable, MLWritable}
@@ -326,6 +327,30 @@ private[ml] object MLUtils {
getInstance[Transformer](name, uid, transformers, Some(params))
}
+ /**
+ * Get the Evaluator instance according to the proto information
+ *
+ * @param sessionHolder
+ * session holder to hold the Spark Connect session state
+ * @param operator
+ * MlOperator information
+ * @param params
+ * The optional parameters of the evaluator
+ * @return
+ * the evaluator
+ */
+ def getEvaluator(
+ sessionHolder: SessionHolder,
+ operator: proto.MlOperator,
+ params: Option[proto.MlParams]): Evaluator = {
+ val name = replaceOperator(sessionHolder, operator.getName)
+ val uid = operator.getUid
+
+ // Load the evaluators by ServiceLoader everytime
+ val evaluators = loadOperators(classOf[Evaluator])
+ getInstance[Evaluator](name, uid, evaluators, params)
+ }
+
/**
* Call "load" function on the ML operator given the operator name
*
diff --git
a/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
new file mode 100644
index 000000000000..502464a1424c
--- /dev/null
+++
b/sql/connect/server/src/test/resources/META-INF/services/org.apache.spark.ml.evaluation.Evaluator
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
evaluators.
+# So register the supported evaluator here if you're trying to add a new one.
+
+org.apache.spark.sql.connect.ml.MyRegressionEvaluator
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
index f4db4077d1f3..7cd95f9f657d 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLBackendSuite.scala
@@ -57,12 +57,28 @@ class MLBackendSuite extends MLHelper {
.build())
}
- test("ML backend: estimator read/write") {
+ test("ML backend: estimator works") {
withSparkConf(
Connect.CONNECT_ML_BACKEND_CLASSES.key ->
"org.apache.spark.sql.connect.ml.MyMlBackend") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val fitCommand = proto.MlCommand
+ .newBuilder()
+ .setFit(
+ proto.MlCommand.Fit
+ .newBuilder()
+ .setDataset(createLocalRelationProto)
+ .setEstimator(getLogisticRegressionBuilder)
+ .setParams(getMaxIterBuilder))
+ .build()
+ val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand)
+ val modelId = fitResult.getOperatorInfo.getObjRef.getId
+
assert(sessionHolder.mlCache.get(modelId).isInstanceOf[MyLogisticRegressionModel])
+ val model =
sessionHolder.mlCache.get(modelId).asInstanceOf[MyLogisticRegressionModel]
+ assert(model.intercept == 3.5f)
+ assert(model.coefficients == 4.6f)
+
// read/write
val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
try {
@@ -103,7 +119,7 @@ class MLBackendSuite extends MLHelper {
}
}
- test("ML backend: model read/write") {
+ test("ML backend: model works") {
withSparkConf(
Connect.CONNECT_ML_BACKEND_CLASSES.key ->
"org.apache.spark.sql.connect.ml.MyMlBackend") {
@@ -159,32 +175,72 @@ class MLBackendSuite extends MLHelper {
assert(
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
== 101010)
+
+ // Fetch double attribute
+ val fakeAttrCmd = fetchCommand(ret.getOperatorInfo.getObjRef.getId,
"predictRaw")
+ val fakeAttrRet = MLHandler.handleMlCommand(sessionHolder, fakeAttrCmd)
+ assert(fakeAttrRet.getParam.getDouble === 1.11)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}
- test("ML backend") {
+ test("ML backend: evaluator works") {
withSparkConf(
Connect.CONNECT_ML_BACKEND_CLASSES.key ->
"org.apache.spark.sql.connect.ml.MyMlBackend") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
- val fitCommand = proto.MlCommand
+
+ val evalCmd = proto.MlCommand
.newBuilder()
- .setFit(
- proto.MlCommand.Fit
+ .setEvaluate(
+ proto.MlCommand.Evaluate
.newBuilder()
- .setDataset(createLocalRelationProto)
- .setEstimator(getLogisticRegressionBuilder)
- .setParams(getMaxIterBuilder))
+ .setDataset(createRegressionEvaluationLocalRelationProto)
+ .setEvaluator(getRegressorEvaluator))
.build()
- val fitResult = MLHandler.handleMlCommand(sessionHolder, fitCommand)
- val modelId = fitResult.getOperatorInfo.getObjRef.getId
-
assert(sessionHolder.mlCache.get(modelId).isInstanceOf[MyLogisticRegressionModel])
- val model =
sessionHolder.mlCache.get(modelId).asInstanceOf[MyLogisticRegressionModel]
- assert(model.intercept == 3.5f)
- assert(model.coefficients == 4.6f)
+ val evalResult = MLHandler.handleMlCommand(sessionHolder, evalCmd)
+ assert(evalResult.getParam.getDouble == 1.11)
+
+ // read/write
+ val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+ try {
+ val path = new File(tempDir,
Identifiable.randomUID("Evaluator")).getPath
+ val writeCmd = proto.MlCommand
+ .newBuilder()
+ .setWrite(
+ proto.MlCommand.Write
+ .newBuilder()
+ .setOperator(getRegressorEvaluator)
+ .setParams(getMetricName)
+ .setPath(path)
+ .setShouldOverwrite(true))
+ .build()
+ MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+ val readCmd = proto.MlCommand
+ .newBuilder()
+ .setRead(
+ proto.MlCommand.Read
+ .newBuilder()
+ .setOperator(getRegressorEvaluator)
+ .setPath(path))
+ .build()
+
+ val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
+
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("fakeParam"))
+
assert(ret.getOperatorInfo.getParams.getParamsMap.containsKey("metricName"))
+ assert(
+
ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
+ == "mae")
+ assert(
+
ret.getOperatorInfo.getParams.getParamsMap.get("fakeParam").getInteger
+ == 101010)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
}
}
+
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
index c778603eeece..844e85fa03b6 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLHelper.scala
@@ -22,8 +22,9 @@ import java.util.Optional
import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg.{Vectors, VectorUDT}
-import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.ml.param.shared.HasMaxIter
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable,
Identifiable, MLReadable, MLReader}
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -32,7 +33,7 @@ import
org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connect.planner.SparkConnectPlanTest
import org.apache.spark.sql.connect.plugin.MLBackendPlugin
-import org.apache.spark.sql.types.{FloatType, Metadata, StructField,
StructType}
+import org.apache.spark.sql.types.{DoubleType, FloatType, Metadata,
StructField, StructType}
trait MLHelper extends SparkFunSuite with SparkConnectPlanTest {
@@ -55,6 +56,46 @@ trait MLHelper extends SparkFunSuite with
SparkConnectPlanTest {
}
createLocalRelationProto(DataTypeUtils.toAttributes(schema), inputRows,
"UTC", Some(schema))
}
+
+ def createRegressionEvaluationLocalRelationProto: proto.Relation = {
+ // The test refers to
+ //
https://github.com/apache/spark/blob/master/python/pyspark/ml/evaluation.py#L331
+ val rows = Seq(
+ InternalRow(-28.98343821, -27.0),
+ InternalRow(20.21491975, 21.5),
+ InternalRow(-25.98418959, -22.0),
+ InternalRow(30.69731842, 33.0),
+ InternalRow(74.69283752, 71.0))
+ val schema = StructType(Seq(StructField("raw", DoubleType),
StructField("label", DoubleType)))
+ val inputRows = rows.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+ createLocalRelationProto(schema, inputRows)
+ }
+
+ def getRegressorEvaluator: proto.MlOperator.Builder =
+ proto.MlOperator
+ .newBuilder()
+ .setName("org.apache.spark.ml.evaluation.RegressionEvaluator")
+ .setUid("RegressionEvaluator")
+ .setType(proto.MlOperator.OperatorType.EVALUATOR)
+
+ def getMetricName: proto.MlParams.Builder =
+ proto.MlParams
+ .newBuilder()
+ .putParams("metricName",
proto.Expression.Literal.newBuilder().setString("mae").build())
+
+ def fetchCommand(modelId: String, method: String): proto.MlCommand = {
+ proto.MlCommand
+ .newBuilder()
+ .setFetch(
+ proto.Fetch
+ .newBuilder()
+ .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
+ .addMethods(proto.Fetch.Method.newBuilder().setMethod(method)))
+ .build()
+ }
}
class MyMlBackend extends MLBackendPlugin {
@@ -65,6 +106,8 @@ class MyMlBackend extends MLBackendPlugin {
Optional.of("org.apache.spark.sql.connect.ml.MyLogisticRegression")
case "org.apache.spark.ml.classification.LogisticRegressionModel" =>
Optional.of("org.apache.spark.sql.connect.ml.MyLogisticRegressionModel")
+ case "org.apache.spark.ml.evaluation.RegressionEvaluator" =>
+ Optional.of("org.apache.spark.sql.connect.ml.MyRegressionEvaluator")
case _ => Optional.empty()
}
}
@@ -74,6 +117,29 @@ trait HasFakedParam extends Params {
final val fakeParam: IntParam = new IntParam(this, "fakeParam", "faked
parameter")
}
+class MyRegressionEvaluator(override val uid: String)
+ extends Evaluator
+ with DefaultParamsWritable
+ with HasFakedParam {
+
+ def this() = this(Identifiable.randomUID("MyRegressionEvaluator"))
+
+ // keep same as RegressionEvaluator
+ val metricName: Param[String] = {
+ new Param(this, "metricName", "metric name in evaluation
(mse|rmse|r2|mae|var)")
+ }
+
+ set(fakeParam, 101010)
+
+ override def evaluate(dataset: Dataset[_]): Double = 1.11
+
+ override def copy(extra: ParamMap): Evaluator = defaultCopy(extra)
+}
+
+object MyRegressionEvaluator extends
DefaultParamsReadable[MyRegressionEvaluator] {
+ override def load(path: String): MyRegressionEvaluator = super.load(path)
+}
+
class MyLogisticRegressionModel(
override val uid: String,
val intercept: Float,
@@ -95,6 +161,9 @@ class MyLogisticRegressionModel(
dataset.toDF()
}
+ // fake a function
+ def predictRaw: Double = 1.11
+
override def transformSchema(schema: StructType): StructType = schema
}
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index b537f1c702bd..49dcd7dbe9ad 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -132,17 +132,6 @@ class MLSuite extends MLHelper {
assert(fakedML.getDouble === 1.0)
}
- def fetchCommand(modelId: String, method: String): proto.MlCommand = {
- proto.MlCommand
- .newBuilder()
- .setFetch(
- proto.Fetch
- .newBuilder()
- .setObjRef(proto.ObjectRef.newBuilder().setId(modelId))
- .addMethods(proto.Fetch.Method.newBuilder().setMethod(method)))
- .build()
- }
-
def trainLogisticRegressionModel(sessionHolder: SessionHolder): String = {
val fitCommand = proto.MlCommand
.newBuilder()
@@ -353,4 +342,60 @@ class MLSuite extends MLHelper {
thrown.message.contains("org.apache.spark.sql.connect.ml.NotImplementingMLReadble
" +
"must implement MLReadable"))
}
+
+ test("RegressionEvaluator works") {
+ val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+
+ val evalCmd = proto.MlCommand
+ .newBuilder()
+ .setEvaluate(
+ proto.MlCommand.Evaluate
+ .newBuilder()
+ .setDataset(createRegressionEvaluationLocalRelationProto)
+ .setEvaluator(getRegressorEvaluator)
+ .setParams(
+ proto.MlParams
+ .newBuilder()
+ .putParams(
+ "predictionCol",
+
proto.Expression.Literal.newBuilder().setString("raw").build())))
+ .build()
+ val evalResult = MLHandler.handleMlCommand(sessionHolder, evalCmd)
+ assert(
+ evalResult.getParam.getDouble > 2.841 &&
+ evalResult.getParam.getDouble < 2.843)
+
+ // read/write
+ val tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
+ try {
+ val path = new File(tempDir,
Identifiable.randomUID("RegressionEvaluator")).getPath
+ val writeCmd = proto.MlCommand
+ .newBuilder()
+ .setWrite(
+ proto.MlCommand.Write
+ .newBuilder()
+ .setOperator(getRegressorEvaluator)
+ .setParams(getMetricName)
+ .setPath(path)
+ .setShouldOverwrite(true))
+ .build()
+ MLHandler.handleMlCommand(sessionHolder, writeCmd)
+
+ val readCmd = proto.MlCommand
+ .newBuilder()
+ .setRead(
+ proto.MlCommand.Read
+ .newBuilder()
+ .setOperator(getRegressorEvaluator)
+ .setPath(path))
+ .build()
+
+ val ret = MLHandler.handleMlCommand(sessionHolder, readCmd)
+ assert(
+ ret.getOperatorInfo.getParams.getParamsMap.get("metricName").getString
==
+ "mae")
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]