Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21081#discussion_r182217722
--- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
---
@@ -123,7 +128,21 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val predictUDF = udf((vector: Vector) => predict(vector))
+ // val predictUDF = udf((vector: Vector) => predict(vector))
+ val predictUDF = if
(dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
+ udf((vector: Vector) => predict(vector))
+ }
+ else {
+ udf((vector: Seq[_]) => {
+ val featureArray = Array.fill[Double](vector.size)(0.0)
--- End diff --
You shouldn't have to do the conversion in this convoluted (and less
efficient) way. I'd recommend doing a match-case statement on dataset.schema;
I think that will be the most succinct. Then you can handle Vector, Seq of
Float, and Seq of Double separately, without conversions to strings.
Same for the similar cases below.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]