Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21081#discussion_r181847784
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
---
    @@ -123,8 +128,15 @@ class KMeansModel private[ml] (
       @Since("2.0.0")
       override def transform(dataset: Dataset[_]): DataFrame = {
         transformSchema(dataset.schema, logging = true)
    -    val predictUDF = udf((vector: Vector) => predict(vector))
    -    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +    // val predictUDF = udf((vector: Vector) => predict(vector))
    +    if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
    +      val predictUDF = udf((vector: Vector) => predict(vector))
    +      dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
    +    } else {
    +      val predictUDF = udf((vector: Seq[_]) =>
    +        predict(Vectors.dense(vector.asInstanceOf[Seq[Double]].toArray)))
    --- End diff --
    
    This may not work with arrays of FloatType.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to