viirya commented on a change in pull request #28704:
URL: https://github.com/apache/spark/pull/28704#discussion_r434754145



##########
File path: mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
##########
@@ -248,6 +248,19 @@ object MLUtils extends Logging {
     }.toArray
   }
 
+  /**
+   * Version of `kFold()` taking a fold column name.
+   */
+  @Since("3.1.0")
+  def kFold(df: DataFrame, numFolds: Int, foldColName: String): 
Array[(RDD[Row], RDD[Row])] = {
+    val foldCol = df.col(foldColName)
+    val dfWithMod = df.withColumn(foldColName, pmod(foldCol, lit(numFolds)))
+    (0 until numFolds).map { fold =>
+      (dfWithMod.filter(col(foldColName) =!= fold).drop(foldColName).rdd,
+        dfWithMod.filter(col(foldColName) === fold).drop(foldColName).rdd)

Review comment:
       Generally I think a valid user-specified folds should not miss any fold 
number. I mentioned it in above discussion that adding fold number check is 
possible however I don't want to do it because performance concern. I'd more 
towards letting users to be responsible for it if they want to specify fold 
numbers by themselves.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to