[ 
https://issues.apache.org/jira/browse/SPARK-7127?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14541326#comment-14541326
 ] 

Bryan Cutler commented on SPARK-7127:
-------------------------------------

Hi [~josephkb],

I've been working with to incorporate the broadcast model over mapped 
partitions, but I seem to be a little stuck with adding the predictions column 
to the input dataFrame.  Here is what I have so far, if you wouldn't mind 
taking a look.  I'm sure there are more elegant ways to do this, but I can't 
figure out why this doesn't work (it throws an sql Analysis exception due to 
prediction missing on withColumn):

{noformat}
override def transform(dataset: DataFrame): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    if ($(predictionCol).nonEmpty) {

      val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)

      val predictions = dataset.select($(featuresCol)).mapPartitions(
        iter => {
          val modelValue = bcastModel.value
          iter.map {
            case Row(features: Vector) =>
              Row(modelValue.predict(features))
          }
        })

      val schema = StructType( Seq(StructField($(predictionCol), DoubleType, 
true)))
      val predDF = dataset.sqlContext.createDataFrame(predictions, schema)

      dataset.withColumn($(predictionCol), predDF.col($(predictionCol)))

    } else {
      this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
        " since no output columns were set.")
      dataset
    }
  }
{noformat}


> Broadcast spark.ml tree ensemble models for predict
> ---------------------------------------------------
>
>                 Key: SPARK-7127
>                 URL: https://issues.apache.org/jira/browse/SPARK-7127
>             Project: Spark
>          Issue Type: Improvement
>          Components: ML
>    Affects Versions: 1.4.0
>            Reporter: Joseph K. Bradley
>            Priority: Minor
>
> GBTRegressor/Classifier and RandomForestRegressor/Classifier should broadcast 
> models and then predict.  This will mean overriding transform().
> Note: Try to reduce duplicated code via the TreeEnsembleModel abstraction.



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to