This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 40b8dfab316e [SPARK-50924][SPARK-50926][ML][PYTHON][CONNECT] Support 
AFTSurvivalRegression and IsotonicRegression on Connect
40b8dfab316e is described below

commit 40b8dfab316ed30f379297b20acfa90859ef8fd0
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 27 16:23:06 2025 +0800

    [SPARK-50924][SPARK-50926][ML][PYTHON][CONNECT] Support 
AFTSurvivalRegression and IsotonicRegression on Connect
    
    ### What changes were proposed in this pull request?
     Support AFTSurvivalRegression and IsotonicRegression on Connect
    
    ### Why are the changes needed?
    feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49687 from zhengruifeng/ml_connect_aft.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 3ba76bff53695f78538c8059bd8db9f39a22cdf6)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Estimator         |   2 +
 .../services/org.apache.spark.ml.Transformer       |   2 +
 .../ml/regression/AFTSurvivalRegression.scala      |   3 +
 .../spark/ml/regression/IsotonicRegression.scala   |   2 +
 python/pyspark/ml/tests/test_regression.py         | 102 +++++++++++++++++++++
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |   6 ++
 6 files changed, 117 insertions(+)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index ef56903de5e0..595355b0c1e4 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -28,6 +28,8 @@ org.apache.spark.ml.classification.RandomForestClassifier
 org.apache.spark.ml.classification.GBTClassifier
 
 # regression
+org.apache.spark.ml.regression.AFTSurvivalRegression
+org.apache.spark.ml.regression.IsotonicRegression
 org.apache.spark.ml.regression.LinearRegression
 org.apache.spark.ml.regression.GeneralizedLinearRegression
 org.apache.spark.ml.regression.DecisionTreeRegressor
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index c973a9899878..0375bac51d39 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -44,6 +44,8 @@ 
org.apache.spark.ml.classification.RandomForestClassificationModel
 org.apache.spark.ml.classification.GBTClassificationModel
 
 # regression
