viirya commented on a change in pull request #28704:
URL: https://github.com/apache/spark/pull/28704#discussion_r437150737



##########
File path: mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
##########
@@ -248,6 +248,19 @@ object MLUtils extends Logging {
     }.toArray
   }
 
+  /**
+   * Version of `kFold()` taking a fold column name.
+   */
+  @Since("3.1.0")
+  def kFold(df: DataFrame, numFolds: Int, foldColName: String): 
Array[(RDD[Row], RDD[Row])] = {
+    val foldCol = df.col(foldColName)
+    val dfWithMod = df.withColumn(foldColName, pmod(foldCol, lit(numFolds)))
+    (0 until numFolds).map { fold =>
+      (dfWithMod.filter(col(foldColName) =!= fold).drop(foldColName).rdd,
+        dfWithMod.filter(col(foldColName) === fold).drop(foldColName).rdd)

Review comment:
       @zhengruifeng Is the check you suggested the same as @huaxingao? I.e., 
checking for empty datasets. Or checking user-specified fold numbers in the 
range [0, numFolds)?
   
   For the later, now I take the value mod numFolds, and I think it should be 
enough for valid fold numbers.

##########
File path: mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
##########
@@ -248,6 +248,19 @@ object MLUtils extends Logging {
     }.toArray
   }
 
+  /**
+   * Version of `kFold()` taking a fold column name.
+   */
+  @Since("3.1.0")
+  def kFold(df: DataFrame, numFolds: Int, foldColName: String): 
Array[(RDD[Row], RDD[Row])] = {
+    val foldCol = df.col(foldColName)
+    val dfWithMod = df.withColumn(foldColName, pmod(foldCol, lit(numFolds)))
+    (0 until numFolds).map { fold =>
+      (dfWithMod.filter(col(foldColName) =!= fold).drop(foldColName).rdd,
+        dfWithMod.filter(col(foldColName) === fold).drop(foldColName).rdd)

Review comment:
       Though I think users should be confident enough to use user-specified 
fold numbers. :) But it also sounds good to me. I will add some checks in next 
commit.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to