Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/20686#discussion_r174655509 --- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala --- @@ -84,26 +84,29 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") - def validateResults(df: DataFrame): Unit = { - df.select("result", "expected").collect().foreach { case Row(vec1: Vector, vec2: Vector) => + def validateResults(rows: Seq[Row]): Unit = { + rows.foreach { case Row(vec1: Vector, vec2: Vector) => assert(vec1 === vec2) } - val resultMetadata = AttributeGroup.fromStructField(df.schema("result")) - val expectedMetadata = AttributeGroup.fromStructField(df.schema("expected")) + val resultMetadata = AttributeGroup.fromStructField(rows.head.schema("result")) + val expectedMetadata = AttributeGroup.fromStructField(rows.head.schema("expected")) assert(resultMetadata.numAttributes === expectedMetadata.numAttributes) resultMetadata.attributes.get.zip(expectedMetadata.attributes.get).foreach { case (a, b) => assert(a === b) } } vectorSlicer.setIndices(Array(1, 4)).setNames(Array.empty) - validateResults(vectorSlicer.transform(df)) + testTransformerByGlobalCheckFunc[(Vector, Vector)](df, vectorSlicer, "result", "expected")( --- End diff -- ok.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org