Repository: spark
Updated Branches:
  refs/heads/master c2520f501 -> 6c5858bc6


[SPARK-9922] [ML] rename StringIndexerReverse to IndexToString

What `StringIndexerInverse` does is not strictly associated with 
`StringIndexer`, and the name is not clearly describing the transformation. 
Renaming to `IndexToString` might be better.

~~I also changed `invert` to `inverse` without arguments. `inputCol` and 
`outputCol` could be set after.~~
I also removed `invert`.

jkbradley holdenk

Author: Xiangrui Meng <[email protected]>

Closes #8152 from mengxr/SPARK-9922.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c5858bc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c5858bc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c5858bc

Branch: refs/heads/master
Commit: 6c5858bc65c8a8602422b46bfa9cf0a1fb296b88
Parents: c2520f5
Author: Xiangrui Meng <[email protected]>
Authored: Thu Aug 13 16:52:17 2015 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Thu Aug 13 16:52:17 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala | 34 +++++--------
 .../spark/ml/feature/StringIndexerSuite.scala   | 50 ++++++++++++++------
 2 files changed, 48 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6c5858bc/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 9e4b0f0..9f6e7b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, 
NominalAttribute}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, 
StructType}
@@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with 
HasInputCol with Ha
  * If the input column is numeric, we cast it to string and index the string 
values.
  * The indices are in [0, numLabels), ordered by label frequencies.
  * So the most frequent label gets index 0.
+ *
+ * @see [[IndexToString]] for the inverse transformation
  */
 @Experimental
 class StringIndexer(override val uid: String) extends 
Estimator[StringIndexerModel]
@@ -170,34 +172,24 @@ class StringIndexerModel private[ml] (
     val copied = new StringIndexerModel(uid, labels)
     copyValues(copied, extra).setParent(parent)
   }
-
-  /**
-   * Return a model to perform the inverse transformation.
-   * Note: By default we keep the original columns during this transformation, 
so the inverse
-   * should only be used on new columns such as predicted labels.
-   */
-  def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
-    new StringIndexerInverse()
-      .setInputCol(inputCol)
-      .setOutputCol(outputCol)
-      .setLabels(labels)
-  }
 }
 
 /**
  * :: Experimental ::
- * Transform a provided column back to the original input types using either 
the metadata
- * on the input column, or if provided using the labels supplied by the user.
- * Note: By default we keep the original columns during this transformation,
- * so the inverse should only be used on new columns such as predicted labels.
+ * A [[Transformer]] that maps a column of string indices back to a new column 
of corresponding
+ * string values using either the ML attributes of the input column, or if 
provided using the labels
+ * supplied by the user.
+ * All original columns are kept during transformation.
+ *
+ * @see [[StringIndexer]] for converting strings into indices
  */
 @Experimental
-class StringIndexerInverse private[ml] (
+class IndexToString private[ml] (
   override val uid: String) extends Transformer
     with HasInputCol with HasOutputCol {
 
   def this() =
-    this(Identifiable.randomUID("strIdxInv"))
+    this(Identifiable.randomUID("idxToStr"))
 
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
@@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] (
     }
     val indexer = udf { index: Double =>
       val idx = index.toInt
-      if (0 <= idx && idx < values.size) {
+      if (0 <= idx && idx < values.length) {
         values(idx)
       } else {
         throw new SparkException(s"Unseen index: $index ??")
@@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] (
       indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
   }
 
-  override def copy(extra: ParamMap): StringIndexerInverse = {
+  override def copy(extra: ParamMap): IndexToString = {
     defaultCopy(extra)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6c5858bc/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 2d24914..fa918ce 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkException
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.MLTestingUtils
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
 
 class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     // a -> 0, b -> 2, c -> 1
     val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 
1.0))
     assert(output === expected)
-    // convert reverse our transform
-    val reversed = indexer.invert("labelIndex", "label2")
-      .transform(transformed)
-      .select("id", "label2")
-    assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
-      reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
-    // Check invert using only metadata
-    val inverse2 = new StringIndexerInverse()
-      .setInputCol("labelIndex")
-      .setOutputCol("label2")
-    val reversed2 = inverse2.transform(transformed).select("id", "label2")
-    assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
-      reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
   }
 
   test("StringIndexerUnseen") {
@@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val df = sqlContext.range(0L, 10L)
     assert(indexerModel.transform(df).eq(df))
   }
+
+  test("IndexToString params") {
+    val idxToStr = new IndexToString()
+    ParamsSuite.checkParams(idxToStr)
+  }
+
+  test("IndexToString.transform") {
+    val labels = Array("a", "b", "c")
+    val df0 = sqlContext.createDataFrame(Seq(
+      (0, "a"), (1, "b"), (2, "c"), (0, "a")
+    )).toDF("index", "expected")
+
+    val idxToStr0 = new IndexToString()
+      .setInputCol("index")
+      .setOutputCol("actual")
+      .setLabels(labels)
+    idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
+      case Row(actual, expected) =>
+        assert(actual === expected)
+    }
+
+    val attr = NominalAttribute.defaultAttr.withValues(labels)
+    val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), 
col("expected"))
+
+    val idxToStr1 = new IndexToString()
+      .setInputCol("indexWithAttr")
+      .setOutputCol("actual")
+    idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
+      case Row(actual, expected) =>
+        assert(actual === expected)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to