This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new cb7ae0407d4 [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT cb7ae0407d4 is described below commit cb7ae0407d440feb6c228b1265af50c0006e21e9 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Mar 11 08:45:54 2023 -0600 [SPARK-42747][ML] Fix incorrect internal status of LoR and AFT ### What changes were proposed in this pull request? Add a hook `onParamChange` in `Params.{set, setDefault, clear}`, so that subclass can update the internal status within it. ### Why are the changes needed? In 3.1, we added internal auxiliary variables in LoR and AFT to optimize prediction/transformation. In LoR, when users call `model.{setThreshold, setThresholds}`, the internal status will be correctly updated. But users still can call `model.set(model.threshold, value)`, then the status will not be updated. And when users call `model.clear(model.threshold)`, the status should be updated with default threshold value 0.5. for example: ``` import org.apache.spark.ml.linalg._ import org.apache.spark.ml.classification._ val df = Seq((1.0, 1.0, Vectors.dense(0.0, 5.0)), (0.0, 2.0, Vectors.dense(1.0, 2.0)), (1.0, 3.0, Vectors.dense(2.0, 1.0)), (0.0, 4.0, Vectors.dense(3.0, 3.0))).toDF("label", "weight", "features") val lor = new LogisticRegression().setWeightCol("weight") val model = lor.fit(df) val vec = Vectors.dense(0.0, 5.0) val p0 = model.predict(vec) // return 0.0 model.setThreshold(0.05) // change status val p1 = model.set(model.threshold, 0.5).predict(vec) // return 1.0; but should be 0.0 val p2 = model.clear(model.threshold).predict(vec) // return 1.0; but should be 0.0 ``` what makes it even worse it that `pyspark.ml` always set params via `model.set(model.threshold, value)`, so the internal status is easily out of sync, see the example in [SPARK-42747](https://issues.apache.org/jira/browse/SPARK-42747) ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added ut Closes #40367 from zhengruifeng/ml_param_hook. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Sean Owen <sro...@gmail.com> (cherry picked from commit 5a702f22f49ca6a1b6220ac645e3fce70ec5189d) Signed-off-by: Sean Owen <sro...@gmail.com> --- .../ml/classification/LogisticRegression.scala | 54 +++++++++------------- .../scala/org/apache/spark/ml/param/params.scala | 16 +++---- .../ml/regression/AFTSurvivalRegression.scala | 26 ++++++----- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../classification/LogisticRegressionSuite.scala | 21 +++++++++ .../ml/regression/AFTSurvivalRegressionSuite.scala | 13 ++++++ python/pyspark/ml/tests/test_algorithms.py | 35 ++++++++++++++ 7 files changed, 113 insertions(+), 54 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3ad1e2c17db..adf77eb6113 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1112,46 +1112,36 @@ class LogisticRegressionModel private[spark] ( _intercept } - private lazy val _intercept = interceptVector(0) - private lazy val _interceptVector = interceptVector.toDense - private lazy val _binaryThresholdArray = { - val array = Array(Double.NaN, Double.NaN) - updateBinaryThresholds(array) - array - } - private def _threshold: Double = _binaryThresholdArray(0) - private def _rawThreshold: Double = _binaryThresholdArray(1) - - private def updateBinaryThresholds(array: Array[Double]): Unit = { - if (!isMultinomial) { - val _threshold = getThreshold - array(0) = _threshold - if (_threshold == 0.0) { - array(1) = Double.NegativeInfinity - } else if (_threshold == 1.0) { - array(1) = Double.PositiveInfinity + private val _interceptVector = if (isMultinomial) interceptVector.toDense else null + private val _intercept = if (!isMultinomial) interceptVector(0) else Double.NaN + // Array(0.5, 0.0) is the value for default threshold (0.5) and thresholds (unset) + private var _binaryThresholds: Array[Double] = if (!isMultinomial) Array(0.5, 0.0) else null + + private[ml] override def onParamChange(param: Param[_]): Unit = { + if (!isMultinomial && (param.name == "threshold" || param.name == "thresholds")) { + if (isDefined(threshold) || isDefined(thresholds)) { + val _threshold = getThreshold + if (_threshold == 0.0) { + _binaryThresholds = Array(_threshold, Double.NegativeInfinity) + } else if (_threshold == 1.0) { + _binaryThresholds = Array(_threshold, Double.PositiveInfinity) + } else { + _binaryThresholds = Array(_threshold, math.log(_threshold / (1.0 - _threshold))) + } } else { - array(1) = math.log(_threshold / (1.0 - _threshold)) + _binaryThresholds = null } } } @Since("1.5.0") - override def setThreshold(value: Double): this.type = { - super.setThreshold(value) - updateBinaryThresholds(_binaryThresholdArray) - this - } + override def setThreshold(value: Double): this.type = super.setThreshold(value) @Since("1.5.0") override def getThreshold: Double = super.getThreshold @Since("1.5.0") - override def setThresholds(value: Array[Double]): this.type = { - super.setThresholds(value) - updateBinaryThresholds(_binaryThresholdArray) - this - } + override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds @@ -1223,7 +1213,7 @@ class LogisticRegressionModel private[spark] ( super.predict(features) } else { // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden. - if (score(features) > _threshold) 1 else 0 + if (score(features) > _binaryThresholds(0)) 1 else 0 } override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { @@ -1265,7 +1255,7 @@ class LogisticRegressionModel private[spark] ( super.raw2prediction(rawPrediction) } else { // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden. - if (rawPrediction(1) > _rawThreshold) 1.0 else 0.0 + if (rawPrediction(1) > _binaryThresholds(1)) 1.0 else 0.0 } } @@ -1274,7 +1264,7 @@ class LogisticRegressionModel private[spark] ( super.probability2prediction(probability) } else { // Note: We should use _threshold instead of $(threshold) since getThreshold is overridden. - if (probability(1) > _threshold) 1.0 else 0.0 + if (probability(1) > _binaryThresholds(0)) 1.0 else 0.0 } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index f12c1f995b7..52840e04eae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -726,6 +726,7 @@ trait Params extends Identifiable with Serializable { protected final def set(paramPair: ParamPair[_]): this.type = { shouldOwn(paramPair.param) paramMap.put(paramPair) + onParamChange(paramPair.param) this } @@ -743,6 +744,7 @@ trait Params extends Identifiable with Serializable { final def clear(param: Param[_]): this.type = { shouldOwn(param) paramMap.remove(param) + onParamChange(param) this } @@ -767,8 +769,9 @@ trait Params extends Identifiable with Serializable { * this method gets called. * @param value the default value */ - protected final def setDefault[T](param: Param[T], value: T): this.type = { + protected[ml] final def setDefault[T](param: Param[T], value: T): this.type = { defaultParamMap.put(param -> value) + onParamChange(param) this } @@ -870,7 +873,7 @@ trait Params extends Identifiable with Serializable { params.foreach { param => // copy default Params if (defaultParamMap.contains(param) && to.hasParam(param.name)) { - to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + to.setDefault(to.getParam(param.name), defaultParamMap(param)) } // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { @@ -879,15 +882,8 @@ trait Params extends Identifiable with Serializable { } to } -} -private[ml] object Params { - /** - * Sets a default param value for a `Params`. - */ - private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = { - params.defaultParamMap.put(param -> value) - } + private[ml] def onParamChange(param: Param[_]): Unit = {} } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index c48fe680e80..5ac58431f17 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -379,25 +379,29 @@ class AFTSurvivalRegressionModel private[ml] ( /** @group setParam */ @Since("1.6.0") - def setQuantileProbabilities(value: Array[Double]): this.type = { - set(quantileProbabilities, value) - _quantiles(0) = $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale)) - this - } + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) /** @group setParam */ @Since("1.6.0") def setQuantilesCol(value: String): this.type = set(quantilesCol, value) - private lazy val _quantiles = { - Array($(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale))) + private var _quantiles: Vector = _ + + private[ml] override def onParamChange(param: Param[_]): Unit = { + if (param.name == "quantileProbabilities") { + if (isDefined(quantileProbabilities)) { + _quantiles = Vectors.dense( + $(quantileProbabilities).map(q => math.exp(math.log(-math.log1p(-q)) * scale))) + } else { + _quantiles = null + } + } } private def lambda2Quantiles(lambda: Double): Vector = { - val quantiles = _quantiles(0).clone() - var i = 0 - while (i < quantiles.length) { quantiles(i) *= lambda; i += 1 } - Vectors.dense(quantiles) + val quantiles = _quantiles.copy + BLAS.scal(lambda, quantiles) + quantiles } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index fec05ccf15c..5e38b0aba95 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -563,7 +563,7 @@ private[ml] object DefaultParamsReader { val param = instance.getParam(paramName) val value = param.jsonDecode(compact(render(jsonValue))) if (isDefault) { - Params.setDefault(instance, param, value) + instance.setDefault(param, value) } else { instance.set(param, value) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 15405371a27..15f2e63bc85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2994,6 +2994,27 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest { val expected = "LogisticRegressionModel: uid=logReg, numClasses=2, numFeatures=3" assert(model.toString === expected) } + + test("test internal thresholds") { + val df = Seq( + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.0, 3.0, Vectors.dense(2.0, 1.0)), + (0.0, 4.0, Vectors.dense(3.0, 3.0)) + ).toDF("label", "weight", "features") + + val lor = new LogisticRegression().setWeightCol("weight") + val model = lor.fit(df) + val vec = Vectors.dense(0.0, 5.0) + + val p0 = model.predict(vec) + model.setThreshold(0.05) + val p1 = model.set(model.threshold, 0.5).predict(vec) + val p2 = model.clear(model.threshold).predict(vec) + + assert(p0 === p1) + assert(p0 === p2) + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index c8f692654e4..c91f9dea705 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -481,6 +481,19 @@ class AFTSurvivalRegressionSuite extends MLTest with DefaultReadWriteTest { } } } + + test("test internal quantiles") { + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val aft = new AFTSurvivalRegression().setQuantilesCol("quantiles") + val model = aft.fit(datasetUnivariate) + val vec = Vectors.dense(6.559282795753792) + + val p1 = model.setQuantileProbabilities(quantileProbabilities).predictQuantiles(vec) + model.setQuantileProbabilities(Array(0.2, 0.3, 0.9)) + val p2 = model.set(model.quantileProbabilities, quantileProbabilities).predictQuantiles(vec) + + assert(p1 === p2) + } } object AFTSurvivalRegressionSuite { diff --git a/python/pyspark/ml/tests/test_algorithms.py b/python/pyspark/ml/tests/test_algorithms.py index accdddb29c0..fb2507fe085 100644 --- a/python/pyspark/ml/tests/test_algorithms.py +++ b/python/pyspark/ml/tests/test_algorithms.py @@ -83,6 +83,41 @@ class LogisticRegressionTest(SparkSessionTestCase): np.allclose(model.interceptVector.toArray(), [-0.9057, -1.1392, -0.0033], atol=1e-4) ) + def test_logistic_regression_with_threshold(self): + + df = self.spark.createDataFrame( + [ + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.0, 3.0, Vectors.dense(2.0, 1.0)), + (0.0, 4.0, Vectors.dense(3.0, 3.0)), + ], + ["label", "weight", "features"], + ) + + lor = LogisticRegression(weightCol="weight") + model = lor.fit(df) + + # status changes 1 + for t in [0.0, 0.1, 0.2, 0.5, 1.0]: + model.setThreshold(t).transform(df) + + # status changes 2 + [model.setThreshold(t).predict(Vectors.dense(0.0, 5.0)) for t in [0.0, 0.1, 0.2, 0.5, 1.0]] + + self.assertEqual( + [row.prediction for row in model.setThreshold(0.0).transform(df).collect()], + [1.0, 1.0, 1.0, 1.0], + ) + self.assertEqual( + [row.prediction for row in model.setThreshold(0.5).transform(df).collect()], + [0.0, 1.0, 1.0, 0.0], + ) + self.assertEqual( + [row.prediction for row in model.setThreshold(1.0).transform(df).collect()], + [0.0, 0.0, 0.0, 0.0], + ) + class MultilayerPerceptronClassifierTest(SparkSessionTestCase): def test_raw_and_probability_prediction(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org