Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21081#discussion_r182269644
--- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
---
@@ -123,7 +128,21 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val predictUDF = udf((vector: Vector) => predict(vector))
+ // val predictUDF = udf((vector: Vector) => predict(vector))
+ val predictUDF = if
(dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
+ udf((vector: Vector) => predict(vector))
--- End diff --
Side note: I realized that "predict" will cause the whole model to be
serialized and sent to workers. But that's actually OK since we do need to
send most of the model data to make predictions and since there's not a clean
way to just sent the model weights. So I think my previous comment about
copying "numClasses" to a local variable was not necessary. Don't bother
reverting the change though.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]