Repository: spark
Updated Branches:
  refs/heads/branch-2.1 94987987a -> 8520d7c6d


[SPARK-21306][ML] OneVsRest should support setWeightCol

## What changes were proposed in this pull request?

add `setWeightCol` method for OneVsRest.

`weightCol` is ignored if classifier doesn't inherit HasWeightCol trait.

## How was this patch tested?

+ [x] add an unit test.

Author: Yan Facai (颜发才) <facai....@gmail.com>

Closes #18554 from facaiy/BUG/oneVsRest_missing_weightCol.

(cherry picked from commit a5a3189974ea4628e9489eb50099a5432174e80c)
Signed-off-by: Yanbo Liang <yblia...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8520d7c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8520d7c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8520d7c6

Branch: refs/heads/branch-2.1
Commit: 8520d7c6d5e880dea3c1a8a874148c07222b4b4b
Parents: 9498798
Author: Yan Facai (颜发才) <facai....@gmail.com>
Authored: Fri Jul 28 10:10:35 2017 +0800
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Fri Jul 28 10:15:59 2017 +0800

----------------------------------------------------------------------
 .../spark/ml/classification/OneVsRest.scala     | 39 ++++++++++++++++++--
 .../ml/classification/OneVsRestSuite.scala      | 10 +++++
 python/pyspark/ml/classification.py             | 27 +++++++++++---
 python/pyspark/ml/tests.py                      | 14 +++++++
 4 files changed, 81 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8520d7c6/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index e58b30d..c4a8f1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -34,6 +34,7 @@ import org.apache.spark.ml._
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
@@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
 /**
  * Params for [[OneVsRest]].
  */
-private[ml] trait OneVsRestParams extends PredictorParams with 
ClassifierTypeTrait {
+private[ml] trait OneVsRestParams extends PredictorParams
+  with ClassifierTypeTrait with HasWeightCol {
 
   /**
    * param for the base binary classifier that we reduce multiclass 
classification into.
@@ -299,6 +301,18 @@ final class OneVsRest @Since("1.4.0") (
   @Since("1.5.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /**
+   * Sets the value of param [[weightCol]].
+   *
+   * This is ignored if weight is not supported by [[classifier]].
+   * If this is not set or empty, we treat all instance weights as 1.0.
+   * Default is not set, so all instances have weight one.
+   *
+   * @group setParam
+   */
+  @Since("2.3.0")
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = true, 
getClassifier.featuresDataType)
@@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
     }
     val numClasses = 
MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
 
-    val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
+    val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
+      getClassifier match {
+        case _: HasWeightCol => true
+        case c =>
+          logWarning(s"weightCol is ignored, as it is not supported by $c 
now.")
+          false
+      }
+    }
+
+    val multiclassLabeled = if (weightColIsUsed) {
+      dataset.select($(labelCol), $(featuresCol), $(weightCol))
+    } else {
+      dataset.select($(labelCol), $(featuresCol))
+    }
 
     // persist if underlying dataset is not persistent.
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
       paramMap.put(classifier.labelCol -> labelColName)
       paramMap.put(classifier.featuresCol -> getFeaturesCol)
       paramMap.put(classifier.predictionCol -> getPredictionCol)
-      classifier.fit(trainingDataset, paramMap)
+      if (weightColIsUsed) {
+        val classifier_ = classifier.asInstanceOf[ClassifierType with 
HasWeightCol]
+        paramMap.put(classifier_.weightCol -> getWeightCol)
+        classifier_.fit(trainingDataset, paramMap)
+      } else {
+        classifier.fit(trainingDataset, paramMap)
+      }
     }.toArray[ClassificationModel[_, _]]
 
     if (handlePersistence) {

http://git-wip-us.apache.org/repos/asf/spark/blob/8520d7c6/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index aacb792..491eca5 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -157,6 +157,16 @@ class OneVsRestSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defau
     assert(output.schema.fieldNames.toSet === Set("label", "features", 
"prediction"))
   }
 
