This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 89d42dc [SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM 89d42dc is described below commit 89d42dc6d38c9508b7009652323d6b343742c5b8 Author: zhengruifeng <ruife...@foxmail.com> AuthorDate: Thu Feb 21 22:21:28 2019 -0600 [SPARK-25097][ML] Support prediction on single instance in KMeans/BiKMeans/GMM ## What changes were proposed in this pull request? expose method `predict` in KMeans/BiKMeans/GMM ## How was this patch tested? added testsuites Closes #22087 from zhengruifeng/clu_pre_instance. Authored-by: zhengruifeng <ruife...@foxmail.com> Signed-off-by: Sean Owen <sean.o...@databricks.com> --- .../spark/ml/clustering/BisectingKMeans.scala | 6 ++--- .../spark/ml/clustering/GaussianMixture.scala | 6 +++-- .../org/apache/spark/ml/clustering/KMeans.scala | 7 +++--- .../spark/ml/clustering/BisectingKMeansSuite.scala | 7 ++++++ .../spark/ml/clustering/GaussianMixtureSuite.scala | 10 ++++++++ .../apache/spark/ml/clustering/KMeansSuite.scala | 7 ++++++ .../scala/org/apache/spark/ml/util/MLTest.scala | 28 +++++++++++++++++++++- 7 files changed, 61 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index d846f17..03afdbe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.Vector @@ -30,7 +29,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} @@ -118,7 +117,8 @@ class BisectingKMeansModel private[ml] ( validateAndTransformSchema(schema) } - private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + @Since("3.0.0") + def predict(features: Vector): Int = parentModel.predict(features) @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index c27ba55..3d6d1e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -121,12 +121,14 @@ class GaussianMixtureModel private[ml] ( validateAndTransformSchema(schema) } - private[clustering] def predict(features: Vector): Int = { + @Since("3.0.0") + def predict(features: Vector): Int = { val r = predictProbability(features) r.argmax } - private[clustering] def predictProbability(features: Vector): Vector = { + @Since("3.0.0") + def predictProbability(features: Vector): Vector = { val probs: Array[Double] = GaussianMixtureModel.computeProbabilities(features.asBreeze.toDenseVector, gaussians, weights) Vectors.dense(probs) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 319747d..b48a966 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, PipelineStage} import org.apache.spark.ml.linalg.Vector @@ -32,8 +31,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.VectorImplicits._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -139,7 +137,8 @@ class KMeansModel private[ml] ( validateAndTransformSchema(schema) } - private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + @Since("3.0.0") + def predict(features: Vector): Int = parentModel.predict(features) @Since("2.0.0") def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 461f8b8..5708097 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -205,6 +205,13 @@ class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { assert(trueCost ~== doubleArrayCost absTol 1e-6) assert(trueCost ~== floatArrayCost absTol 1e-6) } + + test("prediction on single instance") { + val bikm = new BisectingKMeans().setSeed(123L) + val model = bikm.fit(dataset) + testClusteringModelSinglePrediction(model, model.predict, dataset, + model.getFeaturesCol, model.getPredictionCol) + } } object BisectingKMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 13bed9d..11fdd3a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -268,6 +268,16 @@ class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { assert(trueLikelihood ~== doubleLikelihood absTol 1e-6) assert(trueLikelihood ~== floatLikelihood absTol 1e-6) } + + test("prediction on single instance") { + val gmm = new GaussianMixture().setSeed(123L) + val model = gmm.fit(dataset) + testClusteringModelSinglePrediction(model, model.predict, dataset, + model.getFeaturesCol, model.getPredictionCol) + + testClusteringModelSingleProbabilisticPrediction(model, model.predictProbability, dataset, + model.getFeaturesCol, model.getProbabilityCol) + } } object GaussianMixtureSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 4f47d91..b377582 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -244,6 +244,13 @@ class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTes } testPMMLWrite(sc, kmeansModel, checkModel) } + + test("prediction on single instance") { + val kmeans = new KMeans().setSeed(123L) + val model = kmeans.fit(dataset) + testClusteringModelSinglePrediction(model, model.predict, dataset, + model.getFeaturesCol, model.getPredictionCol) + } } object KMeansSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 514fa7f..c23b6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -23,7 +23,7 @@ import org.scalatest.Suite import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext} import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK -import org.apache.spark.ml.{PredictionModel, Transformer} +import org.apache.spark.ml.{Model, PredictionModel, Transformer} import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row} import org.apache.spark.sql.execution.streaming.MemoryStream @@ -156,4 +156,30 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => assert(prediction === model.predict(features)) } } + + def testClusteringModelSinglePrediction( + model: Model[_], + transform: Vector => Int, + dataset: Dataset[_], + input: String, + output: String): Unit = { + model.transform(dataset).select(input, output) + .collect().foreach { + case Row(features: Vector, prediction: Int) => + assert(prediction === transform(features)) + } + } + + def testClusteringModelSingleProbabilisticPrediction( + model: Model[_], + transform: Vector => Vector, + dataset: Dataset[_], + input: String, + output: String): Unit = { + model.transform(dataset).select(input, output) + .collect().foreach { + case Row(features: Vector, prediction: Vector) => + assert(prediction === transform(features)) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org