This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 031e3945e4 Allow auth managers to override the security manager
(#32525)
031e3945e4 is described below
commit 031e3945e44030d4f085753ffdc43dc104708f91
Author: Vincent <[email protected]>
AuthorDate: Mon Jul 24 13:46:44 2023 -0400
Allow auth managers to override the security manager (#32525)
Allow auth managers to override the security manager
---
airflow/api/auth/backend/session.py | 4 +-
airflow/auth/managers/base_auth_manager.py | 12 ++
airflow/auth/managers/fab/fab_auth_manager.py | 5 +
.../auth/managers/fab/security_manager_override.py | 220 +++++++++++++++++++++
airflow/configuration.py | 34 +++-
airflow/www/auth.py | 5 +-
airflow/www/decorators.py | 4 +-
airflow/www/extensions/init_appbuilder.py | 7 -
.../extensions/init_auth_manager.py} | 30 +--
airflow/www/extensions/init_jinja_globals.py | 5 +-
airflow/www/extensions/init_security.py | 5 +-
airflow/www/fab_security/manager.py | 137 +------------
airflow/www/fab_security/sqla/manager.py | 2 +-
airflow/www/security.py | 40 +++-
airflow/www/views.py | 12 +-
tests/auh/managers/fab/test_fab_auth_manager.py | 4 +
.../auh/managers/test_base_auth_manager.py | 27 ++-
17 files changed, 360 insertions(+), 193 deletions(-)
diff --git a/airflow/api/auth/backend/session.py
b/airflow/api/auth/backend/session.py
index c55f748460..ef914b57e4 100644
--- a/airflow/api/auth/backend/session.py
+++ b/airflow/api/auth/backend/session.py
@@ -22,7 +22,7 @@ from typing import Any, Callable, TypeVar, cast
from flask import Response
-from airflow.configuration import auth_manager
+from airflow.www.extensions.init_auth_manager import get_auth_manager
CLIENT_AUTH: tuple[str, str] | Any | None = None
@@ -39,7 +39,7 @@ def requires_authentication(function: T):
@wraps(function)
def decorated(*args, **kwargs):
- if not auth_manager.is_logged_in():
+ if not get_auth_manager().is_logged_in():
return Response("Unauthorized", 401, {})
return function(*args, **kwargs)
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 462fe34d63..ab5356bf8c 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -38,3 +38,15 @@ class BaseAuthManager(LoggingMixin):
def is_logged_in(self) -> bool:
"""Return whether the user is logged in."""
...
+
+ def get_security_manager_override_class(self) -> type:
+ """
+ Return the security manager override class.
+
+ The security manager override class is responsible for overriding the
default security manager
+ class airflow.www.security.AirflowSecurityManager with a custom
implementation. This class is
+ essentially inherited from airflow.www.security.AirflowSecurityManager.
+
+ By default, return an empty class.
+ """
+ return object
diff --git a/airflow/auth/managers/fab/fab_auth_manager.py
b/airflow/auth/managers/fab/fab_auth_manager.py
index b9f0c1e1df..f90a9ac063 100644
--- a/airflow/auth/managers/fab/fab_auth_manager.py
+++ b/airflow/auth/managers/fab/fab_auth_manager.py
@@ -20,6 +20,7 @@ from __future__ import annotations
from flask_login import current_user
from airflow.auth.managers.base_auth_manager import BaseAuthManager
+from airflow.auth.managers.fab.security_manager_override import
FabAirflowSecurityManagerOverride
class FabAuthManager(BaseAuthManager):
@@ -43,3 +44,7 @@ class FabAuthManager(BaseAuthManager):
def is_logged_in(self) -> bool:
"""Return whether the user is logged in."""
return current_user and not current_user.is_anonymous
+
+ def get_security_manager_override_class(self) -> type:
+ """Return the security manager override."""
+ return FabAirflowSecurityManagerOverride
diff --git a/airflow/auth/managers/fab/security_manager_override.py
b/airflow/auth/managers/fab/security_manager_override.py
new file mode 100644
index 0000000000..5be9ee1f36
--- /dev/null
+++ b/airflow/auth/managers/fab/security_manager_override.py
@@ -0,0 +1,220 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+
+from flask_appbuilder.const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID,
AUTH_REMOTE_USER
+from flask_babel import lazy_gettext
+
+
+class FabAirflowSecurityManagerOverride:
+ """
+ This security manager overrides the default AirflowSecurityManager
security manager.
+
+ This security manager is used only if the auth manager FabAuthManager is
used. It defines everything in
+ the security manager that is needed for the FabAuthManager to work. Any
operation specific to
+ the AirflowSecurityManager should be defined here instead of
AirflowSecurityManager.
+
+ :param appbuilder: The appbuilder.
+ :param actionmodelview: The obj instance for action model view.
+ :param authdbview: The class for auth db view.
+ :param authldapview: The class for auth ldap view.
+ :param authoauthview: The class for auth oauth view.
+ :param authoidview: The class for auth oid view.
+ :param authremoteuserview: The class for auth remote user view.
+ :param permissionmodelview: The class for permission model view.
+ :param registeruser_view: The class for register user view.
+ :param registeruserdbview: The class for register user db view.
+ :param registeruseroauthview: The class for register user oauth view.
+ :param registerusermodelview: The class for register user model view.
+ :param registeruseroidview: The class for register user oid view.
+ :param resetmypasswordview: The class for reset my password view.
+ :param resetpasswordview: The class for reset password view.
+ :param rolemodelview: The class for role model view.
+ :param userinfoeditview: The class for user info edit view.
+ :param userdbmodelview: The class for user db model view.
+ :param userldapmodelview: The class for user ldap model view.
+ :param useroauthmodelview: The class for user oauth model view.
+ :param useroidmodelview: The class for user oid model view.
+ :param userremoteusermodelview: The class for user remote user model view.
+ :param userstatschartview: The class for user stats chart view.
+ """
+
+ """ The obj instance for authentication view """
+ auth_view = None
+ """ The obj instance for user view """
+ user_view = None
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.appbuilder = kwargs["appbuilder"]
+ self.actionmodelview = kwargs["actionmodelview"]
+ self.authdbview = kwargs["authdbview"]
+ self.authldapview = kwargs["authldapview"]
+ self.authoauthview = kwargs["authoauthview"]
+ self.authoidview = kwargs["authoidview"]
+ self.authremoteuserview = kwargs["authremoteuserview"]
+ self.permissionmodelview = kwargs["permissionmodelview"]
+ self.registeruser_view = kwargs["registeruser_view"]
+ self.registeruserdbview = kwargs["registeruserdbview"]
+ self.registeruseroauthview = kwargs["registeruseroauthview"]
+ self.registerusermodelview = kwargs["registerusermodelview"]
+ self.registeruseroidview = kwargs["registeruseroidview"]
+ self.resetmypasswordview = kwargs["resetmypasswordview"]
+ self.resetpasswordview = kwargs["resetpasswordview"]
+ self.rolemodelview = kwargs["rolemodelview"]
+ self.userinfoeditview = kwargs["userinfoeditview"]
+ self.userdbmodelview = kwargs["userdbmodelview"]
+ self.userldapmodelview = kwargs["userldapmodelview"]
+ self.useroauthmodelview = kwargs["useroauthmodelview"]
+ self.useroidmodelview = kwargs["useroidmodelview"]
+ self.userremoteusermodelview = kwargs["userremoteusermodelview"]
+ self.userstatschartview = kwargs["userstatschartview"]
+
+ def register_views(self):
+ """Register FAB auth manager related views."""
+ if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True):
+ return
+
+ if self.auth_user_registration:
+ if self.auth_type == AUTH_DB:
+ self.registeruser_view = self.registeruserdbview()
+ elif self.auth_type == AUTH_OID:
+ self.registeruser_view = self.registeruseroidview()
+ elif self.auth_type == AUTH_OAUTH:
+ self.registeruser_view = self.registeruseroauthview()
+ if self.registeruser_view:
+ self.appbuilder.add_view_no_menu(self.registeruser_view)
+
+ self.appbuilder.add_view_no_menu(self.resetpasswordview())
+ self.appbuilder.add_view_no_menu(self.resetmypasswordview())
+ self.appbuilder.add_view_no_menu(self.userinfoeditview())
+
+ if self.auth_type == AUTH_DB:
+ self.user_view = self.userdbmodelview
+ self.auth_view = self.authdbview()
+ elif self.auth_type == AUTH_LDAP:
+ self.user_view = self.userldapmodelview
+ self.auth_view = self.authldapview()
+ elif self.auth_type == AUTH_OAUTH:
+ self.user_view = self.useroauthmodelview
+ self.auth_view = self.authoauthview()
+ elif self.auth_type == AUTH_REMOTE_USER:
+ self.user_view = self.userremoteusermodelview
+ self.auth_view = self.authremoteuserview()
+ else:
+ self.user_view = self.useroidmodelview
+ self.auth_view = self.authoidview()
+
+ self.appbuilder.add_view_no_menu(self.auth_view)
+
+ # this needs to be done after the view is added, otherwise the
blueprint
+ # is not initialized
+ if self.is_auth_limited:
+ self.limiter.limit(self.auth_rate_limit,
methods=["POST"])(self.auth_view.blueprint)
+
+ self.user_view = self.appbuilder.add_view(
+ self.user_view,
+ "List Users",
+ icon="fa-user",
+ label=lazy_gettext("List Users"),
+ category="Security",
+ category_icon="fa-cogs",
+ category_label=lazy_gettext("Security"),
+ )
+
+ role_view = self.appbuilder.add_view(
+ self.rolemodelview,
+ "List Roles",
+ icon="fa-group",
+ label=lazy_gettext("List Roles"),
+ category="Security",
+ category_icon="fa-cogs",
+ )
+ role_view.related_views = [self.user_view.__class__]
+
+ if self.userstatschartview:
+ self.appbuilder.add_view(
+ self.userstatschartview,
+ "User's Statistics",
+ icon="fa-bar-chart-o",
+ label=lazy_gettext("User's Statistics"),
+ category="Security",
+ )
+ if self.auth_user_registration:
+ self.appbuilder.add_view(
+ self.registerusermodelview,
+ "User's Statistics",
+ icon="fa-user-plus",
+ label=lazy_gettext("User Registrations"),
+ category="Security",
+ )
+ self.appbuilder.menu.add_separator("Security")
+ if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEW",
True):
+ self.appbuilder.add_view(
+ self.actionmodelview,
+ "Actions",
+ icon="fa-lock",
+ label=lazy_gettext("Actions"),
+ category="Security",
+ )
+ if self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEW_MENU_VIEW",
True):
+ self.appbuilder.add_view(
+ self.resourcemodelview,
+ "Resources",
+ icon="fa-list-alt",
+ label=lazy_gettext("Resources"),
+ category="Security",
+ )
+ if
self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True):
+ self.appbuilder.add_view(
+ self.permissionmodelview,
+ "Permission Pairs",
+ icon="fa-link",
+ label=lazy_gettext("Permissions"),
+ category="Security",
+ )
+
+ @property
+ def auth_user_registration(self):
+ """Will user self registration be allowed."""
+ return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION"]
+
+ @property
+ def auth_type(self):
+ """Get the auth type."""
+ return self.appbuilder.get_app.config["AUTH_TYPE"]
+
+ @property
+ def is_auth_limited(self) -> bool:
+ """Is the auth rate limited."""
+ return self.appbuilder.get_app.config["AUTH_RATE_LIMITED"]
+
+ @property
+ def auth_rate_limit(self) -> str:
+ """Get the auth rate limit."""
+ return self.appbuilder.get_app.config["AUTH_RATE_LIMIT"]
+
+ @cached_property
+ def resourcemodelview(self):
+ """Return the resource model view."""
+ from airflow.www.views import ResourceModelView
+
+ return ResourceModelView
diff --git a/airflow/configuration.py b/airflow/configuration.py
index cca5588e54..d6492891d3 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -2209,6 +2209,39 @@ def initialize_secrets_backends() ->
list[BaseSecretsBackend]:
return backend_list
[email protected]_cache(maxsize=None)
+def _DEFAULT_CONFIG() -> str:
+ path = _default_config_file_path("default_airflow.cfg")
+ with open(path) as fh:
+ return fh.read()
+
+
[email protected]_cache(maxsize=None)
+def _TEST_CONFIG() -> str:
+ path = _default_config_file_path("default_test.cfg")
+ with open(path) as fh:
+ return fh.read()
+
+
+_deprecated = {
+ "DEFAULT_CONFIG": _DEFAULT_CONFIG,
+ "TEST_CONFIG": _TEST_CONFIG,
+ "TEST_CONFIG_FILE_PATH": functools.partial(_default_config_file_path,
"default_test.cfg"),
+ "DEFAULT_CONFIG_FILE_PATH": functools.partial(_default_config_file_path,
"default_airflow.cfg"),
+}
+
+
+def __getattr__(name):
+ if name in _deprecated:
+ warnings.warn(
+ f"{__name__}.{name} is deprecated and will be removed in future",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return _deprecated[name]()
+ raise AttributeError(f"module {__name__} has no attribute {name}")
+
+
def initialize_auth_manager() -> BaseAuthManager:
"""
Initialize auth manager.
@@ -2257,5 +2290,4 @@ WEBSERVER_CONFIG = "" # Set by initialize_config
conf = initialize_config()
secrets_backend_list = initialize_secrets_backends()
-auth_manager = initialize_auth_manager()
conf.validate()
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index 82fb5d34c5..54114da1c8 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -21,8 +21,9 @@ from typing import Callable, Sequence, TypeVar, cast
from flask import current_app, flash, g, redirect, render_template, request,
url_for
-from airflow.configuration import auth_manager, conf
+from airflow.configuration import conf
from airflow.utils.net import get_hostname
+from airflow.www.extensions.init_auth_manager import get_auth_manager
T = TypeVar("T", bound=Callable)
@@ -46,7 +47,7 @@ def has_access(permissions: Sequence[tuple[str, str]] | None
= None) -> Callable
)
if appbuilder.sm.check_authorization(permissions, dag_id):
return func(*args, **kwargs)
- elif auth_manager.is_logged_in() and not g.user.perms:
+ elif get_auth_manager().is_logged_in() and not g.user.perms:
return (
render_template(
"airflow/no_roles_permissions.html",
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index fc386d220f..af316e3ed0 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -29,10 +29,10 @@ import pendulum
from flask import after_this_request, g, request
from pendulum.parsing.exceptions import ParserError
-from airflow.configuration import auth_manager
from airflow.models import Log
from airflow.utils.log import secrets_masker
from airflow.utils.session import create_session
+from airflow.www.extensions.init_auth_manager import get_auth_manager
T = TypeVar("T", bound=Callable)
@@ -85,7 +85,7 @@ def action_logging(func: Callable | None = None, event: str |
None = None) -> Ca
__tracebackhide__ = True # Hide from pytest traceback.
with create_session() as session:
- if not auth_manager.is_logged_in():
+ if not get_auth_manager().is_logged_in():
user = "anonymous"
else:
user = f"{g.user.username} ({g.user.get_full_name()})"
diff --git a/airflow/www/extensions/init_appbuilder.py
b/airflow/www/extensions/init_appbuilder.py
index ac9d2c9107..ae793ca956 100644
--- a/airflow/www/extensions/init_appbuilder.py
+++ b/airflow/www/extensions/init_appbuilder.py
@@ -208,13 +208,6 @@ class AirflowAppBuilder:
if self.update_perms: # default is True, if False takes precedence
from config
self.update_perms = app.config.get("FAB_UPDATE_PERMS", True)
- _security_manager_class_name =
app.config.get("FAB_SECURITY_MANAGER_CLASS", None)
- if _security_manager_class_name is not None:
- self.security_manager_class =
dynamic_class_import(_security_manager_class_name)
- if self.security_manager_class is None:
- from flask_appbuilder.security.sqla.manager import SecurityManager
-
- self.security_manager_class = SecurityManager
self._addon_managers = app.config["ADDON_MANAGERS"]
self.session = session
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/www/extensions/init_auth_manager.py
similarity index 56%
copy from airflow/auth/managers/base_auth_manager.py
copy to airflow/www/extensions/init_auth_manager.py
index 462fe34d63..d21139f670 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/www/extensions/init_auth_manager.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,24 +16,25 @@
# under the License.
from __future__ import annotations
-from abc import abstractmethod
-
-from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
+from airflow.compat.functools import cache
+from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException
-class BaseAuthManager(LoggingMixin):
+@cache
+def get_auth_manager() -> BaseAuthManager:
"""
- Class to derive in order to implement concrete auth managers.
+ Initialize auth manager.
- Auth managers are responsible for any user management related operation
such as login, logout, authz, ...
+ Import the user manager class, instantiate it and return it.
"""
+ auth_manager_cls = conf.getimport(section="core", key="auth_manager")
- @abstractmethod
- def get_user_name(self) -> str:
- """Return the username associated to the user in session."""
- ...
+ if not auth_manager_cls:
+ raise AirflowConfigException(
+ "No auth manager defined in the config. "
+ "Please specify one using section/key [core/auth_manager]."
+ )
- @abstractmethod
- def is_logged_in(self) -> bool:
- """Return whether the user is logged in."""
- ...
+ return auth_manager_cls()
diff --git a/airflow/www/extensions/init_jinja_globals.py
b/airflow/www/extensions/init_jinja_globals.py
index 9ef948084c..13baeea7bc 100644
--- a/airflow/www/extensions/init_jinja_globals.py
+++ b/airflow/www/extensions/init_jinja_globals.py
@@ -21,10 +21,11 @@ import logging
import pendulum
import airflow
-from airflow.configuration import auth_manager, conf
+from airflow.configuration import conf
from airflow.settings import IS_K8S_OR_K8SCELERY_EXECUTOR, STATE_COLORS
from airflow.utils.net import get_hostname
from airflow.utils.platform import get_airflow_git_version
+from airflow.www.extensions.init_auth_manager import get_auth_manager
def init_jinja_globals(app):
@@ -68,7 +69,7 @@ def init_jinja_globals(app):
"git_version": git_version,
"k8s_or_k8scelery_executor": IS_K8S_OR_K8SCELERY_EXECUTOR,
"rest_api_enabled": False,
- "auth_manager": auth_manager,
+ "auth_manager": get_auth_manager(),
"config_test_connection": conf.get("core", "test_connection",
fallback="Disabled"),
}
diff --git a/airflow/www/extensions/init_security.py
b/airflow/www/extensions/init_security.py
index ba57f99b14..17f93fc1c5 100644
--- a/airflow/www/extensions/init_security.py
+++ b/airflow/www/extensions/init_security.py
@@ -22,8 +22,9 @@ from importlib import import_module
from flask import g, redirect, url_for
from flask_login import logout_user
-from airflow.configuration import auth_manager, conf
+from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, AirflowException
+from airflow.www.extensions.init_auth_manager import get_auth_manager
log = logging.getLogger(__name__)
@@ -68,6 +69,6 @@ def init_api_experimental_auth(app):
def init_check_user_active(app):
@app.before_request
def check_user_active():
- if auth_manager.is_logged_in() and not g.user.is_active:
+ if get_auth_manager().is_logged_in() and not g.user.is_active:
logout_user()
return redirect(url_for(app.appbuilder.sm.auth_view.endpoint +
".login"))
diff --git a/airflow/www/fab_security/manager.py
b/airflow/www/fab_security/manager.py
index 00d3806436..3c59522a62 100644
--- a/airflow/www/fab_security/manager.py
+++ b/airflow/www/fab_security/manager.py
@@ -22,7 +22,6 @@ import base64
import datetime
import json
import logging
-from functools import cached_property
from typing import Any
from uuid import uuid4
@@ -34,7 +33,6 @@ from flask_appbuilder.const import (
AUTH_LDAP,
AUTH_OAUTH,
AUTH_OID,
- AUTH_REMOTE_USER,
LOGMSG_ERR_SEC_ADD_REGISTER_USER,
LOGMSG_ERR_SEC_AUTH_LDAP,
LOGMSG_ERR_SEC_AUTH_LDAP_TLS,
@@ -66,14 +64,14 @@ from flask_appbuilder.security.views import (
UserRemoteUserModelView,
UserStatsChartView,
)
-from flask_babel import lazy_gettext as _
from flask_jwt_extended import JWTManager, current_user as current_user_jwt
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_login import AnonymousUserMixin, LoginManager, current_user
from werkzeug.security import check_password_hash, generate_password_hash
-from airflow.configuration import auth_manager, conf
+from airflow.configuration import conf
+from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.fab_security.sqla.models import Action, Permission,
RegisterUser, Resource, Role, User
# This product contains a modified portion of 'Flask App Builder' developed by
Daniel Vaz Gaspar.
@@ -208,12 +206,6 @@ class BaseSecurityManager:
userstatschartview = UserStatsChartView
permissionmodelview = PermissionModelView
- @cached_property
- def resourcemodelview(self):
- from airflow.www.views import ResourceModelView
-
- return ResourceModelView
-
def __init__(self, appbuilder):
self.appbuilder = appbuilder
app = self.appbuilder.get_app
@@ -374,11 +366,6 @@ class BaseSecurityManager:
def api_login_allow_multiple_providers(self):
return
self.appbuilder.get_app.config["AUTH_API_LOGIN_ALLOW_MULTIPLE_PROVIDERS"]
- @property
- def auth_type(self):
- """Get the auth type."""
- return self.appbuilder.get_app.config["AUTH_TYPE"]
-
@property
def auth_username_ci(self):
"""Gets the auth username for CI."""
@@ -529,18 +516,10 @@ class BaseSecurityManager:
"""Oauth providers."""
return self.appbuilder.get_app.config["OAUTH_PROVIDERS"]
- @property
- def is_auth_limited(self) -> bool:
- return self.appbuilder.get_app.config["AUTH_RATE_LIMITED"]
-
- @property
- def auth_rate_limit(self) -> str:
- return self.appbuilder.get_app.config["AUTH_RATE_LIMIT"]
-
@property
def current_user(self):
"""Current user object."""
- if auth_manager.is_logged_in():
+ if get_auth_manager().is_logged_in():
return g.user
elif current_user_jwt:
return current_user_jwt
@@ -732,114 +711,6 @@ class BaseSecurityManager:
return jwt_decoded_payload
- def register_views(self):
- if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True):
- return
-
- if self.auth_user_registration:
- if self.auth_type == AUTH_DB:
- self.registeruser_view = self.registeruserdbview()
- elif self.auth_type == AUTH_OID:
- self.registeruser_view = self.registeruseroidview()
- elif self.auth_type == AUTH_OAUTH:
- self.registeruser_view = self.registeruseroauthview()
- if self.registeruser_view:
- self.appbuilder.add_view_no_menu(self.registeruser_view)
-
- self.appbuilder.add_view_no_menu(self.resetpasswordview())
- self.appbuilder.add_view_no_menu(self.resetmypasswordview())
- self.appbuilder.add_view_no_menu(self.userinfoeditview())
-
- if self.auth_type == AUTH_DB:
- self.user_view = self.userdbmodelview
- self.auth_view = self.authdbview()
-
- elif self.auth_type == AUTH_LDAP:
- self.user_view = self.userldapmodelview
- self.auth_view = self.authldapview()
- elif self.auth_type == AUTH_OAUTH:
- self.user_view = self.useroauthmodelview
- self.auth_view = self.authoauthview()
- elif self.auth_type == AUTH_REMOTE_USER:
- self.user_view = self.userremoteusermodelview
- self.auth_view = self.authremoteuserview()
- else:
- self.user_view = self.useroidmodelview
- self.auth_view = self.authoidview()
- if self.auth_user_registration:
- pass
- # self.registeruser_view = self.registeruseroidview()
- # self.appbuilder.add_view_no_menu(self.registeruser_view)
-
- self.appbuilder.add_view_no_menu(self.auth_view)
-
- # this needs to be done after the view is added, otherwise the
blueprint
- # is not initialized
- if self.is_auth_limited:
- self.limiter.limit(self.auth_rate_limit,
methods=["POST"])(self.auth_view.blueprint)
-
- self.user_view = self.appbuilder.add_view(
- self.user_view,
- "List Users",
- icon="fa-user",
- label=_("List Users"),
- category="Security",
- category_icon="fa-cogs",
- category_label=_("Security"),
- )
-
- role_view = self.appbuilder.add_view(
- self.rolemodelview,
- "List Roles",
- icon="fa-group",
- label=_("List Roles"),
- category="Security",
- category_icon="fa-cogs",
- )
- role_view.related_views = [self.user_view.__class__]
-
- if self.userstatschartview:
- self.appbuilder.add_view(
- self.userstatschartview,
- "User's Statistics",
- icon="fa-bar-chart-o",
- label=_("User's Statistics"),
- category="Security",
- )
- if self.auth_user_registration:
- self.appbuilder.add_view(
- self.registerusermodelview,
- "User's Statistics",
- icon="fa-user-plus",
- label=_("User Registrations"),
- category="Security",
- )
- self.appbuilder.menu.add_separator("Security")
- if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEW",
True):
- self.appbuilder.add_view(
- self.actionmodelview,
- "Actions",
- icon="fa-lock",
- label=_("Actions"),
- category="Security",
- )
- if self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEW_MENU_VIEW",
True):
- self.appbuilder.add_view(
- self.resourcemodelview,
- "Resources",
- icon="fa-list-alt",
- label=_("Resources"),
- category="Security",
- )
- if
self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True):
- self.appbuilder.add_view(
- self.permissionmodelview,
- "Permission Pairs",
- icon="fa-link",
- label=_("Permissions"),
- category="Security",
- )
-
def create_db(self):
"""Setups the DB, creates admin and public roles if they don't
exist."""
roles_mapping =
self.appbuilder.get_app.config.get("FAB_ROLES_MAPPING", {})
@@ -1415,7 +1286,7 @@ class BaseSecurityManager:
return result
def get_user_menu_access(self, menu_names: list[str] | None = None) ->
set[str]:
- if auth_manager.is_logged_in():
+ if get_auth_manager().is_logged_in():
return self._get_user_permission_resources(g.user, "menu_access",
resource_names=menu_names)
elif current_user_jwt:
return self._get_user_permission_resources(
diff --git a/airflow/www/fab_security/sqla/manager.py
b/airflow/www/fab_security/sqla/manager.py
index 62decfb184..c0daf5553c 100644
--- a/airflow/www/fab_security/sqla/manager.py
+++ b/airflow/www/fab_security/sqla/manager.py
@@ -58,7 +58,7 @@ class SecurityManager(BaseSecurityManager):
permission_model = Permission
registeruser_model = RegisterUser
- def __init__(self, appbuilder):
+ def __init__(self, appbuilder, **kwargs):
"""
Class constructor.
diff --git a/airflow/www/security.py b/airflow/www/security.py
index 33188328fc..d9229c9b1e 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -18,18 +18,18 @@
from __future__ import annotations
import warnings
-from typing import Any, Collection, Container, Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Collection, Container, Iterable,
Sequence
from flask import g
from sqlalchemy import or_
from sqlalchemy.orm import Session, joinedload
-from airflow.configuration import auth_manager
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models import DagBag, DagModel
from airflow.security import permissions
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.fab_security.sqla.manager import SecurityManager
from airflow.www.fab_security.sqla.models import Permission, Resource, Role,
User
from airflow.www.fab_security.views import (
@@ -57,8 +57,14 @@ EXISTING_ROLES = {
"Public",
}
+if TYPE_CHECKING:
+ SecurityManagerOverride: type = object
+else:
+ # Fetch the security manager override from the auth manager
+ SecurityManagerOverride =
get_auth_manager().get_security_manager_override_class()
-class AirflowSecurityManager(SecurityManager, LoggingMixin):
+
+class AirflowSecurityManager(SecurityManagerOverride, SecurityManager,
LoggingMixin):
"""Custom security manager, which introduces a permission model adapted to
Airflow."""
###########################################################################
@@ -193,7 +199,31 @@ class AirflowSecurityManager(SecurityManager,
LoggingMixin):
userstatschartview = CustomUserStatsChartView
def __init__(self, appbuilder) -> None:
- super().__init__(appbuilder)
+ super().__init__(
+ appbuilder=appbuilder,
+ actionmodelview=self.actionmodelview,
+ authdbview=self.authdbview,
+ authldapview=self.authldapview,
+ authoauthview=self.authoauthview,
+ authoidview=self.authoidview,
+ authremoteuserview=self.authremoteuserview,
+ permissionmodelview=self.permissionmodelview,
+ registeruser_view=self.registeruser_view,
+ registeruserdbview=self.registeruserdbview,
+ registeruseroauthview=self.registeruseroauthview,
+ registerusermodelview=self.registerusermodelview,
+ registeruseroidview=self.registeruseroidview,
+ resetmypasswordview=self.resetmypasswordview,
+ resetpasswordview=self.resetpasswordview,
+ rolemodelview=self.rolemodelview,
+ userinfoeditview=self.userinfoeditview,
+ userdbmodelview=self.userdbmodelview,
+ userldapmodelview=self.userldapmodelview,
+ useroauthmodelview=self.useroauthmodelview,
+ useroidmodelview=self.useroidmodelview,
+ userremoteusermodelview=self.userremoteusermodelview,
+ userstatschartview=self.userstatschartview,
+ )
# Go and fix up the SQLAInterface used from the stock one to our
subclass.
# This is needed to support the "hack" where we had to edit
@@ -339,7 +369,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin):
if not user_actions:
user_actions = [permissions.ACTION_CAN_EDIT,
permissions.ACTION_CAN_READ]
- if not auth_manager.is_logged_in():
+ if not get_auth_manager().is_logged_in():
roles = user.roles
else:
if (permissions.ACTION_CAN_EDIT in user_actions and
self.can_edit_all_dags(user)) or (
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 359f1ca52b..54f46575be 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -83,7 +83,7 @@ from airflow.api.common.mark_tasks import (
set_dag_run_state_to_success,
set_state,
)
-from airflow.configuration import AIRFLOW_CONFIG, auth_manager, conf
+from airflow.configuration import AIRFLOW_CONFIG, conf
from airflow.datasets import Dataset
from airflow.exceptions import (
AirflowConfigException,
@@ -131,6 +131,7 @@ from airflow.utils.timezone import td_format, utcnow
from airflow.version import version
from airflow.www import auth, utils as wwwutils
from airflow.www.decorators import action_logging, gzipped
+from airflow.www.extensions.init_auth_manager import get_auth_manager
from airflow.www.forms import (
DagRunEditForm,
DateTimeForm,
@@ -622,16 +623,17 @@ def method_not_allowed(error):
def show_traceback(error):
"""Show Traceback for a given error."""
+ is_logged_in = get_auth_manager().is_logged_in()
return (
render_template(
"airflow/traceback.html",
- python_version=sys.version.split(" ")[0] if
auth_manager.is_logged_in() else "redact",
- airflow_version=version if auth_manager.is_logged_in() else
"redact",
+ python_version=sys.version.split(" ")[0] if is_logged_in else
"redact",
+ airflow_version=version if is_logged_in else "redact",
hostname=get_hostname()
- if conf.getboolean("webserver", "EXPOSE_HOSTNAME") and
auth_manager.is_logged_in()
+ if conf.getboolean("webserver", "EXPOSE_HOSTNAME") and is_logged_in
else "redact",
info=traceback.format_exc()
- if conf.getboolean("webserver", "EXPOSE_STACKTRACE") and
auth_manager.is_logged_in()
+ if conf.getboolean("webserver", "EXPOSE_STACKTRACE") and
is_logged_in
else "Error! Please contact server admin.",
),
500,
diff --git a/tests/auh/managers/fab/test_fab_auth_manager.py
b/tests/auh/managers/fab/test_fab_auth_manager.py
index 4f24b1297e..baaec623f4 100644
--- a/tests/auh/managers/fab/test_fab_auth_manager.py
+++ b/tests/auh/managers/fab/test_fab_auth_manager.py
@@ -22,6 +22,7 @@ from unittest.mock import Mock
import pytest
from airflow.auth.managers.fab.fab_auth_manager import FabAuthManager
+from airflow.auth.managers.fab.security_manager_override import
FabAirflowSecurityManagerOverride
from airflow.www.fab_security.sqla.models import User
@@ -55,3 +56,6 @@ class TestFabAuthManager:
mock_current_user.return_value = user
assert auth_manager.is_logged_in() is False
+
+ def
test_get_security_manager_override_class_return_fab_security_manager_override(self,
auth_manager):
+ assert auth_manager.get_security_manager_override_class() is
FabAirflowSecurityManagerOverride
diff --git a/airflow/auth/managers/base_auth_manager.py
b/tests/auh/managers/test_base_auth_manager.py
similarity index 58%
copy from airflow/auth/managers/base_auth_manager.py
copy to tests/auh/managers/test_base_auth_manager.py
index 462fe34d63..61bee75371 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/tests/auh/managers/test_base_auth_manager.py
@@ -1,4 +1,3 @@
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,24 +16,20 @@
# under the License.
from __future__ import annotations
-from abc import abstractmethod
+import pytest
-from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
-class BaseAuthManager(LoggingMixin):
- """
- Class to derive in order to implement concrete auth managers.
[email protected]
+def auth_manager():
+ class EmptyAuthManager(BaseAuthManager):
+ def get_user_name(self) -> str:
+ raise NotImplementedError()
- Auth managers are responsible for any user management related operation
such as login, logout, authz, ...
- """
+ return EmptyAuthManager()
- @abstractmethod
- def get_user_name(self) -> str:
- """Return the username associated to the user in session."""
- ...
- @abstractmethod
- def is_logged_in(self) -> bool:
- """Return whether the user is logged in."""
- ...
+class TestBaseAuthManager:
+ def test_get_security_manager_override_class_return_empty_class(self,
auth_manager):
+ assert auth_manager.get_security_manager_override_class() is object