Repository: spark Updated Branches: refs/heads/master 5db18ba6e -> 61e05fc58
[SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that both training and predict features have values of 0 or 1 Author: leahmcguire <[email protected]> Closes #6073 from leahmcguire/binaryCheckNB and squashes the following commits: b8442c2 [leahmcguire] changed to if else for value checks 911bf83 [leahmcguire] undid reformat 4eedf1e [leahmcguire] moved bernoulli check 9ee9e84 [leahmcguire] fixed style error 3f3b32c [leahmcguire] fixed zero one check so only called in combiner 831fd27 [leahmcguire] got test working f44bb3c [leahmcguire] removed changes from CV branch 67253f0 [leahmcguire] added check to bernoulli to ensure feature values are zero or one f191c71 [leahmcguire] fixed name 58d060b [leahmcguire] changed param name and test according to comments 04f0d3c [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/61e05fc5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/61e05fc5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/61e05fc5 Branch: refs/heads/master Commit: 61e05fc58e1245de871c409b60951745b5db3420 Parents: 5db18ba Author: leahmcguire <[email protected]> Authored: Wed May 13 14:13:19 2015 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Wed May 13 14:13:19 2015 -0700 ---------------------------------------------------------------------- .../spark/mllib/classification/NaiveBayes.scala | 28 +++++++++++++++-- .../mllib/classification/NaiveBayesSuite.scala | 33 ++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/61e05fc5/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c9b3ff0..b381dc2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { + val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + labels (brzArgmax (brzPi + brzTheta * brzData) ) case "Bernoulli" => + if (!brzData.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + } labels (brzArgmax (brzPi + - (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") @@ -293,12 +298,29 @@ class NaiveBayes private ( } } + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case SparseVector(size, indices, values) => + values + case DenseVector(values) => + values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => { - requireNonnegativeValues(v) + if (modelType == "Bernoulli") { + requireZeroOneBernoulliValues(v) + } else { + requireNonnegativeValues(v) + } (1L, v.toBreeze.toDenseVector) }, mergeValue = (c: (Long, BDV[Double]), v: Vector) => { http://git-wip-us.apache.org/repos/asf/spark/blob/61e05fc5/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ea89b17..40a79a1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("detect non zero or one values in Bernoulli") { + val badTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli") + } + + val okTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)) + ) + + val badPredict = Seq( + Vectors.dense(1.0), + Vectors.dense(2.0), + Vectors.dense(1.0), + Vectors.dense(0.0)) + + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli") + intercept[SparkException] { + model.predict(sc.makeRDD(badPredict, 2)).collect() + } + } + test("model save/load: 2.0 to 2.0") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
