Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21081#discussion_r182269644
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
---
    @@ -123,7 +128,21 @@ class KMeansModel private[ml] (
       @Since("2.0.0")
       override def transform(dataset: Dataset[_]): DataFrame = {
         transformSchema(dataset.schema, logging = true)
    -    val predictUDF = udf((vector: Vector) => predict(vector))
    +    // val predictUDF = udf((vector: Vector) => predict(vector))
    +    val predictUDF = if 
(dataset.schema($(featuresCol)).dataType.equals(new VectorUDT)) {
    +      udf((vector: Vector) => predict(vector))
    --- End diff --
    
    Side note: I realized that "predict" will cause the whole model to be 
serialized and sent to workers.  But that's actually OK since we do need to 
send most of the model data to make predictions and since there's not a clean 
way to just sent the model weights.  So I think my previous comment about 
copying "numClasses" to a local variable was not necessary.  Don't bother 
reverting the change though.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to