+org.apache.spark.ml.regression.AFTSurvivalRegressionModel
+org.apache.spark.ml.regression.IsotonicRegressionModel
 org.apache.spark.ml.regression.LinearRegressionModel
 org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
 org.apache.spark.ml.regression.DecisionTreeRegressionModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 6451cbf0329d..d9f7af73ce33 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -371,6 +371,9 @@ class AFTSurvivalRegressionModel private[ml] (
   extends RegressionModel[Vector, AFTSurvivalRegressionModel] with 
AFTSurvivalRegressionParams
   with MLWritable {
 
+  private[ml] def this() = this(Identifiable.randomUID("aftSurvReg"),
+    Vectors.empty, Double.NaN, Double.NaN)
+
   @Since("3.0.0")
   override def numFeatures: Int = coefficients.size
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index d624270af89d..53850089a5a4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -213,6 +213,8 @@ class IsotonicRegressionModel private[ml] (
     private val oldModel: MLlibIsotonicRegressionModel)
   extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with 
MLWritable {
 
+  private[ml] def this() = this(Identifiable.randomUID("isoReg"), null)
+
   /** @group setParam */
   @Since("1.5.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
diff --git a/python/pyspark/ml/tests/test_regression.py 
b/python/pyspark/ml/tests/test_regression.py
index ed357127d983..322a3d70f9e9 100644
--- a/python/pyspark/ml/tests/test_regression.py
+++ b/python/pyspark/ml/tests/test_regression.py
@@ -23,6 +23,10 @@ import numpy as np
 from pyspark.ml.linalg import Vectors
 from pyspark.sql import SparkSession
 from pyspark.ml.regression import (
+    AFTSurvivalRegression,
+    AFTSurvivalRegressionModel,
+    IsotonicRegression,
+    IsotonicRegressionModel,
     LinearRegression,
     LinearRegressionModel,
     GeneralizedLinearRegression,
@@ -57,6 +61,104 @@ class RegressionTestsMixin:
             .sortWithinPartitions("weight")
         )
 
+    def test_aft_survival(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [(1.0, Vectors.dense(1.0), 1.0), (1e-40, Vectors.sparse(1, [], 
[]), 0.0)],
+            ["label", "features", "censor"],
+        )
+
+        aft = AFTSurvivalRegression()
+        aft.setMaxIter(1)
+        self.assertEqual(aft.getMaxIter(), 1)
+
+        model = aft.fit(df)
+        self.assertEqual(aft.uid, model.uid)
+        self.assertEqual(model.numFeatures, 1)
+        self.assertTrue(np.allclose(model.intercept, 0.0, atol=1e-4), 
model.intercept)
+        self.assertTrue(
+            np.allclose(model.coefficients.toArray(), [0.0], atol=1e-4), 
model.coefficients
+        )
+        self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4), model.scale)
+
+        vec = Vectors.dense(6.3)
+        pred = model.predict(vec)
+        self.assertEqual(pred, 1.0)
+        pred = model.predictQuantiles(vec)
+        self.assertTrue(
+            np.allclose(
+                pred,
+                [
+                    0.010050335853501444,
+                    0.051293294387550536,
+                    0.1053605156578263,
+                    0.2876820724517809,
+                    0.6931471805599453,
+                    1.3862943611198906,
+                    2.302585092994046,
+                    2.9957322735539895,
+                    4.60517018598809,
+                ],
+                atol=1e-4,
+            ),
+            pred,
+        )
+
+        output = model.transform(df)
+        expected_cols = ["label", "features", "censor", "prediction"]
+        self.assertEqual(output.columns, expected_cols)
+        self.assertEqual(output.count(), 2)
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="aft_survival") as d:
+            aft.write().overwrite().save(d)
+            aft2 = AFTSurvivalRegression.load(d)
+            self.assertEqual(str(aft), str(aft2))
+
+            model.write().overwrite().save(d)
+            model2 = AFTSurvivalRegressionModel.load(d)
+            self.assertEqual(str(model), str(model2))
+
+    def test_isotonic_regression(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [(1.0, Vectors.dense(1.0)), (0.0, Vectors.sparse(1, [], []))], 
["label", "features"]
+        )
+
+        ir = IsotonicRegression(
+            isotonic=True,
+            featureIndex=0,
+        )
+        self.assertTrue(ir.getIsotonic())
+        self.assertEqual(ir.getFeatureIndex(), 0)
+
+        model = ir.fit(df)
+        self.assertEqual(model.numFeatures, 1)
+        self.assertTrue(
+            np.allclose(model.boundaries.toArray(), [0.0, 1.0], atol=1e-4), 
model.boundaries
+        )
+        self.assertTrue(
+            np.allclose(model.predictions.toArray(), [0.0, 1.0], atol=1e-4), 
model.predictions
+        )
+
+        pred = model.predict(1.0)
+        self.assertTrue(np.allclose(pred, 1.0, atol=1e-4), pred)
+
+        output = model.transform(df)
+        expected_cols = ["label", "features", "prediction"]
+        self.assertEqual(output.columns, expected_cols)
+        self.assertEqual(output.count(), 2)
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="isotonic_regression") as d:
+            ir.write().overwrite().save(d)
+            ir2 = IsotonicRegression.load(d)
+            self.assertEqual(str(ir), str(ir2))
+
+            model.write().overwrite().save(d)
+            model2 = IsotonicRegressionModel.load(d)
+            self.assertEqual(str(model), str(model2))
+
     def test_linear_regression(self):
         df = self.df
         lr = LinearRegression(
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 86655c7045bf..38181590484b 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -533,6 +533,12 @@ private[ml] object MLUtils {
     (classOf[MultilayerPerceptronClassificationModel], Set("weights", 
"evaluate")),
 
     // Regression Models
+    (
+      classOf[AFTSurvivalRegressionModel],
+      Set("intercept", "coefficients", "scale", "predictQuantiles")),
+    (
+      classOf[IsotonicRegressionModel],
+      Set("boundaries", "predictions", "numFeatures", "predict")),
     (
       classOf[GeneralizedLinearRegressionModel],
       Set("intercept", "coefficients", "numFeatures", "evaluate")),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to