Github user harsha2010 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/6996#discussion_r33283398
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala ---
    @@ -91,74 +60,44 @@ class CrossValidator(override val uid: String) extends 
Estimator[CrossValidatorM
       private val f2jBLAS = new F2jBLAS
     
       /** @group setParam */
    -  def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
    -
    -  /** @group setParam */
    -  def setEstimatorParamMaps(value: Array[ParamMap]): this.type = 
set(estimatorParamMaps, value)
    -
    -  /** @group setParam */
    -  def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
    -
    -  /** @group setParam */
       def setNumFolds(value: Int): this.type = set(numFolds, value)
     
    -  override def fit(dataset: DataFrame): CrossValidatorModel = {
    +  override protected[ml] def validationLogic(
    +      dataset: DataFrame,
    +      est: Estimator[_],
    +      eval: Evaluator,
    +      epm: Array[ParamMap],
    +      numModels: Int): Array[Double] = {
    +
         val schema = dataset.schema
         transformSchema(schema, logging = true)
         val sqlCtx = dataset.sqlContext
    -    val est = $(estimator)
    -    val eval = $(evaluator)
    -    val epm = $(estimatorParamMaps)
    -    val numModels = epm.length
    +
         val metrics = new Array[Double](epm.length)
         val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0)
    +
         splits.zipWithIndex.foreach { case ((training, validation), 
splitIndex) =>
           val trainingDataset = sqlCtx.createDataFrame(training, 
schema).cache()
           val validationDataset = sqlCtx.createDataFrame(validation, 
schema).cache()
    -      // multi-model training
           logDebug(s"Train split $splitIndex with multiple sets of 
parameters.")
    -      val models = est.fit(trainingDataset, 
epm).asInstanceOf[Seq[Model[_]]]
    -      trainingDataset.unpersist()
    +      val newMetrics = measureModels(trainingDataset, validationDataset, 
est, eval, epm, numModels)
    --- End diff --
    
    perhaps a more descriptive name than newMetrics might be metricsPerSplit? 
or something better.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to