This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 178b1c7f6968 [SPARK-50874][ML][PYTHON][CONNECT] Support 
`LinearRegression` on connect
178b1c7f6968 is described below

commit 178b1c7f69680eeff5da7e1ff783cb2d9548f946
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Jan 18 12:56:12 2025 +0800

    [SPARK-50874][ML][PYTHON][CONNECT] Support `LinearRegression` on connect
    
    ### What changes were proposed in this pull request?
    Support LinearRegression on connect
    
    ### Why are the changes needed?
    feature parity for connect
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new feature
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49553 from zhengruifeng/ml_regression.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 93e56c26acec8455ec532fcedea9e22fd05a288a)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   2 +
 .../services/org.apache.spark.ml.Estimator         |   4 +
 .../spark/ml/regression/LinearRegression.scala     |   2 +-
 python/pyspark/ml/regression.py                    |   3 +
 .../ml/tests/connect/test_parity_regression.py     |  49 ++++++
 python/pyspark/ml/tests/test_regression.py         | 180 +++++++++++++++++++++
 python/pyspark/ml/util.py                          |   1 +
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  16 +-
 8 files changed, 255 insertions(+), 2 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index cb5cf0a85a4b..045328a05d8e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -696,6 +696,7 @@ pyspark_ml = Module(
         "pyspark.ml.tests.connect.test_legacy_mode_pipeline",
         "pyspark.ml.tests.connect.test_legacy_mode_tuning",
         "pyspark.ml.tests.test_classification",
+        "pyspark.ml.tests.test_regression",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy and it 
isn't available there
@@ -1117,6 +1118,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_connect_pipeline",
         "pyspark.ml.tests.connect.test_connect_tuning",
         "pyspark.ml.tests.connect.test_parity_classification",
+        "pyspark.ml.tests.connect.test_parity_regression",
         "pyspark.ml.tests.connect.test_parity_evaluation",
     ],
     excluded_python_implementations=[
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index f010e48e8cab..b9a69ed55094 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -23,3 +23,7 @@ org.apache.spark.ml.classification.LogisticRegression
 org.apache.spark.ml.classification.DecisionTreeClassifier
 org.apache.spark.ml.classification.RandomForestClassifier
 org.apache.spark.ml.classification.GBTClassifier
+
+
+# regression
+org.apache.spark.ml.regression.LinearRegression
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index abac9db8df02..4f74dd734e8f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -903,7 +903,7 @@ class LinearRegressionSummary private[regression] (
     val labelCol: String,
     val featuresCol: String,
     private val privateModel: LinearRegressionModel,
-    private val diagInvAtWA: Array[Double]) extends Serializable {
+    private val diagInvAtWA: Array[Double]) extends Summary with Serializable {
 
   @transient private val metrics = {
     val weightCol =
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index d7cc27e27427..a3ab3ea67557 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -44,6 +44,7 @@ from pyspark.ml.param.shared import (
     HasLoss,
     HasVarianceCol,
 )
+from pyspark.ml.util import try_remote_attribute_relation
 from pyspark.ml.tree import (
     _DecisionTreeModel,
     _DecisionTreeParams,
@@ -517,6 +518,7 @@ class LinearRegressionSummary(JavaWrapper):
 
     @property
     @since("2.0.0")
+    @try_remote_attribute_relation
     def predictions(self) -> DataFrame:
         """
         Dataframe outputted by the model's `transform` method.
@@ -651,6 +653,7 @@ class LinearRegressionSummary(JavaWrapper):
 
     @property
     @since("2.0.0")
+    @try_remote_attribute_relation
     def residuals(self) -> DataFrame:
         """
         Residuals (label - predicted value)
diff --git a/python/pyspark/ml/tests/connect/test_parity_regression.py 
b/python/pyspark/ml/tests/connect/test_parity_regression.py
new file mode 100644
index 000000000000..67187bb74bd5
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_regression.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+from pyspark.ml.tests.test_regression import RegressionTestsMixin
+from pyspark.sql import SparkSession
+
+
+class RegressionParityTests(RegressionTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote(
+            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+        ).getOrCreate()
+
+    def test_assert_remote_mode(self):
+        from pyspark.sql import is_remote
+
+        self.assertTrue(is_remote())
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_regression import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_regression.py 
b/python/pyspark/ml/tests/test_regression.py
new file mode 100644
index 000000000000..305e2514a382
--- /dev/null
+++ b/python/pyspark/ml/tests/test_regression.py
@@ -0,0 +1,180 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.linalg import Vectors
+from pyspark.sql import SparkSession
+from pyspark.ml.regression import (
+    LinearRegression,
+    LinearRegressionModel,
+    LinearRegressionSummary,
+    LinearRegressionTrainingSummary,
+)
+
+
+class RegressionTestsMixin:
+    @property
+    def df(self):
+        return (
+            self.spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
+
+    def test_linear_regression(self):
+        df = self.df
+        lr = LinearRegression(
+            regParam=0.0,
+            maxIter=2,
+            solver="normal",
+            weightCol="weight",
+        )
+        self.assertEqual(lr.getRegParam(), 0)
+        self.assertEqual(lr.getMaxIter(), 2)
+        self.assertEqual(lr.getSolver(), "normal")
+        self.assertEqual(lr.getWeightCol(), "weight")
+
+        # Estimator save & load
+        with tempfile.TemporaryDirectory(prefix="linear_regression") as d:
+            lr.write().overwrite().save(d)
+            lr2 = LinearRegression.load(d)
+            self.assertEqual(str(lr), str(lr2))
+
+        model = lr.fit(df)
+        self.assertEqual(model.numFeatures, 2)
+        self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4))
+        self.assertTrue(np.allclose(model.intercept, -0.35, atol=1e-4))
+        self.assertTrue(np.allclose(model.coefficients, [0.65, 0.1125], 
atol=1e-4))
+
+        output = model.transform(df)
+        expected_cols = [
+            "label",
+            "weight",
+            "features",
+            "prediction",
+        ]
+        self.assertEqual(output.columns, expected_cols)
+        self.assertEqual(output.count(), 4)
+
+        self.assertTrue(
+            np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 
0.21249999999999963, atol=1e-4)
+        )
+
+        # Model summary
+        summary = model.summary
+        self.assertTrue(isinstance(summary, LinearRegressionSummary))
+        self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary))
+        self.assertEqual(summary.predictions.columns, expected_cols)
+        self.assertEqual(summary.predictions.count(), 4)
+        self.assertEqual(summary.residuals.columns, ["residuals"])
+        self.assertEqual(summary.residuals.count(), 4)
+
+        self.assertEqual(summary.degreesOfFreedom, 1)
+        self.assertEqual(summary.numInstances, 4)
+        self.assertEqual(summary.objectiveHistory, [0.0])
+        self.assertTrue(
+            np.allclose(
+                summary.coefficientStandardErrors,
+                [1.2859821149611763, 0.6248749874975031, 3.1645497310044184],
+                atol=1e-4,
+            )
+        )
+        self.assertTrue(
+            np.allclose(
+                summary.devianceResiduals, [-0.7424621202458727, 
0.7875000000000003], atol=1e-4
+            )
+        )
+        self.assertTrue(
+            np.allclose(
+                summary.pValues,
+                [0.7020630236843428, 0.8866003086182783, 0.9298746994547682],
+                atol=1e-4,
+            )
+        )
+        self.assertTrue(
+            np.allclose(
+                summary.tValues,
+                [0.5054502643838291, 0.1800360108036021, -0.11060025272186746],
+                atol=1e-4,
+            )
+        )
+        self.assertTrue(np.allclose(summary.explainedVariance, 
0.07997500000000031, atol=1e-4))
+        self.assertTrue(np.allclose(summary.meanAbsoluteError, 
0.4200000000000002, atol=1e-4))
+        self.assertTrue(np.allclose(summary.meanSquaredError, 
0.20212500000000005, atol=1e-4))
+        self.assertTrue(np.allclose(summary.rootMeanSquaredError, 
0.44958314025327956, atol=1e-4))
+        self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4))
+        self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414, 
atol=1e-4))
+
+        summary2 = model.evaluate(df)
+        self.assertTrue(isinstance(summary2, LinearRegressionSummary))
+        self.assertFalse(isinstance(summary2, LinearRegressionTrainingSummary))
+        self.assertEqual(summary2.predictions.columns, expected_cols)
+        self.assertEqual(summary2.predictions.count(), 4)
+        self.assertEqual(summary2.residuals.columns, ["residuals"])
+        self.assertEqual(summary2.residuals.count(), 4)
+
+        self.assertEqual(summary2.degreesOfFreedom, 1)
+        self.assertEqual(summary2.numInstances, 4)
+        self.assertTrue(
+            np.allclose(
+                summary2.devianceResiduals, [-0.7424621202458727, 
0.7875000000000003], atol=1e-4
+            )
+        )
+        self.assertTrue(np.allclose(summary2.explainedVariance, 
0.07997500000000031, atol=1e-4))
+        self.assertTrue(np.allclose(summary2.meanAbsoluteError, 
0.4200000000000002, atol=1e-4))
+        self.assertTrue(np.allclose(summary2.meanSquaredError, 
0.20212500000000005, atol=1e-4))
+        self.assertTrue(np.allclose(summary2.rootMeanSquaredError, 
0.44958314025327956, atol=1e-4))
+        self.assertTrue(np.allclose(summary2.r2, 0.4427212572373862, 
atol=1e-4))
+        self.assertTrue(np.allclose(summary2.r2adj, -0.6718362282878414, 
atol=1e-4))
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="linear_regression_model") as 
d:
+            model.write().overwrite().save(d)
+            model2 = LinearRegressionModel.load(d)
+            self.assertEqual(str(model), str(model2))
+
+
+class RegressionTests(RegressionTestsMixin, unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_regression import *  # noqa: F401,F403
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 74a07ec365b3..6006d131b5c0 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -580,6 +580,7 @@ class GeneralJavaMLWritable(JavaMLWritable):
     (Private) Mixin for ML instances that provide 
:py:class:`GeneralJavaMLWriter`.
     """
 
+    @try_remote_write
     def write(self) -> GeneralJavaMLWriter:
         """Returns an GeneralMLWriter instance for this ML instance."""
         return GeneralJavaMLWriter(self)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index e6e78f15b61f..145afd90e77e 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -394,6 +394,7 @@ private[ml] object MLUtils {
     "featureImportances", // Tree models
     "predictRaw", // ClassificationModel
     "predictProbability", // ProbabilisticClassificationModel
+    "scale", // LinearRegressionModel
     "coefficients",
     "intercept",
     "coefficientMatrix",
@@ -428,7 +429,20 @@ private[ml] object MLUtils {
     "probabilityCol",
     "featuresCol", // LogisticRegressionSummary
     "objectiveHistory",
-    "totalIterations" // _TrainingSummary
+    "coefficientStandardErrors", // _TrainingSummary
+    "degreesOfFreedom", // LinearRegressionSummary
+    "devianceResiduals", // LinearRegressionSummary
+    "explainedVariance", // LinearRegressionSummary
+    "meanAbsoluteError", // LinearRegressionSummary
+    "meanSquaredError", // LinearRegressionSummary
+    "numInstances", // LinearRegressionSummary
+    "pValues", // LinearRegressionSummary
+    "r2", // LinearRegressionSummary
+    "r2adj", // LinearRegressionSummary
+    "residuals", // LinearRegressionSummary
+    "rootMeanSquaredError", // LinearRegressionSummary
+    "tValues", // LinearRegressionSummary
+    "totalIterations" // LinearRegressionSummary
   )
 
   def invokeMethodAllowed(obj: Object, methodName: String): Object = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to