Repository: spark Updated Branches: refs/heads/master ecfb3e73f -> ec03866a7
[SPARK-11343][ML] Allow float and double prediction/label columns in RegressionEvaluator mengxr, felixcheung This pull request just relaxes the type of the prediction/label columns to be float and double. Internally, these columns are casted to double. The other evaluators might need to be changed also. Author: Dominik Dahlem <[email protected]> Closes #9296 from dahlem/ddahlem_regression_evaluator_double_predictions_27102015. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec03866a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec03866a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec03866a Branch: refs/heads/master Commit: ec03866a7ef2d0826520755d47c8c9480148a76c Parents: ecfb3e7 Author: Dominik Dahlem <[email protected]> Authored: Mon Nov 2 16:11:42 2015 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Mon Nov 2 16:11:42 2015 -0800 ---------------------------------------------------------------------- .../spark/ml/evaluation/RegressionEvaluator.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ec03866a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 3fd34d8..ba012f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -23,7 +23,8 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DoubleType, FloatType} /** * :: Experimental :: @@ -72,10 +73,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui @Since("1.4.0") override def evaluate(dataset: DataFrame): Double = { val schema = dataset.schema - SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + val predictionType = schema($(predictionCol)).dataType + require(predictionType == FloatType || predictionType == DoubleType) + val labelType = schema($(labelCol)).dataType + require(labelType == FloatType || labelType == DoubleType) - val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + val predictionAndLabels = dataset + .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType)) .map { case Row(prediction: Double, label: Double) => (prediction, label) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
