This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bf15d63721 More code coverage for the REST API (#35421)
bf15d63721 is described below
commit bf15d63721082ce93782b5e51829c8afbe8fd60b
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Nov 7 13:12:20 2023 +0100
More code coverage for the REST API (#35421)
* More code coverage for the REST API
This commit adds more tests to the REST API modules
---
.../endpoints/forward_to_fab_endpoint.py | 4 +
airflow/api_connexion/schemas/common_schema.py | 7 +-
scripts/cov/restapi_coverage.py | 1 -
.../endpoints/test_forward_to_fab_endpoint.py | 238 +++++++++++++++++++++
.../api_connexion/endpoints/test_xcom_endpoint.py | 57 +++++
5 files changed, 301 insertions(+), 6 deletions(-)
diff --git a/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
b/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
index ded340d82a..95507b7e9f 100644
--- a/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
+++ b/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
@@ -79,12 +79,14 @@ def delete_role(**kwargs) -> APIResponse:
@_require_fab
def patch_role(**kwargs) -> APIResponse:
"""Update a role."""
+ kwargs.pop("body", None)
return role_and_permission_endpoint.patch_role(**kwargs)
@_require_fab
def post_role(**kwargs) -> APIResponse:
"""Create a new role."""
+ kwargs.pop("body", None)
return role_and_permission_endpoint.post_role(**kwargs)
@@ -111,12 +113,14 @@ def get_users(**kwargs) -> APIResponse:
@_require_fab
def post_user(**kwargs) -> APIResponse:
"""Create a new user."""
+ kwargs.pop("body", None)
return user_endpoint.post_user(**kwargs)
@_require_fab
def patch_user(**kwargs) -> APIResponse:
"""Update a user."""
+ kwargs.pop("body", None)
return user_endpoint.patch_user(**kwargs)
diff --git a/airflow/api_connexion/schemas/common_schema.py
b/airflow/api_connexion/schemas/common_schema.py
index a470e6b1c0..16740fc4c0 100644
--- a/airflow/api_connexion/schemas/common_schema.py
+++ b/airflow/api_connexion/schemas/common_schema.py
@@ -48,8 +48,7 @@ class TimeDeltaSchema(Schema):
@marshmallow.post_load
def make_time_delta(self, data, **kwargs):
"""Create time delta based on data."""
- if "objectType" in data:
- del data["objectType"]
+ data.pop("objectType", None)
return datetime.timedelta(**data)
@@ -76,9 +75,7 @@ class RelativeDeltaSchema(Schema):
@marshmallow.post_load
def make_relative_delta(self, data, **kwargs):
"""Create relative delta based on data."""
- if "objectType" in data:
- del data["objectType"]
-
+ data.pop("objectType", None)
return relativedelta.relativedelta(**data)
diff --git a/scripts/cov/restapi_coverage.py b/scripts/cov/restapi_coverage.py
index 9a3dc2a143..1de8e40c48 100644
--- a/scripts/cov/restapi_coverage.py
+++ b/scripts/cov/restapi_coverage.py
@@ -30,7 +30,6 @@ restapi_files = ["tests/api_experimental",
"tests/api_connexion", "tests/api_int
files_not_fully_covered = [
"airflow/api_connexion/endpoints/forward_to_fab_endpoint.py",
"airflow/api_connexion/endpoints/task_instance_endpoint.py",
- "airflow/api_connexion/endpoints/xcom_endpoint.py",
"airflow/api_connexion/exceptions.py",
"airflow/api_connexion/schemas/common_schema.py",
"airflow/api_connexion/security.py",
diff --git a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
new file mode 100644
index 0000000000..5560a4d877
--- /dev/null
+++ b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
+from airflow.auth.managers.fab.models import Role, User
+from airflow.security import permissions
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.www.security import EXISTING_ROLES
+from tests.test_utils.api_connexion_utils import create_role, create_user,
delete_role, delete_user
+
+pytestmark = pytest.mark.db_test
+
+DEFAULT_TIME = "2020-06-11T18:00:00+00:00"
+
+EXAMPLE_USER_NAME = "example_user"
+
+EXAMPLE_USER_EMAIL = "[email protected]"
+
+
+def _delete_user(**filters):
+ with create_session() as session:
+ user = session.query(User).filter_by(**filters).first()
+ if user is None:
+ return
+ user.roles = []
+ session.delete(user)
+
+
[email protected]()
+def autoclean_user_payload(autoclean_username, autoclean_email):
+ return {
+ "username": autoclean_username,
+ "password": "resutsop",
+ "email": autoclean_email,
+ "first_name": "Tester",
+ "last_name": "",
+ }
+
+
[email protected]()
+def autoclean_admin_user(configured_app, autoclean_user_payload):
+ security_manager = configured_app.appbuilder.sm
+ return security_manager.add_user(
+ role=security_manager.find_role("Admin"),
+ **autoclean_user_payload,
+ )
+
+
[email protected]()
+def autoclean_username():
+ _delete_user(username=EXAMPLE_USER_NAME)
+ yield EXAMPLE_USER_NAME
+ _delete_user(username=EXAMPLE_USER_NAME)
+
+
[email protected]()
+def autoclean_email():
+ _delete_user(email=EXAMPLE_USER_EMAIL)
+ yield EXAMPLE_USER_EMAIL
+ _delete_user(email=EXAMPLE_USER_EMAIL)
+
+
[email protected](scope="module")
+def configured_app(minimal_app_for_api):
+ app = minimal_app_for_api
+ create_user(
+ app, # type: ignore
+ username="test",
+ role_name="Test",
+ permissions=[
+ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE),
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION),
+ (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_USER),
+ (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_USER),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER),
+ ],
+ )
+
+ yield app
+
+ delete_user(app, username="test") # type: ignore
+
+
+class TestFABforwarding:
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self, configured_app) -> None:
+ self.app = configured_app
+ self.client = self.app.test_client() # type:ignore
+
+ def teardown_method(self):
+ """
+ Delete all roles except these ones.
+ Test and TestNoPermissions are deleted by delete_user above
+ """
+ session = self.app.appbuilder.get_session
+ existing_roles = set(EXISTING_ROLES)
+ existing_roles.update(["Test", "TestNoPermissions"])
+ roles =
session.query(Role).filter(~Role.name.in_(existing_roles)).all()
+ for role in roles:
+ delete_role(self.app, role.name)
+ users = session.query(User).filter(User.changed_on ==
timezone.parse(DEFAULT_TIME))
+ users.delete(synchronize_session=False)
+ session.commit()
+
+
+class TestFABRoleForwarding(TestFABforwarding):
+
@mock.patch("airflow.api_connexion.endpoints.forward_to_fab_endpoint.get_auth_manager")
+ def test_raises_400_if_manager_is_not_fab(self, mock_get_auth_manager):
+ mock_get_auth_manager.return_value = BaseAuthManager(self.app,
self.app.appbuilder)
+ response = self.client.get("api/v1/roles",
environ_overrides={"REMOTE_USER": "test"})
+ assert response.status_code == 400
+ assert (
+ response.json["detail"]
+ == "This endpoint is only available when using the default auth
manager FabAuthManager."
+ )
+
+ def test_get_role_forwards_to_fab(self):
+ resp = self.client.get("api/v1/roles/Test",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 200
+
+ def test_get_roles_forwards_to_fab(self):
+ resp = self.client.get("api/v1/roles",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 200
+
+ def test_delete_role_forwards_to_fab(self):
+ role = create_role(self.app, "mytestrole")
+ resp = self.client.delete(f"api/v1/roles/{role.name}",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 204
+
+ def test_patch_role_forwards_to_fab(self):
+ role = create_role(self.app, "mytestrole")
+ resp = self.client.patch(
+ f"api/v1/roles/{role.name}", json={"name": "Test2"},
environ_overrides={"REMOTE_USER": "test"}
+ )
+ assert resp.status_code == 200
+
+ def test_post_role_forwards_to_fab(self):
+ payload = {
+ "name": "Test2",
+ "actions": [{"resource": {"name": "Connections"}, "action":
{"name": "can_create"}}],
+ }
+ resp = self.client.post("api/v1/roles", json=payload,
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 200
+
+ def test_get_role_permissions_forwards_to_fab(self):
+ resp = self.client.get("api/v1/permissions",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 200
+
+
+class TestFABUserForwarding(TestFABforwarding):
+ def _create_users(self, count, roles=None):
+ # create users with defined created_on and changed_on date
+ # for easy testing
+ if roles is None:
+ roles = []
+ return [
+ User(
+ first_name=f"test{i}",
+ last_name=f"test{i}",
+ username=f"TEST_USER{i}",
+ email=f"mytest@test{i}.org",
+ roles=roles or [],
+ created_on=timezone.parse(DEFAULT_TIME),
+ changed_on=timezone.parse(DEFAULT_TIME),
+ )
+ for i in range(1, count + 1)
+ ]
+
+ def test_get_user_forwards_to_fab(self):
+ users = self._create_users(1)
+ session = self.app.appbuilder.get_session
+ session.add_all(users)
+ session.commit()
+ resp = self.client.get("api/v1/users/TEST_USER1",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 200
+
+ def test_get_users_forwards_to_fab(self):
+ users = self._create_users(2)
+ session = self.app.appbuilder.get_session
+ session.add_all(users)
+ session.commit()
+ resp = self.client.get("api/v1/users",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 200
+
+ def test_post_user_forwards_to_fab(self, autoclean_username,
autoclean_user_payload):
+ response = self.client.post(
+ "/api/v1/users",
+ json=autoclean_user_payload,
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response.status_code == 200, response.json
+
+ security_manager = self.app.appbuilder.sm
+ user = security_manager.find_user(autoclean_username)
+ assert user is not None
+ assert user.roles == [security_manager.find_role("Public")]
+
+ @pytest.mark.usefixtures("autoclean_admin_user")
+ def test_patch_user_forwards_to_fab(self, autoclean_username,
autoclean_user_payload):
+ autoclean_user_payload["first_name"] = "Changed"
+ response = self.client.patch(
+ f"/api/v1/users/{autoclean_username}",
+ json=autoclean_user_payload,
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response.status_code == 200, response.json
+
+ def test_delete_user_forwards_to_fab(self):
+ users = self._create_users(1)
+ session = self.app.appbuilder.get_session
+ session.add_all(users)
+ session.commit()
+ resp = self.client.delete("api/v1/users/TEST_USER1",
environ_overrides={"REMOTE_USER": "test"})
+ assert resp.status_code == 204
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py
b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index c4919314c5..2efb08f705 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -148,6 +148,21 @@ class TestGetXComEntry(TestXComEndpoint):
"value": "TEST_VALUE",
}
+ def test_should_raise_404_for_non_existent_xcom(self):
+ dag_id = "test-dag-id"
+ task_id = "test-task-id"
+ execution_date = "2005-04-02T00:00:00+00:00"
+ xcom_key = "test-xcom-key"
+ execution_date_parsed = parse_execution_date(execution_date)
+ run_id = DagRun.generate_run_id(DagRunType.MANUAL,
execution_date_parsed)
+ self._create_xcom_entry(dag_id, run_id, execution_date_parsed,
task_id, xcom_key)
+ response = self.client.get(
+
f"/api/v1/dags/nonexistentdagid/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert 404 == response.status_code
+ assert response.json["title"] == "XCom entry not found"
+
def test_should_raises_401_unauthenticated(self):
dag_id = "test-dag-id"
task_id = "test-task-id"
@@ -453,6 +468,48 @@ class TestGetXComEntries(TestXComEndpoint):
assert_expected_result([expected_entry2], map_index=1)
assert_expected_result([expected_entry1, expected_entry2],
map_index=None)
+ def test_should_respond_200_with_xcom_key(self):
+ dag_id = "test-dag-id"
+ task_id = "test-task-id"
+ execution_date = "2005-04-02T00:00:00+00:00"
+ execution_date_parsed = parse_execution_date(execution_date)
+ dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL,
execution_date_parsed)
+ self._create_xcom_entries(dag_id, dag_run_id, execution_date_parsed,
task_id, mapped_ti=True)
+
+ def assert_expected_result(expected_entries, key=None):
+ response = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries"
f"{('?xcom_key='+key )}",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+
+ assert 200 == response.status_code
+ response_data = response.json
+ for xcom_entry in response_data["xcom_entries"]:
+ xcom_entry["timestamp"] = "TIMESTAMP"
+ assert response_data == {
+ "xcom_entries": expected_entries,
+ "total_entries": len(expected_entries),
+ }
+
+ expected_entry1 = {
+ "dag_id": dag_id,
+ "execution_date": execution_date,
+ "key": "test-xcom-key",
+ "task_id": task_id,
+ "timestamp": "TIMESTAMP",
+ "map_index": 0,
+ }
+ expected_entry2 = {
+ "dag_id": dag_id,
+ "execution_date": execution_date,
+ "key": "test-xcom-key",
+ "task_id": task_id,
+ "timestamp": "TIMESTAMP",
+ "map_index": 1,
+ }
+ assert_expected_result([expected_entry1, expected_entry2],
key="test-xcom-key")
+ assert_expected_result([], key="test-xcom-key-1")
+
def test_should_raises_401_unauthenticated(self):
dag_id = "test-dag-id"
task_id = "test-task-id"