Repository: spark Updated Branches: refs/heads/master 540bdee93 -> fe16fd0b8
[SPARK-10349] [ML] OneVsRest use 'when ... otherwise' not UDF to generate new label at binary reduction Currently OneVsRest use UDF to generate new binary label during training. Considering that [SPARK-7321](https://issues.apache.org/jira/browse/SPARK-7321) has been merged, we can use ```when ... otherwise``` which will be more efficiency. Author: Yanbo Liang <yblia...@gmail.com> Closes #8519 from yanboliang/spark-10349. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fe16fd0b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fe16fd0b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fe16fd0b Branch: refs/heads/master Commit: fe16fd0b8b717f01151bc659ec3299dab091c97a Parents: 540bdee Author: Yanbo Liang <yblia...@gmail.com> Authored: Mon Aug 31 16:06:38 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Mon Aug 31 16:06:38 2015 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/OneVsRest.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fe16fd0b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c62e132..debc164 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString val initUDF = udf { () => Map[Int, Double]() } - val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. @@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - val labelUDF = udf { (label: Double) => - if (label.toInt == index) 1.0 else 0.0 - } - // generate new label metadata for the binary problem. - // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val trainingDataset = - multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) + val trainingDataset = multiclassLabeled.withColumn( + labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org