Repository: incubator-airflow Updated Branches: refs/heads/master bd010048b -> 05e1861e2
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/api/experimental/test_endpoints.py ---------------------------------------------------------------------- diff --git a/tests/www_rbac/api/experimental/test_endpoints.py b/tests/www_rbac/api/experimental/test_endpoints.py new file mode 100644 index 0000000..32f1da2 --- /dev/null +++ b/tests/www_rbac/api/experimental/test_endpoints.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta +import json +import unittest +from urllib.parse import quote_plus + +from airflow import configuration +from airflow.api.common.experimental.trigger_dag import trigger_dag +from airflow.models import DagBag, DagRun, Pool, TaskInstance +from airflow.settings import Session +from airflow.utils.timezone import datetime, utcnow +from airflow.www_rbac import app as application + + +class TestApiExperimental(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestApiExperimental, cls).setUpClass() + session = Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.commit() + session.close() + + def setUp(self): + super(TestApiExperimental, self).setUp() + configuration.load_test_config() + app, _ = application.create_app(testing=True) + self.app = app.test_client() + + def tearDown(self): + session = Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.commit() + session.close() + super(TestApiExperimental, self).tearDown() + + def test_task_info(self): + url_template = '/api/experimental/dags/{}/tasks/{}' + + response = self.app.get( + url_template.format('example_bash_operator', 'runme_0') + ) + self.assertIn('"email"', response.data.decode('utf-8')) + self.assertNotIn('error', response.data.decode('utf-8')) + self.assertEqual(200, response.status_code) + + response = self.app.get( + url_template.format('example_bash_operator', 'DNE') + ) + self.assertIn('error', response.data.decode('utf-8')) + self.assertEqual(404, response.status_code) + + response = self.app.get( + url_template.format('DNE', 'DNE') + ) + self.assertIn('error', response.data.decode('utf-8')) + self.assertEqual(404, response.status_code) + + def test_trigger_dag(self): + url_template = '/api/experimental/dags/{}/dag_runs' + response = self.app.post( + url_template.format('example_bash_operator'), + data=json.dumps({'run_id': 'my_run' + utcnow().isoformat()}), + content_type="application/json" + ) + + self.assertEqual(200, response.status_code) + + response = self.app.post( + url_template.format('does_not_exist_dag'), + data=json.dumps({}), + content_type="application/json" + ) + self.assertEqual(404, response.status_code) + + def test_trigger_dag_for_date(self): + url_template = '/api/experimental/dags/{}/dag_runs' + dag_id = 'example_bash_operator' + hour_from_now = utcnow() + timedelta(hours=1) + execution_date = datetime(hour_from_now.year, + hour_from_now.month, + hour_from_now.day, + hour_from_now.hour) + datetime_string = execution_date.isoformat() + + # Test Correct execution + response = self.app.post( + url_template.format(dag_id), + data=json.dumps({'execution_date': datetime_string}), + content_type="application/json" + ) + self.assertEqual(200, response.status_code) + + dagbag = DagBag() + dag = dagbag.get_dag(dag_id) + dag_run = dag.get_dagrun(execution_date) + self.assertTrue(dag_run, + 'Dag Run not found for execution date {}' + .format(execution_date)) + + # Test error for nonexistent dag + response = self.app.post( + url_template.format('does_not_exist_dag'), + data=json.dumps({'execution_date': execution_date.isoformat()}), + content_type="application/json" + ) + self.assertEqual(404, response.status_code) + + # Test error for bad datetime format + response = self.app.post( + url_template.format(dag_id), + data=json.dumps({'execution_date': 'not_a_datetime'}), + content_type="application/json" + ) + self.assertEqual(400, response.status_code) + + def test_task_instance_info(self): + url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}' + dag_id = 'example_bash_operator' + task_id = 'also_run_this' + execution_date = utcnow().replace(microsecond=0) + datetime_string = quote_plus(execution_date.isoformat()) + wrong_datetime_string = quote_plus( + datetime(1990, 1, 1, 1, 1, 1).isoformat() + ) + + # Create DagRun + trigger_dag(dag_id=dag_id, + run_id='test_task_instance_info_run', + execution_date=execution_date) + + # Test Correct execution + response = self.app.get( + url_template.format(dag_id, datetime_string, task_id) + ) + self.assertEqual(200, response.status_code) + self.assertIn('state', response.data.decode('utf-8')) + self.assertNotIn('error', response.data.decode('utf-8')) + + # Test error for nonexistent dag + response = self.app.get( + url_template.format('does_not_exist_dag', datetime_string, + task_id), + ) + self.assertEqual(404, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + + # Test error for nonexistent task + response = self.app.get( + url_template.format(dag_id, datetime_string, 'does_not_exist_task') + ) + self.assertEqual(404, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + + # Test error for nonexistent dag run (wrong execution_date) + response = self.app.get( + url_template.format(dag_id, wrong_datetime_string, task_id) + ) + self.assertEqual(404, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + + # Test error for bad datetime format + response = self.app.get( + url_template.format(dag_id, 'not_a_datetime', task_id) + ) + self.assertEqual(400, response.status_code) + self.assertIn('error', response.data.decode('utf-8')) + + +class TestPoolApiExperimental(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestPoolApiExperimental, cls).setUpClass() + session = Session() + session.query(Pool).delete() + session.commit() + session.close() + + def setUp(self): + super(TestPoolApiExperimental, self).setUp() + configuration.load_test_config() + app, _ = application.create_app(testing=True) + self.app = app.test_client() + self.session = Session() + self.pools = [] + for i in range(2): + name = 'experimental_%s' % (i + 1) + pool = Pool( + pool=name, + slots=i, + description=name, + ) + self.session.add(pool) + self.pools.append(pool) + self.session.commit() + self.pool = self.pools[0] + + def tearDown(self): + self.session.query(Pool).delete() + self.session.commit() + self.session.close() + super(TestPoolApiExperimental, self).tearDown() + + def _get_pool_count(self): + response = self.app.get('/api/experimental/pools') + self.assertEqual(response.status_code, 200) + return len(json.loads(response.data.decode('utf-8'))) + + def test_get_pool(self): + response = self.app.get( + '/api/experimental/pools/{}'.format(self.pool.pool), + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.data.decode('utf-8')), + self.pool.to_json()) + + def test_get_pool_non_existing(self): + response = self.app.get('/api/experimental/pools/foo') + self.assertEqual(response.status_code, 404) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "Pool 'foo' doesn't exist") + + def test_get_pools(self): + response = self.app.get('/api/experimental/pools') + self.assertEqual(response.status_code, 200) + pools = json.loads(response.data.decode('utf-8')) + self.assertEqual(len(pools), 2) + for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])): + self.assertDictEqual(pool, self.pools[i].to_json()) + + def test_create_pool(self): + response = self.app.post( + '/api/experimental/pools', + data=json.dumps({ + 'name': 'foo', + 'slots': 1, + 'description': '', + }), + content_type='application/json', + ) + self.assertEqual(response.status_code, 200) + pool = json.loads(response.data.decode('utf-8')) + self.assertEqual(pool['pool'], 'foo') + self.assertEqual(pool['slots'], 1) + self.assertEqual(pool['description'], '') + self.assertEqual(self._get_pool_count(), 3) + + def test_create_pool_with_bad_name(self): + for name in ('', ' '): + response = self.app.post( + '/api/experimental/pools', + data=json.dumps({ + 'name': name, + 'slots': 1, + 'description': '', + }), + content_type='application/json', + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + json.loads(response.data.decode('utf-8'))['error'], + "Pool name shouldn't be empty", + ) + self.assertEqual(self._get_pool_count(), 2) + + def test_delete_pool(self): + response = self.app.delete( + '/api/experimental/pools/{}'.format(self.pool.pool), + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.data.decode('utf-8')), + self.pool.to_json()) + self.assertEqual(self._get_pool_count(), 1) + + def test_delete_pool_non_existing(self): + response = self.app.delete( + '/api/experimental/pools/foo', + ) + self.assertEqual(response.status_code, 404) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "Pool 'foo' doesn't exist") + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/api/experimental/test_kerberos_endpoints.py ---------------------------------------------------------------------- diff --git a/tests/www_rbac/api/experimental/test_kerberos_endpoints.py b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py new file mode 100644 index 0000000..1ab2b25 --- /dev/null +++ b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import mock +import os +import socket +import unittest + +from datetime import datetime + +from airflow import configuration +from airflow.api.auth.backend.kerberos_auth import client_auth +from airflow.www_rbac import app as application + + [email protected]('KRB5_KTNAME' not in os.environ, + 'Skipping Kerberos API tests due to missing KRB5_KTNAME') +class ApiKerberosTests(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + try: + configuration.conf.add_section("api") + except Exception: + pass + configuration.conf.set("api", + "auth_backend", + "airflow.api.auth.backend.kerberos_auth") + try: + configuration.conf.add_section("kerberos") + except Exception: + pass + configuration.conf.set("kerberos", + "keytab", + os.environ['KRB5_KTNAME']) + + self.app, _ = application.create_app(testing=True) + + def test_trigger_dag(self): + with self.app.test_client() as c: + url_template = '/api/experimental/dags/{}/dag_runs' + response = c.post( + url_template.format('example_bash_operator'), + data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), + content_type="application/json" + ) + self.assertEqual(401, response.status_code) + + response.url = 'http://{}'.format(socket.getfqdn()) + + class Request(): + headers = {} + + response.request = Request() + response.content = '' + response.raw = mock.MagicMock() + response.connection = mock.MagicMock() + response.connection.send = mock.MagicMock() + + # disable mutual authentication for testing + client_auth.mutual_authentication = 3 + + # case can influence the results + client_auth.hostname_override = socket.getfqdn() + + client_auth.handle_response(response) + self.assertIn('Authorization', response.request.headers) + + response2 = c.post( + url_template.format('example_bash_operator'), + data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), + content_type="application/json", + headers=response.request.headers + ) + self.assertEqual(200, response2.status_code) + + def test_unauthorized(self): + with self.app.test_client() as c: + url_template = '/api/experimental/dags/{}/dag_runs' + response = c.post( + url_template.format('example_bash_operator'), + data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), + content_type="application/json" + ) + + self.assertEqual(401, response.status_code) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log ---------------------------------------------------------------------- diff --git a/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log b/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log new file mode 100644 index 0000000..bc10ef7 --- /dev/null +++ b/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log @@ -0,0 +1 @@ +Log for testing. http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_security.py ---------------------------------------------------------------------- diff --git a/tests/www_rbac/test_security.py b/tests/www_rbac/test_security.py new file mode 100644 index 0000000..81a3fad --- /dev/null +++ b/tests/www_rbac/test_security.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import logging + +from flask import Flask +from flask_appbuilder import AppBuilder, SQLA, Model, has_access, expose +from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_appbuilder.views import ModelView, BaseView + +from sqlalchemy import Column, Integer, String, Date, Float + +from airflow.www_rbac.security import init_role + +logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s') +logging.getLogger().setLevel(logging.DEBUG) +log = logging.getLogger(__name__) + + +class SomeModel(Model): + id = Column(Integer, primary_key=True) + field_string = Column(String(50), unique=True, nullable=False) + field_integer = Column(Integer()) + field_float = Column(Float()) + field_date = Column(Date()) + + def __repr__(self): + return str(self.field_string) + + +class SomeModelView(ModelView): + datamodel = SQLAInterface(SomeModel) + base_permissions = ['can_list', 'can_show', 'can_add', 'can_edit', 'can_delete'] + list_columns = ['field_string', 'field_integer', 'field_float', 'field_date'] + + +class SomeBaseView(BaseView): + route_base = '' + + @expose('/some_action') + @has_access + def some_action(self): + return "action!" + + +class TestSecurity(unittest.TestCase): + def setUp(self): + self.app = Flask(__name__) + self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + self.app.config['SECRET_KEY'] = 'secret_key' + self.app.config['CSRF_ENABLED'] = False + self.app.config['WTF_CSRF_ENABLED'] = False + self.db = SQLA(self.app) + self.appbuilder = AppBuilder(self.app, self.db.session) + self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") + self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") + + role_admin = self.appbuilder.sm.find_role('Admin') + self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '[email protected]', + role_admin, 'general') + log.debug("Complete setup!") + + def tearDown(self): + self.appbuilder = None + self.app = None + self.db = None + log.debug("Complete teardown!") + + def test_init_role_baseview(self): + role_name = 'MyRole1' + role_perms = ['can_some_action'] + role_vms = ['SomeBaseView'] + init_role(self.appbuilder.sm, role_name, role_vms, role_perms) + role = self.appbuilder.sm.find_role(role_name) + self.assertIsNotNone(role) + self.assertEqual(len(role_perms), len(role.permissions)) + + def test_init_role_modelview(self): + role_name = 'MyRole2' + role_perms = ['can_list', 'can_show', 'can_add', 'can_edit', 'can_delete'] + role_vms = ['SomeModelView'] + init_role(self.appbuilder.sm, role_name, role_vms, role_perms) + role = self.appbuilder.sm.find_role(role_name) + self.assertIsNotNone(role) + self.assertEqual(len(role_perms), len(role.permissions)) + + def test_invalid_perms(self): + role_name = 'MyRole3' + role_perms = ['can_foo'] + role_vms = ['SomeBaseView'] + with self.assertRaises(Exception) as context: + init_role(self.appbuilder.sm, role_name, role_vms, role_perms) + self.assertEqual("The following permissions are not valid: ['can_foo']", + str(context.exception)) + + def test_invalid_vms(self): + role_name = 'MyRole4' + role_perms = ['can_some_action'] + role_vms = ['NonExistentBaseView'] + with self.assertRaises(Exception) as context: + init_role(self.appbuilder.sm, role_name, role_vms, role_perms) + self.assertEqual("The following view menus are not valid: " + "['NonExistentBaseView']", + str(context.exception)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_utils.py ---------------------------------------------------------------------- diff --git a/tests/www_rbac/test_utils.py b/tests/www_rbac/test_utils.py new file mode 100644 index 0000000..b13a02f --- /dev/null +++ b/tests/www_rbac/test_utils.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from xml.dom import minidom + +from airflow.www_rbac import utils + + +class UtilsTest(unittest.TestCase): + + def setUp(self): + super(UtilsTest, self).setUp() + + def test_normal_variable_should_not_be_hidden(self): + self.assertFalse(utils.should_hide_value_for_key("key")) + + def test_sensitive_variable_should_be_hidden(self): + self.assertTrue(utils.should_hide_value_for_key("google_api_key")) + + def test_sensitive_variable_should_be_hidden_ic(self): + self.assertTrue(utils.should_hide_value_for_key("GOOGLE_API_KEY")) + + def check_generate_pages_html(self, current_page, total_pages, + window=7, check_middle=False): + extra_links = 4 # first, prev, next, last + html_str = utils.generate_pages(current_page, total_pages) + + # dom parser has issues with special « and » + html_str = html_str.replace('«', '') + html_str = html_str.replace('»', '') + dom = minidom.parseString(html_str) + self.assertIsNotNone(dom) + + ulist = dom.getElementsByTagName('ul')[0] + ulist_items = ulist.getElementsByTagName('li') + self.assertEqual(min(window, total_pages) + extra_links, len(ulist_items)) + + def get_text(nodelist): + rc = [] + for node in nodelist: + if node.nodeType == node.TEXT_NODE: + rc.append(node.data) + return ''.join(rc) + + page_items = ulist_items[2:-2] + mid = int(len(page_items) / 2) + for i, item in enumerate(page_items): + a_node = item.getElementsByTagName('a')[0] + href_link = a_node.getAttribute('href') + node_text = get_text(a_node.childNodes) + if node_text == str(current_page + 1): + if check_middle: + self.assertEqual(mid, i) + self.assertEqual('javascript:void(0)', a_node.getAttribute('href')) + self.assertIn('active', item.getAttribute('class')) + else: + link_str = '?page=' + str(int(node_text) - 1) + self.assertEqual(link_str, href_link) + + def test_generate_pager_current_start(self): + self.check_generate_pages_html(current_page=0, + total_pages=6) + + def test_generate_pager_current_middle(self): + self.check_generate_pages_html(current_page=10, + total_pages=20, + check_middle=True) + + def test_generate_pager_current_end(self): + self.check_generate_pages_html(current_page=38, + total_pages=39) + + def test_params_no_values(self): + """Should return an empty string if no params are passed""" + self.assertEquals('', utils.get_params()) + + def test_params_search(self): + self.assertEqual('search=bash_', + utils.get_params(search='bash_')) + + def test_params_showPaused_true(self): + """Should detect True as default for showPaused""" + self.assertEqual('', + utils.get_params(showPaused=True)) + + def test_params_showPaused_false(self): + self.assertEqual('showPaused=False', + utils.get_params(showPaused=False)) + + def test_params_all(self): + """Should return params string ordered by param key""" + self.assertEqual('page=3&search=bash_&showPaused=False', + utils.get_params(showPaused=False, page=3, search='bash_')) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_validators.py ---------------------------------------------------------------------- diff --git a/tests/www_rbac/test_validators.py b/tests/www_rbac/test_validators.py new file mode 100644 index 0000000..38a6142 --- /dev/null +++ b/tests/www_rbac/test_validators.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import unittest + +from airflow.www_rbac import validators + + +class TestGreaterEqualThan(unittest.TestCase): + + def setUp(self): + super(TestGreaterEqualThan, self).setUp() + self.form_field_mock = mock.MagicMock(data='2017-05-06') + self.form_field_mock.gettext.side_effect = lambda msg: msg + self.other_field_mock = mock.MagicMock(data='2017-05-05') + self.other_field_mock.gettext.side_effect = lambda msg: msg + self.other_field_mock.label.text = 'other field' + self.form_stub = {'other_field': self.other_field_mock} + self.form_mock = mock.MagicMock(spec_set=dict) + self.form_mock.__getitem__.side_effect = self.form_stub.__getitem__ + + def _validate(self, fieldname=None, message=None): + if fieldname is None: + fieldname = 'other_field' + + validator = validators.GreaterEqualThan(fieldname=fieldname, + message=message) + + return validator(self.form_mock, self.form_field_mock) + + def test_field_not_found(self): + self.assertRaisesRegexp( + validators.ValidationError, + "^Invalid field name 'some'.$", + self._validate, + fieldname='some', + ) + + def test_form_field_is_none(self): + self.form_field_mock.data = None + + self.assertIsNone(self._validate()) + + def test_other_field_is_none(self): + self.other_field_mock.data = None + + self.assertIsNone(self._validate()) + + def test_both_fields_are_none(self): + self.form_field_mock.data = None + self.other_field_mock.data = None + + self.assertIsNone(self._validate()) + + def test_validation_pass(self): + self.assertIsNone(self._validate()) + + def test_validation_raises(self): + self.form_field_mock.data = '2017-05-04' + + self.assertRaisesRegexp( + validators.ValidationError, + "^Field must be greater than or equal to other field.$", + self._validate, + ) + + def test_validation_raises_custom_message(self): + self.form_field_mock.data = '2017-05-04' + + self.assertRaisesRegexp( + validators.ValidationError, + "^This field must be greater than or equal to MyField.$", + self._validate, + message="This field must be greater than or equal to MyField.", + ) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_views.py ---------------------------------------------------------------------- diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py new file mode 100644 index 0000000..4f71df9 --- /dev/null +++ b/tests/www_rbac/test_views.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import unittest +import urllib +from werkzeug.test import Client +from flask._compat import PY2 +from flask_appbuilder.security.sqla.models import User as ab_user +from airflow import models +from airflow import configuration as conf +from airflow.settings import Session +from airflow.utils import timezone +from airflow.utils.state import State +from airflow.www_rbac import app as application + + +class TestBase(unittest.TestCase): + def setUp(self): + conf.load_test_config() + self.app, self.appbuilder = application.create_app(testing=True) + self.app.config['WTF_CSRF_ENABLED'] = False + self.client = self.app.test_client() + self.session = Session() + self.login() + + def login(self): + sm_session = self.appbuilder.sm.get_session() + self.user = sm_session.query(ab_user).first() + if not self.user: + role_admin = self.appbuilder.sm.find_role('Admin') + self.appbuilder.sm.add_user( + username='test', + first_name='test', + last_name='test', + email='[email protected]', + role=role_admin, + password='test') + return self.client.post('/login/', data=dict( + username='test', + password='test' + ), follow_redirects=True) + + def logout(self): + return self.client.get('/logout/') + + def clear_table(self, model): + self.session.query(model).delete() + self.session.commit() + self.session.close() + + def check_content_in_response(self, text, resp, resp_code=200): + resp_html = resp.data.decode('utf-8') + self.assertEqual(resp_code, resp.status_code) + if isinstance(text, list): + for kw in text: + self.assertIn(kw, resp_html) + else: + self.assertIn(text, resp_html) + + def percent_encode(self, obj): + if PY2: + return urllib.quote_plus(str(obj)) + else: + return urllib.parse.quote_plus(str(obj)) + + +class TestConnectionModelView(TestBase): + def setUp(self): + super(TestConnectionModelView, self).setUp() + self.connection = { + 'conn_id': 'test_conn', + 'conn_type': 'http', + 'host': 'localhost', + 'port': 8080, + 'username': 'root', + 'password': 'admin' + } + + def tearDown(self): + self.clear_table(models.Connection) + super(TestConnectionModelView, self).tearDown() + + def test_create_connection(self): + resp = self.client.post('/connection/add', + data=self.connection, + follow_redirects=True) + self.check_content_in_response('Added Row', resp) + + +class TestVariableModelView(TestBase): + def setUp(self): + super(TestVariableModelView, self).setUp() + self.variable = { + 'key': 'test_key', + 'val': 'text_val', + 'is_encrypted': True + } + + def tearDown(self): + self.clear_table(models.Variable) + super(TestVariableModelView, self).tearDown() + + def test_can_handle_error_on_decrypt(self): + + # create valid variable + resp = self.client.post('/variable/add', + data=self.variable, + follow_redirects=True) + self.assertEqual(resp.status_code, 200) + v = self.session.query(models.Variable).first() + self.assertEqual(v.key, 'test_key') + self.assertEqual(v.val, 'text_val') + + # update the variable with a wrong value, given that is encrypted + Var = models.Variable + (self.session.query(Var) + .filter(Var.key == self.variable['key']) + .update({ + 'val': 'failed_value_not_encrypted' + }, synchronize_session=False)) + self.session.commit() + + # retrieve Variables page, should not fail and contain the Invalid + # label for the variable + resp = self.client.get('/variable/list', follow_redirects=True) + self.check_content_in_response( + '<span class="label label-danger">Invalid</span>', resp) + + def test_xss_prevention(self): + xss = "/variable/list/<img%20src=''%20onerror='alert(1);'>" + + resp = self.client.get( + xss, + follow_redirects=True, + ) + self.assertEqual(resp.status_code, 404) + self.assertNotIn("<img src='' onerror='alert(1);'>", + resp.data.decode("utf-8")) + + def test_import_variables(self): + self.assertEqual(self.session.query(models.Variable).count(), 0) + + content = ('{"str_key": "str_value", "int_key": 60,' + '"list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}') + try: + # python 3+ + bytes_content = io.BytesIO(bytes(content, encoding='utf-8')) + except TypeError: + # python 2.7 + bytes_content = io.BytesIO(bytes(content)) + + resp = self.client.post('/variable/varimport', + data={'file': (bytes_content, 'test.json')}, + follow_redirects=True) + self.check_content_in_response('4 variable(s) successfully updated.', resp) + + +class TestPoolModelView(TestBase): + def setUp(self): + super(TestPoolModelView, self).setUp() + self.pool = { + 'pool': 'test-pool', + 'slots': 777, + 'description': 'test-pool-description', + } + + def tearDown(self): + self.clear_table(models.Pool) + super(TestPoolModelView, self).tearDown() + + def test_create_pool_with_same_name(self): + # create test pool + resp = self.client.post('/pool/add', + data=self.pool, + follow_redirects=True) + self.check_content_in_response('Added Row', resp) + + # create pool with the same name + resp = self.client.post('/pool/add', + data=self.pool, + follow_redirects=True) + self.check_content_in_response('Already exists.', resp) + + def test_create_pool_with_empty_name(self): + + self.pool['pool'] = '' + resp = self.client.post('/pool/add', + data=self.pool, + follow_redirects=True) + self.check_content_in_response('This field is required.', resp) + + +class TestMountPoint(unittest.TestCase): + def setUp(self): + application.app = None + super(TestMountPoint, self).setUp() + conf.load_test_config() + conf.set("webserver", "base_url", "http://localhost:8080/test") + config = dict() + config['WTF_CSRF_METHODS'] = [] + app = application.cached_app(config=config, testing=True) + self.client = Client(app) + + def test_mount(self): + resp, _, _ = self.client.get('/', follow_redirects=True) + txt = b''.join(resp) + self.assertEqual(b"Apache Airflow is not at this location", txt) + + resp, _, _ = self.client.get('/test/home', follow_redirects=True) + resp_html = b''.join(resp) + self.assertIn(b"DAGs", resp_html) + + +class TestAirflowBaseViews(TestBase): + default_date = timezone.datetime(2018, 3, 1) + run_id = "test_{}".format(models.DagRun.id_for_date(default_date)) + + def setUp(self): + super(TestAirflowBaseViews, self).setUp() + self.cleanup_dagruns() + self.prepare_dagruns() + + def cleanup_dagruns(self): + DR = models.DagRun + dag_ids = ['example_bash_operator', + 'example_subdag_operator', + 'example_xcom'] + (self.session + .query(DR) + .filter(DR.dag_id.in_(dag_ids)) + .filter(DR.run_id == self.run_id) + .delete(synchronize_session='fetch')) + self.session.commit() + + def prepare_dagruns(self): + dagbag = models.DagBag(include_examples=True) + self.bash_dag = dagbag.dags['example_bash_operator'] + self.sub_dag = dagbag.dags['example_subdag_operator'] + self.xcom_dag = dagbag.dags['example_xcom'] + + self.bash_dagrun = self.bash_dag.create_dagrun( + run_id=self.run_id, + execution_date=self.default_date, + start_date=timezone.utcnow(), + state=State.RUNNING) + + self.sub_dagrun = self.sub_dag.create_dagrun( + run_id=self.run_id, + execution_date=self.default_date, + start_date=timezone.utcnow(), + state=State.RUNNING) + + self.xcom_dagrun = self.xcom_dag.create_dagrun( + run_id=self.run_id, + execution_date=self.default_date, + start_date=timezone.utcnow(), + state=State.RUNNING) + + def test_index(self): + resp = self.client.get('/', follow_redirects=True) + self.check_content_in_response('DAGs', resp) + + def test_health(self): + resp = self.client.get('health') + self.check_content_in_response('The server is healthy!', resp) + + def test_home(self): + resp = self.client.get('home', follow_redirects=True) + self.check_content_in_response('DAGs', resp) + + def test_task(self): + url = ('task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}' + .format(self.percent_encode(self.default_date))) + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('Task Instance Details', resp) + + def test_xcom(self): + url = ('xcom?task_id=runme_0&dag_id=example_bash_operator&execution_date={}' + .format(self.percent_encode(self.default_date))) + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('XCom', resp) + + def test_rendered(self): + url = ('rendered?task_id=runme_0&dag_id=example_bash_operator&execution_date={}' + .format(self.percent_encode(self.default_date))) + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('Rendered Template', resp) + + def test_pickle_info(self): + url = 'pickle_info?dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.assertEqual(resp.status_code, 200) + + def test_blocked(self): + url = 'blocked' + resp = self.client.get(url, follow_redirects=True) + self.assertEqual(200, resp.status_code) + + def test_dag_stats(self): + resp = self.client.get('dag_stats', follow_redirects=True) + self.assertEqual(resp.status_code, 200) + + def test_task_stats(self): + resp = self.client.get('task_stats', follow_redirects=True) + self.assertEqual(resp.status_code, 200) + + def test_dag_details(self): + url = 'dag_details?dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('DAG details', resp) + + def test_graph(self): + url = 'graph?dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('runme_1', resp) + + def test_tree(self): + url = 'tree?dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('runme_1', resp) + + def test_duration(self): + url = 'duration?days=30&dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('example_bash_operator', resp) + + def test_tries(self): + url = 'tries?days=30&dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('example_bash_operator', resp) + + def test_landing_times(self): + url = 'landing_times?days=30&dag_id=test_example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('example_bash_operator', resp) + + def test_gantt(self): + url = 'gantt?dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('example_bash_operator', resp) + + def test_code(self): + url = 'code?dag_id=example_bash_operator' + resp = self.client.get(url, follow_redirects=True) + self.check_content_in_response('example_bash_operator', resp) + + def test_paused(self): + url = 'paused?dag_id=example_bash_operator&is_paused=false' + resp = self.client.post(url, follow_redirects=True) + self.check_content_in_response('OK', resp) + + def test_success(self): + + url = ('success?task_id=run_this_last&dag_id=example_bash_operator&' + 'execution_date={}&upstream=false&downstream=false&future=false&past=false' + .format(self.percent_encode(self.default_date))) + resp = self.client.get(url) + self.check_content_in_response('Wait a minute', resp) + + def test_clear(self): + url = ('clear?task_id=runme_1&dag_id=example_bash_operator&' + 'execution_date={}&upstream=false&downstream=false&future=false&past=false' + .format(self.percent_encode(self.default_date))) + resp = self.client.get(url) + self.check_content_in_response(['example_bash_operator', 'Wait a minute'], resp) + + def test_run(self): + url = ('run?task_id=runme_0&dag_id=example_bash_operator&ignore_all_deps=false&' + 'ignore_ti_state=true&execution_date={}' + .format(self.percent_encode(self.default_date))) + resp = self.client.get(url) + self.check_content_in_response('', resp, resp_code=302) + + def test_refresh(self): + resp = self.client.get('refresh?dag_id=example_bash_operator') + self.check_content_in_response('', resp, resp_code=302) + + +class TestConfigurationView(TestBase): + def test_configuration(self): + resp = self.client.get('configuration', follow_redirects=True) + self.check_content_in_response( + ['Airflow Configuration', 'Running Configuration'], resp) + + +class TestVersionView(TestBase): + def test_version(self): + resp = self.client.get('version', follow_redirects=True) + self.check_content_in_response('Version Info', resp) + + +if __name__ == '__main__': + unittest.main()
