stale[bot] closed pull request #3249: [AIRFLOW-2354] Change task instance run validation to not exclude das… URL: https://github.com/apache/incubator-airflow/pull/3249
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/www/views.py b/airflow/www/views.py index 5dda0362cc..6a5cec1e0c 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -944,14 +944,16 @@ def run(self): try: from airflow.executors import GetDefaultExecutor - from airflow.executors.celery_executor import CeleryExecutor + from airflow.executors.local_executor import LocalExecutor + from airflow.executors.sequential_executor import SequentialExecutor executor = GetDefaultExecutor() - if not isinstance(executor, CeleryExecutor): - flash("Only works with the CeleryExecutor, sorry", "error") + if isinstance(executor, LocalExecutor) or \ + isinstance(executor, SequentialExecutor): + flash("Doesn't work with the LocalExecutor or SequentialExecutor, sorry", + "error") return redirect(origin) except ImportError: - # in case CeleryExecutor cannot be imported it is not active either - flash("Only works with the CeleryExecutor, sorry", "error") + flash("Error when attempting to validate the executor", "error") return redirect(origin) ti = models.TaskInstance(task=task, execution_date=execution_date) diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py index f064c14c47..8d1ecfaf42 100644 --- a/airflow/www_rbac/views.py +++ b/airflow/www_rbac/views.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -591,14 +591,16 @@ def run(self): try: from airflow.executors import GetDefaultExecutor - from airflow.executors.celery_executor import CeleryExecutor + from airflow.executors.local_executor import LocalExecutor + from airflow.executors.sequential_executor import SequentialExecutor executor = GetDefaultExecutor() - if not isinstance(executor, CeleryExecutor): - flash("Only works with the CeleryExecutor, sorry", "error") + if isinstance(executor, LocalExecutor) or \ + isinstance(executor, SequentialExecutor): + flash("Doesn't work with the LocalExecutor or SequentialExecutor, sorry", + "error") return redirect(origin) except ImportError: - # in case CeleryExecutor cannot be imported it is not active either - flash("Only works with the CeleryExecutor, sorry", "error") + flash("Error when attempting to validate the executor", "error") return redirect(origin) ti = models.TaskInstance(task=task, execution_date=execution_date) diff --git a/tests/www/test_views.py b/tests/www/test_views.py index 3b2892d10c..18dcc70142 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,6 +29,7 @@ from urllib.parse import quote_plus from werkzeug.test import Client +from mock import Mock from airflow import models, configuration, settings from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG @@ -490,5 +491,80 @@ def test_mount(self): self.assertIn(b"DAGs", resp_html) +class TestRunTaskInstanceView(unittest.TestCase): + DAG_ID = 'example_bash_operator' + TASK_ID = 'runme_0' + DEFAULT_DATE = datetime(2018, 3, 1) + ENDPOINT = "/admin/airflow/run?task_id={task_id}&dag_id={dag_id}" \ + "&execution_date={execution_date}".format(dag_id=DAG_ID, + task_id=TASK_ID, + execution_date=DEFAULT_DATE) + EXPECTED_FLASH_ERROR = 'Doesn\'t work with the LocalExecutor or ' \ + 'SequentialExecutor, sorry' + EXPECTED_FLASH_MESSAGE = 'to the message queue, it should start any moment now.' + + @classmethod + def setUpClass(cls): + super(TestRunTaskInstanceView, cls).setUpClass() + session = Session() + session.query(TaskInstance).delete() + session.commit() + session.close() + + def setUp(self): + super(TestRunTaskInstanceView, self).setUp() + + def tearDown(self): + super(TestRunTaskInstanceView, self).tearDown() + + @classmethod + def tearDownClass(cls): + session = Session() + session.query(TaskInstance).delete() + session.commit() + session.close() + super(TestRunTaskInstanceView, cls).tearDownClass() + + def _create_app(self, executor): + from airflow import executors + executor = executors._get_executor(executor) + mock_executor = Mock(executor.__class__) + executors.GetDefaultExecutor = Mock(return_value=mock_executor) + + configuration.load_test_config() + app = application.create_app(testing=True) + app.config['WTF_CSRF_METHODS'] = [] + return app.test_client() + + def test_run_task_instance_with_sequential_executor(self): + app = self._create_app('SequentialExecutor') + + response = app.get( + TestRunTaskInstanceView.ENDPOINT, + ) + with app.session_transaction() as session: + flash_error = dict(session['_flashes']).get('error') + + self.assertEqual(response.status_code, 302) + self.assertIsNotNone(flash_error) + self.assertEqual(flash_error, self.EXPECTED_FLASH_ERROR) + + def test_run_task_instance_with_celery_executor(self): + app = self._create_app('CeleryExecutor') + + response = app.get( + TestRunTaskInstanceView.ENDPOINT, + ) + with app.session_transaction() as session: + flash_d = dict(session['_flashes']) + flash_error = flash_d.get('error') + flash_message = flash_d.get('message') + + self.assertEqual(response.status_code, 302) + self.assertIsNone(flash_error) + self.assertIsNotNone(flash_message) + self.assertTrue(self.EXPECTED_FLASH_MESSAGE in flash_message) + + if __name__ == '__main__': unittest.main() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services