Repository: incubator-airflow Updated Branches: refs/heads/master e3e6aa719 -> ae5c53b6c
[AIRFLOW-1276] Forbid event creation with end_data earlier than start_date Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/d5d02ff7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/d5d02ff7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/d5d02ff7 Branch: refs/heads/master Commit: d5d02ff7257621af40728ed4089d9fe812b98daf Parents: e3e6aa7 Author: Stanislav Kudriashev <[email protected]> Authored: Mon Jun 5 12:56:25 2017 +0300 Committer: Stanislav Kudriashev <[email protected]> Committed: Tue Jun 6 13:03:44 2017 +0300 ---------------------------------------------------------------------- airflow/www/validators.py | 54 +++++++++++++++++++++++ airflow/www/views.py | 18 +++++++- tests/www/test_validators.py | 92 +++++++++++++++++++++++++++++++++++++++ tests/www/test_views.py | 59 ++++++++++++++++++++++++- 4 files changed, 220 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d5d02ff7/airflow/www/validators.py ---------------------------------------------------------------------- diff --git a/airflow/www/validators.py b/airflow/www/validators.py new file mode 100644 index 0000000..4a72983 --- /dev/null +++ b/airflow/www/validators.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from wtforms.validators import EqualTo +from wtforms.validators import ValidationError + + +class GreaterEqualThan(EqualTo): + """Compares the values of two fields. + + :param fieldname: + The name of the other field to compare to. + :param message: + Error message to raise in case of a validation error. Can be + interpolated with `%(other_label)s` and `%(other_name)s` to provide a + more helpful error. + """ + + def __call__(self, form, field): + try: + other = form[self.fieldname] + except KeyError: + raise ValidationError( + field.gettext("Invalid field name '%s'." % self.fieldname) + ) + + if field.data is None or other.data is None: + return + + if field.data < other.data: + d = { + 'other_label': hasattr(other, 'label') and other.label.text + or self.fieldname, + 'other_name': self.fieldname, + } + message = self.message + if message is None: + message = field.gettext('Field must be greater than or equal ' + 'to %(other_label)s.' % d) + else: + message = message % d + + raise ValidationError(message) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d5d02ff7/airflow/www/views.py ---------------------------------------------------------------------- diff --git a/airflow/www/views.py b/airflow/www/views.py index 4fd52fe..e16d201 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -75,6 +75,7 @@ from airflow.utils import logging as log_utils from airflow.utils.dates import infer_time_unit, scale_time_units from airflow.www import utils as wwwutils from airflow.www.forms import DateTimeForm, DateTimeWithNumRunsForm +from airflow.www.validators import GreaterEqualThan from airflow.configuration import AirflowConfigException QUERY_LIMIT = 100000 @@ -2148,9 +2149,22 @@ class KnowEventView(wwwutils.DataProfilingMixin, AirflowModelView): 'start_date', 'end_date', 'reported_by', - 'description') + 'description', + ) + form_args = { + 'end_date': { + 'validators': { + GreaterEqualThan(fieldname='start_date'), + } + } + } column_list = ( - 'label', 'event_type', 'start_date', 'end_date', 'reported_by') + 'label', + 'event_type', + 'start_date', + 'end_date', + 'reported_by', + ) column_default_sort = ("start_date", True) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d5d02ff7/tests/www/test_validators.py ---------------------------------------------------------------------- diff --git a/tests/www/test_validators.py b/tests/www/test_validators.py new file mode 100644 index 0000000..14e964e --- /dev/null +++ b/tests/www/test_validators.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from airflow.www import validators + + +class TestGreaterEqualThan(unittest.TestCase): + + def setUp(self): + super(TestGreaterEqualThan, self).setUp() + self.form_field_mock = mock.MagicMock(data='2017-05-06') + self.form_field_mock.gettext.side_effect = lambda msg: msg + self.other_field_mock = mock.MagicMock(data='2017-05-05') + self.other_field_mock.gettext.side_effect = lambda msg: msg + self.other_field_mock.label.text = 'other field' + self.form_stub = {'other_field': self.other_field_mock} + self.form_mock = mock.MagicMock(spec_set=dict) + self.form_mock.__getitem__.side_effect = self.form_stub.__getitem__ + + def _validate(self, fieldname=None, message=None): + if fieldname is None: + fieldname = 'other_field' + + validator = validators.GreaterEqualThan(fieldname=fieldname, + message=message) + + return validator(self.form_mock, self.form_field_mock) + + def test_field_not_found(self): + self.assertRaisesRegexp( + validators.ValidationError, + "^Invalid field name 'some'.$", + self._validate, + fieldname='some', + ) + + def test_form_field_is_none(self): + self.form_field_mock.data = None + + self.assertIsNone(self._validate()) + + def test_other_field_is_none(self): + self.other_field_mock.data = None + + self.assertIsNone(self._validate()) + + def test_both_fields_are_none(self): + self.form_field_mock.data = None + self.other_field_mock.data = None + + self.assertIsNone(self._validate()) + + def test_validation_pass(self): + self.assertIsNone(self._validate()) + + def test_validation_raises(self): + self.form_field_mock.data = '2017-05-04' + + self.assertRaisesRegexp( + validators.ValidationError, + "^Field must be greater than or equal to other field.$", + self._validate, + ) + + def test_validation_raises_custom_message(self): + self.form_field_mock.data = '2017-05-04' + + self.assertRaisesRegexp( + validators.ValidationError, + "^This field must be greater than or equal to MyField.$", + self._validate, + message="This field must be greater than or equal to MyField.", + ) + + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d5d02ff7/tests/www/test_views.py ---------------------------------------------------------------------- diff --git a/tests/www/test_views.py b/tests/www/test_views.py index a8823e6..3e2fe61 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -15,11 +15,68 @@ import unittest from airflow import configuration -from airflow.models import Pool +from airflow.models import KnownEvent, Pool from airflow.settings import Session from airflow.www import app as application +class TestKnownEventView(unittest.TestCase): + + CREATE_ENDPOINT = '/admin/knownevent/new/?url=/admin/knownevent/' + + @classmethod + def setUpClass(cls): + super(TestKnownEventView, cls).setUpClass() + session = Session() + session.query(KnownEvent).delete() + session.commit() + session.close() + + def setUp(self): + super(TestKnownEventView, self).setUp() + configuration.load_test_config() + app = application.create_app(testing=True) + app.config['WTF_CSRF_METHODS'] = [] + self.app = app.test_client() + self.session = Session() + self.known_event = { + 'label': 'event-label', + 'event_type': 1, + 'start_date': '2017-06-05 12:00:00', + 'end_date': '2017-06-05 13:00:00', + 'reported_by': 'airflow', + 'description': '', + } + + def tearDown(self): + self.session.query(KnownEvent).delete() + self.session.commit() + self.session.close() + super(TestKnownEventView, self).tearDown() + + def test_create_known_event(self): + response = self.app.post( + self.CREATE_ENDPOINT, + data=self.known_event, + follow_redirects=True, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(self.session.query(KnownEvent).count(), 1) + + def test_create_known_event_with_end_data_earlier_than_start_date(self): + self.known_event['end_date'] = '2017-06-05 11:00:00' + response = self.app.post( + self.CREATE_ENDPOINT, + data=self.known_event, + follow_redirects=True, + ) + self.assertIn( + 'Field must be greater than or equal to Start Date.', + response.data.decode('utf-8'), + ) + self.assertEqual(self.session.query(KnownEvent).count(), 0) + + class TestPoolModelView(unittest.TestCase): CREATE_ENDPOINT = '/admin/pool/new/?url=/admin/pool/'
