This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 178b1c7f6968 [SPARK-50874][ML][PYTHON][CONNECT] Support
`LinearRegression` on connect
178b1c7f6968 is described below
commit 178b1c7f69680eeff5da7e1ff783cb2d9548f946
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Jan 18 12:56:12 2025 +0800
[SPARK-50874][ML][PYTHON][CONNECT] Support `LinearRegression` on connect
### What changes were proposed in this pull request?
Support LinearRegression on connect
### Why are the changes needed?
feature parity for connect
### Does this PR introduce _any_ user-facing change?
yes, new feature
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49553 from zhengruifeng/ml_regression.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 93e56c26acec8455ec532fcedea9e22fd05a288a)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
.../services/org.apache.spark.ml.Estimator | 4 +
.../spark/ml/regression/LinearRegression.scala | 2 +-
python/pyspark/ml/regression.py | 3 +
.../ml/tests/connect/test_parity_regression.py | 49 ++++++
python/pyspark/ml/tests/test_regression.py | 180 +++++++++++++++++++++
python/pyspark/ml/util.py | 1 +
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 16 +-
8 files changed, 255 insertions(+), 2 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index cb5cf0a85a4b..045328a05d8e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -696,6 +696,7 @@ pyspark_ml = Module(
"pyspark.ml.tests.connect.test_legacy_mode_pipeline",
"pyspark.ml.tests.connect.test_legacy_mode_tuning",
"pyspark.ml.tests.test_classification",
+ "pyspark.ml.tests.test_regression",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it
isn't available there
@@ -1117,6 +1118,7 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_parity_classification",
+ "pyspark.ml.tests.connect.test_parity_regression",
"pyspark.ml.tests.connect.test_parity_evaluation",
],
excluded_python_implementations=[
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index f010e48e8cab..b9a69ed55094 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -23,3 +23,7 @@ org.apache.spark.ml.classification.LogisticRegression
org.apache.spark.ml.classification.DecisionTreeClassifier
org.apache.spark.ml.classification.RandomForestClassifier
org.apache.spark.ml.classification.GBTClassifier
+
+
+# regression
+org.apache.spark.ml.regression.LinearRegression
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index abac9db8df02..4f74dd734e8f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -903,7 +903,7 @@ class LinearRegressionSummary private[regression] (
val labelCol: String,
val featuresCol: String,
private val privateModel: LinearRegressionModel,
- private val diagInvAtWA: Array[Double]) extends Serializable {
+ private val diagInvAtWA: Array[Double]) extends Summary with Serializable {
@transient private val metrics = {
val weightCol =
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index d7cc27e27427..a3ab3ea67557 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -44,6 +44,7 @@ from pyspark.ml.param.shared import (
HasLoss,
HasVarianceCol,
)
+from pyspark.ml.util import try_remote_attribute_relation
from pyspark.ml.tree import (
_DecisionTreeModel,
_DecisionTreeParams,
@@ -517,6 +518,7 @@ class LinearRegressionSummary(JavaWrapper):
@property
@since("2.0.0")
+ @try_remote_attribute_relation
def predictions(self) -> DataFrame:
"""
Dataframe outputted by the model's `transform` method.
@@ -651,6 +653,7 @@ class LinearRegressionSummary(JavaWrapper):
@property
@since("2.0.0")
+ @try_remote_attribute_relation
def residuals(self) -> DataFrame:
"""
Residuals (label - predicted value)
diff --git a/python/pyspark/ml/tests/connect/test_parity_regression.py
b/python/pyspark/ml/tests/connect/test_parity_regression.py
new file mode 100644
index 000000000000..67187bb74bd5
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_regression.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark.ml.tests.test_regression import RegressionTestsMixin
+from pyspark.sql import SparkSession
+
+
+class RegressionParityTests(RegressionTestsMixin, unittest.TestCase):
+ def setUp(self) -> None:
+ self.spark = SparkSession.builder.remote(
+ os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+ ).getOrCreate()
+
+ def test_assert_remote_mode(self):
+ from pyspark.sql import is_remote
+
+ self.assertTrue(is_remote())
+
+ def tearDown(self) -> None:
+ self.spark.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_regression import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_regression.py
b/python/pyspark/ml/tests/test_regression.py
new file mode 100644
index 000000000000..305e2514a382
--- /dev/null
+++ b/python/pyspark/ml/tests/test_regression.py
@@ -0,0 +1,180 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.linalg import Vectors
+from pyspark.sql import SparkSession
+from pyspark.ml.regression import (
+ LinearRegression,
+ LinearRegressionModel,
+ LinearRegressionSummary,
+ LinearRegressionTrainingSummary,
+)
+
+
+class RegressionTestsMixin:
+ @property
+ def df(self):
+ return (
+ self.spark.createDataFrame(
+ [
+ (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+ ],
+ ["label", "weight", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("weight")
+ )
+
+ def test_linear_regression(self):
+ df = self.df
+ lr = LinearRegression(
+ regParam=0.0,
+ maxIter=2,
+ solver="normal",
+ weightCol="weight",
+ )
+ self.assertEqual(lr.getRegParam(), 0)
+ self.assertEqual(lr.getMaxIter(), 2)
+ self.assertEqual(lr.getSolver(), "normal")
+ self.assertEqual(lr.getWeightCol(), "weight")
+
+ # Estimator save & load
+ with tempfile.TemporaryDirectory(prefix="linear_regression") as d:
+ lr.write().overwrite().save(d)
+ lr2 = LinearRegression.load(d)
+ self.assertEqual(str(lr), str(lr2))
+
+ model = lr.fit(df)
+ self.assertEqual(model.numFeatures, 2)
+ self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4))
+ self.assertTrue(np.allclose(model.intercept, -0.35, atol=1e-4))
+ self.assertTrue(np.allclose(model.coefficients, [0.65, 0.1125],
atol=1e-4))
+
+ output = model.transform(df)
+ expected_cols = [
+ "label",
+ "weight",
+ "features",
+ "prediction",
+ ]
+ self.assertEqual(output.columns, expected_cols)
+ self.assertEqual(output.count(), 4)
+
+ self.assertTrue(
+ np.allclose(model.predict(Vectors.dense(0.0, 5.0)),
0.21249999999999963, atol=1e-4)
+ )
+
+ # Model summary
+ summary = model.summary
+ self.assertTrue(isinstance(summary, LinearRegressionSummary))
+ self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
+ self.assertEqual(summary.predictions.columns, expected_cols)
+ self.assertEqual(summary.predictions.count(), 4)
+ self.assertEqual(summary.residuals.columns, ["residuals"])
+ self.assertEqual(summary.residuals.count(), 4)
+
+ self.assertEqual(summary.degreesOfFreedom, 1)
+ self.assertEqual(summary.numInstances, 4)
+ self.assertEqual(summary.objectiveHistory, [0.0])
+ self.assertTrue(
+ np.allclose(
+ summary.coefficientStandardErrors,
+ [1.2859821149611763, 0.6248749874975031, 3.1645497310044184],
+ atol=1e-4,
+ )
+ )
+ self.assertTrue(
+ np.allclose(
+ summary.devianceResiduals, [-0.7424621202458727,
0.7875000000000003], atol=1e-4
+ )
+ )
+ self.assertTrue(
+ np.allclose(
+ summary.pValues,
+ [0.7020630236843428, 0.8866003086182783, 0.9298746994547682],
+ atol=1e-4,
+ )
+ )
+ self.assertTrue(
+ np.allclose(
+ summary.tValues,
+ [0.5054502643838291, 0.1800360108036021, -0.11060025272186746],
+ atol=1e-4,
+ )
+ )
+ self.assertTrue(np.allclose(summary.explainedVariance,
0.07997500000000031, atol=1e-4))
+ self.assertTrue(np.allclose(summary.meanAbsoluteError,
0.4200000000000002, atol=1e-4))
+ self.assertTrue(np.allclose(summary.meanSquaredError,
0.20212500000000005, atol=1e-4))
+ self.assertTrue(np.allclose(summary.rootMeanSquaredError,
0.44958314025327956, atol=1e-4))
+ self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4))
+ self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414,
atol=1e-4))
+
+ summary2 = model.evaluate(df)
+ self.assertTrue(isinstance(summary2, LinearRegressionSummary))
+ self.assertFalse(isinstance(summary2, LinearRegressionTrainingSummary))
+ self.assertEqual(summary2.predictions.columns, expected_cols)
+ self.assertEqual(summary2.predictions.count(), 4)
+ self.assertEqual(summary2.residuals.columns, ["residuals"])
+ self.assertEqual(summary2.residuals.count(), 4)
+
+ self.assertEqual(summary2.degreesOfFreedom, 1)
+ self.assertEqual(summary2.numInstances, 4)
+ self.assertTrue(
+ np.allclose(
+ summary2.devianceResiduals, [-0.7424621202458727,
0.7875000000000003], atol=1e-4
+ )
+ )
+ self.assertTrue(np.allclose(summary2.explainedVariance,
0.07997500000000031, atol=1e-4))
+ self.assertTrue(np.allclose(summary2.meanAbsoluteError,
0.4200000000000002, atol=1e-4))
+ self.assertTrue(np.allclose(summary2.meanSquaredError,
0.20212500000000005, atol=1e-4))
+ self.assertTrue(np.allclose(summary2.rootMeanSquaredError,
0.44958314025327956, atol=1e-4))
+ self.assertTrue(np.allclose(summary2.r2, 0.4427212572373862,
atol=1e-4))
+ self.assertTrue(np.allclose(summary2.r2adj, -0.6718362282878414,
atol=1e-4))
+
+ # Model save & load
+ with tempfile.TemporaryDirectory(prefix="linear_regression_model") as
d:
+ model.write().overwrite().save(d)
+ model2 = LinearRegressionModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
+
+class RegressionTests(RegressionTestsMixin, unittest.TestCase):
+ def setUp(self) -> None:
+ self.spark = SparkSession.builder.master("local[4]").getOrCreate()
+
+ def tearDown(self) -> None:
+ self.spark.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_regression import * # noqa: F401,F403
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 74a07ec365b3..6006d131b5c0 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -580,6 +580,7 @@ class GeneralJavaMLWritable(JavaMLWritable):
(Private) Mixin for ML instances that provide
:py:class:`GeneralJavaMLWriter`.
"""
+ @try_remote_write
def write(self) -> GeneralJavaMLWriter:
"""Returns an GeneralMLWriter instance for this ML instance."""
return GeneralJavaMLWriter(self)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index e6e78f15b61f..145afd90e77e 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -394,6 +394,7 @@ private[ml] object MLUtils {
"featureImportances", // Tree models
"predictRaw", // ClassificationModel
"predictProbability", // ProbabilisticClassificationModel
+ "scale", // LinearRegressionModel
"coefficients",
"intercept",
"coefficientMatrix",
@@ -428,7 +429,20 @@ private[ml] object MLUtils {
"probabilityCol",
"featuresCol", // LogisticRegressionSummary
"objectiveHistory",
- "totalIterations" // _TrainingSummary
+ "coefficientStandardErrors", // _TrainingSummary
+ "degreesOfFreedom", // LinearRegressionSummary
+ "devianceResiduals", // LinearRegressionSummary
+ "explainedVariance", // LinearRegressionSummary
+ "meanAbsoluteError", // LinearRegressionSummary
+ "meanSquaredError", // LinearRegressionSummary
+ "numInstances", // LinearRegressionSummary
+ "pValues", // LinearRegressionSummary
+ "r2", // LinearRegressionSummary
+ "r2adj", // LinearRegressionSummary
+ "residuals", // LinearRegressionSummary
+ "rootMeanSquaredError", // LinearRegressionSummary
+ "tValues", // LinearRegressionSummary
+ "totalIterations" // LinearRegressionSummary
)
def invokeMethodAllowed(obj: Object, methodName: String): Object = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]