This is an automated email from the ASF dual-hosted git repository. turbaszek pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new 5eb2808 Add read-only Task endpoint (#9330) 5eb2808 is described below commit 5eb2808da1e33cb7a72fad8693d75e8a21401828 Author: Tomek Urbaszek <turbas...@gmail.com> AuthorDate: Thu Jun 25 12:11:11 2020 +0200 Add read-only Task endpoint (#9330) Add API endpoints for tasks and DAG details Co-authored-by: Kamil BreguĊa <kamil.breg...@polidea.com> --- airflow/api_connexion/endpoints/dag_endpoint.py | 12 +- airflow/api_connexion/endpoints/task_endpoint.py | 28 +++- airflow/api_connexion/openapi/v1.yaml | 29 +++- airflow/api_connexion/schemas/common_schema.py | 169 +++++++++++++++++++++ airflow/api_connexion/schemas/dag_schema.py | 93 ++++++++++++ airflow/api_connexion/schemas/task_schema.py | 80 ++++++++++ requirements/requirements-python3.6.txt | 1 + requirements/requirements-python3.7.txt | 1 + requirements/requirements-python3.8.txt | 3 +- requirements/setup-3.6.md5 | 2 +- requirements/setup-3.7.md5 | 2 +- requirements/setup-3.8.md5 | 2 +- setup.py | 1 + tests/api_connexion/endpoints/test_dag_endpoint.py | 91 ++++++++++- .../api_connexion/endpoints/test_task_endpoint.py | 149 +++++++++++++++++- tests/api_connexion/schemas/test_common_schema.py | 145 ++++++++++++++++++ tests/api_connexion/schemas/test_dag_schema.py | 123 +++++++++++++++ tests/api_connexion/schemas/test_task_schema.py | 99 ++++++++++++ tests/cli/commands/test_dag_command.py | 7 + tests/test_utils/db.py | 6 + 20 files changed, 1021 insertions(+), 22 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 7cbb0ef..7cdeeb6 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -15,10 +15,15 @@ # specific language governing permissions and limitations # under the License. +from flask import current_app + +from airflow import DAG +from airflow.api_connexion.exceptions import NotFound # TODO(mik-laj): We have to implement it. # Do you want to help? Please look at: # * https://github.com/apache/airflow/issues/8128 # * https://github.com/apache/airflow/issues/8138 +from airflow.api_connexion.schemas.dag_schema import dag_detail_schema def get_dag(): @@ -28,11 +33,14 @@ def get_dag(): raise NotImplementedError("Not implemented yet.") -def get_dag_details(): +def get_dag_details(dag_id): """ Get details of DAG. """ - raise NotImplementedError("Not implemented yet.") + dag: DAG = current_app.dag_bag.get_dag(dag_id) + if not dag: + raise NotFound("DAG not found") + return dag_detail_schema.dump(dag) def get_dags(): diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index de7eaa4..e23483a 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -14,20 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from flask import current_app -# TODO(mik-laj): We have to implement it. -# Do you want to help? Please look at: https://github.com/apache/airflow/issues/8138 +from airflow import DAG +from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema +from airflow.exceptions import TaskNotFound -def get_task(): +def get_task(dag_id, task_id): """ Get simplified representation of a task. """ - raise NotImplementedError("Not implemented yet.") + dag: DAG = current_app.dag_bag.get_dag(dag_id) + if not dag: + raise NotFound("DAG not found") + try: + task = dag.get_task(task_id=task_id) + except TaskNotFound: + raise NotFound("Task not found") + return task_schema.dump(task) -def get_tasks(): + +def get_tasks(dag_id): """ Get tasks for DAG """ - raise NotImplementedError("Not implemented yet.") + dag: DAG = current_app.dag_bag.get_dag(dag_id) + if not dag: + raise NotFound("DAG not found") + tasks = dag.tasks + task_collection = TaskCollection(tasks=tasks, total_entries=len(tasks)) + return task_collection_schema.dump(task_collection) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index e6ab5f6..1794b65 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -1234,8 +1234,10 @@ components: root_dag_id: type: string readOnly: true + nullable: true is_paused: type: boolean + nullable: true is_subdag: type: boolean readOnly: true @@ -1257,11 +1259,13 @@ components: description: type: string readOnly: true + nullable: true schedule_interval: $ref: '#/components/schemas/ScheduleInterval' readOnly: true tags: type: array + nullable: true items: $ref: '#/components/schemas/Tag' readOnly: true @@ -1638,6 +1642,7 @@ components: format: 'date-time' readOnly: true dag_run_timeout: + nullable: true $ref: '#/components/schemas/TimeDelta' doc_md: type: string @@ -1685,6 +1690,7 @@ components: type: string format: 'date-time' readOnly: true + nullable: true trigger_rule: $ref: '#/components/schemas/TriggerRule' extra_links: @@ -1715,8 +1721,10 @@ components: readOnly: true execution_timeout: $ref: '#/components/schemas/TimeDelta' + nullable: true retry_delay: $ref: '#/components/schemas/TimeDelta' + nullable: true retry_exponential_backoff: type: boolean readOnly: true @@ -2003,17 +2011,35 @@ components: type: object required: - __type + - days + - seconds + - microseconds properties: __type: {type: string} days: {type: integer} seconds: {type: integer} - microsecond: {type: integer} + microseconds: {type: integer} RelativeDelta: # TODO: Why we need these fields? type: object required: - __type + - years + - months + - days + - leapdays + - hours + - minutes + - seconds + - microseconds + - year + - month + - day + - hour + - minute + - second + - microsecond properties: __type: {type: string} years: {type: integer} @@ -2036,6 +2062,7 @@ components: type: object required: - __type + - value properties: __type: {type: string} value: {type: string} diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py new file mode 100644 index 0000000..5e3afe6 --- /dev/null +++ b/airflow/api_connexion/schemas/common_schema.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import inspect +import typing + +import marshmallow +from dateutil import relativedelta +from marshmallow import Schema, fields, validate +from marshmallow_oneofschema import OneOfSchema + +from airflow.serialization.serialized_objects import SerializedBaseOperator +from airflow.utils.weight_rule import WeightRule + + +class CronExpression(typing.NamedTuple): + """Cron expression schema""" + value: str + + +class TimeDeltaSchema(Schema): + """Time delta schema""" + + objectType = fields.Constant("TimeDelta", dump_to="__type") + days = fields.Integer() + seconds = fields.Integer() + microseconds = fields.Integer() + + @marshmallow.post_load + def make_time_delta(self, data, **kwargs): + """Create time delta based on data""" + + if "objectType" in data: + del data["objectType"] + return datetime.timedelta(**data) + + +class RelativeDeltaSchema(Schema): + """Relative delta schema""" + + objectType = fields.Constant("RelativeDelta", dump_to="__type") + years = fields.Integer() + months = fields.Integer() + days = fields.Integer() + leapdays = fields.Integer() + hours = fields.Integer() + minutes = fields.Integer() + seconds = fields.Integer() + microseconds = fields.Integer() + year = fields.Integer() + month = fields.Integer() + day = fields.Integer() + hour = fields.Integer() + minute = fields.Integer() + second = fields.Integer() + microsecond = fields.Integer() + + @marshmallow.post_load + def make_relative_delta(self, data, **kwargs): + """Create relative delta based on data""" + + if "objectType" in data: + del data["objectType"] + + return relativedelta.relativedelta(**data) + + +class CronExpressionSchema(Schema): + """Cron expression schema""" + + objectType = fields.Constant("CronExpression", dump_to="__type", required=True) + value = fields.String(required=True) + + @marshmallow.post_load + def make_cron_expression(self, data, **kwargs): + """Create cron expression based on data""" + return CronExpression(data["value"]) + + +class ScheduleIntervalSchema(OneOfSchema): + """ + Schedule interval. + + It supports the following types: + + * TimeDelta + * RelativeDelta + * CronExpression + """ + type_field = "__type" + type_schemas = { + "TimeDelta": TimeDeltaSchema, + "RelativeDelta": RelativeDeltaSchema, + "CronExpression": CronExpressionSchema, + } + + def _dump(self, obj, update_fields=True, **kwargs): + if isinstance(obj, str): + obj = CronExpression(obj) + + return super()._dump(obj, update_fields=update_fields, **kwargs) + + def get_obj_type(self, obj): + """Select schema based on object type""" + if isinstance(obj, datetime.timedelta): + return "TimeDelta" + elif isinstance(obj, relativedelta.relativedelta): + return "RelativeDelta" + elif isinstance(obj, CronExpression): + return "CronExpression" + else: + raise Exception("Unknown object type: {}".format(obj.__class__.__name__)) + + +class ColorField(fields.String): + """Schema for color property""" + def __init__(self, **metadata): + super().__init__(**metadata) + self.validators = ( + [validate.Regexp("^#[a-fA-F0-9]{3,6}$")] + list(self.validators) + ) + + +class WeightRuleField(fields.String): + """Schema for WeightRule""" + def __init__(self, **metadata): + super().__init__(**metadata) + self.validators = ( + [validate.OneOf(WeightRule.all_weight_rules())] + list(self.validators) + ) + + +class TimezoneField(fields.String): + """Schema for timezone""" + + +class ClassReferenceSchema(Schema): + """ + Class reference schema. + """ + module_path = fields.Method("_get_module", required=True) + class_name = fields.Method("_get_class_name", required=True) + + def _get_module(self, obj): + if isinstance(obj, SerializedBaseOperator): + return obj._task_module # pylint: disable=protected-access + return inspect.getmodule(obj).__name__ + + def _get_class_name(self, obj): + if isinstance(obj, SerializedBaseOperator): + return obj._task_type # pylint: disable=protected-access + if isinstance(obj, type): + return obj.__name__ + return type(obj).__name__ diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py new file mode 100644 index 0000000..5104d70 --- /dev/null +++ b/airflow/api_connexion/schemas/dag_schema.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.api_connexion.schemas.common_schema import ScheduleIntervalSchema, TimeDeltaSchema, TimezoneField +from airflow.models.dag import DagModel, DagTag + + +class DagTagSchema(SQLAlchemySchema): + """Dag Tag schema""" + class Meta: + """Meta""" + + model = DagTag + + name = auto_field() + + +class DAGSchema(SQLAlchemySchema): + """DAG schema""" + + class Meta: + """Meta""" + + model = DagModel + + dag_id = auto_field(dump_only=True) + root_dag_id = auto_field(dump_only=True) + is_paused = auto_field(dump_only=True) + is_subdag = auto_field(dump_only=True) + fileloc = auto_field(dump_only=True) + owners = fields.Method("get_owners", dump_only=True) + description = auto_field(dump_only=True) + schedule_interval = fields.Nested(ScheduleIntervalSchema, dump_only=True) + tags = fields.List(fields.Nested(DagTagSchema), dump_only=True) + + @staticmethod + def get_owners(obj: DagModel): + """Convert owners attribute to DAG representation""" + + if not obj.owners: + return [] + return obj.owners.split(",") + + +class DAGDetailSchema(DAGSchema): + """DAG details""" + + timezone = TimezoneField(dump_only=True) + catchup = fields.Boolean(dump_only=True) + orientation = fields.String(dump_only=True) + concurrency = fields.Integer(dump_only=True) + start_date = fields.DateTime(dump_only=True) + dag_run_timeout = fields.Nested(TimeDeltaSchema, dump_only=True, attribute="dagrun_timeout") + doc_md = fields.String(dump_only=True) + default_view = fields.String(dump_only=True) + + +class DAGCollection(NamedTuple): + """List of DAGs with metadata""" + + dags: List[DagModel] + total_entries: int + + +class DAGCollectionSchema(Schema): + """DAG Collection schema""" + + dags = fields.List(fields.Nested(DAGSchema)) + total_entries = fields.Int() + + +dags_collection_schema = DAGCollectionSchema() +dag_schema = DAGSchema() +dag_detail_schema = DAGDetailSchema() diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py new file mode 100644 index 0000000..52a6a30 --- /dev/null +++ b/airflow/api_connexion/schemas/task_schema.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, NamedTuple + +from marshmallow import Schema, fields + +from airflow.api_connexion.schemas.common_schema import ( + ClassReferenceSchema, ColorField, TimeDeltaSchema, WeightRuleField, +) +from airflow.api_connexion.schemas.dag_schema import DAGSchema +from airflow.models.baseoperator import BaseOperator + + +class TaskSchema(Schema): + """Task schema""" + + class_ref = fields.Method("_get_class_reference", dump_only=True) + task_id = fields.String(dump_only=True) + owner = fields.String(dump_only=True) + start_date = fields.DateTime(dump_only=True) + end_date = fields.DateTime(dump_only=True) + trigger_rule = fields.String(dump_only=True) + extra_links = fields.List( + fields.Nested(ClassReferenceSchema), + dump_only=True, + attribute="operator_extra_links" + ) + depends_on_past = fields.Boolean(dump_only=True) + wait_for_downstream = fields.Boolean(dump_only=True) + retries = fields.Number(dump_only=True) + queue = fields.String(dump_only=True) + pool = fields.String(dump_only=True) + pool_slots = fields.Number(dump_only=True) + execution_timeout = fields.Nested(TimeDeltaSchema, dump_only=True) + retry_delay = fields.Nested(TimeDeltaSchema, dump_only=True) + retry_exponential_backoff = fields.Boolean(dump_only=True) + priority_weight = fields.Number(dump_only=True) + weight_rule = WeightRuleField(dump_only=True) + ui_color = ColorField(dump_only=True) + ui_fgcolor = ColorField(dump_only=True) + template_fields = fields.List(fields.String(), dump_only=True) + sub_dag = fields.Nested(DAGSchema, dump_only=True) + downstream_task_ids = fields.List(fields.String(), dump_only=True) + + def _get_class_reference(self, obj): + result = ClassReferenceSchema().dump(obj) + return result.data if hasattr(result, "data") else result + + +class TaskCollection(NamedTuple): + """List of Tasks with metadata""" + + tasks: List[BaseOperator] + total_entries: int + + +class TaskCollectionSchema(Schema): + """Schema for TaskCollection""" + + tasks = fields.List(fields.Nested(TaskSchema)) + total_entries = fields.Int() + + +task_schema = TaskSchema() +task_collection_schema = TaskCollectionSchema() diff --git a/requirements/requirements-python3.6.txt b/requirements/requirements-python3.6.txt index 8e0e6ad..b6b690b 100644 --- a/requirements/requirements-python3.6.txt +++ b/requirements/requirements-python3.6.txt @@ -214,6 +214,7 @@ lazy-object-proxy==1.5.0 ldap3==2.7 lockfile==0.12.2 marshmallow-enum==1.5.1 +marshmallow-oneofschema==1.0.6 marshmallow-sqlalchemy==0.23.1 marshmallow==2.21.0 mccabe==0.6.1 diff --git a/requirements/requirements-python3.7.txt b/requirements/requirements-python3.7.txt index 06d4761..e18701c 100644 --- a/requirements/requirements-python3.7.txt +++ b/requirements/requirements-python3.7.txt @@ -210,6 +210,7 @@ lazy-object-proxy==1.5.0 ldap3==2.7 lockfile==0.12.2 marshmallow-enum==1.5.1 +marshmallow-oneofschema==1.0.6 marshmallow-sqlalchemy==0.23.1 marshmallow==2.21.0 mccabe==0.6.1 diff --git a/requirements/requirements-python3.8.txt b/requirements/requirements-python3.8.txt index 3885f0b..918f3ef 100644 --- a/requirements/requirements-python3.8.txt +++ b/requirements/requirements-python3.8.txt @@ -45,7 +45,7 @@ apispec==1.3.3 appdirs==1.4.4 argcomplete==1.11.1 asn1crypto==1.3.0 -astroid==2.4.2 +astroid==2.3.3 async-generator==1.10 async-timeout==3.0.1 atlasclient==1.0.0 @@ -210,6 +210,7 @@ lazy-object-proxy==1.5.0 ldap3==2.7 lockfile==0.12.2 marshmallow-enum==1.5.1 +marshmallow-oneofschema==1.0.6 marshmallow-sqlalchemy==0.23.1 marshmallow==2.21.0 mccabe==0.6.1 diff --git a/requirements/setup-3.6.md5 b/requirements/setup-3.6.md5 index 5b4b71f..86c4da4 100644 --- a/requirements/setup-3.6.md5 +++ b/requirements/setup-3.6.md5 @@ -1 +1 @@ -58b2fa003085a21989e7c2cc68a10461 /opt/airflow/setup.py +cac9433ddd48ca884fa160b007be3818 /opt/airflow/setup.py diff --git a/requirements/setup-3.7.md5 b/requirements/setup-3.7.md5 index 5b4b71f..86c4da4 100644 --- a/requirements/setup-3.7.md5 +++ b/requirements/setup-3.7.md5 @@ -1 +1 @@ -58b2fa003085a21989e7c2cc68a10461 /opt/airflow/setup.py +cac9433ddd48ca884fa160b007be3818 /opt/airflow/setup.py diff --git a/requirements/setup-3.8.md5 b/requirements/setup-3.8.md5 index 5b4b71f..86c4da4 100644 --- a/requirements/setup-3.8.md5 +++ b/requirements/setup-3.8.md5 @@ -1 +1 @@ -58b2fa003085a21989e7c2cc68a10461 /opt/airflow/setup.py +cac9433ddd48ca884fa160b007be3818 /opt/airflow/setup.py diff --git a/setup.py b/setup.py index 4a2d451..4443514 100644 --- a/setup.py +++ b/setup.py @@ -710,6 +710,7 @@ INSTALL_REQUIREMENTS = [ 'lockfile>=0.12.2', 'markdown>=2.5.2, <3.0', 'markupsafe>=1.1.1, <2.0', + 'marshmallow-oneofschema<2', 'pandas>=0.17.1, <2.0', 'pendulum~=2.0', 'pep562~=1.0;python_version<"3.7"', diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 5234d36..9261401 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -14,35 +14,120 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import unittest +from datetime import datetime import pytest +from airflow import DAG +from airflow.models import DagBag +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.dummy_operator import DummyOperator from airflow.www import app +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags class TestDagEndpoint(unittest.TestCase): + dag_id = "test_dag" + task_id = "op1" + + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + @classmethod def setUpClass(cls) -> None: super().setUpClass() cls.app = app.create_app(testing=True) # type:ignore + with DAG(cls.dag_id, start_date=datetime(2020, 6, 15), doc_md="details") as dag: + DummyOperator(task_id=cls.task_id) + + cls.dag = dag # type:ignore + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag} + cls.app.dag_bag = dag_bag # type:ignore + def setUp(self) -> None: + self.clean_db() self.client = self.app.test_client() # type:ignore + def tearDown(self) -> None: + self.clean_db() + class TestGetDag(TestDagEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): - response = self.client.get("/api/v1/dag/1/") + response = self.client.get("/api/v1/dags/1/") assert response.status_code == 200 class TestGetDagDetails(TestDagEndpoint): - @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): - response = self.client.get("/api/v1/dag/TEST_DAG_ID/details") + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details") + assert response.status_code == 200 + expected = { + 'catchup': True, + 'concurrency': 16, + 'dag_id': 'test_dag', + 'dag_run_timeout': None, + 'default_view': 'tree', + 'description': None, + 'doc_md': 'details', + 'fileloc': __file__, + 'is_paused': None, + 'is_subdag': False, + 'orientation': 'LR', + 'schedule_interval': { + '__type': 'TimeDelta', + 'days': 1, + 'microseconds': 0, + 'seconds': 0 + }, + 'start_date': '2020-06-15T00:00:00+00:00', + 'tags': None, + 'timezone': "Timezone('UTC')" + } + assert response.json == expected + + def test_should_response_200_serialized(self): + # Create empty app with empty dagbag to check if DAG is read from db + app_serialized = app.create_app(testing=True) # type:ignore + dag_bag = DagBag(os.devnull, include_examples=False, store_serialized_dags=True) + app_serialized.dag_bag = dag_bag # type:ignore + client = app_serialized.test_client() + + SerializedDagModel.write_dag(self.dag) + + expected = { + 'catchup': True, + 'concurrency': 16, + 'dag_id': 'test_dag', + 'dag_run_timeout': None, + 'default_view': 'tree', + 'description': None, + 'doc_md': 'details', + 'fileloc': __file__, + 'is_paused': None, + 'is_subdag': False, + 'orientation': 'LR', + 'schedule_interval': { + '__type': 'TimeDelta', + 'days': 1, + 'microseconds': 0, + 'seconds': 0 + }, + 'start_date': '2020-06-15T00:00:00+00:00', + 'tags': None, + 'timezone': "Timezone('UTC')" + } + response = client.get(f"/api/v1/dags/{self.dag_id}/details") assert response.status_code == 200 + assert response.json == expected class TestGetDags(TestDagEndpoint): diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index ab6b649..92d08ef 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -14,32 +14,169 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import os import unittest +from datetime import datetime -import pytest - +from airflow import DAG +from airflow.models import DagBag +from airflow.models.serialized_dag import SerializedDagModel +from airflow.operators.dummy_operator import DummyOperator from airflow.www import app +from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags class TestTaskEndpoint(unittest.TestCase): + dag_id = "test_dag" + task_id = "op1" + + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + @classmethod def setUpClass(cls) -> None: super().setUpClass() cls.app = app.create_app(testing=True) # type:ignore + with DAG(cls.dag_id, start_date=datetime(2020, 6, 15), doc_md="details") as dag: + DummyOperator(task_id=cls.task_id) + + cls.dag = dag # type:ignore + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag} + cls.app.dag_bag = dag_bag # type:ignore + def setUp(self) -> None: + self.clean_db() self.client = self.app.test_client() # type:ignore + def tearDown(self) -> None: + self.clean_db() + class TestGetTask(TestTaskEndpoint): - @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): - response = self.client.get("/api/v1/dags/TEST_DAG_ID/tasks/TEST_TASK_ID") + expected = { + "class_ref": { + "class_name": "DummyOperator", + "module_path": "airflow.operators.dummy_operator", + }, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "owner": "airflow", + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": "2020-06-15T00:00:00+00:00", + "task_id": "op1", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + } + response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}") assert response.status_code == 200 + assert response.json == expected + + @conf_vars({("core", "store_serialized_dags"): "True"}) + def test_should_response_200_serialized(self): + # Create empty app with empty dagbag to check if DAG is read from db + app_serialized = app.create_app(testing=True) # type:ignore + dag_bag = DagBag(os.devnull, include_examples=False, store_serialized_dags=True) + app_serialized.dag_bag = dag_bag # type:ignore + client = app_serialized.test_client() + + SerializedDagModel.write_dag(self.dag) + + expected = { + "class_ref": { + "class_name": "DummyOperator", + "module_path": "airflow.operators.dummy_operator", + }, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "owner": "airflow", + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": "2020-06-15T00:00:00+00:00", + "task_id": "op1", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + } + response = client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}") + assert response.status_code == 200 + assert response.json == expected + + def test_should_response_404(self): + task_id = "xxxx_not_existing" + response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{task_id}") + assert response.status_code == 404 class TestGetTasks(TestTaskEndpoint): - @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): - response = self.client.get("/api/v1/dags/TEST_DAG_ID/tasks") + expected = { + "tasks": [ + { + "class_ref": { + "class_name": "DummyOperator", + "module_path": "airflow.operators.dummy_operator", + }, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "owner": "airflow", + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": "2020-06-15T00:00:00+00:00", + "task_id": "op1", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + } + ], + "total_entries": 1, + } + response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks") assert response.status_code == 200 + assert response.json == expected + + def test_should_response_404(self): + dag_id = "xxxx_not_existing" + response = self.client.get(f"/api/v1/dags/{dag_id}/tasks") + assert response.status_code == 404 diff --git a/tests/api_connexion/schemas/test_common_schema.py b/tests/api_connexion/schemas/test_common_schema.py new file mode 100644 index 0000000..d0419b0 --- /dev/null +++ b/tests/api_connexion/schemas/test_common_schema.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import unittest + +from dateutil import relativedelta + +from airflow.api_connexion.schemas.common_schema import ( + CronExpression, CronExpressionSchema, RelativeDeltaSchema, ScheduleIntervalSchema, TimeDeltaSchema, +) + + +class TestTimeDeltaSchema(unittest.TestCase): + def test_should_serialize(self): + instance = datetime.timedelta(days=12) + schema_instance = TimeDeltaSchema() + result = schema_instance.dump(instance) + self.assertEqual( + {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, + result.data + ) + + def test_should_deserialize(self): + instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0} + schema_instance = TimeDeltaSchema() + result = schema_instance.load(instance) + expected_instance = datetime.timedelta(days=12) + self.assertEqual(expected_instance, result.data) + + +class TestRelativeDeltaSchema(unittest.TestCase): + def test_should_serialize(self): + instance = relativedelta.relativedelta(days=+12) + schema_instance = RelativeDeltaSchema() + result = schema_instance.dump(instance) + self.assertEqual( + { + '__type': 'RelativeDelta', + "day": None, + "days": 12, + "hour": None, + "hours": 0, + "leapdays": 0, + "microsecond": None, + "microseconds": 0, + "minute": None, + "minutes": 0, + "month": None, + "months": 0, + "second": None, + "seconds": 0, + "year": None, + "years": 0, + }, + result.data, + ) + + def test_should_deserialize(self): + instance = {"__type": "RelativeDelta", "days": 12, "seconds": 0} + schema_instance = RelativeDeltaSchema() + result = schema_instance.load(instance) + expected_instance = relativedelta.relativedelta(days=+12) + self.assertEqual(expected_instance, result.data) + + +class TestCronExpressionSchema(unittest.TestCase): + def test_should_deserialize(self): + instance = {"__type": "CronExpression", "value": "5 4 * * *"} + schema_instance = CronExpressionSchema() + result = schema_instance.load(instance) + expected_instance = CronExpression("5 4 * * *") + self.assertEqual(expected_instance, result.data) + + +class TestScheduleIntervalSchema(unittest.TestCase): + def test_should_serialize_timedelta(self): + instance = datetime.timedelta(days=12) + schema_instance = ScheduleIntervalSchema() + result = schema_instance.dump(instance) + self.assertEqual( + {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0}, + result.data + ) + + def test_should_deserialize_timedelta(self): + instance = {"__type": "TimeDelta", "days": 12, "seconds": 0, "microseconds": 0} + schema_instance = ScheduleIntervalSchema() + result = schema_instance.load(instance) + expected_instance = datetime.timedelta(days=12) + self.assertEqual(expected_instance, result.data) + + def test_should_serialize_relative_delta(self): + instance = relativedelta.relativedelta(days=+12) + schema_instance = ScheduleIntervalSchema() + result = schema_instance.dump(instance) + self.assertEqual( + { + "__type": "RelativeDelta", + "day": None, + "days": 12, + "hour": None, + "hours": 0, + "leapdays": 0, + "microsecond": None, + "microseconds": 0, + "minute": None, + "minutes": 0, + "month": None, + "months": 0, + "second": None, + "seconds": 0, + "year": None, + "years": 0, + }, + result.data, + ) + + def test_should_deserialize_relative_delta(self): + instance = {"__type": "RelativeDelta", "days": 12, "seconds": 0} + schema_instance = ScheduleIntervalSchema() + result = schema_instance.load(instance) + expected_instance = relativedelta.relativedelta(days=+12) + self.assertEqual(expected_instance, result.data) + + def test_should_serialize_cron_expresssion(self): + instance = "5 4 * * *" + schema_instance = ScheduleIntervalSchema() + result = schema_instance.dump(instance) + expected_instance = {"__type": "CronExpression", "value": "5 4 * * *"} + self.assertEqual(expected_instance, result.data) diff --git a/tests/api_connexion/schemas/test_dag_schema.py b/tests/api_connexion/schemas/test_dag_schema.py new file mode 100644 index 0000000..327bce5 --- /dev/null +++ b/tests/api_connexion/schemas/test_dag_schema.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from datetime import datetime + +from airflow import DAG +from airflow.api_connexion.schemas.dag_schema import ( + DAGCollection, DAGCollectionSchema, DAGDetailSchema, DAGSchema, +) +from airflow.models import DagModel, DagTag + + +class TestDagSchema(unittest.TestCase): + def test_serialize(self): + dag_model = DagModel( + dag_id="test_dag_id", + root_dag_id="test_root_dag_id", + is_paused=True, + is_subdag=False, + fileloc="/root/airflow/dags/my_dag.py", + owners="airflow1,airflow2", + description="The description", + schedule_interval="5 4 * * *", + tags=[DagTag(name="tag-1"), DagTag(name="tag-2")], + ) + serialized_dag = DAGSchema().dump(dag_model) + self.assertEqual( + { + "dag_id": "test_dag_id", + "description": "The description", + "fileloc": "/root/airflow/dags/my_dag.py", + "is_paused": True, + "is_subdag": False, + "owners": ["airflow1", "airflow2"], + "root_dag_id": "test_root_dag_id", + "schedule_interval": {"__type": "CronExpression", "value": "5 4 * * *"}, + "tags": [{"name": "tag-1"}, {"name": "tag-2"}], + }, + serialized_dag.data, + ) + + +class TestDAGCollectionSchema(unittest.TestCase): + def test_serialize(self): + dag_model_a = DagModel(dag_id="test_dag_id_a", fileloc="/tmp/a.py") + dag_model_b = DagModel(dag_id="test_dag_id_b", fileloc="/tmp/a.py") + schema = DAGCollectionSchema() + instance = DAGCollection(dags=[dag_model_a, dag_model_b], total_entries=2) + self.assertEqual( + { + "dags": [ + { + "dag_id": "test_dag_id_a", + "description": None, + "fileloc": "/tmp/a.py", + "is_paused": None, + "is_subdag": None, + "owners": [], + "root_dag_id": None, + "schedule_interval": None, + "tags": [], + }, + { + "dag_id": "test_dag_id_b", + "description": None, + "fileloc": "/tmp/a.py", + "is_paused": None, + "is_subdag": None, + "owners": [], + "root_dag_id": None, + "schedule_interval": None, + "tags": [], + }, + ], + "total_entries": 2, + }, + schema.dump(instance).data, + ) + + +class TestDAGDetailSchema: + def test_serialize(self): + dag = DAG( + dag_id="test_dag", + start_date=datetime(2020, 6, 19), + doc_md="docs", + orientation="LR", + default_view="duration", + ) + schema = DAGDetailSchema() + expected = { + 'catchup': True, + 'concurrency': 16, + 'dag_id': 'test_dag', + 'dag_run_timeout': None, + 'default_view': 'duration', + 'description': None, + 'doc_md': 'docs', + 'fileloc': __file__, + 'is_paused': None, + 'is_subdag': False, + 'orientation': 'LR', + 'schedule_interval': {'__type': 'TimeDelta', 'days': 1, 'seconds': 0, 'microseconds': 0}, + 'start_date': '2020-06-19T00:00:00+00:00', + 'tags': None, + 'timezone': "Timezone('UTC')" + } + assert schema.dump(dag).data == expected diff --git a/tests/api_connexion/schemas/test_task_schema.py b/tests/api_connexion/schemas/test_task_schema.py new file mode 100644 index 0000000..a804869 --- /dev/null +++ b/tests/api_connexion/schemas/test_task_schema.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime + +from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema +from airflow.operators.dummy_operator import DummyOperator + + +class TestTaskSchema: + def test_serialize(self): + op = DummyOperator( + task_id="task_id", + start_date=datetime(2020, 6, 16), + end_date=datetime(2020, 6, 26), + ) + result = task_schema.dump(op) + expected = { + "class_ref": { + "module_path": "airflow.operators.dummy_operator", + "class_name": "DummyOperator", + }, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": "2020-06-26T00:00:00+00:00", + "execution_timeout": None, + "extra_links": [], + "owner": "airflow", + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": "2020-06-16T00:00:00+00:00", + "task_id": "task_id", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + } + assert expected == result.data + + +class TestTaskCollectionSchema: + def test_serialize(self): + tasks = [DummyOperator(task_id="task_id1")] + collection = TaskCollection(tasks, 1) + result = task_collection_schema.dump(collection) + expected = { + "tasks": [ + { + "class_ref": { + "class_name": "DummyOperator", + "module_path": "airflow.operators.dummy_operator", + }, + "depends_on_past": False, + "downstream_task_ids": [], + "end_date": None, + "execution_timeout": None, + "extra_links": [], + "owner": "airflow", + "pool": "default_pool", + "pool_slots": 1.0, + "priority_weight": 1.0, + "queue": "default", + "retries": 0.0, + "retry_delay": {"__type": "TimeDelta", "days": 0, "seconds": 300, "microseconds": 0}, + "retry_exponential_backoff": False, + "start_date": None, + "task_id": "task_id1", + "template_fields": [], + "trigger_rule": "all_success", + "ui_color": "#e8f7e4", + "ui_fgcolor": "#000", + "wait_for_downstream": False, + "weight_rule": "downstream", + } + ], + "total_entries": 1, + } + assert expected == result.data diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 6ad235b..fa4a32e 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -35,6 +35,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs dag_folder_path = '/'.join(os.path.realpath(__file__).split('/')[:-1]) @@ -59,8 +60,14 @@ class TestCliDags(unittest.TestCase): @classmethod def setUpClass(cls): cls.dagbag = DagBag(include_examples=True) + cls.dagbag.sync_to_db() cls.parser = cli_parser.get_parser() + @classmethod + def tearDownClass(cls) -> None: + clear_db_runs() + clear_db_dags() + @mock.patch("airflow.cli.commands.dag_command.DAG.run") def test_backfill(self, mock_run): dag_command.dag_backfill(self.parser.parse_args([ diff --git a/tests/test_utils/db.py b/tests/test_utils/db.py index e967712..6c2c297 100644 --- a/tests/test_utils/db.py +++ b/tests/test_utils/db.py @@ -20,6 +20,7 @@ from airflow.models import ( XCom, errors, ) from airflow.models.dagcode import DagCode +from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections from airflow.utils.session import create_session @@ -36,6 +37,11 @@ def clear_db_dags(): session.query(DagModel).delete() +def clear_db_serialized_dags(): + with create_session() as session: + session.query(SerializedDagModel).delete() + + def clear_db_sla_miss(): with create_session() as session: session.query(SlaMiss).delete()