Repository: spark Updated Branches: refs/heads/master 376d90d55 -> 0c8444cf6
[SPARK-14657][SPARKR][ML] RFormula w/o intercept should output reference category when encoding string terms ## What changes were proposed in this pull request? Please see [SPARK-14657](https://issues.apache.org/jira/browse/SPARK-14657) for detail of this bug. I searched online and test some other cases, found when we fit R glm model(or other models powered by R formula) w/o intercept on a dataset including string/category features, one of the categories in the first category feature is being used as reference category, we will not drop any category for that feature. I think we should keep consistent semantics between Spark RFormula and R formula. ## How was this patch tested? Add standard unit tests. cc mengxr Author: Yanbo Liang <yblia...@gmail.com> Closes #12414 from yanboliang/spark-14657. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0c8444cf Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0c8444cf Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0c8444cf Branch: refs/heads/master Commit: 0c8444cf6d0620cd219ddcf5f50b12ff648639e9 Parents: 376d90d Author: Yanbo Liang <yblia...@gmail.com> Authored: Thu Jun 29 10:32:32 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Jun 29 10:32:32 2017 +0800 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/RFormula.scala | 10 ++- .../apache/spark/ml/feature/RFormulaSuite.scala | 83 ++++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0c8444cf/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1fad0a6..4b44878 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -205,12 +205,20 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) }.toMap // Then we handle one-hot encoding and interactions between terms. + var keepReferenceCategory = false val encodedTerms = resolvedFormula.terms.map { case Seq(term) if dataset.schema(term).dataType == StringType => val encodedCol = tmpColumn("onehot") - encoderStages += new OneHotEncoder() + var encoder = new OneHotEncoder() .setInputCol(indexed(term)) .setOutputCol(encodedCol) + // Formula w/o intercept, one of the categories in the first category feature is + // being used as reference category, we will not drop any category for that feature. + if (!hasIntercept && !keepReferenceCategory) { + encoder = encoder.setDropLast(false) + keepReferenceCategory = true + } + encoderStages += encoder prefixesToRewrite(encodedCol + "_") = term + "_" encodedCol case Seq(term) => http://git-wip-us.apache.org/repos/asf/spark/blob/0c8444cf/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 41d0062..23570d6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -213,6 +213,89 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result.collect() === expected.collect()) } + test("formula w/o intercept, we should output reference category when encoding string terms") { + /* + R code: + + df <- data.frame(id = c(1, 2, 3, 4), + a = c("foo", "bar", "bar", "baz"), + b = c("zq", "zz", "zz", "zz"), + c = c(4, 4, 5, 5)) + model.matrix(id ~ a + b + c - 1, df) + + abar abaz afoo bzz c + 1 0 0 1 0 4 + 2 1 0 0 1 4 + 3 1 0 0 1 5 + 4 0 1 0 1 5 + + model.matrix(id ~ a:b + c - 1, df) + + c abar:bzq abaz:bzq afoo:bzq abar:bzz abaz:bzz afoo:bzz + 1 4 0 0 1 0 0 0 + 2 4 0 0 0 1 0 0 + 3 5 0 0 0 1 0 0 + 4 5 0 0 0 0 1 0 + */ + val original = Seq((1, "foo", "zq", 4), (2, "bar", "zz", 4), (3, "bar", "zz", 5), + (4, "baz", "zz", 5)).toDF("id", "a", "b", "c") + + val formula1 = new RFormula().setFormula("id ~ a + b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model1 = formula1.fit(original) + val result1 = model1.transform(original) + val resultSchema1 = model1.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected1 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(5, Array(0, 4), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.dense(0.0, 0.0, 1.0, 1.0, 4.0), 2.0), + (3, "bar", "zz", 5, Vectors.dense(0.0, 0.0, 1.0, 1.0, 5.0), 3.0), + (4, "baz", "zz", 5, Vectors.dense(0.0, 1.0, 0.0, 1.0, 5.0), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result1.schema.toString == resultSchema1.toString) + assert(result1.collect() === expected1.collect()) + + val attrs1 = AttributeGroup.fromStructField(result1.schema("features")) + val expectedAttrs1 = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_foo"), Some(1)), + new BinaryAttribute(Some("a_baz"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(3)), + new BinaryAttribute(Some("b_zz"), Some(4)), + new NumericAttribute(Some("c"), Some(5)))) + assert(attrs1 === expectedAttrs1) + + // There is no impact for string terms interaction. + val formula2 = new RFormula().setFormula("id ~ a:b + c - 1") + .setStringIndexerOrderType(StringIndexer.alphabetDesc) + val model2 = formula2.fit(original) + val result2 = model2.transform(original) + val resultSchema2 = model2.transformSchema(original.schema) + // Note the column order is different between R and Spark. + val expected2 = Seq( + (1, "foo", "zq", 4, Vectors.sparse(7, Array(1, 6), Array(1.0, 4.0)), 1.0), + (2, "bar", "zz", 4, Vectors.sparse(7, Array(4, 6), Array(1.0, 4.0)), 2.0), + (3, "bar", "zz", 5, Vectors.sparse(7, Array(4, 6), Array(1.0, 5.0)), 3.0), + (4, "baz", "zz", 5, Vectors.sparse(7, Array(2, 6), Array(1.0, 5.0)), 4.0) + ).toDF("id", "a", "b", "c", "features", "label") + assert(result2.schema.toString == resultSchema2.toString) + assert(result2.collect() === expected2.collect()) + + val attrs2 = AttributeGroup.fromStructField(result2.schema("features")) + val expectedAttrs2 = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_foo:b_zz"), Some(1)), + new NumericAttribute(Some("a_foo:b_zq"), Some(2)), + new NumericAttribute(Some("a_baz:b_zz"), Some(3)), + new NumericAttribute(Some("a_baz:b_zq"), Some(4)), + new NumericAttribute(Some("a_bar:b_zz"), Some(5)), + new NumericAttribute(Some("a_bar:b_zq"), Some(6)), + new NumericAttribute(Some("c"), Some(7)))) + assert(attrs2 === expectedAttrs2) + } + test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") val original = --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org