This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 985d0589af AWS auth manager: implement all `is_authorized_*` methods 
(but `is_authorized_dag`) (#35928)
985d0589af is described below

commit 985d0589affe5ad6d6f57a5d85009bdae0d4b637
Author: Vincent <[email protected]>
AuthorDate: Wed Nov 29 15:08:17 2023 -0500

    AWS auth manager: implement all `is_authorized_*` methods (but 
`is_authorized_dag`) (#35928)
---
 .../amazon/aws/auth_manager/avp/entities.py        |   9 +-
 .../amazon/aws/auth_manager/aws_auth_manager.py    |  39 +++++-
 .../aws/auth_manager/test_aws_auth_manager.py      | 135 ++++++++++++++++++++-
 3 files changed, 175 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/amazon/aws/auth_manager/avp/entities.py 
b/airflow/providers/amazon/aws/auth_manager/avp/entities.py
index fad5ee1c3f..f64b7c7ef0 100644
--- a/airflow/providers/amazon/aws/auth_manager/avp/entities.py
+++ b/airflow/providers/amazon/aws/auth_manager/avp/entities.py
@@ -30,9 +30,16 @@ class AvpEntities(Enum):
 
     ACTION = "Action"
     ROLE = "Role"
-    VARIABLE = "Variable"
     USER = "User"
 
+    # Resource types
+    CONFIGURATION = "Configuration"
+    CONNECTION = "Connection"
+    DATASET = "Dataset"
+    POOL = "Pool"
+    VARIABLE = "Variable"
+    VIEW = "View"
+
 
 def get_entity_type(resource_type: AvpEntities) -> str:
     """
diff --git a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py 
b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index d552662532..99b6d4f70c 100644
--- a/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++ b/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -91,7 +91,13 @@ class AwsAuthManager(BaseAuthManager):
         details: ConfigurationDetails | None = None,
         user: BaseUser | None = None,
     ) -> bool:
-        return self.is_logged_in()
+        config_section = details.section if details else None
+        return self.avp_facade.is_authorized(
+            method=method,
+            entity_type=AvpEntities.CONFIGURATION,
+            user=user or self.get_user(),
+            entity_id=config_section,
+        )
 
     def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: 
BaseUser | None = None) -> bool:
         return self.is_logged_in()
@@ -103,7 +109,13 @@ class AwsAuthManager(BaseAuthManager):
         details: ConnectionDetails | None = None,
         user: BaseUser | None = None,
     ) -> bool:
-        return self.is_logged_in()
+        connection_id = details.conn_id if details else None
+        return self.avp_facade.is_authorized(
+            method=method,
+            entity_type=AvpEntities.CONNECTION,
+            user=user or self.get_user(),
+            entity_id=connection_id,
+        )
 
     def is_authorized_dag(
         self,
@@ -118,12 +130,24 @@ class AwsAuthManager(BaseAuthManager):
     def is_authorized_dataset(
         self, *, method: ResourceMethod, details: DatasetDetails | None = 
None, user: BaseUser | None = None
     ) -> bool:
-        return self.is_logged_in()
+        dataset_uri = details.uri if details else None
+        return self.avp_facade.is_authorized(
+            method=method,
+            entity_type=AvpEntities.DATASET,
+            user=user or self.get_user(),
+            entity_id=dataset_uri,
+        )
 
     def is_authorized_pool(
         self, *, method: ResourceMethod, details: PoolDetails | None = None, 
user: BaseUser | None = None
     ) -> bool:
-        return self.is_logged_in()
+        pool_name = details.name if details else None
+        return self.avp_facade.is_authorized(
+            method=method,
+            entity_type=AvpEntities.POOL,
+            user=user or self.get_user(),
+            entity_id=pool_name,
+        )
 
     def is_authorized_variable(
         self, *, method: ResourceMethod, details: VariableDetails | None = 
None, user: BaseUser | None = None
@@ -142,7 +166,12 @@ class AwsAuthManager(BaseAuthManager):
         access_view: AccessView,
         user: BaseUser | None = None,
     ) -> bool:
-        return self.is_logged_in()
+        return self.avp_facade.is_authorized(
+            method="GET",
+            entity_type=AvpEntities.VIEW,
+            user=user or self.get_user(),
+            entity_id=access_view.value,
+        )
 
     def get_url_login(self, **kwargs) -> str:
         return url_for("AwsAuthManagerAuthenticationViews.login")
diff --git a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py 
b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
index 9cc4fc602b..440314e67b 100644
--- a/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
+++ b/tests/providers/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -22,7 +22,14 @@ from unittest.mock import ANY, Mock, patch
 import pytest
 from flask import Flask, session
 
-from airflow.auth.managers.models.resource_details import VariableDetails
+from airflow.auth.managers.models.resource_details import (
+    AccessView,
+    ConfigurationDetails,
+    ConnectionDetails,
+    DatasetDetails,
+    PoolDetails,
+    VariableDetails,
+)
 from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
 from airflow.providers.amazon.aws.auth_manager.aws_auth_manager import 
AwsAuthManager
 from 
airflow.providers.amazon.aws.auth_manager.security_manager.aws_security_manager_override
 import (
@@ -110,6 +117,108 @@ class TestAwsAuthManager:
 
         assert result is False
 
+    @pytest.mark.parametrize(
+        "details, user, expected_user, expected_entity_id",
+        [
+            (None, None, ANY, None),
+            (ConfigurationDetails(section="test"), mock, mock, "test"),
+        ],
+    )
+    @patch.object(AwsAuthManager, "avp_facade")
+    @patch.object(AwsAuthManager, "get_user")
+    def test_is_authorized_configuration(
+        self, mock_get_user, mock_avp_facade, details, user, expected_user, 
expected_entity_id, auth_manager
+    ):
+        is_authorized = Mock()
+        mock_avp_facade.is_authorized = is_authorized
+
+        method: ResourceMethod = "GET"
+        auth_manager.is_authorized_configuration(method=method, 
details=details, user=user)
+
+        if not user:
+            mock_get_user.assert_called_once()
+        is_authorized.assert_called_once_with(
+            method=method,
+            entity_type=AvpEntities.CONFIGURATION,
+            user=expected_user,
+            entity_id=expected_entity_id,
+        )
+
+    @pytest.mark.parametrize(
+        "details, user, expected_user, expected_entity_id",
+        [
+            (None, None, ANY, None),
+            (ConnectionDetails(conn_id="conn_id"), mock, mock, "conn_id"),
+        ],
+    )
+    @patch.object(AwsAuthManager, "avp_facade")
+    @patch.object(AwsAuthManager, "get_user")
+    def test_is_authorized_connection(
+        self, mock_get_user, mock_avp_facade, details, user, expected_user, 
expected_entity_id, auth_manager
+    ):
+        is_authorized = Mock()
+        mock_avp_facade.is_authorized = is_authorized
+
+        method: ResourceMethod = "GET"
+        auth_manager.is_authorized_connection(method=method, details=details, 
user=user)
+
+        if not user:
+            mock_get_user.assert_called_once()
+        is_authorized.assert_called_once_with(
+            method=method,
+            entity_type=AvpEntities.CONNECTION,
+            user=expected_user,
+            entity_id=expected_entity_id,
+        )
+
+    @pytest.mark.parametrize(
+        "details, user, expected_user, expected_entity_id",
+        [
+            (None, None, ANY, None),
+            (DatasetDetails(uri="uri"), mock, mock, "uri"),
+        ],
+    )
+    @patch.object(AwsAuthManager, "avp_facade")
+    @patch.object(AwsAuthManager, "get_user")
+    def test_is_authorized_dataset(
+        self, mock_get_user, mock_avp_facade, details, user, expected_user, 
expected_entity_id, auth_manager
+    ):
+        is_authorized = Mock()
+        mock_avp_facade.is_authorized = is_authorized
+
+        method: ResourceMethod = "GET"
+        auth_manager.is_authorized_dataset(method=method, details=details, 
user=user)
+
+        if not user:
+            mock_get_user.assert_called_once()
+        is_authorized.assert_called_once_with(
+            method=method, entity_type=AvpEntities.DATASET, 
user=expected_user, entity_id=expected_entity_id
+        )
+
+    @pytest.mark.parametrize(
+        "details, user, expected_user, expected_entity_id",
+        [
+            (None, None, ANY, None),
+            (PoolDetails(name="pool1"), mock, mock, "pool1"),
+        ],
+    )
+    @patch.object(AwsAuthManager, "avp_facade")
+    @patch.object(AwsAuthManager, "get_user")
+    def test_is_authorized_pool(
+        self, mock_get_user, mock_avp_facade, details, user, expected_user, 
expected_entity_id, auth_manager
+    ):
+        is_authorized = Mock()
+        mock_avp_facade.is_authorized = is_authorized
+
+        method: ResourceMethod = "GET"
+        auth_manager.is_authorized_pool(method=method, details=details, 
user=user)
+
+        if not user:
+            mock_get_user.assert_called_once()
+        is_authorized.assert_called_once_with(
+            method=method, entity_type=AvpEntities.POOL, user=expected_user, 
entity_id=expected_entity_id
+        )
+
     @pytest.mark.parametrize(
         "details, user, expected_user, expected_entity_id",
         [
@@ -126,7 +235,6 @@ class TestAwsAuthManager:
         mock_avp_facade.is_authorized = is_authorized
 
         method: ResourceMethod = "GET"
-
         auth_manager.is_authorized_variable(method=method, details=details, 
user=user)
 
         if not user:
@@ -135,6 +243,29 @@ class TestAwsAuthManager:
             method=method, entity_type=AvpEntities.VARIABLE, 
user=expected_user, entity_id=expected_entity_id
         )
 
+    @pytest.mark.parametrize(
+        "access_view, user, expected_user",
+        [
+            (AccessView.CLUSTER_ACTIVITY, None, ANY),
+            (AccessView.PLUGINS, mock, mock),
+        ],
+    )
+    @patch.object(AwsAuthManager, "avp_facade")
+    @patch.object(AwsAuthManager, "get_user")
+    def test_is_authorized_view(
+        self, mock_get_user, mock_avp_facade, access_view, user, 
expected_user, auth_manager
+    ):
+        is_authorized = Mock()
+        mock_avp_facade.is_authorized = is_authorized
+
+        auth_manager.is_authorized_view(access_view=access_view, user=user)
+
+        if not user:
+            mock_get_user.assert_called_once()
+        is_authorized.assert_called_once_with(
+            method="GET", entity_type=AvpEntities.VIEW, user=expected_user, 
entity_id=access_view.value
+        )
+
     
@patch("airflow.providers.amazon.aws.auth_manager.aws_auth_manager.url_for")
     def test_get_url_login(self, mock_url_for, auth_manager):
         auth_manager.get_url_login()

Reply via email to