Repository: spark Updated Branches: refs/heads/branch-2.1 6f366fbbf -> 2394ae235
[MINOR] Correct validateAndTransformSchema in GaussianMixture and AFTSurvivalRegression ## What changes were proposed in this pull request? The line SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) did not modify the variable schema, hence only the last line had any effect. A temporary variable is used to correctly append the two columns predictionCol and probabilityCol. ## How was this patch tested? Manually. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Cédric Pelvet <[email protected]> Closes #18980 from sharp-pixel/master. (cherry picked from commit 73e04ecc4f29a0fe51687ed1337c61840c976f89) Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2394ae23 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2394ae23 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2394ae23 Branch: refs/heads/branch-2.1 Commit: 2394ae23530259c6de398e57c4e625a68eda842a Parents: 6f366fb Author: Cédric Pelvet <[email protected]> Authored: Sun Aug 20 11:05:54 2017 +0100 Committer: Sean Owen <[email protected]> Committed: Sun Aug 20 11:06:14 2017 +0100 ---------------------------------------------------------------------- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 4 ++-- .../apache/spark/ml/regression/AFTSurvivalRegression.scala | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2394ae23/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ac56845..82aadea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -62,8 +62,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } } http://git-wip-us.apache.org/repos/asf/spark/blob/2394ae23/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index af68e7b..fd74921 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) SchemaUtils.checkNumericType(schema, $(labelCol)) } - if (hasQuantilesCol) { + + val schemaWithQuantilesCol = if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } else schema + + SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType) } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
