Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20686#discussion_r174656617
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala ---
@@ -84,26 +84,29 @@ class VectorSlicerSuite extends SparkFunSuite with
MLlibTestSparkContext with De
val vectorSlicer = new
VectorSlicer().setInputCol("features").setOutputCol("result")
- def validateResults(df: DataFrame): Unit = {
- df.select("result", "expected").collect().foreach { case Row(vec1:
Vector, vec2: Vector) =>
+ def validateResults(rows: Seq[Row]): Unit = {
+ rows.foreach { case Row(vec1: Vector, vec2: Vector) =>
assert(vec1 === vec2)
}
- val resultMetadata =
AttributeGroup.fromStructField(df.schema("result"))
- val expectedMetadata =
AttributeGroup.fromStructField(df.schema("expected"))
+ val resultMetadata =
AttributeGroup.fromStructField(rows.head.schema("result"))
+ val expectedMetadata =
AttributeGroup.fromStructField(rows.head.schema("expected"))
assert(resultMetadata.numAttributes ===
expectedMetadata.numAttributes)
resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach {
case (a, b) =>
assert(a === b)
}
}
vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty)
- validateResults(vectorSlicer.transform(df))
+ testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer,
"result", "expected")(
--- End diff --
I see, makes sense.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]