Repository: spark Updated Branches: refs/heads/master 256358f66 -> 5d1850d4b
[MINOR][ML] Reorg RFormula params. ## What changes were proposed in this pull request? There are mainly two reasons for this reorg: * Some params are placed in ```RFormulaBase```, while others are placed in ```RFormula```, this is disordered. * ```RFormulaModel``` should have params ```handleInvalid```, ```formula``` and ```forceIndexLabel```, that users can get invalid values handling policy, formula or whether to force index label if they only have a ```RFormulaModel```. So we need move these params to ```RFormulaBase``` which is also inherited by ```RFormulaModel```. * ```RFormulaModel``` should support set different ```handleInvalid``` when cross validation. ## How was this patch tested? Existing tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #18681 from yanboliang/rformula-reorg. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5d1850d4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5d1850d4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5d1850d4 Branch: refs/heads/master Commit: 5d1850d4b541a8108c934a174097f3c7e10b5315 Parents: 256358f Author: Yanbo Liang <yblia...@gmail.com> Authored: Thu Jul 20 20:07:16 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Jul 20 20:07:16 2017 +0800 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/RFormula.scala | 95 ++++++++++---------- 1 file changed, 47 insertions(+), 48 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5d1850d4/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index c224454..7da3339 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -35,7 +35,51 @@ import org.apache.spark.sql.types._ /** * Base trait for [[RFormula]] and [[RFormulaModel]]. */ -private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol with HasHandleInvalid { + + /** + * R formula parameter. The formula is provided in string form. + * @group param + */ + @Since("1.5.0") + val formula: Param[String] = new Param(this, "formula", "R model formula") + + /** @group getParam */ + @Since("1.5.0") + def getFormula: String = $(formula) + + /** + * Force to index label whether it is numeric or string type. + * Usually we index label only when it is string type. + * If the formula was used by classification algorithms, + * we can force to index label even it is numeric type by setting this param with true. + * Default: false. + * @group param + */ + @Since("2.1.0") + val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel", + "Force to index label whether it is numeric or string") + setDefault(forceIndexLabel -> false) + + /** @group getParam */ + @Since("2.1.0") + def getForceIndexLabel: Boolean = $(forceIndexLabel) + + /** + * Param for how to handle invalid data (unseen or NULL values) in features and label column + * of string type. Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). + * Default: "error" + * @group param + */ + @Since("2.3.0") + final override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data (unseen or NULL values) in features and label column of string " + + "type. Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) /** * Param for how to order categories of a string FEATURE column used by `StringIndexer`. @@ -68,6 +112,7 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { "The default value is 'frequencyDesc'. When the ordering is set to 'alphabetDesc', " + "RFormula drops the same category as R when encoding strings.", ParamValidators.inArray(StringIndexer.supportedStringOrderType)) + setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc) /** @group getParam */ @Since("2.3.0") @@ -108,20 +153,12 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { @Experimental @Since("1.5.0") class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) - extends Estimator[RFormulaModel] with RFormulaBase with HasHandleInvalid - with DefaultParamsWritable { + extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable { @Since("1.5.0") def this() = this(Identifiable.randomUID("rFormula")) /** - * R formula parameter. The formula is provided in string form. - * @group param - */ - @Since("1.5.0") - val formula: Param[String] = new Param(this, "formula", "R model formula") - - /** * Sets the formula to use for this transformer. Must be called before use. * @group setParam * @param value an R formula in string form (e.g. "y ~ x + z") @@ -129,26 +166,6 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setFormula(value: String): this.type = set(formula, value) - /** @group getParam */ - @Since("1.5.0") - def getFormula: String = $(formula) - - /** - * Param for how to handle invalid data (unseen or NULL values) in features and label column - * of string type. Options are 'skip' (filter out rows with invalid data), - * 'error' (throw an error), or 'keep' (put invalid data in a special additional - * bucket, at index numLabels). - * Default: "error" - * @group param - */ - @Since("2.3.0") - override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to " + - "handle invalid data (unseen or NULL values) in features and label column of string type. " + - "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + - "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", - ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) - setDefault(handleInvalid, StringIndexer.ERROR_INVALID) - /** @group setParam */ @Since("2.3.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) @@ -161,23 +178,6 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) - /** - * Force to index label whether it is numeric or string type. - * Usually we index label only when it is string type. - * If the formula was used by classification algorithms, - * we can force to index label even it is numeric type by setting this param with true. - * Default: false. - * @group param - */ - @Since("2.1.0") - val forceIndexLabel: BooleanParam = new BooleanParam(this, "forceIndexLabel", - "Force to index label whether it is numeric or string") - setDefault(forceIndexLabel -> false) - - /** @group getParam */ - @Since("2.1.0") - def getForceIndexLabel: Boolean = $(forceIndexLabel) - /** @group setParam */ @Since("2.1.0") def setForceIndexLabel(value: Boolean): this.type = set(forceIndexLabel, value) @@ -185,7 +185,6 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) /** @group setParam */ @Since("2.3.0") def setStringIndexerOrderType(value: String): this.type = set(stringIndexerOrderType, value) - setDefault(stringIndexerOrderType, StringIndexer.frequencyDesc) /** Whether the formula specifies fitting an intercept. */ private[ml] def hasIntercept: Boolean = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org