This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 97b3949f0ac Implement `batch_is_authorized_` methods in AWS auth manager (#55307) 97b3949f0ac is described below commit 97b3949f0ac582685976063f4292438e705d815d Author: Vincent <97131062+vincb...@users.noreply.github.com> AuthorDate: Mon Sep 8 14:12:46 2025 -0400 Implement `batch_is_authorized_` methods in AWS auth manager (#55307) --- .../amazon/aws/auth_manager/aws_auth_manager.py | 95 ++++++++++++++++--- .../aws/auth_manager/test_aws_auth_manager.py | 102 +++++++++++++++++++++ 2 files changed, 185 insertions(+), 12 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 387d968ec15..6f3e5851c2f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -44,7 +44,10 @@ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod from airflow.api_fastapi.auth.managers.models.batch_apis import ( + IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, + IsAuthorizedPoolRequest, + IsAuthorizedVariableRequest, ) from airflow.api_fastapi.auth.managers.models.resource_details import ( AccessView, @@ -244,6 +247,27 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]): return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])] + def batch_is_authorized_connection( + self, + requests: Sequence[IsAuthorizedConnectionRequest], + *, + user: AwsAuthManagerUser, + ) -> bool: + facade_requests: Sequence[IsAuthorizedRequest] = [ + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.CONNECTION, + "entity_id": cast("ConnectionDetails", request["details"]).conn_id + if request.get("details") + else None, + }, + ) + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) + def batch_is_authorized_dag( self, requests: Sequence[IsAuthorizedDagRequest], @@ -251,18 +275,65 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]): user: AwsAuthManagerUser, ) -> bool: facade_requests: Sequence[IsAuthorizedRequest] = [ - { - "method": request["method"], - "entity_type": AvpEntities.DAG, - "entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None, - "context": { - "dag_entity": { - "string": cast("DagAccessEntity", request["access_entity"]).value, - }, - } - if request.get("access_entity") - else None, - } + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.DAG, + "entity_id": cast("DagDetails", request["details"]).id + if request.get("details") + else None, + "context": { + "dag_entity": { + "string": cast("DagAccessEntity", request["access_entity"]).value, + }, + } + if request.get("access_entity") + else None, + }, + ) + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) + + def batch_is_authorized_pool( + self, + requests: Sequence[IsAuthorizedPoolRequest], + *, + user: AwsAuthManagerUser, + ) -> bool: + facade_requests: Sequence[IsAuthorizedRequest] = [ + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.POOL, + "entity_id": cast("PoolDetails", request["details"]).name + if request.get("details") + else None, + }, + ) + for request in requests + ] + return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) + + def batch_is_authorized_variable( + self, + requests: Sequence[IsAuthorizedVariableRequest], + *, + user: AwsAuthManagerUser, + ) -> bool: + facade_requests: Sequence[IsAuthorizedRequest] = [ + cast( + "IsAuthorizedRequest", + { + "method": request["method"], + "entity_type": AvpEntities.VARIABLE, + "entity_id": cast("VariableDetails", request["details"]).key + if request.get("details") + else None, + }, + ) for request in requests ] return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user) diff --git a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py index eab473ea9a5..70d5a31986e 100644 --- a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py +++ b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py @@ -439,6 +439,40 @@ class TestAwsAuthManager: ) assert result == [MenuItem.VARIABLES, MenuItem.DAGS] + @patch.object(AwsAuthManager, "avp_facade") + def test_batch_is_authorized_connection( + self, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_connection( + requests=[ + {"method": "GET"}, + {"method": "PUT", "details": ConnectionDetails(conn_id="test")}, + ], + user=mock, + ) + + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.CONNECTION, + "entity_id": None, + }, + { + "method": "PUT", + "entity_type": AvpEntities.CONNECTION, + "entity_id": "test", + }, + ], + user=ANY, + ) + assert result + @patch.object(AwsAuthManager, "avp_facade") def test_batch_is_authorized_dag( self, @@ -510,6 +544,74 @@ class TestAwsAuthManager: ) assert result + @patch.object(AwsAuthManager, "avp_facade") + def test_batch_is_authorized_pool( + self, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_pool( + requests=[ + {"method": "GET"}, + {"method": "PUT", "details": PoolDetails(name="test")}, + ], + user=mock, + ) + + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.POOL, + "entity_id": None, + }, + { + "method": "PUT", + "entity_type": AvpEntities.POOL, + "entity_id": "test", + }, + ], + user=ANY, + ) + assert result + + @patch.object(AwsAuthManager, "avp_facade") + def test_batch_is_authorized_variable( + self, + mock_avp_facade, + auth_manager, + ): + batch_is_authorized = Mock(return_value=True) + mock_avp_facade.batch_is_authorized = batch_is_authorized + + result = auth_manager.batch_is_authorized_variable( + requests=[ + {"method": "GET"}, + {"method": "PUT", "details": VariableDetails(key="test")}, + ], + user=mock, + ) + + batch_is_authorized.assert_called_once_with( + requests=[ + { + "method": "GET", + "entity_type": AvpEntities.VARIABLE, + "entity_id": None, + }, + { + "method": "PUT", + "entity_type": AvpEntities.VARIABLE, + "entity_id": "test", + }, + ], + user=ANY, + ) + assert result + @pytest.mark.parametrize( "method, user, expected_result", [