Repository: spark Updated Branches: refs/heads/master 6f20a92ca -> 0d4a69527
[SPARK-17745][ML][PYSPARK] update NB python api - add weight col parameter ## What changes were proposed in this pull request? update python api for NaiveBayes: add weight col parameter. ## How was this patch tested? doctests added. Author: WeichenXu <weichenxu...@outlook.com> Closes #15406 from WeichenXu123/nb_python_update. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0d4a6952 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0d4a6952 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0d4a6952 Branch: refs/heads/master Commit: 0d4a695279c514c76aa0e9288c70ac7aaef91b03 Parents: 6f20a92 Author: WeichenXu <weichenxu...@outlook.com> Authored: Wed Oct 12 19:52:57 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Wed Oct 12 19:52:57 2016 -0700 ---------------------------------------------------------------------- python/pyspark/ml/classification.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0d4a6952/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ea60fab..3f763a1 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -981,7 +981,7 @@ class GBTClassificationModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWrita @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. `Multinomial NB @@ -995,23 +995,23 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([ - ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), - ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), - ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) - >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial", weightCol="weight") >>> model = nb.fit(df) >>> model.pi - DenseVector([-0.51..., -0.91...]) + DenseVector([-0.81..., -0.58...]) >>> model.theta - DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1) >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() >>> result = model.transform(test0).head() >>> result.prediction 1.0 >>> result.probability - DenseVector([0.42..., 0.57...]) + DenseVector([0.32..., 0.67...]) >>> result.rawPrediction - DenseVector([-1.60..., -1.32...]) + DenseVector([-1.72..., -0.99...]) >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -1045,11 +1045,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial", thresholds=None): + modelType="multinomial", thresholds=None, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial", thresholds=None) + modelType="multinomial", thresholds=None, weightCol=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -1062,11 +1062,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial", thresholds=None): + modelType="multinomial", thresholds=None, weightCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial", thresholds=None) + modelType="multinomial", thresholds=None, weightCol=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org