felixcheung commented on a change in pull request #24939: [SPARK-18569][ML][R]
Support RFormula arithmetic, I() and spark functions
URL: https://github.com/apache/spark/pull/24939#discussion_r303294331
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
##########
@@ -614,3 +652,80 @@ private object VectorAttributeRewriter extends
MLReadable[VectorAttributeRewrite
}
}
}
+
+/**
+ * Utility transformer for adding expressions to dataframe using `expr` spark
function
+ *
+ * @param exprsToSelect set of string expressions to be added as a column to
the dataframe.
+ * The name of the columns will be identical to the
expression
+ */
+private class ExprSelector(
+ override val uid: String,
+ val exprsToSelect: Set[String])
+ extends Transformer with MLWritable {
+
+ def this(exprsToSelect: Set[String]) =
+ this(Identifiable.randomUID("exprSelector"), exprsToSelect)
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ selectExprs(dataset.toDF)
+ }
+
+ private def selectExprs(dataframe: DataFrame): DataFrame = {
+ exprsToSelect.foldLeft(dataframe) { case (ds, col) =>
+ ds.withColumn(col, expr(col))
+ }
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ val spark = SparkSession.builder().getOrCreate()
+ val dummyRDD = spark.sparkContext.parallelize(Seq(Row.empty))
+ val dummyDF = spark.createDataFrame(dummyRDD, schema)
+ selectExprs(dummyDF).schema
+ }
+
+ override def copy(extra: ParamMap): ExprSelector = defaultCopy(extra)
+
+ override def write: MLWriter = new ExprSelector.ExprSelectorWriter(this)
+}
+
+private object ExprSelector extends MLReadable[ExprSelector] {
+
+ override def read: MLReader[ExprSelector] = new ExprSelectorReader
+
+ override def load(path: String): ExprSelector = super.load(path)
+
+ /** [[MLWriter]] instance for [[ExprSelector]] */
+ private[ExprSelector] class ExprSelectorWriter(instance: ExprSelector)
extends MLWriter {
+
+ private case class Data(exprsToSelect: Seq[String])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: exprsToSelect
+ val data = Data(instance.exprsToSelect.toSeq)
+ val dataPath = new Path(path, "data").toString
Review comment:
something more unique? this could conflict easily as "data"
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]