This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 39befdce12 Add system test to test the AWS auth manager (#37947)
39befdce12 is described below

commit 39befdce1205decb871fea86379b427cfc7106bc
Author: Vincent <97131062+vincb...@users.noreply.github.com>
AuthorDate: Fri Mar 8 15:24:47 2024 -0500

    Add system test to test the AWS auth manager (#37947)
---
 .../amazon/aws/auth_manager/views/auth.py          |   2 +-
 tests/conftest.py                                  |  12 +-
 .../amazon/aws/auth_manager/views/test_auth.py     |  33 ++--
 .../system/providers/amazon/aws/tests/__init__.py  |  16 ++
 .../amazon/aws/tests/test_aws_auth_manager.py      | 210 +++++++++++++++++++++
 .../system/providers/amazon/aws/utils/__init__.py  |   8 +-
 6 files changed, 255 insertions(+), 26 deletions(-)

diff --git a/airflow/providers/amazon/aws/auth_manager/views/auth.py 
b/airflow/providers/amazon/aws/auth_manager/views/auth.py
index 213af783dc..7ea602d0dd 100644
--- a/airflow/providers/amazon/aws/auth_manager/views/auth.py
+++ b/airflow/providers/amazon/aws/auth_manager/views/auth.py
@@ -93,7 +93,7 @@ class AwsAuthManagerAuthenticationViews(AirflowBaseView):
             user_id=attributes["id"][0],
             groups=attributes["groups"],
             username=saml_auth.get_nameid(),
-            email=attributes["email"][0],
+            email=attributes["email"][0] if "email" in attributes else None,
         )
         session["aws_user"] = user
 
diff --git a/tests/conftest.py b/tests/conftest.py
index 7fb6a2402f..7cacce0621 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1092,11 +1092,15 @@ def 
refuse_to_run_test_from_wrongly_named_files(request):
     dirname: str = request.node.fspath.dirname
     filename: str = request.node.fspath.basename
     is_system_test: bool = "tests/system/" in dirname
-    if is_system_test and not 
request.node.fspath.basename.startswith("example_"):
+    if is_system_test and not (
+        request.node.fspath.basename.startswith("example_")
+        or request.node.fspath.basename.startswith("test_")
+    ):
         raise Exception(
-            f"All test method files in tests/system must start with 
'example_'. Seems that {filename} "
-            f"contains {request.function} that looks like a test case. Please 
rename the file to "
-            f"follow the example_* pattern if you want to run the tests in it."
+            f"All test method files in tests/system must start with 'example_' 
or 'test_'. "
+            f"Seems that {filename} contains {request.function} that looks 
like a test case. "
+            f"Please rename the file to follow the example_* or test_* pattern 
if you want to run the tests "
+            f"in it."
         )
     if not is_system_test and not 
request.node.fspath.basename.startswith("test_"):
         raise Exception(
diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py 
b/tests/providers/amazon/aws/auth_manager/views/test_auth.py
index 10b0e89af0..2a69a96bd2 100644
--- a/tests/providers/amazon/aws/auth_manager/views/test_auth.py
+++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py
@@ -48,24 +48,21 @@ SAML_METADATA_PARSED = {
 
 @pytest.fixture
 def aws_app():
-    def factory():
-        with conf_vars(
-            {
-                (
-                    "core",
-                    "auth_manager",
-                ): 
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
-                ("aws_auth_manager", "enable"): "True",
-                ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
-            }
-        ):
-            with patch(
-                
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
-            ) as mock_parser:
-                mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
-                return application.create_app(testing=True)
-
-    return factory()
+    with conf_vars(
+        {
+            (
+                "core",
+                "auth_manager",
+            ): 
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+            ("aws_auth_manager", "enable"): "True",
+            ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+        }
+    ):
+        with patch(
+            
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+        ) as mock_parser:
+            mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+            return application.create_app(testing=True)
 
 
 @pytest.mark.db_test
diff --git a/tests/system/providers/amazon/aws/tests/__init__.py 
b/tests/system/providers/amazon/aws/tests/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/amazon/aws/tests/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py 
b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
new file mode 100644
index 0000000000..fda8a0922a
--- /dev/null
+++ b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+import boto3
+import pytest
+
+from airflow.www import app as application
+from tests.system.providers.amazon.aws.utils import set_env_id
+from tests.test_utils.config import conf_vars
+from tests.test_utils.www import check_content_in_response
+
+pytest.importorskip("onelogin")
+
+SAML_METADATA_URL = "/saml/metadata"
+SAML_METADATA_PARSED = {
+    "idp": {
+        "entityId": 
"https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
+        "singleSignOnService": {
+            "url": 
"https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
+            "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+        },
+        "singleLogoutService": {
+            "url": 
"https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>",
+            "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+        },
+        "x509cert": "<cert>",
+    },
+    "security": {"authnRequestsSigned": False},
+    "sp": {"NameIDFormat": 
"urn:oasis:names:tc:SAML:2.0:nameid-format:transient"},
+}
+
+AVP_POLICY_ADMIN = """
+permit (
+    principal in Airflow::Role::"Admin",
+    action,
+    resource
+);
+"""
+
+env_id_cache: str | None = None
+policy_store_id_cache: str | None = None
+
+
+def create_avp_policy_store(env_id):
+    description = f"Created by system test TestAwsAuthManager: {env_id}"
+    client = boto3.client("verifiedpermissions")
+    response = client.create_policy_store(
+        validationSettings={"mode": "OFF"},
+        description=description,
+    )
+    policy_store_id = response["policyStoreId"]
+
+    schema_path = (
+        Path(__file__)
+        .parents[6]
+        .joinpath("airflow", "providers", "amazon", "aws", "auth_manager", 
"cli", "schema.json")
+        .resolve()
+    )
+    with open(schema_path) as schema_file:
+        client.put_schema(
+            policyStoreId=policy_store_id,
+            definition={
+                "cedarJson": schema_file.read(),
+            },
+        )
+
+    client.update_policy_store(
+        policyStoreId=policy_store_id,
+        validationSettings={
+            "mode": "STRICT",
+        },
+        description=description,
+    )
+
+    client.create_policy(
+        policyStoreId=policy_store_id,
+        definition={
+            "static": {"description": "Admin permissions", "statement": 
AVP_POLICY_ADMIN},
+        },
+    )
+
+    return policy_store_id
+
+
+@pytest.fixture
+def env_id():
+    global env_id_cache
+    if not env_id_cache:
+        env_id_cache = set_env_id()
+    return env_id_cache
+
+
+@pytest.fixture
+def region_name():
+    return boto3.session.Session().region_name
+
+
+@pytest.fixture
+def avp_policy_store_id(env_id):
+    global policy_store_id_cache
+    if not policy_store_id_cache:
+        policy_store_id_cache = create_avp_policy_store(env_id)
+    return policy_store_id_cache
+
+
+@pytest.fixture
+def base_app(region_name, avp_policy_store_id):
+    with conf_vars(
+        {
+            (
+                "core",
+                "auth_manager",
+            ): 
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+            ("aws_auth_manager", "enable"): "True",
+            ("aws_auth_manager", "region_name"): region_name,
+            ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+            ("aws_auth_manager", "avp_policy_store_id"): avp_policy_store_id,
+        }
+    ):
+        with patch(
+            
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+        ) as mock_parser, patch(
+            
"airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth"
+        ) as mock_init_saml_auth:
+            mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+
+            yield mock_init_saml_auth
+
+
+@pytest.fixture
+def client_no_permissions(base_app):
+    auth = Mock()
+    auth.is_authenticated.return_value = True
+    auth.get_nameid.return_value = "user_no_permissions"
+    auth.get_attributes.return_value = {
+        "id": ["user_no_permissions"],
+        "groups": [],
+        "email": ["email"],
+    }
+    base_app.return_value = auth
+    return application.create_app(testing=True)
+
+
+@pytest.fixture
+def client_admin_permissions(base_app):
+    auth = Mock()
+    auth.is_authenticated.return_value = True
+    auth.get_nameid.return_value = "user_admin_permissions"
+    auth.get_attributes.return_value = {
+        "id": ["user_admin_permissions"],
+        "groups": ["Admin"],
+    }
+    base_app.return_value = auth
+    return application.create_app(testing=True)
+
+
+@pytest.mark.system("amazon")
+class TestAwsAuthManager:
+    """
+    Run tests on Airflow using AWS auth manager with real credentials
+    """
+
+    @classmethod
+    def teardown_class(cls):
+        cls.delete_avp_policy_store()
+
+    @classmethod
+    def delete_avp_policy_store(cls):
+        client = boto3.client("verifiedpermissions")
+
+        paginator = client.get_paginator("list_policy_stores")
+        pages = paginator.paginate()
+        policy_store_ids = [
+            store["policyStoreId"]
+            for page in pages
+            for store in page["policyStores"]
+            if "description" in store
+            and f"Created by system test TestAwsAuthManager: {env_id_cache}" 
in store["description"]
+        ]
+
+        for policy_store_id in policy_store_ids:
+            client.delete_policy_store(policyStoreId=policy_store_id)
+
+    def test_login_no_permissions(self, client_no_permissions):
+        with client_no_permissions.test_client() as client:
+            response = client.get("/login_callback", follow_redirects=True)
+            check_content_in_response("Your user has no roles and/or 
permissions!", response, 403)
+
+    def test_login_admin(self, client_admin_permissions):
+        with client_admin_permissions.test_client() as client:
+            response = client.get("/login_callback", follow_redirects=True)
+            check_content_in_response("<h2>DAGs</h2>", response, 200)
diff --git a/tests/system/providers/amazon/aws/utils/__init__.py 
b/tests/system/providers/amazon/aws/utils/__init__.py
index 1bdcbf656f..175fe0911b 100644
--- a/tests/system/providers/amazon/aws/utils/__init__.py
+++ b/tests/system/providers/amazon/aws/utils/__init__.py
@@ -43,8 +43,8 @@ DEFAULT_ENV_ID_LEN: int = 8
 DEFAULT_ENV_ID: str = 
f"{DEFAULT_ENV_ID_PREFIX}{uuid4()!s:.{DEFAULT_ENV_ID_LEN}}"
 PURGE_LOGS_INTERVAL_PERIOD = 5
 
-# All test file names will contain this string.
-TEST_FILE_IDENTIFIER: str = "example"
+# All test file names will contain one of these strings.
+TEST_FILE_IDENTIFIERS: list[str] = ["example", "test"]
 
 INVALID_ENV_ID_MSG: str = (
     "In order to maximize compatibility, the SYSTEM_TESTS_ENV_ID must be an 
alphanumeric string "
@@ -68,7 +68,9 @@ def _get_test_name() -> str:
     # The exact layer of the stack will depend on if this is called directly
     # or from another helper, but the test will always contain the identifier.
     test_filename: str = next(
-        frame.filename for frame in inspect.stack() if TEST_FILE_IDENTIFIER in 
frame.filename
+        frame.filename
+        for frame in inspect.stack()
+        if any(identifier in frame.filename for identifier in 
TEST_FILE_IDENTIFIERS)
     )
     return Path(test_filename).stem
 

Reply via email to