Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21081#discussion_r181847784
--- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
---
@@ -123,8 +128,15 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val predictUDF = udf((vector: Vector) => predict(vector))
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ // val predictUDF = udf((vector: Vector) => predict(vector))
+ if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
+ val predictUDF = udf((vector: Vector) => predict(vector))
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ } else {
+ val predictUDF = udf((vector: Seq[_]) =>
+ predict(Vectors.dense(vector.asInstanceOf[Seq[Double]].toArray)))
--- End diff --
This may not work with arrays of FloatType.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]