Repository: spark
Updated Branches:
  refs/heads/master ff7cc45f5 -> 54d13bed8


[SPARK-14159][ML] Fixed bug in StringIndexer + related issue in RFormula

## What changes were proposed in this pull request?

StringIndexerModel.transform sets the output column metadata to use name 
inputCol.  It should not.  Fixing this causes a problem with the metadata 
produced by RFormula.

Fix in RFormula: I added the StringIndexer columns to prefixesToRewrite, and I 
modified VectorAttributeRewriter to find and replace all "prefixes" since 
attributes collect multiple prefixes from StringIndexer + Interaction.

Note that "prefixes" is no longer accurate since internal strings may be 
replaced.

## How was this patch tested?

Unit test which failed before this fix.

Author: Joseph K. Bradley <[email protected]>

Closes #11965 from jkbradley/StringIndexer-fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/54d13bed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/54d13bed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/54d13bed

Branch: refs/heads/master
Commit: 54d13bed87fcf2968f77e1f1153e85184ec91d78
Parents: ff7cc45
Author: Joseph K. Bradley <[email protected]>
Authored: Fri Mar 25 16:00:09 2016 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Fri Mar 25 16:00:09 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/feature/RFormula.scala | 15 ++++++---------
 .../org/apache/spark/ml/feature/StringIndexer.scala  |  7 +++----
 .../apache/spark/ml/feature/StringIndexerSuite.scala | 13 +++++++++++++
 3 files changed, 22 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/54d13bed/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index e7ca7ad..12a76db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -125,6 +125,7 @@ class RFormula(override val uid: String)
           encoderStages += new StringIndexer()
             .setInputCol(term)
             .setOutputCol(indexCol)
+          prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>
           (term, term)
@@ -229,7 +230,7 @@ class RFormulaModel private[feature](
   override def copy(extra: ParamMap): RFormulaModel = copyValues(
     new RFormulaModel(uid, resolvedFormula, pipelineModel))
 
-  override def toString: String = s"RFormulaModel(${resolvedFormula}) 
(uid=$uid)"
+  override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
 
   private def transformLabel(dataset: DataFrame): DataFrame = {
     val labelName = resolvedFormula.label
@@ -400,14 +401,10 @@ private class VectorAttributeRewriter(
       val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
       val attrs = group.attributes.get.map { attr =>
         if (attr.name.isDefined) {
-          val name = attr.name.get
-          val replacement = prefixesToRewrite.filter { case (k, _) => 
name.startsWith(k) }
-          if (replacement.nonEmpty) {
-            val (k, v) = replacement.headOption.get
-            attr.withName(v + name.stripPrefix(k))
-          } else {
-            attr
+          val name = prefixesToRewrite.foldLeft(attr.name.get) { case 
(curName, (from, to)) =>
+            curName.replace(from, to)
           }
+          attr.withName(name)
         } else {
           attr
         }
@@ -416,7 +413,7 @@ private class VectorAttributeRewriter(
     }
     val otherCols = dataset.columns.filter(_ != vectorCol).map(dataset.col)
     val rewrittenCol = dataset.col(vectorCol).as(vectorCol, metadata)
-    dataset.select((otherCols :+ rewrittenCol): _*)
+    dataset.select(otherCols :+ rewrittenCol : _*)
   }
 
   override def transformSchema(schema: StructType): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/54d13bed/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index c579a0d..faa0f6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -161,15 +161,14 @@ class StringIndexerModel (
     }
 
     val metadata = NominalAttribute.defaultAttr
-      .withName($(inputCol)).withValues(labels).toMetadata()
+      .withName($(outputCol)).withValues(labels).toMetadata()
     // If we are skipping invalid records, filter them out.
-    val filteredDataset = (getHandleInvalid) match {
-      case "skip" => {
+    val filteredDataset = getHandleInvalid match {
+      case "skip" =>
         val filterer = udf { label: String =>
           labelToIndex.contains(label)
         }
         dataset.where(filterer(dataset($(inputCol))))
-      }
       case _ => dataset
     }
     filteredDataset.select(col("*"),

http://git-wip-us.apache.org/repos/asf/spark/blob/54d13bed/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d40e69d..2c3255e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -210,4 +210,17 @@ class StringIndexerSuite
       .setLabels(Array("a", "b", "c"))
     testDefaultReadWrite(t)
   }
+
+  test("StringIndexer metadata") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, 
"a"), (5, "c")), 2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val attrs =
+      NominalAttribute.decodeStructField(transformed.schema("labelIndex"), 
preserveName = true)
+    assert(attrs.name.nonEmpty && attrs.name.get === "labelIndex")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to