Repository: incubator-airflow
Updated Branches:
  refs/heads/master bd010048b -> 05e1861e2


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/api/experimental/test_endpoints.py
----------------------------------------------------------------------
diff --git a/tests/www_rbac/api/experimental/test_endpoints.py 
b/tests/www_rbac/api/experimental/test_endpoints.py
new file mode 100644
index 0000000..32f1da2
--- /dev/null
+++ b/tests/www_rbac/api/experimental/test_endpoints.py
@@ -0,0 +1,302 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import timedelta
+import json
+import unittest
+from urllib.parse import quote_plus
+
+from airflow import configuration
+from airflow.api.common.experimental.trigger_dag import trigger_dag
+from airflow.models import DagBag, DagRun, Pool, TaskInstance
+from airflow.settings import Session
+from airflow.utils.timezone import datetime, utcnow
+from airflow.www_rbac import app as application
+
+
+class TestApiExperimental(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestApiExperimental, cls).setUpClass()
+        session = Session()
+        session.query(DagRun).delete()
+        session.query(TaskInstance).delete()
+        session.commit()
+        session.close()
+
+    def setUp(self):
+        super(TestApiExperimental, self).setUp()
+        configuration.load_test_config()
+        app, _ = application.create_app(testing=True)
+        self.app = app.test_client()
+
+    def tearDown(self):
+        session = Session()
+        session.query(DagRun).delete()
+        session.query(TaskInstance).delete()
+        session.commit()
+        session.close()
+        super(TestApiExperimental, self).tearDown()
+
+    def test_task_info(self):
+        url_template = '/api/experimental/dags/{}/tasks/{}'
+
+        response = self.app.get(
+            url_template.format('example_bash_operator', 'runme_0')
+        )
+        self.assertIn('"email"', response.data.decode('utf-8'))
+        self.assertNotIn('error', response.data.decode('utf-8'))
+        self.assertEqual(200, response.status_code)
+
+        response = self.app.get(
+            url_template.format('example_bash_operator', 'DNE')
+        )
+        self.assertIn('error', response.data.decode('utf-8'))
+        self.assertEqual(404, response.status_code)
+
+        response = self.app.get(
+            url_template.format('DNE', 'DNE')
+        )
+        self.assertIn('error', response.data.decode('utf-8'))
+        self.assertEqual(404, response.status_code)
+
+    def test_trigger_dag(self):
+        url_template = '/api/experimental/dags/{}/dag_runs'
+        response = self.app.post(
+            url_template.format('example_bash_operator'),
+            data=json.dumps({'run_id': 'my_run' + utcnow().isoformat()}),
+            content_type="application/json"
+        )
+
+        self.assertEqual(200, response.status_code)
+
+        response = self.app.post(
+            url_template.format('does_not_exist_dag'),
+            data=json.dumps({}),
+            content_type="application/json"
+        )
+        self.assertEqual(404, response.status_code)
+
+    def test_trigger_dag_for_date(self):
+        url_template = '/api/experimental/dags/{}/dag_runs'
+        dag_id = 'example_bash_operator'
+        hour_from_now = utcnow() + timedelta(hours=1)
+        execution_date = datetime(hour_from_now.year,
+                                  hour_from_now.month,
+                                  hour_from_now.day,
+                                  hour_from_now.hour)
+        datetime_string = execution_date.isoformat()
+
+        # Test Correct execution
+        response = self.app.post(
+            url_template.format(dag_id),
+            data=json.dumps({'execution_date': datetime_string}),
+            content_type="application/json"
+        )
+        self.assertEqual(200, response.status_code)
+
+        dagbag = DagBag()
+        dag = dagbag.get_dag(dag_id)
+        dag_run = dag.get_dagrun(execution_date)
+        self.assertTrue(dag_run,
+                        'Dag Run not found for execution date {}'
+                        .format(execution_date))
+
+        # Test error for nonexistent dag
+        response = self.app.post(
+            url_template.format('does_not_exist_dag'),
+            data=json.dumps({'execution_date': execution_date.isoformat()}),
+            content_type="application/json"
+        )
+        self.assertEqual(404, response.status_code)
+
+        # Test error for bad datetime format
+        response = self.app.post(
+            url_template.format(dag_id),
+            data=json.dumps({'execution_date': 'not_a_datetime'}),
+            content_type="application/json"
+        )
+        self.assertEqual(400, response.status_code)
+
+    def test_task_instance_info(self):
+        url_template = '/api/experimental/dags/{}/dag_runs/{}/tasks/{}'
+        dag_id = 'example_bash_operator'
+        task_id = 'also_run_this'
+        execution_date = utcnow().replace(microsecond=0)
+        datetime_string = quote_plus(execution_date.isoformat())
+        wrong_datetime_string = quote_plus(
+            datetime(1990, 1, 1, 1, 1, 1).isoformat()
+        )
+
+        # Create DagRun
+        trigger_dag(dag_id=dag_id,
+                    run_id='test_task_instance_info_run',
+                    execution_date=execution_date)
+
+        # Test Correct execution
+        response = self.app.get(
+            url_template.format(dag_id, datetime_string, task_id)
+        )
+        self.assertEqual(200, response.status_code)
+        self.assertIn('state', response.data.decode('utf-8'))
+        self.assertNotIn('error', response.data.decode('utf-8'))
+
+        # Test error for nonexistent dag
+        response = self.app.get(
+            url_template.format('does_not_exist_dag', datetime_string,
+                                task_id),
+        )
+        self.assertEqual(404, response.status_code)
+        self.assertIn('error', response.data.decode('utf-8'))
+
+        # Test error for nonexistent task
+        response = self.app.get(
+            url_template.format(dag_id, datetime_string, 'does_not_exist_task')
+        )
+        self.assertEqual(404, response.status_code)
+        self.assertIn('error', response.data.decode('utf-8'))
+
+        # Test error for nonexistent dag run (wrong execution_date)
+        response = self.app.get(
+            url_template.format(dag_id, wrong_datetime_string, task_id)
+        )
+        self.assertEqual(404, response.status_code)
+        self.assertIn('error', response.data.decode('utf-8'))
+
+        # Test error for bad datetime format
+        response = self.app.get(
+            url_template.format(dag_id, 'not_a_datetime', task_id)
+        )
+        self.assertEqual(400, response.status_code)
+        self.assertIn('error', response.data.decode('utf-8'))
+
+
+class TestPoolApiExperimental(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        super(TestPoolApiExperimental, cls).setUpClass()
+        session = Session()
+        session.query(Pool).delete()
+        session.commit()
+        session.close()
+
+    def setUp(self):
+        super(TestPoolApiExperimental, self).setUp()
+        configuration.load_test_config()
+        app, _ = application.create_app(testing=True)
+        self.app = app.test_client()
+        self.session = Session()
+        self.pools = []
+        for i in range(2):
+            name = 'experimental_%s' % (i + 1)
+            pool = Pool(
+                pool=name,
+                slots=i,
+                description=name,
+            )
+            self.session.add(pool)
+            self.pools.append(pool)
+        self.session.commit()
+        self.pool = self.pools[0]
+
+    def tearDown(self):
+        self.session.query(Pool).delete()
+        self.session.commit()
+        self.session.close()
+        super(TestPoolApiExperimental, self).tearDown()
+
+    def _get_pool_count(self):
+        response = self.app.get('/api/experimental/pools')
+        self.assertEqual(response.status_code, 200)
+        return len(json.loads(response.data.decode('utf-8')))
+
+    def test_get_pool(self):
+        response = self.app.get(
+            '/api/experimental/pools/{}'.format(self.pool.pool),
+        )
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(json.loads(response.data.decode('utf-8')),
+                         self.pool.to_json())
+
+    def test_get_pool_non_existing(self):
+        response = self.app.get('/api/experimental/pools/foo')
+        self.assertEqual(response.status_code, 404)
+        self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+                         "Pool 'foo' doesn't exist")
+
+    def test_get_pools(self):
+        response = self.app.get('/api/experimental/pools')
+        self.assertEqual(response.status_code, 200)
+        pools = json.loads(response.data.decode('utf-8'))
+        self.assertEqual(len(pools), 2)
+        for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
+            self.assertDictEqual(pool, self.pools[i].to_json())
+
+    def test_create_pool(self):
+        response = self.app.post(
+            '/api/experimental/pools',
+            data=json.dumps({
+                'name': 'foo',
+                'slots': 1,
+                'description': '',
+            }),
+            content_type='application/json',
+        )
+        self.assertEqual(response.status_code, 200)
+        pool = json.loads(response.data.decode('utf-8'))
+        self.assertEqual(pool['pool'], 'foo')
+        self.assertEqual(pool['slots'], 1)
+        self.assertEqual(pool['description'], '')
+        self.assertEqual(self._get_pool_count(), 3)
+
+    def test_create_pool_with_bad_name(self):
+        for name in ('', '    '):
+            response = self.app.post(
+                '/api/experimental/pools',
+                data=json.dumps({
+                    'name': name,
+                    'slots': 1,
+                    'description': '',
+                }),
+                content_type='application/json',
+            )
+            self.assertEqual(response.status_code, 400)
+            self.assertEqual(
+                json.loads(response.data.decode('utf-8'))['error'],
+                "Pool name shouldn't be empty",
+            )
+        self.assertEqual(self._get_pool_count(), 2)
+
+    def test_delete_pool(self):
+        response = self.app.delete(
+            '/api/experimental/pools/{}'.format(self.pool.pool),
+        )
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(json.loads(response.data.decode('utf-8')),
+                         self.pool.to_json())
+        self.assertEqual(self._get_pool_count(), 1)
+
+    def test_delete_pool_non_existing(self):
+        response = self.app.delete(
+            '/api/experimental/pools/foo',
+        )
+        self.assertEqual(response.status_code, 404)
+        self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+                         "Pool 'foo' doesn't exist")
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
----------------------------------------------------------------------
diff --git a/tests/www_rbac/api/experimental/test_kerberos_endpoints.py 
b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
new file mode 100644
index 0000000..1ab2b25
--- /dev/null
+++ b/tests/www_rbac/api/experimental/test_kerberos_endpoints.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import mock
+import os
+import socket
+import unittest
+
+from datetime import datetime
+
+from airflow import configuration
+from airflow.api.auth.backend.kerberos_auth import client_auth
+from airflow.www_rbac import app as application
+
+
[email protected]('KRB5_KTNAME' not in os.environ,
+                 'Skipping Kerberos API tests due to missing KRB5_KTNAME')
+class ApiKerberosTests(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        try:
+            configuration.conf.add_section("api")
+        except Exception:
+            pass
+        configuration.conf.set("api",
+                               "auth_backend",
+                               "airflow.api.auth.backend.kerberos_auth")
+        try:
+            configuration.conf.add_section("kerberos")
+        except Exception:
+            pass
+        configuration.conf.set("kerberos",
+                               "keytab",
+                               os.environ['KRB5_KTNAME'])
+
+        self.app, _ = application.create_app(testing=True)
+
+    def test_trigger_dag(self):
+        with self.app.test_client() as c:
+            url_template = '/api/experimental/dags/{}/dag_runs'
+            response = c.post(
+                url_template.format('example_bash_operator'),
+                data=json.dumps(dict(run_id='my_run' + 
datetime.now().isoformat())),
+                content_type="application/json"
+            )
+            self.assertEqual(401, response.status_code)
+
+            response.url = 'http://{}'.format(socket.getfqdn())
+
+            class Request():
+                headers = {}
+
+            response.request = Request()
+            response.content = ''
+            response.raw = mock.MagicMock()
+            response.connection = mock.MagicMock()
+            response.connection.send = mock.MagicMock()
+
+            # disable mutual authentication for testing
+            client_auth.mutual_authentication = 3
+
+            # case can influence the results
+            client_auth.hostname_override = socket.getfqdn()
+
+            client_auth.handle_response(response)
+            self.assertIn('Authorization', response.request.headers)
+
+            response2 = c.post(
+                url_template.format('example_bash_operator'),
+                data=json.dumps(dict(run_id='my_run' + 
datetime.now().isoformat())),
+                content_type="application/json",
+                headers=response.request.headers
+            )
+            self.assertEqual(200, response2.status_code)
+
+    def test_unauthorized(self):
+        with self.app.test_client() as c:
+            url_template = '/api/experimental/dags/{}/dag_runs'
+            response = c.post(
+                url_template.format('example_bash_operator'),
+                data=json.dumps(dict(run_id='my_run' + 
datetime.now().isoformat())),
+                content_type="application/json"
+            )
+
+            self.assertEqual(401, response.status_code)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log
----------------------------------------------------------------------
diff --git 
a/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log
 
b/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log
new file mode 100644
index 0000000..bc10ef7
--- /dev/null
+++ 
b/tests/www_rbac/test_logs/dag_for_testing_log_view/task_for_testing_log_view/2017-09-01T00.00.00/1.log
@@ -0,0 +1 @@
+Log for testing.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_security.py
----------------------------------------------------------------------
diff --git a/tests/www_rbac/test_security.py b/tests/www_rbac/test_security.py
new file mode 100644
index 0000000..81a3fad
--- /dev/null
+++ b/tests/www_rbac/test_security.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import unittest
+import logging
+
+from flask import Flask
+from flask_appbuilder import AppBuilder, SQLA, Model, has_access, expose
+from flask_appbuilder.models.sqla.interface import SQLAInterface
+from flask_appbuilder.views import ModelView, BaseView
+
+from sqlalchemy import Column, Integer, String, Date, Float
+
+from airflow.www_rbac.security import init_role
+
+logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s')
+logging.getLogger().setLevel(logging.DEBUG)
+log = logging.getLogger(__name__)
+
+
+class SomeModel(Model):
+    id = Column(Integer, primary_key=True)
+    field_string = Column(String(50), unique=True, nullable=False)
+    field_integer = Column(Integer())
+    field_float = Column(Float())
+    field_date = Column(Date())
+
+    def __repr__(self):
+        return str(self.field_string)
+
+
+class SomeModelView(ModelView):
+    datamodel = SQLAInterface(SomeModel)
+    base_permissions = ['can_list', 'can_show', 'can_add', 'can_edit', 
'can_delete']
+    list_columns = ['field_string', 'field_integer', 'field_float', 
'field_date']
+
+
+class SomeBaseView(BaseView):
+    route_base = ''
+
+    @expose('/some_action')
+    @has_access
+    def some_action(self):
+        return "action!"
+
+
+class TestSecurity(unittest.TestCase):
+    def setUp(self):
+        self.app = Flask(__name__)
+        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
+        self.app.config['SECRET_KEY'] = 'secret_key'
+        self.app.config['CSRF_ENABLED'] = False
+        self.app.config['WTF_CSRF_ENABLED'] = False
+        self.db = SQLA(self.app)
+        self.appbuilder = AppBuilder(self.app, self.db.session)
+        self.appbuilder.add_view(SomeBaseView, "SomeBaseView", 
category="BaseViews")
+        self.appbuilder.add_view(SomeModelView, "SomeModelView", 
category="ModelViews")
+
+        role_admin = self.appbuilder.sm.find_role('Admin')
+        self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', 
'[email protected]',
+                                                role_admin, 'general')
+        log.debug("Complete setup!")
+
+    def tearDown(self):
+        self.appbuilder = None
+        self.app = None
+        self.db = None
+        log.debug("Complete teardown!")
+
+    def test_init_role_baseview(self):
+        role_name = 'MyRole1'
+        role_perms = ['can_some_action']
+        role_vms = ['SomeBaseView']
+        init_role(self.appbuilder.sm, role_name, role_vms, role_perms)
+        role = self.appbuilder.sm.find_role(role_name)
+        self.assertIsNotNone(role)
+        self.assertEqual(len(role_perms), len(role.permissions))
+
+    def test_init_role_modelview(self):
+        role_name = 'MyRole2'
+        role_perms = ['can_list', 'can_show', 'can_add', 'can_edit', 
'can_delete']
+        role_vms = ['SomeModelView']
+        init_role(self.appbuilder.sm, role_name, role_vms, role_perms)
+        role = self.appbuilder.sm.find_role(role_name)
+        self.assertIsNotNone(role)
+        self.assertEqual(len(role_perms), len(role.permissions))
+
+    def test_invalid_perms(self):
+        role_name = 'MyRole3'
+        role_perms = ['can_foo']
+        role_vms = ['SomeBaseView']
+        with self.assertRaises(Exception) as context:
+            init_role(self.appbuilder.sm, role_name, role_vms, role_perms)
+        self.assertEqual("The following permissions are not valid: 
['can_foo']",
+                         str(context.exception))
+
+    def test_invalid_vms(self):
+        role_name = 'MyRole4'
+        role_perms = ['can_some_action']
+        role_vms = ['NonExistentBaseView']
+        with self.assertRaises(Exception) as context:
+            init_role(self.appbuilder.sm, role_name, role_vms, role_perms)
+        self.assertEqual("The following view menus are not valid: "
+                         "['NonExistentBaseView']",
+                         str(context.exception))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_utils.py
----------------------------------------------------------------------
diff --git a/tests/www_rbac/test_utils.py b/tests/www_rbac/test_utils.py
new file mode 100644
index 0000000..b13a02f
--- /dev/null
+++ b/tests/www_rbac/test_utils.py
@@ -0,0 +1,109 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from xml.dom import minidom
+
+from airflow.www_rbac import utils
+
+
+class UtilsTest(unittest.TestCase):
+
+    def setUp(self):
+        super(UtilsTest, self).setUp()
+
+    def test_normal_variable_should_not_be_hidden(self):
+        self.assertFalse(utils.should_hide_value_for_key("key"))
+
+    def test_sensitive_variable_should_be_hidden(self):
+        self.assertTrue(utils.should_hide_value_for_key("google_api_key"))
+
+    def test_sensitive_variable_should_be_hidden_ic(self):
+        self.assertTrue(utils.should_hide_value_for_key("GOOGLE_API_KEY"))
+
+    def check_generate_pages_html(self, current_page, total_pages,
+                                  window=7, check_middle=False):
+        extra_links = 4  # first, prev, next, last
+        html_str = utils.generate_pages(current_page, total_pages)
+
+        # dom parser has issues with special « and »
+        html_str = html_str.replace('«', '')
+        html_str = html_str.replace('»', '')
+        dom = minidom.parseString(html_str)
+        self.assertIsNotNone(dom)
+
+        ulist = dom.getElementsByTagName('ul')[0]
+        ulist_items = ulist.getElementsByTagName('li')
+        self.assertEqual(min(window, total_pages) + extra_links, 
len(ulist_items))
+
+        def get_text(nodelist):
+            rc = []
+            for node in nodelist:
+                if node.nodeType == node.TEXT_NODE:
+                    rc.append(node.data)
+            return ''.join(rc)
+
+        page_items = ulist_items[2:-2]
+        mid = int(len(page_items) / 2)
+        for i, item in enumerate(page_items):
+            a_node = item.getElementsByTagName('a')[0]
+            href_link = a_node.getAttribute('href')
+            node_text = get_text(a_node.childNodes)
+            if node_text == str(current_page + 1):
+                if check_middle:
+                    self.assertEqual(mid, i)
+                self.assertEqual('javascript:void(0)', 
a_node.getAttribute('href'))
+                self.assertIn('active', item.getAttribute('class'))
+            else:
+                link_str = '?page=' + str(int(node_text) - 1)
+                self.assertEqual(link_str, href_link)
+
+    def test_generate_pager_current_start(self):
+        self.check_generate_pages_html(current_page=0,
+                                       total_pages=6)
+
+    def test_generate_pager_current_middle(self):
+        self.check_generate_pages_html(current_page=10,
+                                       total_pages=20,
+                                       check_middle=True)
+
+    def test_generate_pager_current_end(self):
+        self.check_generate_pages_html(current_page=38,
+                                       total_pages=39)
+
+    def test_params_no_values(self):
+        """Should return an empty string if no params are passed"""
+        self.assertEquals('', utils.get_params())
+
+    def test_params_search(self):
+        self.assertEqual('search=bash_',
+                         utils.get_params(search='bash_'))
+
+    def test_params_showPaused_true(self):
+        """Should detect True as default for showPaused"""
+        self.assertEqual('',
+                         utils.get_params(showPaused=True))
+
+    def test_params_showPaused_false(self):
+        self.assertEqual('showPaused=False',
+                         utils.get_params(showPaused=False))
+
+    def test_params_all(self):
+        """Should return params string ordered by param key"""
+        self.assertEqual('page=3&search=bash_&showPaused=False',
+                         utils.get_params(showPaused=False, page=3, 
search='bash_'))
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_validators.py
----------------------------------------------------------------------
diff --git a/tests/www_rbac/test_validators.py 
b/tests/www_rbac/test_validators.py
new file mode 100644
index 0000000..38a6142
--- /dev/null
+++ b/tests/www_rbac/test_validators.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import mock
+import unittest
+
+from airflow.www_rbac import validators
+
+
+class TestGreaterEqualThan(unittest.TestCase):
+
+    def setUp(self):
+        super(TestGreaterEqualThan, self).setUp()
+        self.form_field_mock = mock.MagicMock(data='2017-05-06')
+        self.form_field_mock.gettext.side_effect = lambda msg: msg
+        self.other_field_mock = mock.MagicMock(data='2017-05-05')
+        self.other_field_mock.gettext.side_effect = lambda msg: msg
+        self.other_field_mock.label.text = 'other field'
+        self.form_stub = {'other_field': self.other_field_mock}
+        self.form_mock = mock.MagicMock(spec_set=dict)
+        self.form_mock.__getitem__.side_effect = self.form_stub.__getitem__
+
+    def _validate(self, fieldname=None, message=None):
+        if fieldname is None:
+            fieldname = 'other_field'
+
+        validator = validators.GreaterEqualThan(fieldname=fieldname,
+                                                message=message)
+
+        return validator(self.form_mock, self.form_field_mock)
+
+    def test_field_not_found(self):
+        self.assertRaisesRegexp(
+            validators.ValidationError,
+            "^Invalid field name 'some'.$",
+            self._validate,
+            fieldname='some',
+        )
+
+    def test_form_field_is_none(self):
+        self.form_field_mock.data = None
+
+        self.assertIsNone(self._validate())
+
+    def test_other_field_is_none(self):
+        self.other_field_mock.data = None
+
+        self.assertIsNone(self._validate())
+
+    def test_both_fields_are_none(self):
+        self.form_field_mock.data = None
+        self.other_field_mock.data = None
+
+        self.assertIsNone(self._validate())
+
+    def test_validation_pass(self):
+        self.assertIsNone(self._validate())
+
+    def test_validation_raises(self):
+        self.form_field_mock.data = '2017-05-04'
+
+        self.assertRaisesRegexp(
+            validators.ValidationError,
+            "^Field must be greater than or equal to other field.$",
+            self._validate,
+        )
+
+    def test_validation_raises_custom_message(self):
+        self.form_field_mock.data = '2017-05-04'
+
+        self.assertRaisesRegexp(
+            validators.ValidationError,
+            "^This field must be greater than or equal to MyField.$",
+            self._validate,
+            message="This field must be greater than or equal to MyField.",
+        )
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/05e1861e/tests/www_rbac/test_views.py
----------------------------------------------------------------------
diff --git a/tests/www_rbac/test_views.py b/tests/www_rbac/test_views.py
new file mode 100644
index 0000000..4f71df9
--- /dev/null
+++ b/tests/www_rbac/test_views.py
@@ -0,0 +1,405 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+import unittest
+import urllib
+from werkzeug.test import Client
+from flask._compat import PY2
+from flask_appbuilder.security.sqla.models import User as ab_user
+from airflow import models
+from airflow import configuration as conf
+from airflow.settings import Session
+from airflow.utils import timezone
+from airflow.utils.state import State
+from airflow.www_rbac import app as application
+
+
+class TestBase(unittest.TestCase):
+    def setUp(self):
+        conf.load_test_config()
+        self.app, self.appbuilder = application.create_app(testing=True)
+        self.app.config['WTF_CSRF_ENABLED'] = False
+        self.client = self.app.test_client()
+        self.session = Session()
+        self.login()
+
+    def login(self):
+        sm_session = self.appbuilder.sm.get_session()
+        self.user = sm_session.query(ab_user).first()
+        if not self.user:
+            role_admin = self.appbuilder.sm.find_role('Admin')
+            self.appbuilder.sm.add_user(
+                username='test',
+                first_name='test',
+                last_name='test',
+                email='[email protected]',
+                role=role_admin,
+                password='test')
+        return self.client.post('/login/', data=dict(
+            username='test',
+            password='test'
+        ), follow_redirects=True)
+
+    def logout(self):
+        return self.client.get('/logout/')
+
+    def clear_table(self, model):
+        self.session.query(model).delete()
+        self.session.commit()
+        self.session.close()
+
+    def check_content_in_response(self, text, resp, resp_code=200):
+        resp_html = resp.data.decode('utf-8')
+        self.assertEqual(resp_code, resp.status_code)
+        if isinstance(text, list):
+            for kw in text:
+                self.assertIn(kw, resp_html)
+        else:
+            self.assertIn(text, resp_html)
+
+    def percent_encode(self, obj):
+        if PY2:
+            return urllib.quote_plus(str(obj))
+        else:
+            return urllib.parse.quote_plus(str(obj))
+
+
+class TestConnectionModelView(TestBase):
+    def setUp(self):
+        super(TestConnectionModelView, self).setUp()
+        self.connection = {
+            'conn_id': 'test_conn',
+            'conn_type': 'http',
+            'host': 'localhost',
+            'port': 8080,
+            'username': 'root',
+            'password': 'admin'
+        }
+
+    def tearDown(self):
+        self.clear_table(models.Connection)
+        super(TestConnectionModelView, self).tearDown()
+
+    def test_create_connection(self):
+        resp = self.client.post('/connection/add',
+                                data=self.connection,
+                                follow_redirects=True)
+        self.check_content_in_response('Added Row', resp)
+
+
+class TestVariableModelView(TestBase):
+    def setUp(self):
+        super(TestVariableModelView, self).setUp()
+        self.variable = {
+            'key': 'test_key',
+            'val': 'text_val',
+            'is_encrypted': True
+        }
+
+    def tearDown(self):
+        self.clear_table(models.Variable)
+        super(TestVariableModelView, self).tearDown()
+
+    def test_can_handle_error_on_decrypt(self):
+
+        # create valid variable
+        resp = self.client.post('/variable/add',
+                                data=self.variable,
+                                follow_redirects=True)
+        self.assertEqual(resp.status_code, 200)
+        v = self.session.query(models.Variable).first()
+        self.assertEqual(v.key, 'test_key')
+        self.assertEqual(v.val, 'text_val')
+
+        # update the variable with a wrong value, given that is encrypted
+        Var = models.Variable
+        (self.session.query(Var)
+            .filter(Var.key == self.variable['key'])
+            .update({
+                'val': 'failed_value_not_encrypted'
+            }, synchronize_session=False))
+        self.session.commit()
+
+        # retrieve Variables page, should not fail and contain the Invalid
+        # label for the variable
+        resp = self.client.get('/variable/list', follow_redirects=True)
+        self.check_content_in_response(
+            '<span class="label label-danger">Invalid</span>', resp)
+
+    def test_xss_prevention(self):
+        xss = "/variable/list/<img%20src=''%20onerror='alert(1);'>"
+
+        resp = self.client.get(
+            xss,
+            follow_redirects=True,
+        )
+        self.assertEqual(resp.status_code, 404)
+        self.assertNotIn("<img src='' onerror='alert(1);'>",
+                         resp.data.decode("utf-8"))
+
+    def test_import_variables(self):
+        self.assertEqual(self.session.query(models.Variable).count(), 0)
+
+        content = ('{"str_key": "str_value", "int_key": 60,'
+                   '"list_key": [1, 2], "dict_key": {"k_a": 2, "k_b": 3}}')
+        try:
+            # python 3+
+            bytes_content = io.BytesIO(bytes(content, encoding='utf-8'))
+        except TypeError:
+            # python 2.7
+            bytes_content = io.BytesIO(bytes(content))
+
+        resp = self.client.post('/variable/varimport',
+                                data={'file': (bytes_content, 'test.json')},
+                                follow_redirects=True)
+        self.check_content_in_response('4 variable(s) successfully updated.', 
resp)
+
+
+class TestPoolModelView(TestBase):
+    def setUp(self):
+        super(TestPoolModelView, self).setUp()
+        self.pool = {
+            'pool': 'test-pool',
+            'slots': 777,
+            'description': 'test-pool-description',
+        }
+
+    def tearDown(self):
+        self.clear_table(models.Pool)
+        super(TestPoolModelView, self).tearDown()
+
+    def test_create_pool_with_same_name(self):
+        # create test pool
+        resp = self.client.post('/pool/add',
+                                data=self.pool,
+                                follow_redirects=True)
+        self.check_content_in_response('Added Row', resp)
+
+        # create pool with the same name
+        resp = self.client.post('/pool/add',
+                                data=self.pool,
+                                follow_redirects=True)
+        self.check_content_in_response('Already exists.', resp)
+
+    def test_create_pool_with_empty_name(self):
+
+        self.pool['pool'] = ''
+        resp = self.client.post('/pool/add',
+                                data=self.pool,
+                                follow_redirects=True)
+        self.check_content_in_response('This field is required.', resp)
+
+
+class TestMountPoint(unittest.TestCase):
+    def setUp(self):
+        application.app = None
+        super(TestMountPoint, self).setUp()
+        conf.load_test_config()
+        conf.set("webserver", "base_url", "http://localhost:8080/test";)
+        config = dict()
+        config['WTF_CSRF_METHODS'] = []
+        app = application.cached_app(config=config, testing=True)
+        self.client = Client(app)
+
+    def test_mount(self):
+        resp, _, _ = self.client.get('/', follow_redirects=True)
+        txt = b''.join(resp)
+        self.assertEqual(b"Apache Airflow is not at this location", txt)
+
+        resp, _, _ = self.client.get('/test/home', follow_redirects=True)
+        resp_html = b''.join(resp)
+        self.assertIn(b"DAGs", resp_html)
+
+
+class TestAirflowBaseViews(TestBase):
+    default_date = timezone.datetime(2018, 3, 1)
+    run_id = "test_{}".format(models.DagRun.id_for_date(default_date))
+
+    def setUp(self):
+        super(TestAirflowBaseViews, self).setUp()
+        self.cleanup_dagruns()
+        self.prepare_dagruns()
+
+    def cleanup_dagruns(self):
+        DR = models.DagRun
+        dag_ids = ['example_bash_operator',
+                   'example_subdag_operator',
+                   'example_xcom']
+        (self.session
+             .query(DR)
+             .filter(DR.dag_id.in_(dag_ids))
+             .filter(DR.run_id == self.run_id)
+             .delete(synchronize_session='fetch'))
+        self.session.commit()
+
+    def prepare_dagruns(self):
+        dagbag = models.DagBag(include_examples=True)
+        self.bash_dag = dagbag.dags['example_bash_operator']
+        self.sub_dag = dagbag.dags['example_subdag_operator']
+        self.xcom_dag = dagbag.dags['example_xcom']
+
+        self.bash_dagrun = self.bash_dag.create_dagrun(
+            run_id=self.run_id,
+            execution_date=self.default_date,
+            start_date=timezone.utcnow(),
+            state=State.RUNNING)
+
+        self.sub_dagrun = self.sub_dag.create_dagrun(
+            run_id=self.run_id,
+            execution_date=self.default_date,
+            start_date=timezone.utcnow(),
+            state=State.RUNNING)
+
+        self.xcom_dagrun = self.xcom_dag.create_dagrun(
+            run_id=self.run_id,
+            execution_date=self.default_date,
+            start_date=timezone.utcnow(),
+            state=State.RUNNING)
+
+    def test_index(self):
+        resp = self.client.get('/', follow_redirects=True)
+        self.check_content_in_response('DAGs', resp)
+
+    def test_health(self):
+        resp = self.client.get('health')
+        self.check_content_in_response('The server is healthy!', resp)
+
+    def test_home(self):
+        resp = self.client.get('home', follow_redirects=True)
+        self.check_content_in_response('DAGs', resp)
+
+    def test_task(self):
+        url = 
('task?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'
+               .format(self.percent_encode(self.default_date)))
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('Task Instance Details', resp)
+
+    def test_xcom(self):
+        url = 
('xcom?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'
+               .format(self.percent_encode(self.default_date)))
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('XCom', resp)
+
+    def test_rendered(self):
+        url = 
('rendered?task_id=runme_0&dag_id=example_bash_operator&execution_date={}'
+               .format(self.percent_encode(self.default_date)))
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('Rendered Template', resp)
+
+    def test_pickle_info(self):
+        url = 'pickle_info?dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.assertEqual(resp.status_code, 200)
+
+    def test_blocked(self):
+        url = 'blocked'
+        resp = self.client.get(url, follow_redirects=True)
+        self.assertEqual(200, resp.status_code)
+
+    def test_dag_stats(self):
+        resp = self.client.get('dag_stats', follow_redirects=True)
+        self.assertEqual(resp.status_code, 200)
+
+    def test_task_stats(self):
+        resp = self.client.get('task_stats', follow_redirects=True)
+        self.assertEqual(resp.status_code, 200)
+
+    def test_dag_details(self):
+        url = 'dag_details?dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('DAG details', resp)
+
+    def test_graph(self):
+        url = 'graph?dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('runme_1', resp)
+
+    def test_tree(self):
+        url = 'tree?dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('runme_1', resp)
+
+    def test_duration(self):
+        url = 'duration?days=30&dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('example_bash_operator', resp)
+
+    def test_tries(self):
+        url = 'tries?days=30&dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('example_bash_operator', resp)
+
+    def test_landing_times(self):
+        url = 'landing_times?days=30&dag_id=test_example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('example_bash_operator', resp)
+
+    def test_gantt(self):
+        url = 'gantt?dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('example_bash_operator', resp)
+
+    def test_code(self):
+        url = 'code?dag_id=example_bash_operator'
+        resp = self.client.get(url, follow_redirects=True)
+        self.check_content_in_response('example_bash_operator', resp)
+
+    def test_paused(self):
+        url = 'paused?dag_id=example_bash_operator&is_paused=false'
+        resp = self.client.post(url, follow_redirects=True)
+        self.check_content_in_response('OK', resp)
+
+    def test_success(self):
+
+        url = ('success?task_id=run_this_last&dag_id=example_bash_operator&'
+               
'execution_date={}&upstream=false&downstream=false&future=false&past=false'
+               .format(self.percent_encode(self.default_date)))
+        resp = self.client.get(url)
+        self.check_content_in_response('Wait a minute', resp)
+
+    def test_clear(self):
+        url = ('clear?task_id=runme_1&dag_id=example_bash_operator&'
+               
'execution_date={}&upstream=false&downstream=false&future=false&past=false'
+               .format(self.percent_encode(self.default_date)))
+        resp = self.client.get(url)
+        self.check_content_in_response(['example_bash_operator', 'Wait a 
minute'], resp)
+
+    def test_run(self):
+        url = 
('run?task_id=runme_0&dag_id=example_bash_operator&ignore_all_deps=false&'
+               'ignore_ti_state=true&execution_date={}'
+               .format(self.percent_encode(self.default_date)))
+        resp = self.client.get(url)
+        self.check_content_in_response('', resp, resp_code=302)
+
+    def test_refresh(self):
+        resp = self.client.get('refresh?dag_id=example_bash_operator')
+        self.check_content_in_response('', resp, resp_code=302)
+
+
+class TestConfigurationView(TestBase):
+    def test_configuration(self):
+        resp = self.client.get('configuration', follow_redirects=True)
+        self.check_content_in_response(
+            ['Airflow Configuration', 'Running Configuration'], resp)
+
+
+class TestVersionView(TestBase):
+    def test_version(self):
+        resp = self.client.get('version', follow_redirects=True)
+        self.check_content_in_response('Version Info', resp)
+
+
+if __name__ == '__main__':
+    unittest.main()


Reply via email to