This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bf15d63721 More code coverage for the REST API (#35421)
bf15d63721 is described below

commit bf15d63721082ce93782b5e51829c8afbe8fd60b
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Nov 7 13:12:20 2023 +0100

    More code coverage for the REST API (#35421)
    
    * More code coverage for the REST API
    
    This commit adds more tests to the REST API modules
---
 .../endpoints/forward_to_fab_endpoint.py           |   4 +
 airflow/api_connexion/schemas/common_schema.py     |   7 +-
 scripts/cov/restapi_coverage.py                    |   1 -
 .../endpoints/test_forward_to_fab_endpoint.py      | 238 +++++++++++++++++++++
 .../api_connexion/endpoints/test_xcom_endpoint.py  |  57 +++++
 5 files changed, 301 insertions(+), 6 deletions(-)

diff --git a/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py 
b/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
index ded340d82a..95507b7e9f 100644
--- a/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
+++ b/airflow/api_connexion/endpoints/forward_to_fab_endpoint.py
@@ -79,12 +79,14 @@ def delete_role(**kwargs) -> APIResponse:
 @_require_fab
 def patch_role(**kwargs) -> APIResponse:
     """Update a role."""
+    kwargs.pop("body", None)
     return role_and_permission_endpoint.patch_role(**kwargs)
 
 
 @_require_fab
 def post_role(**kwargs) -> APIResponse:
     """Create a new role."""
+    kwargs.pop("body", None)
     return role_and_permission_endpoint.post_role(**kwargs)
 
 
@@ -111,12 +113,14 @@ def get_users(**kwargs) -> APIResponse:
 @_require_fab
 def post_user(**kwargs) -> APIResponse:
     """Create a new user."""
+    kwargs.pop("body", None)
     return user_endpoint.post_user(**kwargs)
 
 
 @_require_fab
 def patch_user(**kwargs) -> APIResponse:
     """Update a user."""
+    kwargs.pop("body", None)
     return user_endpoint.patch_user(**kwargs)
 
 
diff --git a/airflow/api_connexion/schemas/common_schema.py 
b/airflow/api_connexion/schemas/common_schema.py
index a470e6b1c0..16740fc4c0 100644
--- a/airflow/api_connexion/schemas/common_schema.py
+++ b/airflow/api_connexion/schemas/common_schema.py
@@ -48,8 +48,7 @@ class TimeDeltaSchema(Schema):
     @marshmallow.post_load
     def make_time_delta(self, data, **kwargs):
         """Create time delta based on data."""
-        if "objectType" in data:
-            del data["objectType"]
+        data.pop("objectType", None)
         return datetime.timedelta(**data)
 
 
@@ -76,9 +75,7 @@ class RelativeDeltaSchema(Schema):
     @marshmallow.post_load
     def make_relative_delta(self, data, **kwargs):
         """Create relative delta based on data."""
-        if "objectType" in data:
-            del data["objectType"]
-
+        data.pop("objectType", None)
         return relativedelta.relativedelta(**data)
 
 
diff --git a/scripts/cov/restapi_coverage.py b/scripts/cov/restapi_coverage.py
index 9a3dc2a143..1de8e40c48 100644
--- a/scripts/cov/restapi_coverage.py
+++ b/scripts/cov/restapi_coverage.py
@@ -30,7 +30,6 @@ restapi_files = ["tests/api_experimental", 
"tests/api_connexion", "tests/api_int
 files_not_fully_covered = [
     "airflow/api_connexion/endpoints/forward_to_fab_endpoint.py",
     "airflow/api_connexion/endpoints/task_instance_endpoint.py",
-    "airflow/api_connexion/endpoints/xcom_endpoint.py",
     "airflow/api_connexion/exceptions.py",
     "airflow/api_connexion/schemas/common_schema.py",
     "airflow/api_connexion/security.py",
diff --git a/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py 
b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
new file mode 100644
index 0000000000..5560a4d877
--- /dev/null
+++ b/tests/api_connexion/endpoints/test_forward_to_fab_endpoint.py
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
+from airflow.auth.managers.fab.models import Role, User
+from airflow.security import permissions
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.www.security import EXISTING_ROLES
+from tests.test_utils.api_connexion_utils import create_role, create_user, 
delete_role, delete_user
+
+pytestmark = pytest.mark.db_test
+
+DEFAULT_TIME = "2020-06-11T18:00:00+00:00"
+
+EXAMPLE_USER_NAME = "example_user"
+
+EXAMPLE_USER_EMAIL = "[email protected]"
+
+
+def _delete_user(**filters):
+    with create_session() as session:
+        user = session.query(User).filter_by(**filters).first()
+        if user is None:
+            return
+        user.roles = []
+        session.delete(user)
+
+
[email protected]()
+def autoclean_user_payload(autoclean_username, autoclean_email):
+    return {
+        "username": autoclean_username,
+        "password": "resutsop",
+        "email": autoclean_email,
+        "first_name": "Tester",
+        "last_name": "",
+    }
+
+
[email protected]()
+def autoclean_admin_user(configured_app, autoclean_user_payload):
+    security_manager = configured_app.appbuilder.sm
+    return security_manager.add_user(
+        role=security_manager.find_role("Admin"),
+        **autoclean_user_payload,
+    )
+
+
[email protected]()
+def autoclean_username():
+    _delete_user(username=EXAMPLE_USER_NAME)
+    yield EXAMPLE_USER_NAME
+    _delete_user(username=EXAMPLE_USER_NAME)
+
+
[email protected]()
+def autoclean_email():
+    _delete_user(email=EXAMPLE_USER_EMAIL)
+    yield EXAMPLE_USER_EMAIL
+    _delete_user(email=EXAMPLE_USER_EMAIL)
+
+
[email protected](scope="module")
+def configured_app(minimal_app_for_api):
+    app = minimal_app_for_api
+    create_user(
+        app,  # type: ignore
+        username="test",
+        role_name="Test",
+        permissions=[
+            (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE),
+            (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE),
+            (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION),
+            (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_USER),
+            (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER),
+            (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_USER),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER),
+        ],
+    )
+
+    yield app
+
+    delete_user(app, username="test")  # type: ignore
+
+
+class TestFABforwarding:
+    @pytest.fixture(autouse=True)
+    def setup_attrs(self, configured_app) -> None:
+        self.app = configured_app
+        self.client = self.app.test_client()  # type:ignore
+
+    def teardown_method(self):
+        """
+        Delete all roles except these ones.
+        Test and TestNoPermissions are deleted by delete_user above
+        """
+        session = self.app.appbuilder.get_session
+        existing_roles = set(EXISTING_ROLES)
+        existing_roles.update(["Test", "TestNoPermissions"])
+        roles = 
session.query(Role).filter(~Role.name.in_(existing_roles)).all()
+        for role in roles:
+            delete_role(self.app, role.name)
+        users = session.query(User).filter(User.changed_on == 
timezone.parse(DEFAULT_TIME))
+        users.delete(synchronize_session=False)
+        session.commit()
+
+
+class TestFABRoleForwarding(TestFABforwarding):
+    
@mock.patch("airflow.api_connexion.endpoints.forward_to_fab_endpoint.get_auth_manager")
+    def test_raises_400_if_manager_is_not_fab(self, mock_get_auth_manager):
+        mock_get_auth_manager.return_value = BaseAuthManager(self.app, 
self.app.appbuilder)
+        response = self.client.get("api/v1/roles", 
environ_overrides={"REMOTE_USER": "test"})
+        assert response.status_code == 400
+        assert (
+            response.json["detail"]
+            == "This endpoint is only available when using the default auth 
manager FabAuthManager."
+        )
+
+    def test_get_role_forwards_to_fab(self):
+        resp = self.client.get("api/v1/roles/Test", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 200
+
+    def test_get_roles_forwards_to_fab(self):
+        resp = self.client.get("api/v1/roles", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 200
+
+    def test_delete_role_forwards_to_fab(self):
+        role = create_role(self.app, "mytestrole")
+        resp = self.client.delete(f"api/v1/roles/{role.name}", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 204
+
+    def test_patch_role_forwards_to_fab(self):
+        role = create_role(self.app, "mytestrole")
+        resp = self.client.patch(
+            f"api/v1/roles/{role.name}", json={"name": "Test2"}, 
environ_overrides={"REMOTE_USER": "test"}
+        )
+        assert resp.status_code == 200
+
+    def test_post_role_forwards_to_fab(self):
+        payload = {
+            "name": "Test2",
+            "actions": [{"resource": {"name": "Connections"}, "action": 
{"name": "can_create"}}],
+        }
+        resp = self.client.post("api/v1/roles", json=payload, 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 200
+
+    def test_get_role_permissions_forwards_to_fab(self):
+        resp = self.client.get("api/v1/permissions", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 200
+
+
+class TestFABUserForwarding(TestFABforwarding):
+    def _create_users(self, count, roles=None):
+        # create users with defined created_on and changed_on date
+        # for easy testing
+        if roles is None:
+            roles = []
+        return [
+            User(
+                first_name=f"test{i}",
+                last_name=f"test{i}",
+                username=f"TEST_USER{i}",
+                email=f"mytest@test{i}.org",
+                roles=roles or [],
+                created_on=timezone.parse(DEFAULT_TIME),
+                changed_on=timezone.parse(DEFAULT_TIME),
+            )
+            for i in range(1, count + 1)
+        ]
+
+    def test_get_user_forwards_to_fab(self):
+        users = self._create_users(1)
+        session = self.app.appbuilder.get_session
+        session.add_all(users)
+        session.commit()
+        resp = self.client.get("api/v1/users/TEST_USER1", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 200
+
+    def test_get_users_forwards_to_fab(self):
+        users = self._create_users(2)
+        session = self.app.appbuilder.get_session
+        session.add_all(users)
+        session.commit()
+        resp = self.client.get("api/v1/users", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 200
+
+    def test_post_user_forwards_to_fab(self, autoclean_username, 
autoclean_user_payload):
+        response = self.client.post(
+            "/api/v1/users",
+            json=autoclean_user_payload,
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response.status_code == 200, response.json
+
+        security_manager = self.app.appbuilder.sm
+        user = security_manager.find_user(autoclean_username)
+        assert user is not None
+        assert user.roles == [security_manager.find_role("Public")]
+
+    @pytest.mark.usefixtures("autoclean_admin_user")
+    def test_patch_user_forwards_to_fab(self, autoclean_username, 
autoclean_user_payload):
+        autoclean_user_payload["first_name"] = "Changed"
+        response = self.client.patch(
+            f"/api/v1/users/{autoclean_username}",
+            json=autoclean_user_payload,
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response.status_code == 200, response.json
+
+    def test_delete_user_forwards_to_fab(self):
+        users = self._create_users(1)
+        session = self.app.appbuilder.get_session
+        session.add_all(users)
+        session.commit()
+        resp = self.client.delete("api/v1/users/TEST_USER1", 
environ_overrides={"REMOTE_USER": "test"})
+        assert resp.status_code == 204
diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py 
b/tests/api_connexion/endpoints/test_xcom_endpoint.py
index c4919314c5..2efb08f705 100644
--- a/tests/api_connexion/endpoints/test_xcom_endpoint.py
+++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py
@@ -148,6 +148,21 @@ class TestGetXComEntry(TestXComEndpoint):
             "value": "TEST_VALUE",
         }
 
+    def test_should_raise_404_for_non_existent_xcom(self):
+        dag_id = "test-dag-id"
+        task_id = "test-task-id"
+        execution_date = "2005-04-02T00:00:00+00:00"
+        xcom_key = "test-xcom-key"
+        execution_date_parsed = parse_execution_date(execution_date)
+        run_id = DagRun.generate_run_id(DagRunType.MANUAL, 
execution_date_parsed)
+        self._create_xcom_entry(dag_id, run_id, execution_date_parsed, 
task_id, xcom_key)
+        response = self.client.get(
+            
f"/api/v1/dags/nonexistentdagid/dagRuns/{run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}",
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert 404 == response.status_code
+        assert response.json["title"] == "XCom entry not found"
+
     def test_should_raises_401_unauthenticated(self):
         dag_id = "test-dag-id"
         task_id = "test-task-id"
@@ -453,6 +468,48 @@ class TestGetXComEntries(TestXComEndpoint):
         assert_expected_result([expected_entry2], map_index=1)
         assert_expected_result([expected_entry1, expected_entry2], 
map_index=None)
 
+    def test_should_respond_200_with_xcom_key(self):
+        dag_id = "test-dag-id"
+        task_id = "test-task-id"
+        execution_date = "2005-04-02T00:00:00+00:00"
+        execution_date_parsed = parse_execution_date(execution_date)
+        dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL, 
execution_date_parsed)
+        self._create_xcom_entries(dag_id, dag_run_id, execution_date_parsed, 
task_id, mapped_ti=True)
+
+        def assert_expected_result(expected_entries, key=None):
+            response = self.client.get(
+                "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries" 
f"{('?xcom_key='+key )}",
+                environ_overrides={"REMOTE_USER": "test"},
+            )
+
+            assert 200 == response.status_code
+            response_data = response.json
+            for xcom_entry in response_data["xcom_entries"]:
+                xcom_entry["timestamp"] = "TIMESTAMP"
+            assert response_data == {
+                "xcom_entries": expected_entries,
+                "total_entries": len(expected_entries),
+            }
+
+        expected_entry1 = {
+            "dag_id": dag_id,
+            "execution_date": execution_date,
+            "key": "test-xcom-key",
+            "task_id": task_id,
+            "timestamp": "TIMESTAMP",
+            "map_index": 0,
+        }
+        expected_entry2 = {
+            "dag_id": dag_id,
+            "execution_date": execution_date,
+            "key": "test-xcom-key",
+            "task_id": task_id,
+            "timestamp": "TIMESTAMP",
+            "map_index": 1,
+        }
+        assert_expected_result([expected_entry1, expected_entry2], 
key="test-xcom-key")
+        assert_expected_result([], key="test-xcom-key-1")
+
     def test_should_raises_401_unauthenticated(self):
         dag_id = "test-dag-id"
         task_id = "test-task-id"

Reply via email to