Repository: spark Updated Branches: refs/heads/master ab197308a -> 3eb52092b
[SPARK-22974][ML] Attach attributes to output column of CountVectorModel ## What changes were proposed in this pull request? The output column from `CountVectorModel` lacks attribute. So a later transformer like `Interaction` can raise error because no attribute available. ## How was this patch tested? Added test. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #20313 from viirya/SPARK-22974. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: DB Tsai <d_t...@apple.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3eb52092 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3eb52092 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3eb52092 Branch: refs/heads/master Commit: 3eb52092b3aa9d7d2fc1e50ac237d47bfb3b9e92 Parents: ab19730 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Tue Aug 14 05:05:16 2018 +0000 Committer: DB Tsai <d_t...@apple.com> Committed: Tue Aug 14 05:05:16 2018 +0000 ---------------------------------------------------------------------- .../apache/spark/ml/feature/CountVectorizer.scala | 5 ++++- .../spark/ml/feature/CountVectorizerSuite.scala | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3eb52092/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 10c48c3..dc8eb82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -317,7 +318,9 @@ class CountVectorizerModel( Vectors.sparse(dictBr.value.size, effectiveCounts) } - dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) + val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]] + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + dataset.withColumn($(outputCol), vectorizer(col($(inputCol))), metadata) } @Since("1.5.0") http://git-wip-us.apache.org/repos/asf/spark/blob/3eb52092/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 6121766..bca580d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -289,4 +289,20 @@ class CountVectorizerSuite extends MLTest with DefaultReadWriteTest { val newInstance = testDefaultReadWrite(instance) assert(newInstance.vocabulary === instance.vocabulary) } + + test("SPARK-22974: CountVectorModel should attach proper attribute to output column") { + val df = spark.createDataFrame(Seq( + (0, 1.0, Array("a", "b", "c")), + (1, 2.0, Array("a", "b", "b", "c", "a", "d")) + )).toDF("id", "features1", "words") + + val cvm = new CountVectorizerModel(Array("a", "b", "c")) + .setInputCol("words") + .setOutputCol("features2") + + val df1 = cvm.transform(df) + val interaction = new Interaction().setInputCols(Array("features1", "features2")) + .setOutputCol("features") + interaction.transform(df1) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org