Repository: incubator-airflow Updated Branches: refs/heads/master a45e2d188 -> 9958aa9d5
[AIRFLOW-1275] Put 'airflow pool' into API Closes #2346 from skudriashev/airflow-1275 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/9958aa9d Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/9958aa9d Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/9958aa9d Branch: refs/heads/master Commit: 9958aa9d5326b75cf7082c0bc36c13b063f1924f Parents: a45e2d1 Author: Stanislav Kudriashev <[email protected]> Authored: Wed Jun 21 16:36:45 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Wed Jun 21 16:36:51 2017 +0200 ---------------------------------------------------------------------- airflow/api/client/api_client.py | 36 +- airflow/api/client/json_client.py | 60 +++- airflow/api/client/local_client.py | 20 +- airflow/api/common/experimental/pool.py | 85 +++++ airflow/bin/cli.py | 52 ++- airflow/models.py | 8 + airflow/www/api/experimental/endpoints.py | 68 +++- tests/api/__init__.py | 6 - tests/api/client/local_client.py | 107 ------ tests/api/client/test_local_client.py | 144 ++++++++ tests/api/common/experimental/__init__.py | 13 + tests/api/common/experimental/mark_tasks.py | 396 ++++++++++++++++++++++ tests/api/common/experimental/test_pool.py | 132 ++++++++ tests/api/common/mark_tasks.py | 396 ---------------------- tests/core.py | 59 +++- tests/www/api/experimental/test_endpoints.py | 153 ++++++++- 16 files changed, 1155 insertions(+), 580 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/client/api_client.py ---------------------------------------------------------------------- diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py index 6a77538..f24d809 100644 --- a/airflow/api/client/api_client.py +++ b/airflow/api/client/api_client.py @@ -14,17 +14,47 @@ # -class Client: +class Client(object): + """Base API client for all API clients.""" + def __init__(self, api_base_url, auth): self._api_base_url = api_base_url self._auth = auth def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): - """ - Creates a dag run for the specified dag + """Create a dag run for the specified dag. + :param dag_id: :param run_id: :param conf: + :param execution_date: :return: """ raise NotImplementedError() + + def get_pool(self, name): + """Get pool. + + :param name: pool name + """ + raise NotImplementedError() + + def get_pools(self): + """Get all pools.""" + raise NotImplementedError() + + def create_pool(self, name, slots, description): + """Create a pool. + + :param name: pool name + :param slots: pool slots amount + :param description: pool description + """ + raise NotImplementedError() + + def delete_pool(self, name): + """Delete pool. + + :param name: pool name + """ + raise NotImplementedError() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/client/json_client.py ---------------------------------------------------------------------- diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py index d74fc63..37e24d3 100644 --- a/airflow/api/client/json_client.py +++ b/airflow/api/client/json_client.py @@ -11,30 +11,70 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from future.moves.urllib.parse import urljoin +import requests from airflow.api.client import api_client -import requests - class Client(api_client.Client): + """Json API client implementation.""" + + def _request(self, url, method='GET', json=None): + params = { + 'url': url, + 'auth': self._auth, + } + if json is not None: + params['json'] = json + + resp = getattr(requests, method.lower())(**params) + if not resp.ok: + try: + data = resp.json() + except Exception: + data = {} + raise IOError(data.get('error', 'Server error')) + + return resp.json() + def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id) url = urljoin(self._api_base_url, endpoint) - - resp = requests.post(url, - auth=self._auth, + data = self._request(url, method='POST', json={ "run_id": run_id, "conf": conf, "execution_date": execution_date, }) + return data['message'] - if not resp.ok: - raise IOError() + def get_pool(self, name): + endpoint = '/api/experimental/pools/{}'.format(name) + url = urljoin(self._api_base_url, endpoint) + pool = self._request(url) + return pool['pool'], pool['slots'], pool['description'] - data = resp.json() + def get_pools(self): + endpoint = '/api/experimental/pools' + url = urljoin(self._api_base_url, endpoint) + pools = self._request(url) + return [(p['pool'], p['slots'], p['description']) for p in pools] - return data['message'] + def create_pool(self, name, slots, description): + endpoint = '/api/experimental/pools' + url = urljoin(self._api_base_url, endpoint) + pool = self._request(url, method='POST', + json={ + 'name': name, + 'slots': slots, + 'description': description, + }) + return pool['pool'], pool['slots'], pool['description'] + + def delete_pool(self, name): + endpoint = '/api/experimental/pools/{}'.format(name) + url = urljoin(self._api_base_url, endpoint) + pool = self._request(url, method='DELETE') + return pool['pool'], pool['slots'], pool['description'] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/client/local_client.py ---------------------------------------------------------------------- diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py index 05f27f6..5bc7f76 100644 --- a/airflow/api/client/local_client.py +++ b/airflow/api/client/local_client.py @@ -11,15 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from airflow.api.client import api_client +from airflow.api.common.experimental import pool from airflow.api.common.experimental import trigger_dag class Client(api_client.Client): + """Local API client implementation.""" + def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None): dr = trigger_dag.trigger_dag(dag_id=dag_id, run_id=run_id, conf=conf, execution_date=execution_date) return "Created {}".format(dr) + + def get_pool(self, name): + p = pool.get_pool(name=name) + return p.pool, p.slots, p.description + + def get_pools(self): + return [(p.pool, p.slots, p.description) for p in pool.get_pools()] + + def create_pool(self, name, slots, description): + p = pool.create_pool(name=name, slots=slots, description=description) + return p.pool, p.slots, p.description + + def delete_pool(self, name): + p = pool.delete_pool(name=name) + return p.pool, p.slots, p.description http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/common/experimental/pool.py ---------------------------------------------------------------------- diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py new file mode 100644 index 0000000..6e963a2 --- /dev/null +++ b/airflow/api/common/experimental/pool.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.exceptions import AirflowException +from airflow.models import Pool +from airflow.utils.db import provide_session + + +class PoolBadRequest(AirflowException): + status = 400 + + +class PoolNotFound(AirflowException): + status = 404 + + +@provide_session +def get_pool(name, session=None): + """Get pool by a given name.""" + if not (name and name.strip()): + raise PoolBadRequest("Pool name shouldn't be empty") + + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound("Pool '%s' doesn't exist" % name) + + return pool + + +@provide_session +def get_pools(session=None): + """Get all pools.""" + return session.query(Pool).all() + + +@provide_session +def create_pool(name, slots, description, session=None): + """Create a pool with a given parameters.""" + if not (name and name.strip()): + raise PoolBadRequest("Pool name shouldn't be empty") + + try: + slots = int(slots) + except ValueError: + raise PoolBadRequest("Bad value for `slots`: %s" % slots) + + session.expire_on_commit = False + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + pool = Pool(pool=name, slots=slots, description=description) + session.add(pool) + else: + pool.slots = slots + pool.description = description + + session.commit() + + return pool + + +@provide_session +def delete_pool(name, session=None): + """Delete pool by a given name.""" + if not (name and name.strip()): + raise PoolBadRequest("Pool name shouldn't be empty") + + pool = session.query(Pool).filter_by(pool=name).first() + if pool is None: + raise PoolNotFound("Pool '%s' doesn't exist" % name) + + session.delete(pool) + session.commit() + + return pool http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/bin/cli.py ---------------------------------------------------------------------- diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index 41f979f..4b3a0ed 100755 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException from airflow.executors import GetDefaultExecutor from airflow.models import (DagModel, DagBag, TaskInstance, DagPickle, DagRun, Variable, DagStat, - Pool, Connection) + Connection) from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS) from airflow.utils import db as db_utils from airflow.utils import logging as logging_utils @@ -187,40 +187,28 @@ def trigger_dag(args): def pool(args): - session = settings.Session() - if args.get or (args.set and args.set[0]) or args.delete: - name = args.get or args.delete or args.set[0] - pool = ( - session.query(Pool) - .filter(Pool.pool == name) - .first()) - if pool and args.get: - print("{} ".format(pool)) - return - elif not pool and (args.get or args.delete): - print("No pool named {} found".format(name)) - elif not pool and args.set: - pool = Pool( - pool=name, - slots=args.set[1], - description=args.set[2]) - session.add(pool) - session.commit() - print("{} ".format(pool)) - elif pool and args.set: - pool.slots = args.set[1] - pool.description = args.set[2] - session.commit() - print("{} ".format(pool)) - return - elif pool and args.delete: - session.query(Pool).filter_by(pool=args.delete).delete() - session.commit() - print("Pool {} deleted".format(name)) + def _tabulate(pools): + return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'], + tablefmt="fancy_grid") + try: + if args.get is not None: + pools = [api_client.get_pool(name=args.get)] + elif args.set: + pools = [api_client.create_pool(name=args.set[0], + slots=args.set[1], + description=args.set[2])] + elif args.delete: + pools = [api_client.delete_pool(name=args.delete)] + else: + pools = api_client.get_pools() + except (AirflowException, IOError) as err: + logging.error(err) + else: + logging.info(_tabulate(pools=pools)) -def variables(args): +def variables(args): if args.get: try: var = Variable.get(args.get, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index 2c433ad..0002572 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -4395,6 +4395,14 @@ class Pool(Base): def __repr__(self): return self.pool + def to_json(self): + return { + 'id': self.id, + 'pool': self.pool, + 'slots': self.slots, + 'description': self.description, + } + @provide_session def used_slots(self, session): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/www/api/experimental/endpoints.py ---------------------------------------------------------------------- diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index be92735..a8d7f5c 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import airflow.api +from airflow.api.common.experimental import pool as pool_api from airflow.api.common.experimental import trigger_dag as trigger from airflow.api.common.experimental.get_task import get_task from airflow.api.common.experimental.get_task_instance import get_task_instance @@ -96,7 +98,6 @@ def test(): @requires_authentication def task_info(dag_id, task_id): """Returns a JSON with a task's public instance variables. """ - try: info = get_task(dag_id, task_id) except AirflowException as err: @@ -169,4 +170,67 @@ def latest_dag_runs(): 'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id, execution_date=dagrun.execution_date) }) - return jsonify(items=payload) # old flask versions dont support jsonifying arrays + return jsonify(items=payload) # old flask versions dont support jsonifying arrays + + +@api_experimental.route('/pools/<string:name>', methods=['GET']) +@requires_authentication +def get_pool(name): + """Get pool by a given name.""" + try: + pool = pool_api.get_pool(name=name) + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify(pool.to_json()) + + +@api_experimental.route('/pools', methods=['GET']) +@requires_authentication +def get_pools(): + """Get all pools.""" + try: + pools = pool_api.get_pools() + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify([p.to_json() for p in pools]) + + [email protected] +@api_experimental.route('/pools', methods=['POST']) +@requires_authentication +def create_pool(): + """Create a pool.""" + params = request.get_json(force=True) + try: + pool = pool_api.create_pool(**params) + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify(pool.to_json()) + + [email protected] +@api_experimental.route('/pools/<string:name>', methods=['DELETE']) +@requires_authentication +def delete_pool(name): + """Delete pool.""" + try: + pool = pool_api.delete_pool(name=name) + except AirflowException as e: + _log.error(e) + response = jsonify(error="{}".format(e)) + response.status_code = getattr(e, 'status', 500) + return response + else: + return jsonify(pool.to_json()) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/__init__.py ---------------------------------------------------------------------- diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 37d59f0..9d7677a 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -11,9 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import - -from .client import * -from .common import * - http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/client/local_client.py ---------------------------------------------------------------------- diff --git a/tests/api/client/local_client.py b/tests/api/client/local_client.py deleted file mode 100644 index a36b71f..0000000 --- a/tests/api/client/local_client.py +++ /dev/null @@ -1,107 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import unittest -import datetime - -from mock import patch - -from airflow import AirflowException -from airflow import models - -from airflow.api.client.local_client import Client -from airflow.utils.state import State - -EXECDATE = datetime.datetime.now() -EXECDATE_NOFRACTIONS = EXECDATE.replace(microsecond=0) -EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat() - -real_datetime_class = datetime.datetime - - -def mock_datetime_now(target, dt): - class DatetimeSubclassMeta(type): - @classmethod - def __instancecheck__(mcs, obj): - return isinstance(obj, real_datetime_class) - - class BaseMockedDatetime(real_datetime_class): - @classmethod - def now(cls, tz=None): - return target.replace(tzinfo=tz) - - @classmethod - def utcnow(cls): - return target - - # Python2 & Python3 compatible metaclass - MockedDatetime = DatetimeSubclassMeta('datetime', (BaseMockedDatetime,), {}) - - return patch.object(dt, 'datetime', MockedDatetime) - - -class TestLocalClient(unittest.TestCase): - def setUp(self): - self.client = Client(api_base_url=None, auth=None) - - @patch.object(models.DAG, 'create_dagrun') - def test_trigger_dag(self, mock): - client = self.client - - # non existent - with self.assertRaises(AirflowException): - client.trigger_dag(dag_id="blablabla") - - import airflow.api.common.experimental.trigger_dag - with mock_datetime_now(EXECDATE, airflow.api.common.experimental.trigger_dag.datetime): - # no execution date, execution date should be set automatically - client.trigger_dag(dag_id="test_start_date_scheduling") - mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO), - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=None, - external_trigger=True) - mock.reset_mock() - - # execution date with microseconds cutoff - client.trigger_dag(dag_id="test_start_date_scheduling", execution_date=EXECDATE) - mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO), - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=None, - external_trigger=True) - mock.reset_mock() - - # run id - run_id = "my_run_id" - client.trigger_dag(dag_id="test_start_date_scheduling", run_id=run_id) - mock.assert_called_once_with(run_id=run_id, - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=None, - external_trigger=True) - mock.reset_mock() - - # test conf - conf = '{"name": "John"}' - client.trigger_dag(dag_id="test_start_date_scheduling", conf=conf) - mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO), - execution_date=EXECDATE_NOFRACTIONS, - state=State.RUNNING, - conf=json.loads(conf), - external_trigger=True) - mock.reset_mock() - - # this is a unit test only, cannot verify existing dag run http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/client/test_local_client.py ---------------------------------------------------------------------- diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py new file mode 100644 index 0000000..7a759fe --- /dev/null +++ b/tests/api/client/test_local_client.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import unittest + +from mock import patch + +from airflow import AirflowException +from airflow.api.client.local_client import Client +from airflow import models +from airflow import settings +from airflow.utils.state import State + +EXECDATE = datetime.datetime.now() +EXECDATE_NOFRACTIONS = EXECDATE.replace(microsecond=0) +EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat() + +real_datetime_class = datetime.datetime + + +def mock_datetime_now(target, dt): + class DatetimeSubclassMeta(type): + @classmethod + def __instancecheck__(mcs, obj): + return isinstance(obj, real_datetime_class) + + class BaseMockedDatetime(real_datetime_class): + @classmethod + def now(cls, tz=None): + return target.replace(tzinfo=tz) + + @classmethod + def utcnow(cls): + return target + + # Python2 & Python3 compatible metaclass + MockedDatetime = DatetimeSubclassMeta('datetime', (BaseMockedDatetime,), {}) + + return patch.object(dt, 'datetime', MockedDatetime) + + +class TestLocalClient(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestLocalClient, cls).setUpClass() + session = settings.Session() + session.query(models.Pool).delete() + session.commit() + session.close() + + def setUp(self): + super(TestLocalClient, self).setUp() + self.client = Client(api_base_url=None, auth=None) + self.session = settings.Session() + + def tearDown(self): + self.session.query(models.Pool).delete() + self.session.commit() + self.session.close() + super(TestLocalClient, self).tearDown() + + @patch.object(models.DAG, 'create_dagrun') + def test_trigger_dag(self, mock): + client = self.client + + # non existent + with self.assertRaises(AirflowException): + client.trigger_dag(dag_id="blablabla") + + import airflow.api.common.experimental.trigger_dag + with mock_datetime_now(EXECDATE, airflow.api.common.experimental.trigger_dag.datetime): + # no execution date, execution date should be set automatically + client.trigger_dag(dag_id="test_start_date_scheduling") + mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO), + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=None, + external_trigger=True) + mock.reset_mock() + + # execution date with microseconds cutoff + client.trigger_dag(dag_id="test_start_date_scheduling", execution_date=EXECDATE) + mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO), + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=None, + external_trigger=True) + mock.reset_mock() + + # run id + run_id = "my_run_id" + client.trigger_dag(dag_id="test_start_date_scheduling", run_id=run_id) + mock.assert_called_once_with(run_id=run_id, + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=None, + external_trigger=True) + mock.reset_mock() + + # test conf + conf = '{"name": "John"}' + client.trigger_dag(dag_id="test_start_date_scheduling", conf=conf) + mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO), + execution_date=EXECDATE_NOFRACTIONS, + state=State.RUNNING, + conf=json.loads(conf), + external_trigger=True) + mock.reset_mock() + + def test_get_pool(self): + self.client.create_pool(name='foo', slots=1, description='') + pool = self.client.get_pool(name='foo') + self.assertEqual(pool, ('foo', 1, '')) + + def test_get_pools(self): + self.client.create_pool(name='foo1', slots=1, description='') + self.client.create_pool(name='foo2', slots=2, description='') + pools = sorted(self.client.get_pools(), key=lambda p: p[0]) + self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')]) + + def test_create_pool(self): + pool = self.client.create_pool(name='foo', slots=1, description='') + self.assertEqual(pool, ('foo', 1, '')) + self.assertEqual(self.session.query(models.Pool).count(), 1) + + def test_delete_pool(self): + self.client.create_pool(name='foo', slots=1, description='') + self.assertEqual(self.session.query(models.Pool).count(), 1) + self.client.delete_pool(name='foo') + self.assertEqual(self.session.query(models.Pool).count(), 0) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/experimental/__init__.py ---------------------------------------------------------------------- diff --git a/tests/api/common/experimental/__init__.py b/tests/api/common/experimental/__init__.py new file mode 100644 index 0000000..9d7677a --- /dev/null +++ b/tests/api/common/experimental/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/experimental/mark_tasks.py ---------------------------------------------------------------------- diff --git a/tests/api/common/experimental/mark_tasks.py b/tests/api/common/experimental/mark_tasks.py new file mode 100644 index 0000000..e4395ae --- /dev/null +++ b/tests/api/common/experimental/mark_tasks.py @@ -0,0 +1,396 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow import models +from airflow.api.common.experimental.mark_tasks import ( + set_state, _create_dagruns, set_dag_run_state) +from airflow.settings import Session +from airflow.utils.dates import days_ago +from airflow.utils.state import State +from datetime import datetime, timedelta + +DEV_NULL = "/dev/null" + + +class TestMarkTasks(unittest.TestCase): + + def setUp(self): + self.dagbag = models.DagBag(include_examples=True) + self.dag1 = self.dagbag.dags['test_example_bash_operator'] + self.dag2 = self.dagbag.dags['example_subdag_operator'] + + self.execution_dates = [days_ago(2), days_ago(1)] + + drs = _create_dagruns(self.dag1, self.execution_dates, + state=State.RUNNING, + run_id_template="scheduled__{}") + for dr in drs: + dr.dag = self.dag1 + dr.verify_integrity() + + drs = _create_dagruns(self.dag2, + [self.dag2.default_args['start_date']], + state=State.RUNNING, + run_id_template="scheduled__{}") + + for dr in drs: + dr.dag = self.dag2 + dr.verify_integrity() + + self.session = Session() + + def tearDown(self): + self.dag1.clear() + self.dag2.clear() + + # just to make sure we are fully cleaned up + self.session.query(models.DagRun).delete() + self.session.query(models.TaskInstance).delete() + self.session.commit() + self.session.close() + + def snapshot_state(self, dag, execution_dates): + TI = models.TaskInstance + tis = self.session.query(TI).filter( + TI.dag_id==dag.dag_id, + TI.execution_date.in_(execution_dates) + ).all() + + self.session.expunge_all() + + return tis + + def verify_state(self, dag, task_ids, execution_dates, state, old_tis): + TI = models.TaskInstance + + tis = self.session.query(TI).filter( + TI.dag_id==dag.dag_id, + TI.execution_date.in_(execution_dates) + ).all() + + self.assertTrue(len(tis) > 0) + + for ti in tis: + if ti.task_id in task_ids and ti.execution_date in execution_dates: + self.assertEqual(ti.state, state) + else: + for old_ti in old_tis: + if (old_ti.task_id == ti.task_id + and old_ti.execution_date == ti.execution_date): + self.assertEqual(ti.state, old_ti.state) + + def test_mark_tasks_now(self): + # set one task to success but do not commit + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=False) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + None, snapshot) + + # set one and only one task to success + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.SUCCESS, snapshot) + + # set no tasks + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 0) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.SUCCESS, snapshot) + + # set task to other than success + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.FAILED, commit=True) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.FAILED, snapshot) + + # dont alter other tasks + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_0") + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.SUCCESS, snapshot) + + def test_mark_downstream(self): + # test downstream + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + relatives = task.get_flat_relatives(upstream=False) + task_ids = [t.task_id for t in relatives] + task_ids.append(task.task_id) + + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=True, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 3) + self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], + State.SUCCESS, snapshot) + + def test_mark_upstream(self): + # test upstream + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("run_after_loop") + relatives = task.get_flat_relatives(upstream=True) + task_ids = [t.task_id for t in relatives] + task_ids.append(task.task_id) + + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=True, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 4) + self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], + State.SUCCESS, snapshot) + + def test_mark_tasks_future(self): + # set one task to success towards end of scheduled dag runs + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=True, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 2) + self.verify_state(self.dag1, [task.task_id], self.execution_dates, + State.SUCCESS, snapshot) + + def test_mark_tasks_past(self): + # set one task to success towards end of scheduled dag runs + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + altered = set_state(task=task, execution_date=self.execution_dates[1], + upstream=False, downstream=False, future=False, + past=True, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 2) + self.verify_state(self.dag1, [task.task_id], self.execution_dates, + State.SUCCESS, snapshot) + + def test_mark_tasks_subdag(self): + # set one task to success towards end of scheduled dag runs + task = self.dag2.get_task("section-1") + relatives = task.get_flat_relatives(upstream=False) + task_ids = [t.task_id for t in relatives] + task_ids.append(task.task_id) + + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=True, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 14) + + # cannot use snapshot here as that will require drilling down the + # the sub dag tree essentially recreating the same code as in the + # tested logic. + self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], + State.SUCCESS, []) + + +class TestMarkDAGRun(unittest.TestCase): + def setUp(self): + self.dagbag = models.DagBag(include_examples=True) + self.dag1 = self.dagbag.dags['test_example_bash_operator'] + self.dag2 = self.dagbag.dags['example_subdag_operator'] + + self.execution_dates = [days_ago(3), days_ago(2), days_ago(1)] + + self.session = Session() + + def verify_dag_run_states(self, dag, date, state=State.SUCCESS): + drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date) + dr = drs[0] + self.assertEqual(dr.get_state(), state) + tis = dr.get_task_instances(session=self.session) + for ti in tis: + self.assertEqual(ti.state, state) + + def test_set_running_dag_run_state(self): + date = self.execution_dates[0] + dr = self.dag1.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.RUNNING, + execution_date=date, + session=self.session + ) + for ti in dr.get_task_instances(session=self.session): + ti.set_state(State.RUNNING, self.session) + + altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) + + # All of the task should be altered + self.assertEqual(len(altered), len(self.dag1.tasks)) + self.verify_dag_run_states(self.dag1, date) + + def test_set_success_dag_run_state(self): + date = self.execution_dates[0] + + dr = self.dag1.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.SUCCESS, + execution_date=date, + session=self.session + ) + for ti in dr.get_task_instances(session=self.session): + ti.set_state(State.SUCCESS, self.session) + + altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) + + # None of the task should be altered + self.assertEqual(len(altered), 0) + self.verify_dag_run_states(self.dag1, date) + + def test_set_failed_dag_run_state(self): + date = self.execution_dates[0] + dr = self.dag1.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.FAILED, + execution_date=date, + session=self.session + ) + dr.get_task_instance('runme_0').set_state(State.FAILED, self.session) + + altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) + + # All of the task should be altered + self.assertEqual(len(altered), len(self.dag1.tasks)) + self.verify_dag_run_states(self.dag1, date) + + def test_set_mixed_dag_run_state(self): + """ + This test checks function set_dag_run_state with mixed task instance + state. + """ + date = self.execution_dates[0] + dr = self.dag1.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.FAILED, + execution_date=date, + session=self.session + ) + # success task + dr.get_task_instance('runme_0').set_state(State.SUCCESS, self.session) + # skipped task + dr.get_task_instance('runme_1').set_state(State.SKIPPED, self.session) + # retry task + dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY, self.session) + # queued task + dr.get_task_instance('also_run_this').set_state(State.QUEUED, self.session) + # running task + dr.get_task_instance('run_after_loop').set_state(State.RUNNING, self.session) + # failed task + dr.get_task_instance('run_this_last').set_state(State.FAILED, self.session) + + altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) + + self.assertEqual(len(altered), len(self.dag1.tasks) - 1) # only 1 task succeeded + self.verify_dag_run_states(self.dag1, date) + + def test_set_state_without_commit(self): + date = self.execution_dates[0] + + # Running dag run and task instances + dr = self.dag1.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.RUNNING, + execution_date=date, + session=self.session + ) + for ti in dr.get_task_instances(session=self.session): + ti.set_state(State.RUNNING, self.session) + + altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=False) + + # All of the task should be altered + self.assertEqual(len(altered), len(self.dag1.tasks)) + + # Both dag run and task instances' states should remain the same + self.verify_dag_run_states(self.dag1, date, State.RUNNING) + + def test_set_state_with_multiple_dagruns(self): + dr1 = self.dag2.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.FAILED, + execution_date=self.execution_dates[0], + session=self.session + ) + dr2 = self.dag2.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.FAILED, + execution_date=self.execution_dates[1], + session=self.session + ) + dr3 = self.dag2.create_dagrun( + run_id='manual__' + datetime.now().isoformat(), + state=State.RUNNING, + execution_date=self.execution_dates[2], + session=self.session + ) + + altered = set_dag_run_state(self.dag2, self.execution_dates[1], + state=State.SUCCESS, commit=True) + + # Recursively count number of tasks in the dag + def count_dag_tasks(dag): + count = len(dag.tasks) + subdag_counts = [count_dag_tasks(subdag) for subdag in dag.subdags] + count += sum(subdag_counts) + return count + + self.assertEqual(len(altered), count_dag_tasks(self.dag2)) + self.verify_dag_run_states(self.dag2, self.execution_dates[1]) + + # Make sure other dag status are not changed + dr1 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0]) + dr1 = dr1[0] + self.assertEqual(dr1.get_state(), State.FAILED) + dr3 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2]) + dr3 = dr3[0] + self.assertEqual(dr3.get_state(), State.RUNNING) + + def test_set_dag_run_state_edge_cases(self): + # Dag does not exist + altered = set_dag_run_state(None, self.execution_dates[0]) + self.assertEqual(len(altered), 0) + + # Invalid execution date + altered = set_dag_run_state(self.dag1, None) + self.assertEqual(len(altered), 0) + self.assertRaises(AssertionError, set_dag_run_state, self.dag1, timedelta(microseconds=-1)) + + # DagRun does not exist + # This will throw AssertionError since dag.latest_execution_date does not exist + self.assertRaises(AssertionError, set_dag_run_state, self.dag1, self.execution_dates[0]) + + def tearDown(self): + self.dag1.clear() + self.dag2.clear() + + self.session.query(models.DagRun).delete() + self.session.query(models.TaskInstance).delete() + self.session.query(models.DagStat).delete() + self.session.commit() + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/experimental/test_pool.py ---------------------------------------------------------------------- diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py new file mode 100644 index 0000000..98969b8 --- /dev/null +++ b/tests/api/common/experimental/test_pool.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow.api.common.experimental import pool as pool_api +from airflow import models +from airflow import settings + + +class TestPool(unittest.TestCase): + + def setUp(self): + super(TestPool, self).setUp() + self.session = settings.Session() + self.pools = [] + for i in range(2): + name = 'experimental_%s' % (i + 1) + pool = models.Pool( + pool=name, + slots=i, + description=name, + ) + self.session.add(pool) + self.pools.append(pool) + self.session.commit() + + def tearDown(self): + self.session.query(models.Pool).delete() + self.session.commit() + self.session.close() + super(TestPool, self).tearDown() + + def test_get_pool(self): + pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session) + self.assertEqual(pool.pool, self.pools[0].pool) + + def test_get_pool_non_existing(self): + self.assertRaisesRegexp(pool_api.PoolNotFound, + "^Pool 'test' doesn't exist$", + pool_api.get_pool, + name='test', + session=self.session) + + def test_get_pool_bad_name(self): + for name in ('', ' '): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Pool name shouldn't be empty$", + pool_api.get_pool, + name=name, + session=self.session) + + def test_get_pools(self): + pools = sorted(pool_api.get_pools(session=self.session), + key=lambda p: p.pool) + self.assertEqual(pools[0].pool, self.pools[0].pool) + self.assertEqual(pools[1].pool, self.pools[1].pool) + + def test_create_pool(self): + pool = pool_api.create_pool(name='foo', + slots=5, + description='', + session=self.session) + self.assertEqual(pool.pool, 'foo') + self.assertEqual(pool.slots, 5) + self.assertEqual(pool.description, '') + self.assertEqual(self.session.query(models.Pool).count(), 3) + + def test_create_pool_existing(self): + pool = pool_api.create_pool(name=self.pools[0].pool, + slots=5, + description='', + session=self.session) + self.assertEqual(pool.pool, self.pools[0].pool) + self.assertEqual(pool.slots, 5) + self.assertEqual(pool.description, '') + self.assertEqual(self.session.query(models.Pool).count(), 2) + + def test_create_pool_bad_name(self): + for name in ('', ' '): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Pool name shouldn't be empty$", + pool_api.create_pool, + name=name, + slots=5, + description='', + session=self.session) + + def test_create_pool_bad_slots(self): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Bad value for `slots`: foo$", + pool_api.create_pool, + name='foo', + slots='foo', + description='', + session=self.session) + + def test_delete_pool(self): + pool = pool_api.delete_pool(name=self.pools[0].pool, + session=self.session) + self.assertEqual(pool.pool, self.pools[0].pool) + self.assertEqual(self.session.query(models.Pool).count(), 1) + + def test_delete_pool_non_existing(self): + self.assertRaisesRegexp(pool_api.PoolNotFound, + "^Pool 'test' doesn't exist$", + pool_api.delete_pool, + name='test', + session=self.session) + + def test_delete_pool_bad_name(self): + for name in ('', ' '): + self.assertRaisesRegexp(pool_api.PoolBadRequest, + "^Pool name shouldn't be empty$", + pool_api.delete_pool, + name=name, + session=self.session) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/mark_tasks.py ---------------------------------------------------------------------- diff --git a/tests/api/common/mark_tasks.py b/tests/api/common/mark_tasks.py deleted file mode 100644 index 8a3759f..0000000 --- a/tests/api/common/mark_tasks.py +++ /dev/null @@ -1,396 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import unittest - -from airflow import models -from airflow.api.common.experimental.mark_tasks import ( - set_state, _create_dagruns, set_dag_run_state) -from airflow.settings import Session -from airflow.utils.dates import days_ago -from airflow.utils.state import State -from datetime import datetime, timedelta - -DEV_NULL = "/dev/null" - - -class TestMarkTasks(unittest.TestCase): - def setUp(self): - self.dagbag = models.DagBag(include_examples=True) - self.dag1 = self.dagbag.dags['test_example_bash_operator'] - self.dag2 = self.dagbag.dags['example_subdag_operator'] - - self.execution_dates = [days_ago(2), days_ago(1)] - - drs = _create_dagruns(self.dag1, self.execution_dates, - state=State.RUNNING, - run_id_template="scheduled__{}") - for dr in drs: - dr.dag = self.dag1 - dr.verify_integrity() - - drs = _create_dagruns(self.dag2, - [self.dag2.default_args['start_date']], - state=State.RUNNING, - run_id_template="scheduled__{}") - - for dr in drs: - dr.dag = self.dag2 - dr.verify_integrity() - - self.session = Session() - - def snapshot_state(self, dag, execution_dates): - TI = models.TaskInstance - tis = self.session.query(TI).filter( - TI.dag_id==dag.dag_id, - TI.execution_date.in_(execution_dates) - ).all() - - self.session.expunge_all() - - return tis - - def verify_state(self, dag, task_ids, execution_dates, state, old_tis): - TI = models.TaskInstance - - tis = self.session.query(TI).filter( - TI.dag_id==dag.dag_id, - TI.execution_date.in_(execution_dates) - ).all() - - self.assertTrue(len(tis) > 0) - - for ti in tis: - if ti.task_id in task_ids and ti.execution_date in execution_dates: - self.assertEqual(ti.state, state) - else: - for old_ti in old_tis: - if (old_ti.task_id == ti.task_id - and old_ti.execution_date == ti.execution_date): - self.assertEqual(ti.state, old_ti.state) - - def test_mark_tasks_now(self): - # set one task to success but do not commit - snapshot = self.snapshot_state(self.dag1, self.execution_dates) - task = self.dag1.get_task("runme_1") - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=False) - self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - None, snapshot) - - # set one and only one task to success - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.SUCCESS, snapshot) - - # set no tasks - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 0) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.SUCCESS, snapshot) - - # set task to other than success - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.FAILED, commit=True) - self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.FAILED, snapshot) - - # dont alter other tasks - snapshot = self.snapshot_state(self.dag1, self.execution_dates) - task = self.dag1.get_task("runme_0") - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 1) - self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], - State.SUCCESS, snapshot) - - def test_mark_downstream(self): - # test downstream - snapshot = self.snapshot_state(self.dag1, self.execution_dates) - task = self.dag1.get_task("runme_1") - relatives = task.get_flat_relatives(upstream=False) - task_ids = [t.task_id for t in relatives] - task_ids.append(task.task_id) - - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=True, future=False, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 3) - self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], - State.SUCCESS, snapshot) - - def test_mark_upstream(self): - # test upstream - snapshot = self.snapshot_state(self.dag1, self.execution_dates) - task = self.dag1.get_task("run_after_loop") - relatives = task.get_flat_relatives(upstream=True) - task_ids = [t.task_id for t in relatives] - task_ids.append(task.task_id) - - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=True, downstream=False, future=False, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 4) - self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], - State.SUCCESS, snapshot) - - def test_mark_tasks_future(self): - # set one task to success towards end of scheduled dag runs - snapshot = self.snapshot_state(self.dag1, self.execution_dates) - task = self.dag1.get_task("runme_1") - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=False, future=True, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 2) - self.verify_state(self.dag1, [task.task_id], self.execution_dates, - State.SUCCESS, snapshot) - - def test_mark_tasks_past(self): - # set one task to success towards end of scheduled dag runs - snapshot = self.snapshot_state(self.dag1, self.execution_dates) - task = self.dag1.get_task("runme_1") - altered = set_state(task=task, execution_date=self.execution_dates[1], - upstream=False, downstream=False, future=False, - past=True, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 2) - self.verify_state(self.dag1, [task.task_id], self.execution_dates, - State.SUCCESS, snapshot) - - def test_mark_tasks_subdag(self): - # set one task to success towards end of scheduled dag runs - task = self.dag2.get_task("section-1") - relatives = task.get_flat_relatives(upstream=False) - task_ids = [t.task_id for t in relatives] - task_ids.append(task.task_id) - - altered = set_state(task=task, execution_date=self.execution_dates[0], - upstream=False, downstream=True, future=False, - past=False, state=State.SUCCESS, commit=True) - self.assertEqual(len(altered), 14) - - # cannot use snapshot here as that will require drilling down the - # the sub dag tree essentially recreating the same code as in the - # tested logic. - self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], - State.SUCCESS, []) - - def tearDown(self): - self.dag1.clear() - self.dag2.clear() - - # just to make sure we are fully cleaned up - self.session.query(models.DagRun).delete() - self.session.query(models.TaskInstance).delete() - self.session.commit() - - self.session.close() - -class TestMarkDAGRun(unittest.TestCase): - def setUp(self): - self.dagbag = models.DagBag(include_examples=True) - self.dag1 = self.dagbag.dags['test_example_bash_operator'] - self.dag2 = self.dagbag.dags['example_subdag_operator'] - - self.execution_dates = [days_ago(3), days_ago(2), days_ago(1)] - - self.session = Session() - - def verify_dag_run_states(self, dag, date, state=State.SUCCESS): - drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date) - dr = drs[0] - self.assertEqual(dr.get_state(), state) - tis = dr.get_task_instances(session=self.session) - for ti in tis: - self.assertEqual(ti.state, state) - - def test_set_running_dag_run_state(self): - date = self.execution_dates[0] - dr = self.dag1.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.RUNNING, - execution_date=date, - session=self.session - ) - for ti in dr.get_task_instances(session=self.session): - ti.set_state(State.RUNNING, self.session) - - altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) - - # All of the task should be altered - self.assertEqual(len(altered), len(self.dag1.tasks)) - self.verify_dag_run_states(self.dag1, date) - - def test_set_success_dag_run_state(self): - date = self.execution_dates[0] - - dr = self.dag1.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.SUCCESS, - execution_date=date, - session=self.session - ) - for ti in dr.get_task_instances(session=self.session): - ti.set_state(State.SUCCESS, self.session) - - altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) - - # None of the task should be altered - self.assertEqual(len(altered), 0) - self.verify_dag_run_states(self.dag1, date) - - def test_set_failed_dag_run_state(self): - date = self.execution_dates[0] - dr = self.dag1.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.FAILED, - execution_date=date, - session=self.session - ) - dr.get_task_instance('runme_0').set_state(State.FAILED, self.session) - - altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) - - # All of the task should be altered - self.assertEqual(len(altered), len(self.dag1.tasks)) - self.verify_dag_run_states(self.dag1, date) - - def test_set_mixed_dag_run_state(self): - """ - This test checks function set_dag_run_state with mixed task instance - state. - """ - date = self.execution_dates[0] - dr = self.dag1.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.FAILED, - execution_date=date, - session=self.session - ) - # success task - dr.get_task_instance('runme_0').set_state(State.SUCCESS, self.session) - # skipped task - dr.get_task_instance('runme_1').set_state(State.SKIPPED, self.session) - # retry task - dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY, self.session) - # queued task - dr.get_task_instance('also_run_this').set_state(State.QUEUED, self.session) - # running task - dr.get_task_instance('run_after_loop').set_state(State.RUNNING, self.session) - # failed task - dr.get_task_instance('run_this_last').set_state(State.FAILED, self.session) - - altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True) - - self.assertEqual(len(altered), len(self.dag1.tasks) - 1) # only 1 task succeeded - self.verify_dag_run_states(self.dag1, date) - - def test_set_state_without_commit(self): - date = self.execution_dates[0] - - # Running dag run and task instances - dr = self.dag1.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.RUNNING, - execution_date=date, - session=self.session - ) - for ti in dr.get_task_instances(session=self.session): - ti.set_state(State.RUNNING, self.session) - - altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=False) - - # All of the task should be altered - self.assertEqual(len(altered), len(self.dag1.tasks)) - - # Both dag run and task instances' states should remain the same - self.verify_dag_run_states(self.dag1, date, State.RUNNING) - - def test_set_state_with_multiple_dagruns(self): - dr1 = self.dag2.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.FAILED, - execution_date=self.execution_dates[0], - session=self.session - ) - dr2 = self.dag2.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.FAILED, - execution_date=self.execution_dates[1], - session=self.session - ) - dr3 = self.dag2.create_dagrun( - run_id='manual__' + datetime.now().isoformat(), - state=State.RUNNING, - execution_date=self.execution_dates[2], - session=self.session - ) - - altered = set_dag_run_state(self.dag2, self.execution_dates[1], - state=State.SUCCESS, commit=True) - - # Recursively count number of tasks in the dag - def count_dag_tasks(dag): - count = len(dag.tasks) - subdag_counts = [count_dag_tasks(subdag) for subdag in dag.subdags] - count += sum(subdag_counts) - return count - - self.assertEqual(len(altered), count_dag_tasks(self.dag2)) - self.verify_dag_run_states(self.dag2, self.execution_dates[1]) - - # Make sure other dag status are not changed - dr1 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0]) - dr1 = dr1[0] - self.assertEqual(dr1.get_state(), State.FAILED) - dr3 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2]) - dr3 = dr3[0] - self.assertEqual(dr3.get_state(), State.RUNNING) - - def test_set_dag_run_state_edge_cases(self): - # Dag does not exist - altered = set_dag_run_state(None, self.execution_dates[0]) - self.assertEqual(len(altered), 0) - - # Invalid execution date - altered = set_dag_run_state(self.dag1, None) - self.assertEqual(len(altered), 0) - self.assertRaises(AssertionError, set_dag_run_state, self.dag1, timedelta(microseconds=-1)) - - # DagRun does not exist - # This will throw AssertionError since dag.latest_execution_date does not exist - self.assertRaises(AssertionError, set_dag_run_state, self.dag1, self.execution_dates[0]) - - def tearDown(self): - self.dag1.clear() - self.dag2.clear() - - self.session.query(models.DagRun).delete() - self.session.query(models.TaskInstance).delete() - self.session.query(models.DagStat).delete() - self.session.commit() - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/core.py ---------------------------------------------------------------------- diff --git a/tests/core.py b/tests/core.py index 8ccd4e7..259b61d 100644 --- a/tests/core.py +++ b/tests/core.py @@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase): class CliTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(CliTests, cls).setUpClass() + cls._cleanup() + def setUp(self): + super(CliTests, self).setUp() configuration.load_test_config() app = application.create_app() app.config['TESTING'] = True self.parser = cli.CLIFactory.get_parser() self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True) + self.session = Session() + + def tearDown(self): + self._cleanup(session=self.session) + super(CliTests, self).tearDown() + + @staticmethod + def _cleanup(session=None): + if session is None: + session = Session() + + session.query(models.Pool).delete() + session.query(models.Variable).delete() + session.commit() + session.close() def test_cli_list_dags(self): args = self.parser.parse_args(['list_dags', '--report']) @@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase): cli.connections(self.parser.parse_args(['connections', '--list'])) stdout = mock_stdout.getvalue() conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]] - for ii, line in enumerate(stdout.split('\n')) - if ii % 2 == 1] + for ii, line in enumerate(stdout.split('\n')) + if ii % 2 == 1] conns = [conn for conn in conns if len(conn) > 0] # Assert that some of the connections are present in the output as @@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase): '-c', 'NOT JSON']) ) - def test_pool(self): - # Checks if all subcommands are properly received - cli.pool(self.parser.parse_args([ - 'pool', '-s', 'foo', '1', '"my foo pool"'])) - cli.pool(self.parser.parse_args([ - 'pool', '-g', 'foo'])) - cli.pool(self.parser.parse_args([ - 'pool', '-x', 'foo'])) + def test_pool_create(self): + cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test'])) + self.assertEqual(self.session.query(models.Pool).count(), 1) + + def test_pool_get(self): + cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test'])) + try: + cli.pool(self.parser.parse_args(['pool', '-g', 'foo'])) + except Exception as e: + self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e) + + def test_pool_delete(self): + cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test'])) + cli.pool(self.parser.parse_args(['pool', '-x', 'foo'])) + self.assertEqual(self.session.query(models.Pool).count(), 0) + + def test_pool_no_args(self): + try: + cli.pool(self.parser.parse_args(['pool'])) + except Exception as e: + self.fail("The 'pool' command raised unexpectedly: %s" % e) def test_variables(self): # Checks if all subcommands are properly received @@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase): self.assertEqual('original', models.Variable.get('bar')) self.assertEqual('{"foo": "bar"}', models.Variable.get('foo')) - session = settings.Session() - session.query(Variable).delete() - session.commit() - session.close() os.remove('variables1.json') os.remove('variables2.json') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/www/api/experimental/test_endpoints.py ---------------------------------------------------------------------- diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py index dacee32..65a6f75 100644 --- a/tests/www/api/experimental/test_endpoints.py +++ b/tests/www/api/experimental/test_endpoints.py @@ -19,22 +19,35 @@ from urllib.parse import quote_plus from airflow import configuration from airflow.api.common.experimental.trigger_dag import trigger_dag -from airflow.models import DagBag, DagRun, TaskInstance +from airflow.models import DagBag, DagRun, Pool, TaskInstance from airflow.settings import Session from airflow.www import app as application -class ApiExperimentalTests(unittest.TestCase): +class TestApiExperimental(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestApiExperimental, cls).setUpClass() + session = Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.commit() + session.close() def setUp(self): + super(TestApiExperimental, self).setUp() configuration.load_test_config() app = application.create_app(testing=True) self.app = app.test_client() + + def tearDown(self): session = Session() session.query(DagRun).delete() session.query(TaskInstance).delete() session.commit() session.close() + super(TestApiExperimental, self).tearDown() def test_task_info(self): url_template = '/api/experimental/dags/{}/tasks/{}' @@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase): url_template = '/api/experimental/dags/{}/dag_runs' response = self.app.post( url_template.format('example_bash_operator'), - data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())), + data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}), content_type="application/json" ) @@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase): response = self.app.post( url_template.format('does_not_exist_dag'), - data=json.dumps(dict()), + data=json.dumps({}), content_type="application/json" ) self.assertEqual(404, response.status_code) @@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase): # Test Correct execution response = self.app.post( url_template.format(dag_id), - data=json.dumps(dict(execution_date=datetime_string)), + data=json.dumps({'execution_date': datetime_string}), content_type="application/json" ) self.assertEqual(200, response.status_code) @@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase): # Test error for nonexistent dag response = self.app.post( url_template.format('does_not_exist_dag'), - data=json.dumps(dict(execution_date=execution_date.isoformat())), + data=json.dumps({'execution_date': execution_date.isoformat()}), content_type="application/json" ) self.assertEqual(404, response.status_code) @@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase): # Test error for bad datetime format response = self.app.post( url_template.format(dag_id), - data=json.dumps(dict(execution_date='not_a_datetime')), + data=json.dumps({'execution_date': 'not_a_datetime'}), content_type="application/json" ) self.assertEqual(400, response.status_code) @@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase): task_id = 'also_run_this' execution_date = datetime.now().replace(microsecond=0) datetime_string = quote_plus(execution_date.isoformat()) - wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat()) + wrong_datetime_string = quote_plus( + datetime(1990, 1, 1, 1, 1, 1).isoformat() + ) # Create DagRun trigger_dag(dag_id=dag_id, @@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase): # Test error for nonexistent dag response = self.app.get( - url_template.format('does_not_exist_dag', datetime_string, task_id), + url_template.format('does_not_exist_dag', datetime_string, + task_id), ) self.assertEqual(404, response.status_code) self.assertIn('error', response.data.decode('utf-8')) @@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase): ) self.assertEqual(400, response.status_code) self.assertIn('error', response.data.decode('utf-8')) + + +class TestPoolApiExperimental(unittest.TestCase): + + @classmethod + def setUpClass(cls): + super(TestPoolApiExperimental, cls).setUpClass() + session = Session() + session.query(Pool).delete() + session.commit() + session.close() + + def setUp(self): + super(TestPoolApiExperimental, self).setUp() + configuration.load_test_config() + app = application.create_app(testing=True) + self.app = app.test_client() + self.session = Session() + self.pools = [] + for i in range(2): + name = 'experimental_%s' % (i + 1) + pool = Pool( + pool=name, + slots=i, + description=name, + ) + self.session.add(pool) + self.pools.append(pool) + self.session.commit() + self.pool = self.pools[0] + + def tearDown(self): + self.session.query(Pool).delete() + self.session.commit() + self.session.close() + super(TestPoolApiExperimental, self).tearDown() + + def _get_pool_count(self): + response = self.app.get('/api/experimental/pools') + self.assertEqual(response.status_code, 200) + return len(json.loads(response.data.decode('utf-8'))) + + def test_get_pool(self): + response = self.app.get( + '/api/experimental/pools/{}'.format(self.pool.pool), + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.data.decode('utf-8')), + self.pool.to_json()) + + def test_get_pool_non_existing(self): + response = self.app.get('/api/experimental/pools/foo') + self.assertEqual(response.status_code, 404) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "Pool 'foo' doesn't exist") + + def test_get_pools(self): + response = self.app.get('/api/experimental/pools') + self.assertEqual(response.status_code, 200) + pools = json.loads(response.data.decode('utf-8')) + self.assertEqual(len(pools), 2) + for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])): + self.assertDictEqual(pool, self.pools[i].to_json()) + + def test_create_pool(self): + response = self.app.post( + '/api/experimental/pools', + data=json.dumps({ + 'name': 'foo', + 'slots': 1, + 'description': '', + }), + content_type='application/json', + ) + self.assertEqual(response.status_code, 200) + pool = json.loads(response.data.decode('utf-8')) + self.assertEqual(pool['pool'], 'foo') + self.assertEqual(pool['slots'], 1) + self.assertEqual(pool['description'], '') + self.assertEqual(self._get_pool_count(), 3) + + def test_create_pool_with_bad_name(self): + for name in ('', ' '): + response = self.app.post( + '/api/experimental/pools', + data=json.dumps({ + 'name': name, + 'slots': 1, + 'description': '', + }), + content_type='application/json', + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + json.loads(response.data.decode('utf-8'))['error'], + "Pool name shouldn't be empty", + ) + self.assertEqual(self._get_pool_count(), 2) + + def test_delete_pool(self): + response = self.app.delete( + '/api/experimental/pools/{}'.format(self.pool.pool), + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(json.loads(response.data.decode('utf-8')), + self.pool.to_json()) + self.assertEqual(self._get_pool_count(), 1) + + def test_delete_pool_non_existing(self): + response = self.app.delete( + '/api/experimental/pools/foo', + ) + self.assertEqual(response.status_code, 404) + self.assertEqual(json.loads(response.data.decode('utf-8'))['error'], + "Pool 'foo' doesn't exist") + + +if __name__ == '__main__': + unittest.main()
