Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/19208#discussion_r148908525
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---
@@ -187,6 +191,50 @@ class CrossValidatorSuite
.compareParamMaps(cv.getEstimatorParamMaps,
cv2.getEstimatorParamMaps)
}
+ test("CrossValidator expose sub models") {
+ val lr = new LogisticRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.001, 1000.0))
+ .addGrid(lr.maxIter, Array(0, 3))
+ .build()
+ val eval = new BinaryClassificationEvaluator
+ val numFolds = 3
+ val subPath = new File(tempDir, "testCrossValidatorSubModels")
+ val persistSubModelsPath = new File(subPath, "subModels").toString
+
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(numFolds)
+ .setParallelism(1)
+ .setCollectSubModels(true)
+
+ val cvModel = cv.fit(dataset)
+
+ assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds)
+ cvModel.subModels.foreach(array => assert(array.length ==
lrParamMaps.length))
+
+ // Test the default value for option "persistSubModel" to be "true"
+ val savingPathWithSubModels = new File(subPath, "cvModel3").getPath
+ cvModel.save(savingPathWithSubModels)
+ val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
+ assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds)
+ cvModel3.subModels.foreach(array => assert(array.length ==
lrParamMaps.length))
+
+ val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath
+ cvModel.write.option("persistSubModels",
"false").save(savingPathWithoutSubModels)
+ val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
--- End diff --
Shall we try saving cvModel2 with persistSubModels = true and check for an
Exception?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]