Repository: spark Updated Branches: refs/heads/master 9525c563d -> 9fe38aba1
[SPARK-11108][ML] OneHotEncoder should support other numeric types Adding support for other numeric types: * Integer * Short * Long * Float * Decimal Author: sethah <seth.hendrickso...@gmail.com> Closes #9777 from sethah/SPARK-11108. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9fe38aba Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9fe38aba Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9fe38aba Branch: refs/heads/master Commit: 9fe38aba1f70a4cb19ec1f9df4814fce0b267b54 Parents: 9525c56 Author: sethah <seth.hendrickso...@gmail.com> Authored: Thu Mar 10 13:17:41 2016 +0200 Committer: Nick Pentreath <nick.pentre...@gmail.com> Committed: Thu Mar 10 13:17:41 2016 +0200 ---------------------------------------------------------------------- .../apache/spark/ml/feature/OneHotEncoder.scala | 9 ++++-- .../spark/ml/feature/OneHotEncoderSuite.scala | 29 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9fe38aba/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e9df161..fa5013d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.{DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} /** * :: Experimental :: @@ -70,7 +70,8 @@ class OneHotEncoder(override val uid: String) extends Transformer val inputColName = $(inputCol) val outputColName = $(outputCol) - SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + require(schema(inputColName).dataType.isInstanceOf[NumericType], + s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") val inputFields = schema.fields require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") @@ -133,7 +134,9 @@ class OneHotEncoder(override val uid: String) extends Transformer val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) .aggregate(0.0)( (m, x) => { - assert(x >=0.0 && x == x.toInt, + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") + assert(x >= 0.0 && x == x.toInt, s"Values from column $inputColName must be indices, but got $x.") math.max(m, x) }, http://git-wip-us.apache.org/repos/asf/spark/blob/9fe38aba/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index e238b33..49803ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -111,4 +112,32 @@ class OneHotEncoderSuite .setDropLast(false) testDefaultReadWrite(t) } + + test("OneHotEncoder with varying types") { + val df = stringIndexed() + val dfWithTypes = df + .withColumn("shortLabel", df("labelIndex").cast(ShortType)) + .withColumn("longLabel", df("labelIndex").cast(LongType)) + .withColumn("intLabel", df("labelIndex").cast(IntegerType)) + .withColumn("floatLabel", df("labelIndex").cast(FloatType)) + .withColumn("decimalLabel", df("labelIndex").cast(DecimalType(10, 0))) + val cols = Array("labelIndex", "shortLabel", "longLabel", "intLabel", + "floatLabel", "decimalLabel") + for (col <- cols) { + val encoder = new OneHotEncoder() + .setInputCol(col) + .setOutputCol("labelVec") + .setDropLast(false) + val encoded = encoder.transform(dfWithTypes) + + val output = encoded.select("id", "labelVec").rdd.map { r => + val vec = r.getAs[Vector](1) + (r.getInt(0), vec(0), vec(1), vec(2)) + }.collect().toSet + // a -> 0, b -> 2, c -> 1 + val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), + (3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) + assert(output === expected) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org