Github user jkbradley commented on a diff in the pull request: https://github.com/apache/spark/pull/21081#discussion_r182269723 --- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala --- @@ -123,7 +128,21 @@ class KMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val predictUDF = udf((vector: Vector) => predict(vector)) + // val predictUDF = udf((vector: Vector) => predict(vector)) + val predictUDF = if (dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) { + udf((vector: Vector) => predict(vector)) + } + else { + udf((vector: Seq[_]) => { + val featureArray = Array.fill[Double](vector.size)(0.0) --- End diff -- Here's what I meant: ``` val predictUDF = featuresDataType match { case _: VectorUDT => udf((vector: Vector) => predict(vector)) case fdt: ArrayType => fdt.elementType match { case _: FloatType => ??? case _: DoubleType => ??? } } ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org