Github user sarutak commented on a diff in the pull request:
https://github.com/apache/spark/pull/10381#discussion_r48473237
--- Diff: mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
---
@@ -349,6 +352,41 @@ class ParamsSuite extends SparkFunSuite {
val t3 = t.copy(ParamMap(t.maxIter -> 20))
assert(t3.isSet(t3.maxIter))
}
+
+ test("Filtering ParamMap") {
+ val params1 = new MyParams("my_params1")
+ val params2 = new MyParams("my_params2")
+ val paramMap = ParamMap(
+ params1.intParam -> 1,
+ params2.intParam -> 1,
+ params1.doubleParam -> 0.2,
+ params2.doubleParam -> 0.2)
+ val filteredParamMap = paramMap.filter(params1)
+
+ assert(filteredParamMap.size === 2)
+ filteredParamMap.toSeq.foreach {
+ case ParamPair(p, _) =>
+ assert(p.parent === params1.uid)
+ }
+
+ // At the previous implementation of ParamMap#filter,
+ // mutable.Map#filterKeys was used internally but
+ // the return type of the method is not serializable (see SI-6654).
+ // Now mutable.Map#filter is used instead of filterKeys and the return
type is serializable.
+ // So let's ensure serializability.
+ val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
+ try {
+ objOut.writeObject(filteredParamMap)
+ } catch {
+ case _: NotSerializableException =>
+ fail("The field of ParamMap 'map' may not be serializable. " +
+ "See SI-6654 and the implementation of ParamMap#filter")
+ case e: Exception =>
--- End diff --
Oh, yes, `ByteArrayOutputStream` doesn't need to be closed. I'll simplify
it.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]