Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/4087#discussion_r27426348
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala ---
@@ -264,16 +373,42 @@ object NaiveBayes {
/**
* Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
*
- * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can
handle all kinds of
- * discrete data. For example, by converting documents into TF-IDF
vectors, it can be used for
- * document classification. By making every vector a 0-1 vector, it can
also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * This is the default Multinomial NB ([[http://tinyurl.com/lsdw6p]])
which can handle all
+ * kinds of discrete data. For example, by converting documents into
TF-IDF vectors, it
+ * can be used for document classification.
*
* @param input RDD of `(label, array of features)` pairs. Every vector
should be a frequency
* vector or a count vector.
* @param lambda The smoothing parameter
*/
def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
- new NaiveBayes(lambda).run(input)
+ new NaiveBayes(lambda, "Multinomial").run(input)
}
+
+ /**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * The model type can be set to either Multinomial NB
([[http://tinyurl.com/lsdw6p]])
+ * or Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The Multinomial NB
can handle
+ * discrete count data and can be called by setting the model type to
"multinomial".
+ * For example, it can be used with word counts or TF_IDF vectors of
documents.
+ * The Bernoulli model fits presence or absence (0-1) counts. By making
every vector a
+ * 0-1 vector and setting the model type to "bernoulli", the fits and
predicts as
+ * Bernoulli NB.
+ *
+ * @param input RDD of `(label, array of features)` pairs. Every vector
should be a frequency
+ * vector or a count vector.
+ * @param lambda The smoothing parameter
+ *
+ * @param modelType The type of NB model to fit from the enumeration
NaiveBayesModels, can be
+ * multinomial or bernoulli
+ */
+ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String):
NaiveBayesModel = {
+ if (supportedModelTypes.contains(modelType)) {
+ new NaiveBayes(lambda, modelType).run(input)
+ } else {
+ throw new UnknownError(s"NaiveBayes was created with an unknown
ModelType: $modelType")
--- End diff --
Can you please use require? Since this is an entry point, the parameter
check should throw an IllegalArgumentException (which require does).
Elsewhere, in the internals, we can throw UnknownErrors since those errors
should never actually happen.
```
require(supportedModelTypes.contains(modelType), s"NaiveBayes was created
with an unknown ModelType: $modelType")
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]