Repository: spark Updated Branches: refs/heads/master 098be27ad -> ce2b056d3
[SPARK-10686] [ML] Add quantilesCol to AFTSurvivalRegression By default ```quantilesCol``` should be empty. If ```quantileProbabilities``` is set, we should append quantiles as a new column (of type Vector). Author: Yanbo Liang <yblia...@gmail.com> Closes #8836 from yanboliang/spark-10686. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce2b056d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce2b056d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce2b056d Branch: refs/heads/master Commit: ce2b056d35c0c75d5c162b93680ee2d84152e911 Parents: 098be27 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Sep 23 15:26:02 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed Sep 23 15:26:02 2015 -0700 ---------------------------------------------------------------------- .../ml/regression/AFTSurvivalRegression.scala | 51 +++++++++++--- .../regression/AFTSurvivalRegressionSuite.scala | 74 +++++++++++++------- 2 files changed, 91 insertions(+), 34 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ce2b056d/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 5b25db6..717caac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait AFTSurvivalRegressionParams extends Params with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter - with HasTol with HasFitIntercept { + with HasTol with HasFitIntercept with Logging { /** * Param for censor column name. @@ -59,21 +59,35 @@ private[regression] trait AFTSurvivalRegressionParams extends Params /** * Param for quantile probabilities array. - * Values of the quantile probabilities array should be in the range [0, 1]. + * Values of the quantile probabilities array should be in the range [0, 1] + * and the array should be non-empty. * @group param */ @Since("1.6.0") final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, "quantileProbabilities", "quantile probabilities array", - (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)) && t.length > 0) /** @group getParam */ @Since("1.6.0") def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + setDefault(quantileProbabilities -> Array(0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)) - /** Checks whether the input has quantile probabilities array. */ - protected[regression] def hasQuantileProbabilities: Boolean = { - isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + /** + * Param for quantiles column name. + * This column will output quantiles of corresponding quantileProbabilities if it is set. + * @group param + */ + @Since("1.6.0") + final val quantilesCol: Param[String] = new Param(this, "quantilesCol", "quantiles column name") + + /** @group getParam */ + @Since("1.6.0") + def getQuantilesCol: String = $(quantilesCol) + + /** Checks whether the input has quantiles column name. */ + protected[regression] def hasQuantilesCol: Boolean = { + isDefined(quantilesCol) && $(quantilesCol) != "" } /** @@ -90,6 +104,9 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) } + if (hasQuantilesCol) { + SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) + } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } } @@ -124,6 +141,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("1.6.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + /** * Set if we should fit the intercept * Default is true. @@ -243,10 +268,12 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + /** @group setParam */ + @Since("1.6.0") + def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { - require(hasQuantileProbabilities, - "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") // scale parameter for the Weibull distribution of lifetime val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) // shape parameter for the Weibull distribution of lifetime @@ -266,7 +293,13 @@ class AFTSurvivalRegressionModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + if (hasQuantilesCol) { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol)))) + } else { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } } @Since("1.6.0") http://git-wip-us.apache.org/repos/asf/spark/blob/ce2b056d/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index ca7140a..359f310 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -22,8 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vectors} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -59,16 +58,20 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(aftr.getFitIntercept) assert(aftr.getMaxIter === 100) assert(aftr.getTol === 1E-6) - val model = aftr.fit(datasetUnivariate) + val model = aftr.setQuantileProbabilities(Array(0.1, 0.8)) + .setQuantilesCol("quantiles") + .fit(datasetUnivariate) // copied model must have the same parent. MLTestingUtils.checkCopy(model) model.transform(datasetUnivariate) - .select("label", "prediction") + .select("label", "prediction", "quantiles") .collect() assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") + assert(model.getQuantileProbabilities === Array(0.1, 0.8)) + assert(model.getQuantilesCol === "quantiles") assert(model.intercept !== 0.0) assert(model.hasParent) } @@ -108,7 +111,10 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex } test("aft survival regression with univariate") { - val trainer = new AFTSurvivalRegression + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") val model = trainer.fit(datasetUnivariate) /* @@ -159,23 +165,25 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 0.1879174 2.6801195 14.5779394 */ val features = Vectors.dense(6.559282795753792) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 4.494763 val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetUnivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) } } test("aft survival regression with multivariate") { - val trainer = new AFTSurvivalRegression + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") val model = trainer.fit(datasetMultivariate) /* @@ -227,23 +235,26 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 0.5287044 3.3285858 10.7517072 */ val features = Vectors.dense(2.233396950271428, -2.5321374085997683) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 4.761219 val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) } } test("aft survival regression w/o intercept") { - val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + .setFitIntercept(false) val model = trainer.fit(datasetMultivariate) /* @@ -294,18 +305,31 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 1.452103 25.506077 158.428600 */ val features = Vectors.dense(2.233396950271428, -2.5321374085997683) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 44.54465 val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) + } + } + + test("aft survival regression w/o quantiles column") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + val outputDf = model.transform(datasetUnivariate) + + assert(outputDf.schema.fieldNames.contains("quantiles") === false) + + outputDf.select("features", "prediction") + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction ~== model.predict(features) relTol 1E-5) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org