This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 39befdce12 Add system test to test the AWS auth manager (#37947) 39befdce12 is described below commit 39befdce1205decb871fea86379b427cfc7106bc Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Fri Mar 8 15:24:47 2024 -0500 Add system test to test the AWS auth manager (#37947) --- .../amazon/aws/auth_manager/views/auth.py | 2 +- tests/conftest.py | 12 +- .../amazon/aws/auth_manager/views/test_auth.py | 33 ++-- .../system/providers/amazon/aws/tests/__init__.py | 16 ++ .../amazon/aws/tests/test_aws_auth_manager.py | 210 +++++++++++++++++++++ .../system/providers/amazon/aws/utils/__init__.py | 8 +- 6 files changed, 255 insertions(+), 26 deletions(-) diff --git a/airflow/providers/amazon/aws/auth_manager/views/auth.py b/airflow/providers/amazon/aws/auth_manager/views/auth.py index 213af783dc..7ea602d0dd 100644 --- a/airflow/providers/amazon/aws/auth_manager/views/auth.py +++ b/airflow/providers/amazon/aws/auth_manager/views/auth.py @@ -93,7 +93,7 @@ class AwsAuthManagerAuthenticationViews(AirflowBaseView): user_id=attributes["id"][0], groups=attributes["groups"], username=saml_auth.get_nameid(), - email=attributes["email"][0], + email=attributes["email"][0] if "email" in attributes else None, ) session["aws_user"] = user diff --git a/tests/conftest.py b/tests/conftest.py index 7fb6a2402f..7cacce0621 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1092,11 +1092,15 @@ def refuse_to_run_test_from_wrongly_named_files(request): dirname: str = request.node.fspath.dirname filename: str = request.node.fspath.basename is_system_test: bool = "tests/system/" in dirname - if is_system_test and not request.node.fspath.basename.startswith("example_"): + if is_system_test and not ( + request.node.fspath.basename.startswith("example_") + or request.node.fspath.basename.startswith("test_") + ): raise Exception( - f"All test method files in tests/system must start with 'example_'. Seems that {filename} " - f"contains {request.function} that looks like a test case. Please rename the file to " - f"follow the example_* pattern if you want to run the tests in it." + f"All test method files in tests/system must start with 'example_' or 'test_'. " + f"Seems that {filename} contains {request.function} that looks like a test case. " + f"Please rename the file to follow the example_* or test_* pattern if you want to run the tests " + f"in it." ) if not is_system_test and not request.node.fspath.basename.startswith("test_"): raise Exception( diff --git a/tests/providers/amazon/aws/auth_manager/views/test_auth.py b/tests/providers/amazon/aws/auth_manager/views/test_auth.py index 10b0e89af0..2a69a96bd2 100644 --- a/tests/providers/amazon/aws/auth_manager/views/test_auth.py +++ b/tests/providers/amazon/aws/auth_manager/views/test_auth.py @@ -48,24 +48,21 @@ SAML_METADATA_PARSED = { @pytest.fixture def aws_app(): - def factory(): - with conf_vars( - { - ( - "core", - "auth_manager", - ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", - ("aws_auth_manager", "enable"): "True", - ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL, - } - ): - with patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" - ) as mock_parser: - mock_parser.parse_remote.return_value = SAML_METADATA_PARSED - return application.create_app(testing=True) - - return factory() + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL, + } + ): + with patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" + ) as mock_parser: + mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + return application.create_app(testing=True) @pytest.mark.db_test diff --git a/tests/system/providers/amazon/aws/tests/__init__.py b/tests/system/providers/amazon/aws/tests/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/system/providers/amazon/aws/tests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py new file mode 100644 index 0000000000..fda8a0922a --- /dev/null +++ b/tests/system/providers/amazon/aws/tests/test_aws_auth_manager.py @@ -0,0 +1,210 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock, patch + +import boto3 +import pytest + +from airflow.www import app as application +from tests.system.providers.amazon.aws.utils import set_env_id +from tests.test_utils.config import conf_vars +from tests.test_utils.www import check_content_in_response + +pytest.importorskip("onelogin") + +SAML_METADATA_URL = "/saml/metadata" +SAML_METADATA_PARSED = { + "idp": { + "entityId": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>", + "singleSignOnService": { + "url": "https://portal.sso.us-east-1.amazonaws.com/saml/assertion/<assertion>", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "singleLogoutService": { + "url": "https://portal.sso.us-east-1.amazonaws.com/saml/logout/<assertion>", + "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect", + }, + "x509cert": "<cert>", + }, + "security": {"authnRequestsSigned": False}, + "sp": {"NameIDFormat": "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"}, +} + +AVP_POLICY_ADMIN = """ +permit ( + principal in Airflow::Role::"Admin", + action, + resource +); +""" + +env_id_cache: str | None = None +policy_store_id_cache: str | None = None + + +def create_avp_policy_store(env_id): + description = f"Created by system test TestAwsAuthManager: {env_id}" + client = boto3.client("verifiedpermissions") + response = client.create_policy_store( + validationSettings={"mode": "OFF"}, + description=description, + ) + policy_store_id = response["policyStoreId"] + + schema_path = ( + Path(__file__) + .parents[6] + .joinpath("airflow", "providers", "amazon", "aws", "auth_manager", "cli", "schema.json") + .resolve() + ) + with open(schema_path) as schema_file: + client.put_schema( + policyStoreId=policy_store_id, + definition={ + "cedarJson": schema_file.read(), + }, + ) + + client.update_policy_store( + policyStoreId=policy_store_id, + validationSettings={ + "mode": "STRICT", + }, + description=description, + ) + + client.create_policy( + policyStoreId=policy_store_id, + definition={ + "static": {"description": "Admin permissions", "statement": AVP_POLICY_ADMIN}, + }, + ) + + return policy_store_id + + +@pytest.fixture +def env_id(): + global env_id_cache + if not env_id_cache: + env_id_cache = set_env_id() + return env_id_cache + + +@pytest.fixture +def region_name(): + return boto3.session.Session().region_name + + +@pytest.fixture +def avp_policy_store_id(env_id): + global policy_store_id_cache + if not policy_store_id_cache: + policy_store_id_cache = create_avp_policy_store(env_id) + return policy_store_id_cache + + +@pytest.fixture +def base_app(region_name, avp_policy_store_id): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.providers.amazon.aws.auth_manager.aws_auth_manager.AwsAuthManager", + ("aws_auth_manager", "enable"): "True", + ("aws_auth_manager", "region_name"): region_name, + ("aws_auth_manager", "saml_metadata_url"): SAML_METADATA_URL, + ("aws_auth_manager", "avp_policy_store_id"): avp_policy_store_id, + } + ): + with patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.OneLogin_Saml2_IdPMetadataParser" + ) as mock_parser, patch( + "airflow.providers.amazon.aws.auth_manager.views.auth.AwsAuthManagerAuthenticationViews._init_saml_auth" + ) as mock_init_saml_auth: + mock_parser.parse_remote.return_value = SAML_METADATA_PARSED + + yield mock_init_saml_auth + + +@pytest.fixture +def client_no_permissions(base_app): + auth = Mock() + auth.is_authenticated.return_value = True + auth.get_nameid.return_value = "user_no_permissions" + auth.get_attributes.return_value = { + "id": ["user_no_permissions"], + "groups": [], + "email": ["email"], + } + base_app.return_value = auth + return application.create_app(testing=True) + + +@pytest.fixture +def client_admin_permissions(base_app): + auth = Mock() + auth.is_authenticated.return_value = True + auth.get_nameid.return_value = "user_admin_permissions" + auth.get_attributes.return_value = { + "id": ["user_admin_permissions"], + "groups": ["Admin"], + } + base_app.return_value = auth + return application.create_app(testing=True) + + +@pytest.mark.system("amazon") +class TestAwsAuthManager: + """ + Run tests on Airflow using AWS auth manager with real credentials + """ + + @classmethod + def teardown_class(cls): + cls.delete_avp_policy_store() + + @classmethod + def delete_avp_policy_store(cls): + client = boto3.client("verifiedpermissions") + + paginator = client.get_paginator("list_policy_stores") + pages = paginator.paginate() + policy_store_ids = [ + store["policyStoreId"] + for page in pages + for store in page["policyStores"] + if "description" in store + and f"Created by system test TestAwsAuthManager: {env_id_cache}" in store["description"] + ] + + for policy_store_id in policy_store_ids: + client.delete_policy_store(policyStoreId=policy_store_id) + + def test_login_no_permissions(self, client_no_permissions): + with client_no_permissions.test_client() as client: + response = client.get("/login_callback", follow_redirects=True) + check_content_in_response("Your user has no roles and/or permissions!", response, 403) + + def test_login_admin(self, client_admin_permissions): + with client_admin_permissions.test_client() as client: + response = client.get("/login_callback", follow_redirects=True) + check_content_in_response("<h2>DAGs</h2>", response, 200) diff --git a/tests/system/providers/amazon/aws/utils/__init__.py b/tests/system/providers/amazon/aws/utils/__init__.py index 1bdcbf656f..175fe0911b 100644 --- a/tests/system/providers/amazon/aws/utils/__init__.py +++ b/tests/system/providers/amazon/aws/utils/__init__.py @@ -43,8 +43,8 @@ DEFAULT_ENV_ID_LEN: int = 8 DEFAULT_ENV_ID: str = f"{DEFAULT_ENV_ID_PREFIX}{uuid4()!s:.{DEFAULT_ENV_ID_LEN}}" PURGE_LOGS_INTERVAL_PERIOD = 5 -# All test file names will contain this string. -TEST_FILE_IDENTIFIER: str = "example" +# All test file names will contain one of these strings. +TEST_FILE_IDENTIFIERS: list[str] = ["example", "test"] INVALID_ENV_ID_MSG: str = ( "In order to maximize compatibility, the SYSTEM_TESTS_ENV_ID must be an alphanumeric string " @@ -68,7 +68,9 @@ def _get_test_name() -> str: # The exact layer of the stack will depend on if this is called directly # or from another helper, but the test will always contain the identifier. test_filename: str = next( - frame.filename for frame in inspect.stack() if TEST_FILE_IDENTIFIER in frame.filename + frame.filename + for frame in inspect.stack() + if any(identifier in frame.filename for identifier in TEST_FILE_IDENTIFIERS) ) return Path(test_filename).stem