Repository: spark Updated Branches: refs/heads/branch-2.0 8b6742a37 -> 80b8711b3
[SPARK-15738][PYSPARK][ML] Adding Pyspark ml RFormula __str__ method similar to Scala API ## What changes were proposed in this pull request? Adding __str__ to RFormula and model that will show the set formula param and resolved formula. This is currently present in the Scala API, found missing in PySpark during Spark 2.0 coverage review. ## How was this patch tested? run pyspark-ml tests locally Author: Bryan Cutler <cutl...@gmail.com> Closes #13481 from BryanCutler/pyspark-ml-rformula_str-SPARK-15738. (cherry picked from commit 7d7a0a5e0749909e97d90188707cc9220a1bb73a) Signed-off-by: Yanbo Liang <yblia...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80b8711b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80b8711b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80b8711b Branch: refs/heads/branch-2.0 Commit: 80b8711b342c5a569fe89d7ffbdd552653b9b6ec Parents: 8b6742a Author: Bryan Cutler <cutl...@gmail.com> Authored: Fri Jun 10 11:27:30 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Fri Jun 10 14:01:55 2016 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../org/apache/spark/ml/feature/RFormulaParser.scala | 14 +++++++++++++- python/pyspark/ml/feature.py | 12 ++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/80b8711b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2916b6d..a7ca0fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -182,7 +182,7 @@ class RFormula(override val uid: String) override def copy(extra: ParamMap): RFormula = defaultCopy(extra) - override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" + override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)" } @Since("2.0.0") http://git-wip-us.apache.org/repos/asf/spark/blob/80b8711b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 19aecff..2dd565a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -126,7 +126,19 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * @param hasIntercept whether the formula specifies fitting with an intercept. */ private[ml] case class ResolvedRFormula( - label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) + label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) { + + override def toString: String = { + val ts = terms.map { + case t if t.length > 1 => + s"${t.mkString("{", ",", "}")}" + case t => + t.mkString + } + val termStr = ts.mkString("[", ",", "]") + s"ResolvedRFormula(label=$label, terms=$termStr, hasIntercept=$hasIntercept)" + } +} /** * R formula terms. See the R formula docs here for more information: http://git-wip-us.apache.org/repos/asf/spark/blob/80b8711b/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index bfb2fb7..ca77ac3 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2528,6 +2528,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM True >>> loadedRF.getLabelCol() == rf.getLabelCol() True + >>> str(loadedRF) + 'RFormula(y ~ x + s) (uid=...)' >>> modelPath = temp_path + "/rFormulaModel" >>> model.save(modelPath) >>> loadedModel = RFormulaModel.load(modelPath) @@ -2542,6 +2544,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM |0.0|0.0| a|[0.0,1.0]| 0.0| +---+---+---+---------+-----+ ... + >>> str(loadedModel) + 'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...)' .. versionadded:: 1.5.0 """ @@ -2586,6 +2590,10 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM def _create_model(self, java_model): return RFormulaModel(java_model) + def __str__(self): + formulaStr = self.getFormula() if self.isDefined(self.formula) else "" + return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid) + class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ @@ -2597,6 +2605,10 @@ class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): .. versionadded:: 1.5.0 """ + def __str__(self): + resolvedFormula = self._call_java("resolvedFormula") + return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid) + @inherit_doc class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org