srowen closed pull request #23144: [SPARK-26172][ML][WIP] Unify String Params'
case-insensitivity in ML
URL: https://github.com/apache/spark/pull/23144
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 27a7db0b2f5d4..db47daa005b2e 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -92,10 +92,10 @@ private[classification] trait LogisticRegressionParams
extends ProbabilisticClas
* @group param
*/
@Since("2.1.0")
- final val family: Param[String] = new Param(this, "family",
+ final val family: Param[String] = new StringParam(this, "family",
"The name of family which is a description of the label distribution to be
used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
- (value: String) =>
supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))
+ ParamValidators.inArray(supportedFamilyNames), StringParamNormalizer.lower)
/** @group getParam */
@Since("2.1.0")
@@ -537,7 +537,7 @@ class LogisticRegression @Since("1.2.0") (
case None => histogram.length
}
- val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
+ val isMultinomial = getFamily match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"Binomial family only
supports 1 or 2 " +
s"outcome classes but found $numClasses.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 4c50f1e3292bc..7407f5653d59c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.param
import java.lang.reflect.Modifier
import java.util.{List => JList}
+import java.util.Locale
import java.util.NoSuchElementException
import scala.annotation.varargs
@@ -524,6 +525,65 @@ class BooleanParam(parent: String, name: String, doc:
String) // No need for isV
}
}
+/**
+ * :: DeveloperApi ::
+ * Specialized version of `Param[String]` for Java.
+ */
+@DeveloperApi
+class StringParam(
+ parent: String,
+ name: String,
+ doc: String,
+ isValid: String => Boolean,
+ val normalize: String => String)
+ extends Param[String](parent, name, doc, isValid) {
+
+ def this(parent: Identifiable, name: String, doc: String) =
+ this(parent.uid, name, doc, ParamValidators.alwaysTrue,
StringParamNormalizer.identical)
+
+ def this(parent: Identifiable, name: String, doc: String, isValid: String =>
Boolean) =
+ this(parent.uid, name, doc, isValid, StringParamNormalizer.identical)
+
+ def this(parent: Identifiable, name: String, doc: String, isValid: String =>
Boolean,
+ normalize: String => String) = this(parent.uid, name, doc, isValid,
normalize)
+
+ private[param] override def validate(value: String): Unit = {
+ if (!isValid(normalize(value))) {
+ throw new IllegalArgumentException(
+ s"$parent parameter $name given invalid value $value.")
+ }
+ }
+
+ /** Creates a param pair with the given value (for Java). */
+ override def w(value: String): ParamPair[String] = super.w(value)
+
+ override def jsonEncode(value: String): String = {
+ compact(render(JString(value)))
+ }
+
+ override def jsonDecode(json: String): String = {
+ implicit val formats = DefaultFormats
+ parse(json).extract[String]
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Factory methods for common string normalization functions for
`StringParam.normalize`.
+ */
+@DeveloperApi
+object StringParamNormalizer {
+
+ /** (private[param]) Default Normalizer always return the original value */
+ private[ml] def identical: String => String = (s: String) => s
+
+ /** (private[param]) Default Normalizer always return the lower case */
+ private[ml] def lower: String => String = (s: String) =>
s.toLowerCase(Locale.ROOT)
+
+ /** (private[param]) Default Normalizer always return the upper case */
+ private[ml] def upper: String => String = (s: String) =>
s.toUpperCase(Locale.ROOT)
+}
+
/**
* :: DeveloperApi ::
* Specialized version of `Param[Array[String]]` for Java.
@@ -749,7 +809,13 @@ trait Params extends Identifiable with Serializable {
*/
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
- paramMap.put(paramPair)
+ paramPair match {
+ case ParamPair(param: StringParam, value: String) =>
+ paramMap.put(new ParamPair(param, param.normalize(value)))
+
+ case _ =>
+ paramMap.put(paramPair)
+ }
this
}
diff --git
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 84c10e2f85c81..26dd4cf740f2e 100644
---
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.classification
+import java.util.Locale
+
import scala.collection.JavaConverters._
import scala.language.existentials
import scala.util.Random
@@ -251,7 +253,7 @@ class LogisticRegressionSuite extends MLTest with
DefaultReadWriteTest {
test("check summary types for binary and multiclass") {
val lr = new LogisticRegression()
- .setFamily("binomial")
+ .setFamily("Binomial")
.setMaxIter(1)
val blorModel = lr.fit(smallBinaryDataset)
@@ -259,7 +261,7 @@ class LogisticRegressionSuite extends MLTest with
DefaultReadWriteTest {
assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary])
assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
- val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
+ val mlorModel = lr.setFamily("MULtinomial").fit(smallMultinomialDataset)
assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
withClue("cannot get binary summary for multiclass model") {
intercept[RuntimeException] {
@@ -2748,9 +2750,9 @@ class LogisticRegressionSuite extends MLTest with
DefaultReadWriteTest {
Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset),
("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data)
=>
lr.setFamily(family)
- assert(lr.getFamily === family)
+ assert(lr.getFamily === family.toLowerCase(Locale.ROOT))
val model = lr.fit(data)
- assert(model.getFamily === family)
+ assert(model.getFamily === family.toLowerCase(Locale.ROOT))
}
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]