This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 40b8dfab316e [SPARK-50924][SPARK-50926][ML][PYTHON][CONNECT] Support
AFTSurvivalRegression and IsotonicRegression on Connect
40b8dfab316e is described below
commit 40b8dfab316ed30f379297b20acfa90859ef8fd0
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 27 16:23:06 2025 +0800
[SPARK-50924][SPARK-50926][ML][PYTHON][CONNECT] Support
AFTSurvivalRegression and IsotonicRegression on Connect
### What changes were proposed in this pull request?
Support AFTSurvivalRegression and IsotonicRegression on Connect
### Why are the changes needed?
feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49687 from zhengruifeng/ml_connect_aft.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 3ba76bff53695f78538c8059bd8db9f39a22cdf6)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Estimator | 2 +
.../services/org.apache.spark.ml.Transformer | 2 +
.../ml/regression/AFTSurvivalRegression.scala | 3 +
.../spark/ml/regression/IsotonicRegression.scala | 2 +
python/pyspark/ml/tests/test_regression.py | 102 +++++++++++++++++++++
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 6 ++
6 files changed, 117 insertions(+)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index ef56903de5e0..595355b0c1e4 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -28,6 +28,8 @@ org.apache.spark.ml.classification.RandomForestClassifier
org.apache.spark.ml.classification.GBTClassifier
# regression
+org.apache.spark.ml.regression.AFTSurvivalRegression
+org.apache.spark.ml.regression.IsotonicRegression
org.apache.spark.ml.regression.LinearRegression
org.apache.spark.ml.regression.GeneralizedLinearRegression
org.apache.spark.ml.regression.DecisionTreeRegressor
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index c973a9899878..0375bac51d39 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -44,6 +44,8 @@
org.apache.spark.ml.classification.RandomForestClassificationModel
org.apache.spark.ml.classification.GBTClassificationModel
# regression
+org.apache.spark.ml.regression.AFTSurvivalRegressionModel
+org.apache.spark.ml.regression.IsotonicRegressionModel
org.apache.spark.ml.regression.LinearRegressionModel
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
org.apache.spark.ml.regression.DecisionTreeRegressionModel
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 6451cbf0329d..d9f7af73ce33 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -371,6 +371,9 @@ class AFTSurvivalRegressionModel private[ml] (
extends RegressionModel[Vector, AFTSurvivalRegressionModel] with
AFTSurvivalRegressionParams
with MLWritable {
+ private[ml] def this() = this(Identifiable.randomUID("aftSurvReg"),
+ Vectors.empty, Double.NaN, Double.NaN)
+
@Since("3.0.0")
override def numFeatures: Int = coefficients.size
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index d624270af89d..53850089a5a4 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -213,6 +213,8 @@ class IsotonicRegressionModel private[ml] (
private val oldModel: MLlibIsotonicRegressionModel)
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with
MLWritable {
+ private[ml] def this() = this(Identifiable.randomUID("isoReg"), null)
+
/** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
diff --git a/python/pyspark/ml/tests/test_regression.py
b/python/pyspark/ml/tests/test_regression.py
index ed357127d983..322a3d70f9e9 100644
--- a/python/pyspark/ml/tests/test_regression.py
+++ b/python/pyspark/ml/tests/test_regression.py
@@ -23,6 +23,10 @@ import numpy as np
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession
from pyspark.ml.regression import (
+ AFTSurvivalRegression,
+ AFTSurvivalRegressionModel,
+ IsotonicRegression,
+ IsotonicRegressionModel,
LinearRegression,
LinearRegressionModel,
GeneralizedLinearRegression,
@@ -57,6 +61,104 @@ class RegressionTestsMixin:
.sortWithinPartitions("weight")
)
+ def test_aft_survival(self):
+ spark = self.spark
+ df = spark.createDataFrame(
+ [(1.0, Vectors.dense(1.0), 1.0), (1e-40, Vectors.sparse(1, [],
[]), 0.0)],
+ ["label", "features", "censor"],
+ )
+
+ aft = AFTSurvivalRegression()
+ aft.setMaxIter(1)
+ self.assertEqual(aft.getMaxIter(), 1)
+
+ model = aft.fit(df)
+ self.assertEqual(aft.uid, model.uid)
+ self.assertEqual(model.numFeatures, 1)
+ self.assertTrue(np.allclose(model.intercept, 0.0, atol=1e-4),
model.intercept)
+ self.assertTrue(
+ np.allclose(model.coefficients.toArray(), [0.0], atol=1e-4),
model.coefficients
+ )
+ self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4), model.scale)
+
+ vec = Vectors.dense(6.3)
+ pred = model.predict(vec)
+ self.assertEqual(pred, 1.0)
+ pred = model.predictQuantiles(vec)
+ self.assertTrue(
+ np.allclose(
+ pred,
+ [
+ 0.010050335853501444,
+ 0.051293294387550536,
+ 0.1053605156578263,
+ 0.2876820724517809,
+ 0.6931471805599453,
+ 1.3862943611198906,
+ 2.302585092994046,
+ 2.9957322735539895,
+ 4.60517018598809,
+ ],
+ atol=1e-4,
+ ),
+ pred,
+ )
+
+ output = model.transform(df)
+ expected_cols = ["label", "features", "censor", "prediction"]
+ self.assertEqual(output.columns, expected_cols)
+ self.assertEqual(output.count(), 2)
+
+ # Model save & load
+ with tempfile.TemporaryDirectory(prefix="aft_survival") as d:
+ aft.write().overwrite().save(d)
+ aft2 = AFTSurvivalRegression.load(d)
+ self.assertEqual(str(aft), str(aft2))
+
+ model.write().overwrite().save(d)
+ model2 = AFTSurvivalRegressionModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
+ def test_isotonic_regression(self):
+ spark = self.spark
+ df = spark.createDataFrame(
+ [(1.0, Vectors.dense(1.0)), (0.0, Vectors.sparse(1, [], []))],
["label", "features"]
+ )
+
+ ir = IsotonicRegression(
+ isotonic=True,
+ featureIndex=0,
+ )
+ self.assertTrue(ir.getIsotonic())
+ self.assertEqual(ir.getFeatureIndex(), 0)
+
+ model = ir.fit(df)
+ self.assertEqual(model.numFeatures, 1)
+ self.assertTrue(
+ np.allclose(model.boundaries.toArray(), [0.0, 1.0], atol=1e-4),
model.boundaries
+ )
+ self.assertTrue(
+ np.allclose(model.predictions.toArray(), [0.0, 1.0], atol=1e-4),
model.predictions
+ )
+
+ pred = model.predict(1.0)
+ self.assertTrue(np.allclose(pred, 1.0, atol=1e-4), pred)
+
+ output = model.transform(df)
+ expected_cols = ["label", "features", "prediction"]
+ self.assertEqual(output.columns, expected_cols)
+ self.assertEqual(output.count(), 2)
+
+ # Model save & load
+ with tempfile.TemporaryDirectory(prefix="isotonic_regression") as d:
+ ir.write().overwrite().save(d)
+ ir2 = IsotonicRegression.load(d)
+ self.assertEqual(str(ir), str(ir2))
+
+ model.write().overwrite().save(d)
+ model2 = IsotonicRegressionModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
def test_linear_regression(self):
df = self.df
lr = LinearRegression(
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 86655c7045bf..38181590484b 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -533,6 +533,12 @@ private[ml] object MLUtils {
(classOf[MultilayerPerceptronClassificationModel], Set("weights",
"evaluate")),
// Regression Models
+ (
+ classOf[AFTSurvivalRegressionModel],
+ Set("intercept", "coefficients", "scale", "predictQuantiles")),
+ (
+ classOf[IsotonicRegressionModel],
+ Set("boundaries", "predictions", "numFeatures", "predict")),
(
classOf[GeneralizedLinearRegressionModel],
Set("intercept", "coefficients", "numFeatures", "evaluate")),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]