Repository: spark
Updated Branches:
  refs/heads/branch-2.2 cb64064dc -> 708f68c8a


[SPARK-20669][ML] LoR.family and LDA.optimizer should be case insensitive

## What changes were proposed in this pull request?
make param `family` in LoR and `optimizer` in LDA case insensitive

## How was this patch tested?
updated tests

yanboliang

Author: Zheng RuiFeng <[email protected]>

Closes #17910 from zhengruifeng/lr_family_lowercase.

(cherry picked from commit 9970aa0962ec253a6e838aea26a627689dc5b011)
Signed-off-by: Yanbo Liang <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/708f68c8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/708f68c8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/708f68c8

Branch: refs/heads/branch-2.2
Commit: 708f68c8a4f0a57b8774af9d3e3cd010cd8aff5d
Parents: cb64064
Author: Zheng RuiFeng <[email protected]>
Authored: Mon May 15 23:21:44 2017 +0800
Committer: Yanbo Liang <[email protected]>
Committed: Mon May 15 23:21:57 2017 +0800

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |  4 +--
 .../org/apache/spark/ml/clustering/LDA.scala    | 30 ++++++++++----------
 .../LogisticRegressionSuite.scala               | 11 +++++++
 .../apache/spark/ml/clustering/LDASuite.scala   | 10 +++++++
 4 files changed, 38 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 42dc7fb..0534872 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams 
extends ProbabilisticClas
   final val family: Param[String] = new Param(this, "family",
     "The name of family which is a description of the label distribution to be 
used in the " +
       s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
-    ParamValidators.inArray[String](supportedFamilyNames))
+    (value: String) => 
supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))
 
   /** @group getParam */
   @Since("2.1.0")
@@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") (
       case None => histogram.length
     }
 
-    val isMultinomial = $(family) match {
+    val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
       case "binomial" =>
         require(numClasses == 1 || numClasses == 2, s"Binomial family only 
supports 1 or 2 " +
         s"outcome classes but found $numClasses.")

http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index e3026c8..3da29b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
   @Since("1.6.0")
   final val optimizer = new Param[String](this, "optimizer", "Optimizer or 
inference" +
     " algorithm used to estimate the LDA model. Supported: " + 
supportedOptimizers.mkString(", "),
-    (o: String) =>
-      
ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT)))
+    (value: String) => 
supportedOptimizers.contains(value.toLowerCase(Locale.ROOT)))
 
   /** @group getParam */
   @Since("1.6.0")
@@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
           s" ${getDocConcentration.length}, but k = $getK.  docConcentration 
must be an array of" +
           s" length either 1 (scalar) or k (num topics).")
       }
-      getOptimizer match {
+      getOptimizer.toLowerCase(Locale.ROOT) match {
         case "online" =>
           require(getDocConcentration.forall(_ >= 0),
             "For Online LDA optimizer, docConcentration values must be >= 0.  
Found values: " +
@@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
       }
     }
     if (isSet(topicConcentration)) {
-      getOptimizer match {
+      getOptimizer.toLowerCase(Locale.ROOT) match {
         case "online" =>
           require(getTopicConcentration >= 0, s"For Online LDA optimizer, 
topicConcentration" +
             s" must be >= 0.  Found value: $getTopicConcentration")
@@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with 
HasFeaturesCol with HasM
     SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
   }
 
-  private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer 
match {
-    case "online" =>
-      new OldOnlineLDAOptimizer()
-        .setTau0($(learningOffset))
-        .setKappa($(learningDecay))
-        .setMiniBatchFraction($(subsamplingRate))
-        .setOptimizeDocConcentration($(optimizeDocConcentration))
-    case "em" =>
-      new OldEMLDAOptimizer()
-        .setKeepLastCheckpoint($(keepLastCheckpoint))
-  }
+  private[clustering] def getOldOptimizer: OldLDAOptimizer =
+    getOptimizer.toLowerCase(Locale.ROOT) match {
+      case "online" =>
+        new OldOnlineLDAOptimizer()
+          .setTau0($(learningOffset))
+          .setKappa($(learningDecay))
+          .setMiniBatchFraction($(subsamplingRate))
+          .setOptimizeDocConcentration($(optimizeDocConcentration))
+      case "em" =>
+        new OldEMLDAOptimizer()
+          .setKeepLastCheckpoint($(keepLastCheckpoint))
+    }
 }
 
 private object LDAParams {

http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index bf6bfe3..1ffd8dc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2582,6 +2582,17 @@ class LogisticRegressionSuite
         assert(expected.coefficients.toArray === actual.coefficients.toArray)
       }
   }
+
+  test("string params should be case-insensitive") {
+    val lr = new LogisticRegression()
+    Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset),
+      ("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) 
=>
+      lr.setFamily(family)
+      assert(lr.getFamily === family)
+      val model = lr.fit(data)
+      assert(model.getFamily === family)
+    }
+  }
 }
 
 object LogisticRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/708f68c8/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index b4fe63a..e73bbc1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -313,4 +313,14 @@ class LDASuite extends SparkFunSuite with 
MLlibTestSparkContext with DefaultRead
 
     assert(model.getCheckpointFiles.isEmpty)
   }
+
+  test("string params should be case-insensitive") {
+    val lda = new LDA()
+    Seq("eM", "oNLinE").foreach { optimizer =>
+      lda.setOptimizer(optimizer)
+      assert(lda.getOptimizer === optimizer)
+      val model = lda.fit(dataset)
+      assert(model.getOptimizer === optimizer)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to