zhengruifeng commented on a change in pull request #26413: 
[SPARK-16872][ML][PYSPARK] Impl Gaussian Naive Bayes Classifier
URL: https://github.com/apache/spark/pull/26413#discussion_r344432487
 
 

 ##########
 File path: 
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
 ##########
 @@ -396,15 +545,29 @@ object NaiveBayesModel extends 
MLReadable[NaiveBayesModel] {
     private val className = classOf[NaiveBayesModel].getName
 
     override def load(path: String): NaiveBayesModel = {
+      implicit val format = DefaultFormats
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
+      val modelTypeJson = metadata.getParamValue("modelType")
+      val modelType = Param.jsonDecode[String](compact(render(modelTypeJson)))
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
       val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
-      val Row(pi: Vector, theta: Matrix) = 
MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
-        .select("pi", "theta")
-        .head()
-      val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+      val model = if (major.toInt < 3 || modelType != NaiveBayes.Gaussian) {
 
 Review comment:
   I have test loading old version models, and it works fine
   In 2.4.4
   ```scala
   scala> import org.apache.spark.ml.feature._
   import org.apache.spark.ml.feature._
   scala> import org.apache.spark.ml.regression._
   import org.apache.spark.ml.regression._
   scala> import org.apache.spark.ml.classification._
   import org.apache.spark.ml.classification._
   scala> var df = spark.read.format("libsvm").load("/data1/Datasets/a9a/a9a")
   19/11/09 15:05:36 WARN LibSVMFileFormat: 'numFeatures' option not specified, 
determining the number of features by going though the input. If you know the 
number in advance, please specify it via 'numFeatures' option to avoid the 
extra scan.
   df: org.apache.spark.sql.DataFrame = [label: double, features: vector]
   
   scala> df.persist()
   res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: 
double, features: vector]
   
   scala> df.count
   res1: Long = 32561
   
   scala> (0 until 8).foreach(_ => df = df.union(df))
   
   scala> df.count
   res3: Long = 8335616                                                         
   
   
   scala> 
   
   scala> val nb = new NaiveBayes()
   nb: org.apache.spark.ml.classification.NaiveBayes = nb_a87b69dac8f6
   
   scala> val model = nb.fit(df)
   [Stage 7:==========================================>           (201 + 13) / 
256]19/11/09 15:06:03 WARN BLAS: Failed to load implementation from: 
com.github.fommil.netlib.NativeSystemBLAS
   19/11/09 15:06:03 WARN BLAS: Failed to load implementation from: 
com.github.fommil.netlib.NativeRefBLAS
   model: org.apache.spark.ml.classification.NaiveBayesModel = NaiveBayesModel 
(uid=nb_a87b69dac8f6) with 2 classes
   
   scala> model.save("/tmp/nbm_2.4.4")
   ```
   
   In this PR:
   ```scala
   
   scala> import org.apache.spark.ml.classification._
   import org.apache.spark.ml.classification._
   scala> val model = NaiveBayesModel.load("/tmp/nbm_2.4.4")
   model: org.apache.spark.ml.classification.NaiveBayesModel = NaiveBayesModel 
(uid=nb_a87b69dac8f6) with 2 classes
   scala> model.sigma
   res0: org.apache.spark.ml.linalg.Matrix = null
   
   scala> model.getModelType
   res1: String = multinomial
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to