Repository: spark Updated Branches: refs/heads/branch-1.6 4c28b4c8f -> 9c0cf22f7
[SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication Fixes problem and verifies fix by test suite. Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn and deduplicates SchemaUtils.appendColumn functions. Author: Grzegorz Chilkiewicz <[email protected]> Closes #10741 from grzegorz-chilkiewicz/master. (cherry picked from commit b1835d727234fdff42aa8cadd17ddcf43b0bed15) Signed-off-by: Joseph K. Bradley <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9c0cf22f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9c0cf22f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9c0cf22f Branch: refs/heads/branch-1.6 Commit: 9c0cf22f7681ae05d894ae05f6a91a9467787519 Parents: 4c28b4c Author: Grzegorz Chilkiewicz <[email protected]> Authored: Tue Feb 2 11:16:24 2016 -0800 Committer: Joseph K. Bradley <[email protected]> Committed: Tue Feb 2 11:16:44 2016 -0800 ---------------------------------------------------------------------- .../apache/spark/ml/feature/StopWordsRemover.scala | 4 +--- .../scala/org/apache/spark/ml/util/SchemaUtils.scala | 8 +++----- .../spark/ml/feature/StopWordsRemoverSuite.scala | 15 +++++++++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9c0cf22f/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 3188085..d9a9049 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -148,9 +148,7 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) http://git-wip-us.apache.org/repos/asf/spark/blob/9c0cf22f/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f6514..7decbbd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -54,12 +54,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/9c0cf22f/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index fb217e0..a5b24c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -89,4 +89,19 @@ class StopWordsRemoverSuite .setCaseSensitive(true) testDefaultReadWrite(t) } + + test("StopWordsRemover output column already exists") { + val outputCol = "expected" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol(outputCol) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("The", "the", "swift"), Seq("swift")) + )).toDF("raw", outputCol) + + val thrown = intercept[IllegalArgumentException] { + testStopWordsRemover(remover, dataSet) + } + assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
