Repository: spark
Updated Branches:
  refs/heads/branch-1.5 2c7f8da58 -> 2b6b1d12f


[SPARK-9922] [ML] rename StringIndexerReverse to IndexToString

What `StringIndexerInverse` does is not strictly associated with 
`StringIndexer`, and the name is not clearly describing the transformation. 
Renaming to `IndexToString` might be better.

~~I also changed `invert` to `inverse` without arguments. `inputCol` and 
`outputCol` could be set after.~~
I also removed `invert`.

jkbradley holdenk

Author: Xiangrui Meng <m...@databricks.com>

Closes #8152 from mengxr/SPARK-9922.

(cherry picked from commit 6c5858bc65c8a8602422b46bfa9cf0a1fb296b88)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2b6b1d12
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2b6b1d12
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2b6b1d12

Branch: refs/heads/branch-1.5
Commit: 2b6b1d12fb6bd0bd86988babc4c807856011f246
Parents: 2c7f8da
Author: Xiangrui Meng <m...@databricks.com>
Authored: Thu Aug 13 16:52:17 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Aug 13 16:54:06 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala | 34 ++++++--------
 .../spark/ml/feature/StringIndexerSuite.scala   | 47 ++++++++++++++------
 2 files changed, 47 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2b6b1d12/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 569c834..b87e154 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, 
NominalAttribute}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, 
StructType}
@@ -58,6 +58,8 @@ private[feature] trait StringIndexerBase extends Params with 
HasInputCol with Ha
  * If the input column is numeric, we cast it to string and index the string 
values.
  * The indices are in [0, numLabels), ordered by label frequencies.
  * So the most frequent label gets index 0.
+ *
+ * @see [[IndexToString]] for the inverse transformation
  */
 @Experimental
 class StringIndexer(override val uid: String) extends 
Estimator[StringIndexerModel]
@@ -152,34 +154,24 @@ class StringIndexerModel private[ml] (
     val copied = new StringIndexerModel(uid, labels)
     copyValues(copied, extra).setParent(parent)
   }
-
-  /**
-   * Return a model to perform the inverse transformation.
-   * Note: By default we keep the original columns during this transformation, 
so the inverse
-   * should only be used on new columns such as predicted labels.
-   */
-  def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
-    new StringIndexerInverse()
-      .setInputCol(inputCol)
-      .setOutputCol(outputCol)
-      .setLabels(labels)
-  }
 }
 
 /**
  * :: Experimental ::
- * Transform a provided column back to the original input types using either 
the metadata
- * on the input column, or if provided using the labels supplied by the user.
- * Note: By default we keep the original columns during this transformation,
- * so the inverse should only be used on new columns such as predicted labels.
+ * A [[Transformer]] that maps a column of string indices back to a new column 
of corresponding
+ * string values using either the ML attributes of the input column, or if 
provided using the labels
+ * supplied by the user.
+ * All original columns are kept during transformation.
+ *
+ * @see [[StringIndexer]] for converting strings into indices
  */
 @Experimental
-class StringIndexerInverse private[ml] (
+class IndexToString private[ml] (
   override val uid: String) extends Transformer
     with HasInputCol with HasOutputCol {
 
   def this() =
-    this(Identifiable.randomUID("strIdxInv"))
+    this(Identifiable.randomUID("idxToStr"))
 
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
@@ -239,7 +231,7 @@ class StringIndexerInverse private[ml] (
     }
     val indexer = udf { index: Double =>
       val idx = index.toInt
-      if (0 <= idx && idx < values.size) {
+      if (0 <= idx && idx < values.length) {
         values(idx)
       } else {
         throw new SparkException(s"Unseen index: $index ??")
@@ -250,7 +242,7 @@ class StringIndexerInverse private[ml] (
       indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
   }
 
-  override def copy(extra: ParamMap): StringIndexerInverse = {
+  override def copy(extra: ParamMap): IndexToString = {
     defaultCopy(extra)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2b6b1d12/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 124af62..4a12e0b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -22,6 +22,8 @@ import org.apache.spark.ml.attribute.{Attribute, 
NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.MLTestingUtils
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.functions.col
 
 class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -52,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     // a -> 0, b -> 2, c -> 1
     val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 
1.0))
     assert(output === expected)
-    // convert reverse our transform
-    val reversed = indexer.invert("labelIndex", "label2")
-      .transform(transformed)
-      .select("id", "label2")
-    assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
-      reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
-    // Check invert using only metadata
-    val inverse2 = new StringIndexerInverse()
-      .setInputCol("labelIndex")
-      .setOutputCol("label2")
-    val reversed2 = inverse2.transform(transformed).select("id", "label2")
-    assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
-      reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
   }
 
   test("StringIndexer with a numeric input column") {
@@ -93,4 +82,36 @@ class StringIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val df = sqlContext.range(0L, 10L)
     assert(indexerModel.transform(df).eq(df))
   }
+
+  test("IndexToString params") {
+    val idxToStr = new IndexToString()
+    ParamsSuite.checkParams(idxToStr)
+  }
+
+  test("IndexToString.transform") {
+    val labels = Array("a", "b", "c")
+    val df0 = sqlContext.createDataFrame(Seq(
+      (0, "a"), (1, "b"), (2, "c"), (0, "a")
+    )).toDF("index", "expected")
+
+    val idxToStr0 = new IndexToString()
+      .setInputCol("index")
+      .setOutputCol("actual")
+      .setLabels(labels)
+    idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
+      case Row(actual, expected) =>
+        assert(actual === expected)
+    }
+
+    val attr = NominalAttribute.defaultAttr.withValues(labels)
+    val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), 
col("expected"))
+
+    val idxToStr1 = new IndexToString()
+      .setInputCol("indexWithAttr")
+      .setOutputCol("actual")
+    idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
+      case Row(actual, expected) =>
+        assert(actual === expected)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to