Repository: spark
Updated Branches:
  refs/heads/branch-1.6 c3ed9504d -> 496496b25


[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula - 1.6 
backport

Backport of [https://github.com/apache/spark/pull/11965] for branch-1.6.
There were no merge conflicts.

## What changes were proposed in this pull request?

StringIndexerModel.transform sets the output column metadata to use name 
inputCol.  It should not.  Fixing this causes a problem with the metadata 
produced by RFormula.

Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I 
modified VectorAttributeRewriter to find and replace all "prefixes" since 
attributes collect multiple prefixes from StringIndexer + Interaction.

Note that "prefixes" is no longer accurate since internal strings may be 
replaced.

## How was this patch tested?

Unit test which failed before this fix.

Author: Joseph K. Bradley <[email protected]>

Closes #12595 from jkbradley/StringIndexer-fix-1.6.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/496496b2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/496496b2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/496496b2

Branch: refs/heads/branch-1.6
Commit: 496496b2548387ad2889eaaf884ab039096436ef
Parents: c3ed950
Author: Joseph K. Bradley <[email protected]>
Authored: Tue Apr 26 14:00:39 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Tue Apr 26 14:00:39 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/feature/RFormula.scala | 15 ++++++---------
 .../org/apache/spark/ml/feature/StringIndexer.scala  |  7 +++----
 .../apache/spark/ml/feature/StringIndexerSuite.scala | 13 +++++++++++++
 3 files changed, 22 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/496496b2/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 5c43a41..564c867 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -101,6 +101,7 @@ class RFormula(override val uid: String) extends 
Estimator[RFormulaModel] with R
           encoderStages += new StringIndexer()
             .setInputCol(term)
             .setOutputCol(indexCol)
+          prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>
           (term, term)
@@ -198,7 +199,7 @@ class RFormulaModel private[feature](
   override def copy(extra: ParamMap): RFormulaModel = copyValues(
     new RFormulaModel(uid, resolvedFormula, pipelineModel))
 
-  override def toString: String = s"RFormulaModel(${resolvedFormula}) 
(uid=$uid)"
+  override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
 
   private def transformLabel(dataset: DataFrame): DataFrame = {
     val labelName = resolvedFormula.label
@@ -268,14 +269,10 @@ private class VectorAttributeRewriter(
       val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
       val attrs = group.attributes.get.map { attr =>
         if (attr.name.isDefined) {
-          val name = attr.name.get
-          val replacement = prefixesToRewrite.filter { case (k, _) => 
name.startsWith(k) }
-          if (replacement.nonEmpty) {
-            val (k, v) = replacement.headOption.get
-            attr.withName(v + name.stripPrefix(k))
-          } else {
-            attr
+          val name = prefixesToRewrite.foldLeft(attr.name.get) { case 
(curName, (from, to)) =>
+            curName.replace(from, to)
           }
+          attr.withName(name)
         } else {
           attr
         }
@@ -284,7 +281,7 @@ private class VectorAttributeRewriter(
     }
     val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
     val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
-    dataset.select((otherCols :+ rewrittenCol): _*)
+    dataset.select(otherCols :+ rewrittenCol : _*)
   }
 
   override def transformSchema(schema: StructType): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/496496b2/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index b3413a1..a843cc8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -160,15 +160,14 @@ class StringIndexerModel (
     }
 
     val metadata = NominalAttribute.defaultAttr
-      .withName($(inputCol)).withValues(labels).toMetadata()
+      .withName($(outputCol)).withValues(labels).toMetadata()
     // If we are skipping invalid records, filter them out.
-    val filteredDataset = (getHandleInvalid) match {
-      case "skip" => {
+    val filteredDataset = getHandleInvalid match {
+      case "skip" =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)
         }
         dataset.where(filterer(dataset($(inputCol))))
-      }
       case _ => dataset
     }
     filteredDataset.select(col("*"),

http://git-wip-us.apache.org/repos/asf/spark/blob/496496b2/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 26f4613..6ba4aaa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -210,4 +210,17 @@ class StringIndexerSuite
       .setLabels(Array("a", "b", "c"))
     testDefaultReadWrite(t)
   }
+
+  test("StringIndexer metadata") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, 
"a"), (5, "c")), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val attrs =
+      NominalAttribute.decodeStructField(transformed.schema("labelIndex"), 
preserveName = true)
+    assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to