Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/12604#discussion_r61016852
  
    --- Diff: python/pyspark/ml/tests.py ---
    @@ -684,6 +689,52 @@ def _compare_pipelines(self, m1, m2):
                 self.assertEqual(len(m1.stages), len(m2.stages))
                 for s1, s2 in zip(m1.stages, m2.stages):
                     self._compare_pipelines(s1, s2)
    +        elif isinstance(m1, OneVsRestParams):
    +            # Check the equality of classifiers (value and parent).
    +            self._compare_pipelines(m1.getClassifier(), m2.getClassifier())
    +            self.assertEqual(m1.classifier.parent, m2.classifier.parent)
    +
    +            # Check the equality of other params (value and parent).
    +            for p in m1.params:
    +                if p.name != "classifier":
    +                    self.assertEqual(m1.getOrDefault(p), 
m2.getOrDefault(p))
    +                    self.assertEqual(p.parent, m2.getParam(p.name).parent)
    +
    +            # Check extra attributes of OneVsRestModel.
    +            if isinstance(m1, OneVsRestModel):
    +                self.assertEqual(len(m1.models), len(m2.models))
    +                for x, y in zip(m1.models, m2.models):
    +                    self._compare_pipelines(x, y)
    +        elif isinstance(m1, ValidatorParams):
    +            # Check the equality of estimators (value and parent).
    +            self._compare_pipelines(m1.getEstimator(), m2.getEstimator())
    +            self.assertEqual(m1.estimator.parent, m2.estimator.parent)
    +
    +            # Check the equality of evaluators (value and parent).
    +            self._compare_pipelines(m1.getEvaluator(), m2.getEvaluator())
    +            self.assertEqual(m1.evaluator.parent, m2.evaluator.parent)
    +
    +            # Check the equality of estimator parameter maps (value and 
parent).
    +            self.assertEqual(len(m1.getEstimatorParamMaps()), 
len(m2.getEstimatorParamMaps()))
    +            for epm1, epm2 in zip(m1.getEstimatorParamMaps(), 
m2.getEstimatorParamMaps()):
    +                self.assertEqual(len(epm1), len(epm2))
    +                for pair in epm1:
    +                    self.assertIn(pair, epm2)
    --- End diff --
    
    Check value.  If value is an instance of ```Params```, then call 
```_compare_pipelines``` recursively on it.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to