Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21081#discussion_r183558056
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala ---
@@ -199,6 +201,47 @@ class KMeansSuite extends SparkFunSuite with
MLlibTestSparkContext with DefaultR
assert(e.getCause.getMessage.contains("Cosine distance is not
defined"))
}
+ test("KMean with Array input") {
+ val featuresColNameD = "array_double_features"
+ val featuresColNameF = "array_float_features"
+
+ val doubleUDF = udf { (features: Vector) =>
+ val featureArray = Array.fill[Double](features.size)(0.0)
+ features.foreachActive((idx, value) => featureArray(idx) =
value.toFloat)
+ featureArray
+ }
+ val floatUDF = udf { (features: Vector) =>
+ val featureArray = Array.fill[Float](features.size)(0.0f)
+ features.foreachActive((idx, value) => featureArray(idx) =
value.toFloat)
+ featureArray
+ }
+
+ val newdatasetD = dataset.withColumn(featuresColNameD,
doubleUDF(col("features")))
+ .drop("features")
+ val newdatasetF = dataset.withColumn(featuresColNameF,
floatUDF(col("features")))
+ .drop("features")
+
+ assert(newdatasetD.schema(featuresColNameD).dataType.equals(new
ArrayType(DoubleType, false)))
+ assert(newdatasetF.schema(featuresColNameF).dataType.equals(new
ArrayType(FloatType, false)))
+
+ val kmeansD = new
KMeans().setK(k).setFeaturesCol(featuresColNameD).setSeed(1)
--- End diff --
Also do: `setMaxIter(1)` to make this a little faster.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]