viirya commented on a change in pull request #28704:
URL: https://github.com/apache/spark/pull/28704#discussion_r434754145
##########
File path: mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
##########
@@ -248,6 +248,19 @@ object MLUtils extends Logging {
}.toArray
}
+ /**
+ * Version of `kFold()` taking a fold column name.
+ */
+ @Since("3.1.0")
+ def kFold(df: DataFrame, numFolds: Int, foldColName: String):
Array[(RDD[Row], RDD[Row])] = {
+ val foldCol = df.col(foldColName)
+ val dfWithMod = df.withColumn(foldColName, pmod(foldCol, lit(numFolds)))
+ (0 until numFolds).map { fold =>
+ (dfWithMod.filter(col(foldColName) =!= fold).drop(foldColName).rdd,
+ dfWithMod.filter(col(foldColName) === fold).drop(foldColName).rdd)
Review comment:
Generally I think a valid user-specified folds should not miss any fold
number. I mentioned it in above discussion that adding fold number check is
possible however I don't want to do it because performance concern. I'd more
towards letting users to be responsible for it if they want to specify fold
numbers by themselves.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]