This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 39befdce12 Add system test to test the AWS auth manager (#37947)
39befdce12 is described below
commit 39befdce1205decb871fea86379b427cfc7106bc
Author: Vincent <[email protected]>
AuthorDate: Fri Mar 8 15:24:47 2024 -0500
Add system test to test the AWS auth manager (#37947)
---
.../amazon/aws/auth_manager/views/auth.py | 2 +-
tests/conftest.py | 12 +-
.../amazon/aws/auth_manager/views/test_auth.py | 33 ++--
.../system/providers/amazon/aws/tests/__init__.py | 16 ++
.../amazon/aws/tests/test_aws_auth_manager.py | 210 +++++++++++++++++++++
.../system/providers/amazon/aws/utils/__init__.py | 8 +-
6 files changed, 255 insertions(+), 26 deletions(-)
diff --git a/airflow/providers/amazon/aws/auth_manager/views/auth.py
b/airflow/providers/amazon/aws/auth_manager/views/auth.py
index 213af783dc..7ea602d0dd 100644
--- a/airflow/providers/amazon/aws/auth_manager/views/auth.py
+++ b/airflow/providers/amazon/aws/auth_manager/views/auth.py
@@ -93,7 +93,7 @@ class AwsAuthManagerAuthenticationViews(AirflowBaseView):
user_id=attributes["id"][0],
groups=attributes["groups"],
username=saml_auth.get_nameid(),
- email=attributes["email"][0],
+ email=attributes["email"][0] if "email" in attributes else None,
)
session["aws_user"] = user
diff --git a/tests/conftest.py b/tests/conftest.py
index 7fb6a2402f..7cacce0621 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1092,11 +1092,15 @@ def
refuse_to_run_test_from_wrongly_named_files(request):
dirname: str = request.node.fspath.dirname
filename: str = request.node.fspath.basename
is_system_test: bool = "tests/system/" in dirname
- if is_system_test and not
request.node.fspath.basename.startswith("example_"):
+ if is_system_test and not (
+ request.node.fspath.basename.startswith("example_")
+ or request.node.fspath.basename.startswith("test_")
+ ):
raise Exception(
- f"All test method files in tests/system must start with
'example_'. Seems that {filename} "
- f"contains {request.function} that looks like a test case. Please
rename the file to "
- f"follow the example_* pattern if you want to run the tests in it."
+ f"All test method files in tests/system must start with 'example_'
or 'test_'. "
+ f"Seems that {filename} contains {request.function} that looks
like a test case. "
+ f"Please rename the file to follow the example_* or test_* pattern
if you want to run the tests "
+ f"in it."
)
if not is_system_test and not
request.node.fspath.basename.startswith("test_"):
raise Exception(
diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py
b/tests/providers/amazon/aws/auth_manager/views/test_auth.py
index 10b0e89af0..2a69a96bd2 100644
--- a/tests/providers/amazon/aws/auth_manager/views/test_auth.py
+++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py
@@ -48,24 +48,21 @@ SAML_METADATA_PARSED = {
@pytest.fixture
def aws_app():
- def factory():
- with conf_vars(
- {
- (
- "core",
- "auth_manager",
- ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
- ("aws_auth_manager", "enable"): "True",
- ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
- }
- ):
- with patch(
-
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
- ) as mock_parser:
- mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
- return application.create_app(testing=True)
-
- return factory()
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+ }
+ ):
+ with patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+ ) as mock_parser:
+ mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+ return application.create_app(testing=True)
@pytest.mark.db_test
diff --git a/tests/system/providers/amazon/aws/tests/__init__.py
b/tests/system/providers/amazon/aws/tests/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/amazon/aws/tests/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
new file mode 100644
index 0000000000..fda8a0922a
--- /dev/null
+++ b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+import boto3
+import pytest
+
+from airflow.www import app as application
+from tests.system.providers.amazon.aws.utils import set_env_id
+from tests.test_utils.config import conf_vars
+from tests.test_utils.www import check_content_in_response
+
+pytest.importorskip("onelogin")
+
+SAML_METADATA_URL = "/saml/metadata"
+SAML_METADATA_PARSED = {
+ "idp": {
+ "entityId":
"https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
+ "singleSignOnService": {
+ "url":
"https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>",
+ "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+ },
+ "singleLogoutService": {
+ "url":
"https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>",
+ "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
+ },
+ "x509cert": "<cert>",
+ },
+ "security": {"authnRequestsSigned": False},
+ "sp": {"NameIDFormat":
"urn:oasis:names:tc:SAML:2.0:nameid-format:transient"},
+}
+
+AVP_POLICY_ADMIN = """
+permit (
+ principal in Airflow::Role::"Admin",
+ action,
+ resource
+);
+"""
+
+env_id_cache: str | None = None
+policy_store_id_cache: str | None = None
+
+
+def create_avp_policy_store(env_id):
+ description = f"Created by system test TestAwsAuthManager: {env_id}"
+ client = boto3.client("verifiedpermissions")
+ response = client.create_policy_store(
+ validationSettings={"mode": "OFF"},
+ description=description,
+ )
+ policy_store_id = response["policyStoreId"]
+
+ schema_path = (
+ Path(__file__)
+ .parents[6]
+ .joinpath("airflow", "providers", "amazon", "aws", "auth_manager",
"cli", "schema.json")
+ .resolve()
+ )
+ with open(schema_path) as schema_file:
+ client.put_schema(
+ policyStoreId=policy_store_id,
+ definition={
+ "cedarJson": schema_file.read(),
+ },
+ )
+
+ client.update_policy_store(
+ policyStoreId=policy_store_id,
+ validationSettings={
+ "mode": "STRICT",
+ },
+ description=description,
+ )
+
+ client.create_policy(
+ policyStoreId=policy_store_id,
+ definition={
+ "static": {"description": "Admin permissions", "statement":
AVP_POLICY_ADMIN},
+ },
+ )
+
+ return policy_store_id
+
+
[email protected]
+def env_id():
+ global env_id_cache
+ if not env_id_cache:
+ env_id_cache = set_env_id()
+ return env_id_cache
+
+
[email protected]
+def region_name():
+ return boto3.session.Session().region_name
+
+
[email protected]
+def avp_policy_store_id(env_id):
+ global policy_store_id_cache
+ if not policy_store_id_cache:
+ policy_store_id_cache = create_avp_policy_store(env_id)
+ return policy_store_id_cache
+
+
[email protected]
+def base_app(region_name, avp_policy_store_id):
+ with conf_vars(
+ {
+ (
+ "core",
+ "auth_manager",
+ ):
"airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager",
+ ("aws_auth_manager", "enable"): "True",
+ ("aws_auth_manager", "region_name"): region_name,
+ ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL,
+ ("aws_auth_manager", "avp_policy_store_id"): avp_policy_store_id,
+ }
+ ):
+ with patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser"
+ ) as mock_parser, patch(
+
"airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth"
+ ) as mock_init_saml_auth:
+ mock_parser.parse_remote.return_value = SAML_METADATA_PARSED
+
+ yield mock_init_saml_auth
+
+
[email protected]
+def client_no_permissions(base_app):
+ auth = Mock()
+ auth.is_authenticated.return_value = True
+ auth.get_nameid.return_value = "user_no_permissions"
+ auth.get_attributes.return_value = {
+ "id": ["user_no_permissions"],
+ "groups": [],
+ "email": ["email"],
+ }
+ base_app.return_value = auth
+ return application.create_app(testing=True)
+
+
[email protected]
+def client_admin_permissions(base_app):
+ auth = Mock()
+ auth.is_authenticated.return_value = True
+ auth.get_nameid.return_value = "user_admin_permissions"
+ auth.get_attributes.return_value = {
+ "id": ["user_admin_permissions"],
+ "groups": ["Admin"],
+ }
+ base_app.return_value = auth
+ return application.create_app(testing=True)
+
+
[email protected]("amazon")
+class TestAwsAuthManager:
+ """
+ Run tests on Airflow using AWS auth manager with real credentials
+ """
+
+ @classmethod
+ def teardown_class(cls):
+ cls.delete_avp_policy_store()
+
+ @classmethod
+ def delete_avp_policy_store(cls):
+ client = boto3.client("verifiedpermissions")
+
+ paginator = client.get_paginator("list_policy_stores")
+ pages = paginator.paginate()
+ policy_store_ids = [
+ store["policyStoreId"]
+ for page in pages
+ for store in page["policyStores"]
+ if "description" in store
+ and f"Created by system test TestAwsAuthManager: {env_id_cache}"
in store["description"]
+ ]
+
+ for policy_store_id in policy_store_ids:
+ client.delete_policy_store(policyStoreId=policy_store_id)
+
+ def test_login_no_permissions(self, client_no_permissions):
+ with client_no_permissions.test_client() as client:
+ response = client.get("/login_callback", follow_redirects=True)
+ check_content_in_response("Your user has no roles and/or
permissions!", response, 403)
+
+ def test_login_admin(self, client_admin_permissions):
+ with client_admin_permissions.test_client() as client:
+ response = client.get("/login_callback", follow_redirects=True)
+ check_content_in_response("<h2>DAGs</h2>", response, 200)
diff --git a/tests/system/providers/amazon/aws/utils/__init__.py
b/tests/system/providers/amazon/aws/utils/__init__.py
index 1bdcbf656f..175fe0911b 100644
--- a/tests/system/providers/amazon/aws/utils/__init__.py
+++ b/tests/system/providers/amazon/aws/utils/__init__.py
@@ -43,8 +43,8 @@ DEFAULT_ENV_ID_LEN: int = 8
DEFAULT_ENV_ID: str =
f"{DEFAULT_ENV_ID_PREFIX}{uuid4()!s:.{DEFAULT_ENV_ID_LEN}}"
PURGE_LOGS_INTERVAL_PERIOD = 5
-# All test file names will contain this string.
-TEST_FILE_IDENTIFIER: str = "example"
+# All test file names will contain one of these strings.
+TEST_FILE_IDENTIFIERS: list[str] = ["example", "test"]
INVALID_ENV_ID_MSG: str = (
"In order to maximize compatibility, the SYSTEM_TESTS_ENV_ID must be an
alphanumeric string "
@@ -68,7 +68,9 @@ def _get_test_name() -> str:
# The exact layer of the stack will depend on if this is called directly
# or from another helper, but the test will always contain the identifier.
test_filename: str = next(
- frame.filename for frame in inspect.stack() if TEST_FILE_IDENTIFIER in
frame.filename
+ frame.filename
+ for frame in inspect.stack()
+ if any(identifier in frame.filename for identifier in
TEST_FILE_IDENTIFIERS)
)
return Path(test_filename).stem