zhengruifeng commented on a change in pull request #27882: [SPARK-31127][ML]
Add abstract Selector
URL: https://github.com/apache/spark/pull/27882#discussion_r391396858
##########
File path:
mllib/src/main/scala/org/apache/spark/ml/feature/FValueSelector.scala
##########
@@ -154,106 +46,75 @@ private[feature] trait FValueSelectorParams extends
Params
* set to 50.
*/
@Since("3.1.0")
-final class FValueSelector @Since("3.1.0") (override val uid: String)
- extends Estimator[FValueSelectorModel] with FValueSelectorParams
- with DefaultParamsWritable {
+final class FValueSelector @Since("3.1.0") (@Since("3.1.0") override val uid:
String) extends
+ Selector[FValueSelectorModel] {
@Since("3.1.0")
def this() = this(Identifiable.randomUID("FValueSelector"))
/** @group setParam */
@Since("3.1.0")
- def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
+ override def setNumTopFeatures(value: Int): this.type =
super.setNumTopFeatures(value)
/** @group setParam */
@Since("3.1.0")
- def setPercentile(value: Double): this.type = set(percentile, value)
+ override def setPercentile(value: Double): this.type =
super.setPercentile(value)
/** @group setParam */
@Since("3.1.0")
- def setFpr(value: Double): this.type = set(fpr, value)
+ override def setFpr(value: Double): this.type = super.setFpr(value)
/** @group setParam */
@Since("3.1.0")
- def setFdr(value: Double): this.type = set(fdr, value)
+ override def setFdr(value: Double): this.type = super.setFdr(value)
/** @group setParam */
@Since("3.1.0")
- def setFwe(value: Double): this.type = set(fwe, value)
+ override def setFwe(value: Double): this.type = super.setFwe(value)
/** @group setParam */
@Since("3.1.0")
- def setSelectorType(value: String): this.type = set(selectorType, value)
+ override def setSelectorType(value: String): this.type =
super.setSelectorType(value)
/** @group setParam */
@Since("3.1.0")
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ override def setFeaturesCol(value: String): this.type =
super.setFeaturesCol(value)
/** @group setParam */
@Since("3.1.0")
- def setOutputCol(value: String): this.type = set(outputCol, value)
+ override def setOutputCol(value: String): this.type =
super.setOutputCol(value)
/** @group setParam */
@Since("3.1.0")
- def setLabelCol(value: String): this.type = set(labelCol, value)
+ override def setLabelCol(value: String): this.type = super.setLabelCol(value)
+ /**
+ * get the SelectionTestResult for every feature against the label
+ */
@Since("3.1.0")
- override def fit(dataset: Dataset[_]): FValueSelectorModel = {
- transformSchema(dataset.schema, logging = true)
- dataset.select(col($(labelCol)).cast(DoubleType),
col($(featuresCol))).rdd.map {
- case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
- }
-
- val testResult = FValueTest.testRegression(dataset, getFeaturesCol,
getLabelCol)
- .zipWithIndex
- val features = $(selectorType) match {
- case "numTopFeatures" =>
- testResult
- .sortBy { case (res, _) => res.pValue }
- .take(getNumTopFeatures)
- case "percentile" =>
- testResult
- .sortBy { case (res, _) => res.pValue }
- .take((testResult.length * getPercentile).toInt)
- case "fpr" =>
- testResult
- .filter { case (res, _) => res.pValue < getFpr }
- case "fdr" =>
- // This uses the Benjamini-Hochberg procedure.
- //
https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure
- val tempRes = testResult
- .sortBy { case (res, _) => res.pValue }
- val selected = tempRes
- .zipWithIndex
- .filter { case ((res, _), index) =>
- res.pValue <= getFdr * (index + 1) / testResult.length }
- if (selected.isEmpty) {
- Array.empty[(SelectionTestResult, Int)]
- } else {
- val maxIndex = selected.map(_._2).max
- tempRes.take(maxIndex + 1)
- }
- case "fwe" =>
- testResult
- .filter { case (res, _) => res.pValue < getFwe / testResult.length }
- case errorType =>
- throw new IllegalStateException(s"Unknown Selector Type: $errorType")
- }
- val indices = features.map { case (_, index) => index }
- copyValues(new FValueSelectorModel(uid, indices.sorted)
- .setParent(this))
+ protected[this] override def getSelectionTestResult(dataset: Dataset[_]):
+ Array[SelectionTestResult] = {
+ SelectionTest.fValueTest(dataset, getFeaturesCol, getLabelCol)
}
+ /**
+ * Create a new instance of concrete SelectorModel.
+ * @param indices The indices of the selected features
+ * @param pValues The pValues of the selected features
+ * @param statistics The f value of the selected features
+ * @return A new SelectorModel instance
+ */
@Since("3.1.0")
- override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
- SchemaUtils.checkNumericType(schema, $(labelCol))
- SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ protected[this] def createSelectorModel(
Review comment:
ditto
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]