Github user mengxr commented on a diff in the pull request: https://github.com/apache/spark/pull/21195#discussion_r186556808 --- Diff: mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala --- @@ -247,4 +247,21 @@ object MLTestingUtils extends SparkFunSuite { } models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} } + + /** + * Helper function for testing different input types for features. Given a DataFrame, generate + * three output DataFrames: one having vector feature column with float precision, one having + * double array feature column with float precision, and one having float array feature column. + */ + def generateArrayFeatureDataset(dataset: Dataset[_]): (Dataset[_], Dataset[_], Dataset[_]) = { + val toFloatVectorUDF = udf { (features: Vector) => features.toArray.map(_.toFloat).toVector} + val toDoubleArrayUDF = udf { (features: Vector) => features.toArray} + val toFloatArrayUDF = udf { (features: Vector) => features.toArray.map(_.toFloat)} + val newDataset = dataset.withColumn("features", toFloatVectorUDF(col("features"))) + val newDatasetD = dataset.withColumn("features", toDoubleArrayUDF(col("features"))) --- End diff -- This doesn't truncate the precision to single. Did you want to use `newDataset` instead of `dataset`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org