zhengruifeng commented on a change in pull request #26413:
[SPARK-16872][ML][PYSPARK] Impl Gaussian Naive Bayes Classifier
URL: https://github.com/apache/spark/pull/26413#discussion_r344432487
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
##########
@@ -396,15 +545,29 @@ object NaiveBayesModel extends
MLReadable[NaiveBayesModel] {
private val className = classOf[NaiveBayesModel].getName
override def load(path: String): NaiveBayesModel = {
+ implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val (major, minor) =
VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ val modelTypeJson = metadata.getParamValue("modelType")
+ val modelType = Param.jsonDecode[String](compact(render(modelTypeJson)))
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
- val Row(pi: Vector, theta: Matrix) =
MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
- .select("pi", "theta")
- .head()
- val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+ val model = if (major.toInt < 3 || modelType != NaiveBayes.Gaussian) {
Review comment:
I have test loading old version models, and it works fine
In 2.4.4
```scala
scala> import org.apache.spark.ml.feature._
import org.apache.spark.ml.feature._
scala> import org.apache.spark.ml.regression._
import org.apache.spark.ml.regression._
scala> import org.apache.spark.ml.classification._
import org.apache.spark.ml.classification._
scala> var df = spark.read.format("libsvm").load("/data1/Datasets/a9a/a9a")
19/11/09 15:05:36 WARN LibSVMFileFormat: 'numFeatures' option not specified,
determining the number of features by going though the input. If you know the
number in advance, please specify it via 'numFeatures' option to avoid the
extra scan.
df: org.apache.spark.sql.DataFrame = [label: double, features: vector]
scala> df.persist()
res0: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label:
double, features: vector]
scala> df.count
res1: Long = 32561
scala> (0 until 8).foreach(_ => df = df.union(df))
scala> df.count
res3: Long = 8335616
scala>
scala> val nb = new NaiveBayes()
nb: org.apache.spark.ml.classification.NaiveBayes = nb_a87b69dac8f6
scala> val model = nb.fit(df)
[Stage 7:==========================================> (201 + 13) /
256]19/11/09 15:06:03 WARN BLAS: Failed to load implementation from:
com.github.fommil.netlib.NativeSystemBLAS
19/11/09 15:06:03 WARN BLAS: Failed to load implementation from:
com.github.fommil.netlib.NativeRefBLAS
model: org.apache.spark.ml.classification.NaiveBayesModel = NaiveBayesModel
(uid=nb_a87b69dac8f6) with 2 classes
scala> model.save("/tmp/nbm_2.4.4")
```
In this PR:
```scala
scala> import org.apache.spark.ml.classification._
import org.apache.spark.ml.classification._
scala> val model = NaiveBayesModel.load("/tmp/nbm_2.4.4")
model: org.apache.spark.ml.classification.NaiveBayesModel = NaiveBayesModel
(uid=nb_a87b69dac8f6) with 2 classes
scala> model.sigma
res0: org.apache.spark.ml.linalg.Matrix = null
scala> model.getModelType
res1: String = multinomial
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]