This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 8b94ace Add read-only endpoints for DAG Model (#9045)
8b94ace is described below
commit 8b94ace597f47e350161d799b6b45aad80f45ae4
Author: Kamil BreguĊa <[email protected]>
AuthorDate: Thu Jul 9 07:28:34 2020 +0200
Add read-only endpoints for DAG Model (#9045)
Co-authored-by: Tomek Urbaszek <[email protected]>
Co-authored-by: Tomek Urbaszek <[email protected]>
---
airflow/api_connexion/endpoints/dag_endpoint.py | 35 ++++--
airflow/api_connexion/schemas/dag_schema.py | 1 +
tests/api_connexion/endpoints/test_dag_endpoint.py | 130 ++++++++++++++++++++-
3 files changed, 151 insertions(+), 15 deletions(-)
diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py
b/airflow/api_connexion/endpoints/dag_endpoint.py
index 7cdeeb6..4f6aa2e 100644
--- a/airflow/api_connexion/endpoints/dag_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_endpoint.py
@@ -14,23 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from flask import current_app
+from sqlalchemy import func
from airflow import DAG
from airflow.api_connexion.exceptions import NotFound
-# TODO(mik-laj): We have to implement it.
-# Do you want to help? Please look at:
-# * https://github.com/apache/airflow/issues/8128
-# * https://github.com/apache/airflow/issues/8138
-from airflow.api_connexion.schemas.dag_schema import dag_detail_schema
+from airflow.api_connexion.parameters import check_limit, format_parameters
+from airflow.api_connexion.schemas.dag_schema import (
+ DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema,
+)
+from airflow.models.dag import DagModel
+from airflow.utils.session import provide_session
-def get_dag():
+@provide_session
+def get_dag(dag_id, session):
"""
Get basic information about a DAG.
"""
- raise NotImplementedError("Not implemented yet.")
+ dag = session.query(DagModel).filter(DagModel.dag_id ==
dag_id).one_or_none()
+
+ if dag is None:
+ raise NotFound("DAG not found")
+
+ return dag_schema.dump(dag)
def get_dag_details(dag_id):
@@ -43,11 +50,19 @@ def get_dag_details(dag_id):
return dag_detail_schema.dump(dag)
-def get_dags():
+@format_parameters({
+ 'limit': check_limit
+})
+@provide_session
+def get_dags(session, limit, offset=0):
"""
Get all DAGs.
"""
- raise NotImplementedError("Not implemented yet.")
+ dags =
session.query(DagModel).order_by(DagModel.dag_id).offset(offset).limit(limit).all()
+
+ total_entries = session.query(func.count(DagModel.dag_id)).scalar()
+
+ return dags_collection_schema.dump(DAGCollection(dags=dags,
total_entries=total_entries))
def patch_dag():
diff --git a/airflow/api_connexion/schemas/dag_schema.py
b/airflow/api_connexion/schemas/dag_schema.py
index aff859a..bae2228 100644
--- a/airflow/api_connexion/schemas/dag_schema.py
+++ b/airflow/api_connexion/schemas/dag_schema.py
@@ -89,4 +89,5 @@ class DAGCollectionSchema(Schema):
dags_collection_schema = DAGCollectionSchema()
dag_schema = DAGSchema()
+
dag_detail_schema = DAGDetailSchema()
diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py
b/tests/api_connexion/endpoints/test_dag_endpoint.py
index 6289b6f..1ba360f 100644
--- a/tests/api_connexion/endpoints/test_dag_endpoint.py
+++ b/tests/api_connexion/endpoints/test_dag_endpoint.py
@@ -19,11 +19,13 @@ import unittest
from datetime import datetime
import pytest
+from parameterized import parameterized
from airflow import DAG
-from airflow.models import DagBag
+from airflow.models import DagBag, DagModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.session import provide_session
from airflow.www import app
from tests.test_utils.db import clear_db_dags, clear_db_runs,
clear_db_serialized_dags
@@ -58,13 +60,41 @@ class TestDagEndpoint(unittest.TestCase):
def tearDown(self) -> None:
self.clean_db()
+ @provide_session
+ def _create_dag_models(self, count, session=None):
+ for num in range(1, count + 1):
+ dag_model = DagModel(
+ dag_id=f"TEST_DAG_{num}",
+ fileloc=f"/tmp/dag_{num}.py",
+ schedule_interval="2 2 * * *"
+ )
+ session.add(dag_model)
+
class TestGetDag(TestDagEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
def test_should_response_200(self):
- response = self.client.get("/api/v1/dags/1/")
+ self._create_dag_models(1)
+ response = self.client.get("/api/v1/dags/TEST_DAG_1")
assert response.status_code == 200
+ current_response = response.json
+ current_response["fileloc"] = "/tmp/test-dag.py"
+ self.assertEqual({
+ 'dag_id': 'TEST_DAG_1',
+ 'description': None,
+ 'fileloc': '/tmp/test-dag.py',
+ 'is_paused': False,
+ 'is_subdag': False,
+ 'owners': [],
+ 'root_dag_id': None,
+ 'schedule_interval': {'__type': 'CronExpression', 'value': '2 2 *
* *'},
+ 'tags': []
+ }, current_response)
+
+ def test_should_response_404(self):
+ response = self.client.get("/api/v1/dags/INVALID_DAG")
+ assert response.status_code == 404
+
class TestGetDagDetails(TestDagEndpoint):
def test_should_response_200(self):
@@ -133,11 +163,101 @@ class TestGetDagDetails(TestDagEndpoint):
class TestGetDags(TestDagEndpoint):
- @pytest.mark.skip(reason="Not implemented yet")
+
def test_should_response_200(self):
- response = self.client.get("/api/v1/dags/1")
+ self._create_dag_models(2)
+
+ response = self.client.get("api/v1/dags")
+
+ assert response.status_code == 200
+
+ self.assertEqual(
+ {
+ "dags": [
+ {
+ "dag_id": "TEST_DAG_1",
+ "description": None,
+ "fileloc": "/tmp/dag_1.py",
+ "is_paused": False,
+ "is_subdag": False,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": {"__type": "CronExpression",
"value": "2 2 * * *"},
+ "tags": [],
+ },
+ {
+ "dag_id": "TEST_DAG_2",
+ "description": None,
+ "fileloc": "/tmp/dag_2.py",
+ "is_paused": False,
+ "is_subdag": False,
+ "owners": [],
+ "root_dag_id": None,
+ "schedule_interval": {"__type": "CronExpression",
"value": "2 2 * * *"},
+ "tags": [],
+ },
+ ],
+ "total_entries": 2,
+ },
+ response.json,
+ )
+
+ @parameterized.expand(
+ [
+ ("api/v1/dags?limit=1", ["TEST_DAG_1"]),
+ ("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]),
+ (
+ "api/v1/dags?offset=5",
+ [
+ "TEST_DAG_5",
+ "TEST_DAG_6",
+ "TEST_DAG_7",
+ "TEST_DAG_8",
+ "TEST_DAG_9",
+ ],
+ ),
+ (
+ "api/v1/dags?offset=0",
+ [
+ "TEST_DAG_1",
+ "TEST_DAG_10",
+ "TEST_DAG_2",
+ "TEST_DAG_3",
+ "TEST_DAG_4",
+ "TEST_DAG_5",
+ "TEST_DAG_6",
+ "TEST_DAG_7",
+ "TEST_DAG_8",
+ "TEST_DAG_9",
+ ],
+ ),
+ ("api/v1/dags?limit=1&offset=5", ["TEST_DAG_5"]),
+ ("api/v1/dags?limit=1&offset=1", ["TEST_DAG_10"]),
+ ("api/v1/dags?limit=2&offset=2", ["TEST_DAG_2", "TEST_DAG_3"]),
+ ]
+ )
+ def test_should_response_200_and_handle_pagination(self, url,
expected_dag_ids):
+ self._create_dag_models(10)
+
+ response = self.client.get(url)
+
+ assert response.status_code == 200
+
+ dag_ids = [dag["dag_id"] for dag in response.json['dags']]
+
+ self.assertEqual(expected_dag_ids, dag_ids)
+ self.assertEqual(10, response.json['total_entries'])
+
+ def test_should_response_200_default_limit(self):
+ self._create_dag_models(101)
+
+ response = self.client.get("api/v1/dags")
+
assert response.status_code == 200
+ self.assertEqual(100, len(response.json['dags']))
+ self.assertEqual(101, response.json['total_entries'])
+
class TestPatchDag(TestDagEndpoint):
@pytest.mark.skip(reason="Not implemented yet")