Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/19746#discussion_r156750447
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala ---
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+
+class VectorSizeHintSuite
+ extends SparkFunSuite with MLlibTestSparkContext with
DefaultReadWriteTest {
+
+ import testImplicits._
+
+ test("Test Param Validators") {
+ intercept[IllegalArgumentException] (new
VectorSizeHint().setHandleInvalid("invalidValue"))
+ intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3))
+ }
+
+ test("Adding size to column of vectors.") {
+
+ val size = 3
+ val vectorColName = "vector"
+ val denseVector = Vectors.dense(1, 2, 3)
+ val sparseVector = Vectors.sparse(size, Array(), Array())
+
+ val data = Seq(denseVector, denseVector,
sparseVector).map(Tuple1.apply)
+ val dataFrame = data.toDF(vectorColName)
+ assert(
+ AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size
== -1,
+ "Transformer did not add expected size data.")
+
+ for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+ val transformer = new VectorSizeHint()
+ .setInputCol(vectorColName)
+ .setSize(size)
+ .setHandleInvalid(handleInvalid)
+ val withSize = transformer.transform(dataFrame)
+ assert(
+
AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size,
+ "Transformer did not add expected size data.")
+ withSize.collect
+ }
+ }
+
+ test("Size hint preserves attributes.") {
+
+ val size = 3
+ val vectorColName = "vector"
+ val data = Seq((1, 2, 3), (2, 3, 3))
+ val dataFrame = data.toDF("x", "y", "z")
+
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z"))
+ .setOutputCol(vectorColName)
+ val dataFrameWithMetadata = assembler.transform(dataFrame)
+ val group =
AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName))
+
+ for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+ val transformer = new VectorSizeHint()
+ .setInputCol(vectorColName)
+ .setSize(size)
+ .setHandleInvalid(handleInvalid)
+ val withSize = transformer.transform(dataFrameWithMetadata)
+
+ val newGroup =
AttributeGroup.fromStructField(withSize.schema(vectorColName))
+ assert(newGroup.size === size, "Transformer did not add expected
size data.")
+ assert(
+ newGroup.attributes.get.deep === group.attributes.get.deep,
+ "SizeHintTransformer did not preserve attributes.")
+ withSize.collect
+ }
+ }
+
+ test("Size miss-match between current and target size raises an error.")
{
+ val size = 4
+ val vectorColName = "vector"
+ val data = Seq((1, 2, 3), (2, 3, 3))
+ val dataFrame = data.toDF("x", "y", "z")
+
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z"))
+ .setOutputCol(vectorColName)
+ val dataFrameWithMetadata = assembler.transform(dataFrame)
+
+ for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+ val transformer = new VectorSizeHint()
+ .setInputCol(vectorColName)
+ .setSize(size)
+ .setHandleInvalid(handleInvalid)
+
intercept[SparkException](transformer.transform(dataFrameWithMetadata))
+ }
+ }
+
+ test("Handle invalid does the right thing.") {
+
+ val vector = Vectors.dense(1, 2, 3)
+ val short = Vectors.dense(2)
+ val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector")
+ val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector")
+
+ val sizeHint = new VectorSizeHint()
+ .setInputCol("vector")
+ .setHandleInvalid("error")
+ .setSize(3)
+
+ intercept[SparkException](sizeHint.transform(dataWithNull).collect)
+ intercept[SparkException](sizeHint.transform(dataWithShort).collect)
+
+ sizeHint.setHandleInvalid("skip")
+ assert(sizeHint.transform(dataWithNull).count() === 1)
+ assert(sizeHint.transform(dataWithShort).count() === 1)
+ }
--- End diff --
Did you a thought on how to test `keep`/`optimistic`. I could verify that
the invalid data is not removed but that's a little bit weird to test. It's
ensuring that this option allows the column to get into a "bad state" where the
metadata doesn't match the contents. Is that what you had in mind?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]