Repository: spark Updated Branches: refs/heads/branch-2.2 cb64064dc -> 708f68c8a
[SPARK-20669][ML] LoR.family and LDA.optimizer should be case insensitive ## What changes were proposed in this pull request? make param `family` in LoR and `optimizer` in LDA case insensitive ## How was this patch tested? updated tests yanboliang Author: Zheng RuiFeng <[email protected]> Closes #17910 from zhengruifeng/lr_family_lowercase. (cherry picked from commit 9970aa0962ec253a6e838aea26a627689dc5b011) Signed-off-by: Yanbo Liang <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/708f68c8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/708f68c8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/708f68c8 Branch: refs/heads/branch-2.2 Commit: 708f68c8a4f0a57b8774af9d3e3cd010cd8aff5d Parents: cb64064 Author: Zheng RuiFeng <[email protected]> Authored: Mon May 15 23:21:44 2017 +0800 Committer: Yanbo Liang <[email protected]> Committed: Mon May 15 23:21:57 2017 +0800 ---------------------------------------------------------------------- .../ml/classification/LogisticRegression.scala | 4 +-- .../org/apache/spark/ml/clustering/LDA.scala | 30 ++++++++++---------- .../LogisticRegressionSuite.scala | 11 +++++++ .../apache/spark/ml/clustering/LDASuite.scala | 10 +++++++ 4 files changed, 38 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 42dc7fb..0534872 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the label distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - ParamValidators.inArray[String](supportedFamilyNames)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("2.1.0") @@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") ( case None => histogram.length } - val isMultinomial = $(family) match { + val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match { case "binomial" => require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " + s"outcome classes but found $numClasses.") http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index e3026c8..3da29b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), - (o: String) => - ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT))) + (value: String) => supportedOptimizers.contains(value.toLowerCase(Locale.ROOT))) /** @group getParam */ @Since("1.6.0") @@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" + s" length either 1 (scalar) or k (num topics).") } - getOptimizer match { + getOptimizer.toLowerCase(Locale.ROOT) match { case "online" => require(getDocConcentration.forall(_ >= 0), "For Online LDA optimizer, docConcentration values must be >= 0. Found values: " + @@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } if (isSet(topicConcentration)) { - getOptimizer match { + getOptimizer.toLowerCase(Locale.ROOT) match { case "online" => require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" + s" must be >= 0. Found value: $getTopicConcentration") @@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } - private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match { - case "online" => - new OldOnlineLDAOptimizer() - .setTau0($(learningOffset)) - .setKappa($(learningDecay)) - .setMiniBatchFraction($(subsamplingRate)) - .setOptimizeDocConcentration($(optimizeDocConcentration)) - case "em" => - new OldEMLDAOptimizer() - .setKeepLastCheckpoint($(keepLastCheckpoint)) - } + private[clustering] def getOldOptimizer: OldLDAOptimizer = + getOptimizer.toLowerCase(Locale.ROOT) match { + case "online" => + new OldOnlineLDAOptimizer() + .setTau0($(learningOffset)) + .setKappa($(learningDecay)) + .setMiniBatchFraction($(subsamplingRate)) + .setOptimizeDocConcentration($(optimizeDocConcentration)) + case "em" => + new OldEMLDAOptimizer() + .setKeepLastCheckpoint($(keepLastCheckpoint)) + } } private object LDAParams { http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bf6bfe3..1ffd8dc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2582,6 +2582,17 @@ class LogisticRegressionSuite assert(expected.coefficients.toArray === actual.coefficients.toArray) } } + + test("string params should be case-insensitive") { + val lr = new LogisticRegression() + Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset), + ("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) => + lr.setFamily(family) + assert(lr.getFamily === family) + val model = lr.fit(data) + assert(model.getFamily === family) + } + } } object LogisticRegressionSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b4fe63a..e73bbc1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -313,4 +313,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getCheckpointFiles.isEmpty) } + + test("string params should be case-insensitive") { + val lda = new LDA() + Seq("eM", "oNLinE").foreach { optimizer => + lda.setOptimizer(optimizer) + assert(lda.getOptimizer === optimizer) + val model = lda.fit(dataset) + assert(model.getOptimizer === optimizer) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