+  test("SPARK-21306: OneVsRest should support setWeightCol") {
+    val dataset2 = dataset.withColumn("weight", lit(1))
+    // classifier inherits hasWeightCol
+    val ova = new OneVsRest().setWeightCol("weight").setClassifier(new 
LogisticRegression())
+    assert(ova.fit(dataset2) !== null)
+    // classifier doesn't inherit hasWeightCol
+    val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new 
DecisionTreeClassifier())
+    assert(ova2.fit(dataset2) !== null)
+  }
+
   test("OneVsRest.copy and OneVsRestModel.copy") {
     val lr = new LogisticRegression()
       .setMaxIter(1)

http://git-wip-us.apache.org/repos/asf/spark/blob/8520d7c6/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 2b47c40..f88be70 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1331,7 +1331,7 @@ class MultilayerPerceptronClassificationModel(JavaModel, 
JavaPredictionModel, Ja
         return self._call_java("weights")
 
 
-class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
+class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, 
HasPredictionCol):
     """
     Parameters for OneVsRest and OneVsRestModel.
     """
@@ -1394,10 +1394,10 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, 
MLWritable):
 
     @keyword_only
     def __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction",
-                 classifier=None):
+                 classifier=None, weightCol=None):
         """
         __init__(self, featuresCol="features", labelCol="label", 
predictionCol="prediction", \
-                 classifier=None)
+                 classifier=None, weightCol=None)
         """
         super(OneVsRest, self).__init__()
         kwargs = self._input_kwargs
@@ -1405,9 +1405,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, 
MLWritable):
 
     @keyword_only
     @since("2.0.0")
-    def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, 
classifier=None):
+    def setParams(self, featuresCol=None, labelCol=None, predictionCol=None,
+                  classifier=None, weightCol=None):
         """
-        setParams(self, featuresCol=None, labelCol=None, predictionCol=None, 
classifier=None):
+        setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \
+                  classifier=None, weightCol=None):
         Sets params for OneVsRest.
         """
         kwargs = self._input_kwargs
@@ -1423,7 +1425,18 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, 
MLWritable):
 
         numClasses = int(dataset.agg({labelCol: 
"max"}).head()["max("+labelCol+")"]) + 1
 
-        multiclassLabeled = dataset.select(labelCol, featuresCol)
+        weightCol = None
+        if (self.isDefined(self.weightCol) and self.getWeightCol()):
+            if isinstance(classifier, HasWeightCol):
+                weightCol = self.getWeightCol()
+            else:
+                warnings.warn("weightCol is ignored, "
+                              "as it is not supported by {} 
now.".format(classifier))
+
+        if weightCol:
+            multiclassLabeled = dataset.select(labelCol, featuresCol, 
weightCol)
+        else:
+            multiclassLabeled = dataset.select(labelCol, featuresCol)
 
         # persist if underlying dataset is not persistent.
         handlePersistence = \
@@ -1439,6 +1452,8 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, 
MLWritable):
             paramMap = dict([(classifier.labelCol, binaryLabelCol),
                             (classifier.featuresCol, featuresCol),
                             (classifier.predictionCol, predictionCol)])
+            if weightCol:
+                paramMap[classifier.weightCol] = weightCol
             return classifier.fit(trainingDataset, paramMap)
 
         # TODO: Parallel training for all classes.

http://git-wip-us.apache.org/repos/asf/spark/blob/8520d7c6/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7152036..9046e9f 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1218,6 +1218,20 @@ class OneVsRestTests(SparkSessionTestCase):
         output = model.transform(df)
         self.assertEqual(output.columns, ["label", "features", "prediction"])
 
+    def test_support_for_weightCol(self):
+        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
+                                         (1.0, Vectors.sparse(2, [], []), 1.0),
+                                         (2.0, Vectors.dense(0.5, 0.5), 1.0)],
+                                        ["label", "features", "weight"])
+        # classifier inherits hasWeightCol
+        lr = LogisticRegression(maxIter=5, regParam=0.01)
+        ovr = OneVsRest(classifier=lr, weightCol="weight")
+        self.assertIsNotNone(ovr.fit(df))
+        # classifier doesn't inherit hasWeightCol
+        dt = DecisionTreeClassifier()
+        ovr2 = OneVsRest(classifier=dt, weightCol="weight")
+        self.assertIsNotNone(ovr2.fit(df))
+
 
 class HashingTFTest(SparkSessionTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to