Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21081#discussion_r181841557
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala ---
@@ -194,6 +195,34 @@ class KMeansSuite extends SparkFunSuite with
MLlibTestSparkContext with DefaultR
assert(e.getCause.getMessage.contains("Cosine distance is not
defined"))
}
+ test("KMean with Array input") {
+ val featuresColName = "array_model_features"
+
+ val arrayUDF = udf { (features: Vector) =>
+ features.toArray
+ }
+ val newdataset = dataset.withColumn(featuresColName,
arrayUDF(col("features")) )
+
+ val kmeans = new KMeans()
+ .setFeaturesCol(featuresColName)
+
+ assert(kmeans.getK === 2)
+ assert(kmeans.getFeaturesCol === featuresColName)
+ assert(kmeans.getPredictionCol === "prediction")
+ assert(kmeans.getMaxIter === 20)
+ assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
+ assert(kmeans.getInitSteps === 2)
+ assert(kmeans.getTol === 1e-4)
+ assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN)
+ val model = kmeans.setMaxIter(1).fit(newdataset)
+
+ MLTestingUtils.checkCopyAndUids(kmeans, model)
--- End diff --
ditto for hasSummary and copying
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]