This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8775f841dd Move methods from security managers to
`FabAirflowSecurityManagerOverride` (#33044)
8775f841dd is described below
commit 8775f841ddc57960358c83050841f02a3bc0f8ca
Author: Vincent <[email protected]>
AuthorDate: Tue Aug 15 10:21:24 2023 -0400
Move methods from security managers to `FabAirflowSecurityManagerOverride`
(#33044)
---
.pre-commit-config.yaml | 1 +
.../fab/security_manager/modules/__init__.py | 17 +
.../managers/fab/security_manager/modules/db.py} | 455 ++++++++++----------
.../managers/fab/security_manager/modules/oauth.py | 186 +++++++++
.../auth/managers/fab/security_manager/override.py | 167 +++++++-
airflow/www/fab_security/manager.py | 301 +-------------
airflow/www/fab_security/sqla/manager.py | 461 +--------------------
airflow/www/security.py | 21 -
.../managers/fab/security_manager/test_override.py | 11 +-
9 files changed, 588 insertions(+), 1032 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 2abf7de0b7..1c8feab070 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -499,6 +499,7 @@ repos:
exclude: >
(?x)
^airflow/api_connexion/openapi/v1.yaml$|
+ ^airflow/auth/managers/fab/security_manager/|
^airflow/cli/commands/webserver_command.py$|
^airflow/config_templates/|
^airflow/models/baseoperator.py$|
diff --git a/airflow/auth/managers/fab/security_manager/modules/__init__.py
b/airflow/auth/managers/fab/security_manager/modules/__init__.py
new file mode 100644
index 0000000000..217e5db960
--- /dev/null
+++ b/airflow/auth/managers/fab/security_manager/modules/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/www/fab_security/sqla/manager.py
b/airflow/auth/managers/fab/security_manager/modules/db.py
similarity index 68%
copy from airflow/www/fab_security/sqla/manager.py
copy to airflow/auth/managers/fab/security_manager/modules/db.py
index 83e2119492..31a1cfd7e2 100644
--- a/airflow/www/fab_security/sqla/manager.py
+++ b/airflow/auth/managers/fab/security_manager/modules/db.py
@@ -1,3 +1,4 @@
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -19,100 +20,190 @@ from __future__ import annotations
import logging
import uuid
-from flask_appbuilder import const as c
+from flask_appbuilder import const
from flask_appbuilder.models.sqla import Base
-from flask_appbuilder.models.sqla.interface import SQLAInterface
-from sqlalchemy import and_, func, inspect, literal
-from sqlalchemy.orm.exc import MultipleResultsFound
+from sqlalchemy import func, inspect
+from sqlalchemy.exc import MultipleResultsFound
from werkzeug.security import generate_password_hash
-from airflow.auth.managers.fab.models import (
- Action,
- Permission,
- RegisterUser,
- Resource,
- Role,
- User,
- assoc_permission_role,
-)
-from airflow.www.fab_security.manager import BaseSecurityManager
+from airflow import AirflowException
+from airflow.auth.managers.fab.models import Action, Permission, Resource, Role
log = logging.getLogger(__name__)
-class SecurityManager(BaseSecurityManager):
+class FabAirflowSecurityManagerOverrideDb:
"""
- Responsible for authentication, registering security views, role and
permission auto management.
+ This class contains all methods in
+
airflow.auth.managers.fab.security_manager.override.FabAirflowSecurityManagerOverride
related to the
+ database.
+
+ FabAirflowSecurityManagerOverride is split into multiple classes to avoid
having one massive class.
- If you want to change anything just inherit and override, then
- pass your own security manager to AppBuilder.
+ :param appbuilder: The appbuilder.
"""
- user_model = User
- """ Override to set your own User Model """
+ """ Models """
role_model = Role
- """ Override to set your own Role Model """
+ permission_model = Permission
action_model = Action
resource_model = Resource
- permission_model = Permission
- registeruser_model = RegisterUser
-
- def __init__(self, appbuilder, **kwargs):
- """
- Class constructor.
-
- :param appbuilder: F.A.B AppBuilder main object
- """
- super().__init__(appbuilder)
- user_datamodel = SQLAInterface(self.user_model)
- if self.auth_type == c.AUTH_DB:
- self.userdbmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_LDAP:
- self.userldapmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_OID:
- self.useroidmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_OAUTH:
- self.useroauthmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_REMOTE_USER:
- self.userremoteusermodelview.datamodel = user_datamodel
-
- if self.userstatschartview:
- self.userstatschartview.datamodel = user_datamodel
- if self.auth_user_registration:
- self.registerusermodelview.datamodel =
SQLAInterface(self.registeruser_model)
-
- self.rolemodelview.datamodel = SQLAInterface(self.role_model)
- self.actionmodelview.datamodel = SQLAInterface(self.action_model)
- self.resourcemodelview.datamodel = SQLAInterface(self.resource_model)
- self.permissionmodelview.datamodel =
SQLAInterface(self.permission_model)
- self.create_db()
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.appbuilder = kwargs["appbuilder"]
@property
def get_session(self):
return self.appbuilder.get_session
- def register_views(self):
- super().register_views()
-
def create_db(self):
+ """
+ Create the database.
+
+ Creates admin and public roles if they don't exist.
+ """
+ if not self.appbuilder.update_perms:
+ log.debug("Skipping db since appbuilder disables update_perms")
+ return
try:
engine = self.get_session.get_bind(mapper=None, clause=None)
inspector = inspect(engine)
if "ab_user" not in inspector.get_table_names():
- log.info(c.LOGMSG_INF_SEC_NO_DB)
+ log.info(const.LOGMSG_INF_SEC_NO_DB)
Base.metadata.create_all(engine)
- log.info(c.LOGMSG_INF_SEC_ADD_DB)
- super().create_db()
+ log.info(const.LOGMSG_INF_SEC_ADD_DB)
+
+ roles_mapping =
self.appbuilder.app.config.get("FAB_ROLES_MAPPING", {})
+ for pk, name in roles_mapping.items():
+ self.update_role(pk, name)
+ for role_name in self._builtin_roles:
+ self.add_role(role_name)
+ if self.auth_role_admin not in self._builtin_roles:
+ self.add_role(self.auth_role_admin)
+ self.add_role(self.auth_role_public)
+ if self.count_users() == 0 and self.auth_role_public !=
self.auth_role_admin:
+ log.warning(const.LOGMSG_WAR_SEC_NO_USER)
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_CREATE_DB.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_CREATE_DB.format(str(e)))
exit(1)
- def find_register_user(self, registration_hash):
- return (
- self.get_session.query(self.registeruser_model)
- .filter(self.registeruser_model.registration_hash ==
registration_hash)
- .scalar()
- )
+ """
+ -----------
+ Role entity
+ -----------
+ """
+
+ def update_role(self, role_id, name: str) -> Role | None:
+ """Update a role in the database."""
+ role = self.get_session.get(self.role_model, role_id)
+ if not role:
+ return None
+ try:
+ role.name = name
+ self.get_session.merge(role)
+ self.get_session.commit()
+ log.info(const.LOGMSG_INF_SEC_UPD_ROLE.format(role))
+ except Exception as e:
+ log.error(const.LOGMSG_ERR_SEC_UPD_ROLE.format(str(e)))
+ self.get_session.rollback()
+ return None
+ return role
+
+ def add_role(self, name: str) -> Role:
+ """Add a role in the database."""
+ role = self.find_role(name)
+ if role is None:
+ try:
+ role = self.role_model()
+ role.name = name
+ self.get_session.add(role)
+ self.get_session.commit()
+ log.info(const.LOGMSG_INF_SEC_ADD_ROLE.format(name))
+ return role
+ except Exception as e:
+ log.error(const.LOGMSG_ERR_SEC_ADD_ROLE.format(str(e)))
+ self.get_session.rollback()
+ return role
+
+ def find_role(self, name):
+ """
+ Find a role in the database.
+
+ :param name: the role name
+ """
+ return
self.get_session.query(self.role_model).filter_by(name=name).one_or_none()
+
+ def get_all_roles(self):
+ return self.get_session.query(self.role_model).all()
+
+ def get_public_role(self):
+ return
self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none()
+
+ def delete_role(self, role_name: str) -> None:
+ """
+ Delete the given Role.
+
+ :param role_name: the name of a role in the ab_role table
+ """
+ session = self.get_session
+ role = session.query(Role).filter(Role.name == role_name).first()
+ if role:
+ log.info("Deleting role '%s'", role_name)
+ session.delete(role)
+ session.commit()
+ else:
+ raise AirflowException(f"Role named '{role_name}' does not exist")
+
+ """
+ -----------
+ User entity
+ -----------
+ """
+
+ def add_user(
+ self,
+ username,
+ first_name,
+ last_name,
+ email,
+ role,
+ password="",
+ hashed_password="",
+ ):
+ """Generic function to create user."""
+ try:
+ user = self.user_model()
+ user.first_name = first_name
+ user.last_name = last_name
+ user.username = username
+ user.email = email
+ user.active = True
+ user.roles = role if isinstance(role, list) else [role]
+ if hashed_password:
+ user.password = hashed_password
+ else:
+ user.password = generate_password_hash(password)
+ self.get_session.add(user)
+ self.get_session.commit()
+ log.info(const.LOGMSG_INF_SEC_ADD_USER.format(username))
+ return user
+ except Exception as e:
+ log.error(const.LOGMSG_ERR_SEC_ADD_USER.format(str(e)))
+ self.get_session.rollback()
+ return False
+
+ def load_user(self, user_id):
+ """Load user by ID."""
+ return self.get_user_by_id(int(user_id))
+
+ def get_user_by_id(self, pk):
+ return self.get_session.get(self.user_model, pk)
+
+ def count_users(self):
+ """Return the number of users in the database."""
+ return self.get_session.query(func.count(self.user_model.id)).scalar()
def add_register_user(self, username, first_name, last_name, email,
password="", hashed_password=""):
"""
@@ -135,24 +226,9 @@ class SecurityManager(BaseSecurityManager):
self.get_session.commit()
return register_user
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_REGISTER_USER.format(str(e)))
- self.appbuilder.get_session.rollback()
- return None
-
- def del_register_user(self, register_user):
- """
- Deletes registration object from database.
-
- :param register_user: RegisterUser object to delete
- """
- try:
- self.get_session.delete(register_user)
- self.get_session.commit()
- return True
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_REGISTER_USER.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_ADD_REGISTER_USER.format(str(e)))
self.get_session.rollback()
- return False
+ return None
def find_user(self, username=None, email=None):
"""Finds user by username or email."""
@@ -180,92 +256,46 @@ class SecurityManager(BaseSecurityManager):
log.error("Multiple results found for user with email %s",
email)
return None
- def get_all_users(self):
- return self.get_session.query(self.user_model).all()
-
- def add_user(
- self,
- username,
- first_name,
- last_name,
- email,
- role,
- password="",
- hashed_password="",
- ):
- """Generic function to create user."""
- try:
- user = self.user_model()
- user.first_name = first_name
- user.last_name = last_name
- user.username = username
- user.email = email
- user.active = True
- user.roles = role if isinstance(role, list) else [role]
- if hashed_password:
- user.password = hashed_password
- else:
- user.password = generate_password_hash(password)
- self.get_session.add(user)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_USER.format(username))
- return user
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_USER.format(str(e)))
- self.get_session.rollback()
- return False
-
- def count_users(self):
- return self.get_session.query(func.count(self.user_model.id)).scalar()
+ def find_register_user(self, registration_hash):
+ return (
+ self.get_session.query(self.registeruser_model)
+ .filter(self.registeruser_model.registration_hash ==
registration_hash)
+ .scalar()
+ )
def update_user(self, user):
try:
self.get_session.merge(user)
self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_UPD_USER.format(user))
+ log.info(const.LOGMSG_INF_SEC_UPD_USER.format(user))
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_UPD_USER.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_UPD_USER.format(str(e)))
self.get_session.rollback()
return False
- def add_role(self, name: str) -> Role:
- role = self.find_role(name)
- if role is None:
- try:
- role = self.role_model()
- role.name = name
- self.get_session.add(role)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_ROLE.format(name))
- return role
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_ROLE.format(str(e)))
- self.get_session.rollback()
- return role
+ def del_register_user(self, register_user):
+ """
+ Deletes registration object from database.
- def update_role(self, role_id, name: str) -> Role | None:
- role = self.get_session.get(self.role_model, role_id)
- if not role:
- return None
+ :param register_user: RegisterUser object to delete
+ """
try:
- role.name = name
- self.get_session.merge(role)
+ self.get_session.delete(register_user)
self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_UPD_ROLE.format(role))
+ return True
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_UPD_ROLE.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_DEL_REGISTER_USER.format(str(e)))
self.get_session.rollback()
- return None
- return role
-
- def find_role(self, name):
- return
self.get_session.query(self.role_model).filter_by(name=name).one_or_none()
+ return False
- def get_all_roles(self):
- return self.get_session.query(self.role_model).all()
+ def get_all_users(self):
+ return self.get_session.query(self.user_model).all()
- def get_public_role(self):
- return
self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none()
+ """
+ -------------
+ Action entity
+ -------------
+ """
def get_action(self, name: str) -> Action:
"""
@@ -276,55 +306,6 @@ class SecurityManager(BaseSecurityManager):
"""
return
self.get_session.query(self.action_model).filter_by(name=name).one_or_none()
- def permission_exists_in_one_or_more_roles(
- self, resource_name: str, action_name: str, role_ids: list[int]
- ) -> bool:
- """
- Efficiently check if a certain permission exists on a list of role
ids; used by `has_access`.
-
- :param resource_name: The view's name to check if exists on one of the
roles
- :param action_name: The permission name to check if exists
- :param role_ids: a list of Role ids
- :return: Boolean
- """
- q = (
- self.appbuilder.get_session.query(self.permission_model)
- .join(
- assoc_permission_role,
- and_(self.permission_model.id ==
assoc_permission_role.c.permission_view_id),
- )
- .join(self.role_model)
- .join(self.action_model)
- .join(self.resource_model)
- .filter(
- self.resource_model.name == resource_name,
- self.action_model.name == action_name,
- self.role_model.id.in_(role_ids),
- )
- .exists()
- )
- # Special case for MSSQL/Oracle (works on PG and MySQL > 8)
- if self.appbuilder.get_session.bind.dialect.name in ("mssql",
"oracle"):
- return
self.appbuilder.get_session.query(literal(True)).filter(q).scalar()
- return self.appbuilder.get_session.query(q).scalar()
-
- def filter_roles_by_perm_with_action(self, action_name: str, role_ids:
list[int]):
- """Find roles with permission."""
- return (
- self.appbuilder.get_session.query(self.permission_model)
- .join(
- assoc_permission_role,
- and_(self.permission_model.id ==
assoc_permission_role.c.permission_view_id),
- )
- .join(self.role_model)
- .join(self.action_model)
- .join(self.resource_model)
- .filter(
- self.action_model.name == action_name,
- self.role_model.id.in_(role_ids),
- )
- ).all()
-
def create_action(self, name):
"""
Adds an action to the backend, model action.
@@ -341,7 +322,7 @@ class SecurityManager(BaseSecurityManager):
self.get_session.commit()
return action
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_PERMISSION.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_ADD_PERMISSION.format(str(e)))
self.get_session.rollback()
return action
@@ -350,11 +331,10 @@ class SecurityManager(BaseSecurityManager):
Deletes a permission action.
:param name: Name of action to delete (e.g. can_read).
- :return: Whether or not delete was successful.
"""
action = self.get_action(name)
if not action:
- log.warning(c.LOGMSG_WAR_SEC_DEL_PERMISSION.format(name))
+ log.warning(const.LOGMSG_WAR_SEC_DEL_PERMISSION.format(name))
return False
try:
perms = (
@@ -363,33 +343,30 @@ class SecurityManager(BaseSecurityManager):
.all()
)
if perms:
- log.warning(c.LOGMSG_WAR_SEC_DEL_PERM_PVM.format(action,
perms))
+ log.warning(const.LOGMSG_WAR_SEC_DEL_PERM_PVM.format(action,
perms))
return False
self.get_session.delete(action)
self.get_session.commit()
return True
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMISSION.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_DEL_PERMISSION.format(str(e)))
self.get_session.rollback()
return False
+ """
+ ---------------
+ Resource entity
+ ---------------
+ """
+
def get_resource(self, name: str) -> Resource:
"""
Returns a resource record by name, if it exists.
:param name: Name of resource
- :return: Resource record
"""
return
self.get_session.query(self.resource_model).filter_by(name=name).one_or_none()
- def get_all_resources(self) -> list[Resource]:
- """
- Gets all existing resource records.
-
- :return: List of all resources
- """
- return self.get_session.query(self.resource_model).all()
-
def create_resource(self, name) -> Resource:
"""
Create a resource with the given name.
@@ -406,10 +383,18 @@ class SecurityManager(BaseSecurityManager):
self.get_session.commit()
return resource
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_VIEWMENU.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_ADD_VIEWMENU.format(str(e)))
self.get_session.rollback()
return resource
+ def get_all_resources(self) -> list[Resource]:
+ """
+ Gets all existing resource records.
+
+ :return: List of all resources
+ """
+ return self.get_session.query(self.resource_model).all()
+
def delete_resource(self, name: str) -> bool:
"""
Deletes a Resource from the backend.
@@ -419,7 +404,7 @@ class SecurityManager(BaseSecurityManager):
"""
resource = self.get_resource(name)
if not resource:
- log.warning(c.LOGMSG_WAR_SEC_DEL_VIEWMENU.format(name))
+ log.warning(const.LOGMSG_WAR_SEC_DEL_VIEWMENU.format(name))
return False
try:
perms = (
@@ -428,20 +413,20 @@ class SecurityManager(BaseSecurityManager):
.all()
)
if perms:
- log.warning(c.LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM.format(resource,
perms))
+
log.warning(const.LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM.format(resource, perms))
return False
self.get_session.delete(resource)
self.get_session.commit()
return True
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMISSION.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_DEL_PERMISSION.format(str(e)))
self.get_session.rollback()
return False
"""
- ----------------------
- PERMISSION VIEW MENU
- ----------------------
+ ---------------
+ Permission entity
+ ---------------
"""
def get_permission(
@@ -496,10 +481,10 @@ class SecurityManager(BaseSecurityManager):
try:
self.get_session.add(perm)
self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_PERMVIEW.format(str(perm)))
+ log.info(const.LOGMSG_INF_SEC_ADD_PERMVIEW.format(str(perm)))
return perm
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_PERMVIEW.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_ADD_PERMVIEW.format(str(e)))
self.get_session.rollback()
return None
@@ -522,7 +507,7 @@ class SecurityManager(BaseSecurityManager):
self.get_session.query(self.role_model).filter(self.role_model.permissions.contains(perm)).first()
)
if roles:
- log.warning(c.LOGMSG_WAR_SEC_DEL_PERMVIEW.format(resource_name,
action_name, roles))
+
log.warning(const.LOGMSG_WAR_SEC_DEL_PERMVIEW.format(resource_name,
action_name, roles))
return
try:
# delete permission on resource
@@ -531,17 +516,11 @@ class SecurityManager(BaseSecurityManager):
# if no more permission on permission view, delete permission
if not
self.get_session.query(self.permission_model).filter_by(action=perm.action).all():
self.delete_action(perm.action.name)
- log.info(c.LOGMSG_INF_SEC_DEL_PERMVIEW.format(action_name,
resource_name))
+ log.info(const.LOGMSG_INF_SEC_DEL_PERMVIEW.format(action_name,
resource_name))
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMVIEW.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_DEL_PERMVIEW.format(str(e)))
self.get_session.rollback()
- def perms_include_action(self, perms, action_name):
- for perm in perms:
- if perm.action and perm.action.name == action_name:
- return True
- return False
-
def add_permission_to_role(self, role: Role, permission: Permission |
None) -> None:
"""
Add an existing permission pair to a role.
@@ -555,9 +534,9 @@ class SecurityManager(BaseSecurityManager):
role.permissions.append(permission)
self.get_session.merge(role)
self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_PERMROLE.format(str(permission),
role.name))
+
log.info(const.LOGMSG_INF_SEC_ADD_PERMROLE.format(str(permission), role.name))
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_PERMROLE.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_ADD_PERMROLE.format(str(e)))
self.get_session.rollback()
def remove_permission_from_role(self, role: Role, permission: Permission)
-> None:
@@ -572,7 +551,7 @@ class SecurityManager(BaseSecurityManager):
role.permissions.remove(permission)
self.get_session.merge(role)
self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_DEL_PERMROLE.format(str(permission),
role.name))
+
log.info(const.LOGMSG_INF_SEC_DEL_PERMROLE.format(str(permission), role.name))
except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMROLE.format(str(e)))
+ log.error(const.LOGMSG_ERR_SEC_DEL_PERMROLE.format(str(e)))
self.get_session.rollback()
diff --git a/airflow/auth/managers/fab/security_manager/modules/oauth.py
b/airflow/auth/managers/fab/security_manager/modules/oauth.py
new file mode 100644
index 0000000000..ad5bffff77
--- /dev/null
+++ b/airflow/auth/managers/fab/security_manager/modules/oauth.py
@@ -0,0 +1,186 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import base64
+import json
+import logging
+
+import re2
+from flask import session
+
+log = logging.getLogger(__name__)
+
+
+class FabAirflowSecurityManagerOverrideOauth:
+ """
+ This class contains all methods in
+
airflow.auth.managers.fab.security_manager.override.FabAirflowSecurityManagerOverride
related to the
+ oauth authentication.
+
+ FabAirflowSecurityManagerOverride is split into multiple classes to avoid
having one massive class.
+ """
+
+ def get_oauth_user_info(self, provider, resp):
+ """
+ Get the OAuth user information from different OAuth APIs.
+
+ All providers have different ways to retrieve user info.
+ """
+ # for GITHUB
+ if provider == "github" or provider == "githublocal":
+ me = self.oauth_remotes[provider].get("user")
+ data = me.json()
+ log.debug("User info from GitHub: %s", data)
+ return {"username": "github_" + data.get("login")}
+ # for twitter
+ if provider == "twitter":
+ me = self.oauth_remotes[provider].get("account/settings.json")
+ data = me.json()
+ log.debug("User info from Twitter: %s", data)
+ return {"username": "twitter_" + data.get("screen_name", "")}
+ # for linkedin
+ if provider == "linkedin":
+ me = self.oauth_remotes[provider].get(
+ "people/~:(id,email-address,first-name,last-name)?format=json"
+ )
+ data = me.json()
+ log.debug("User info from LinkedIn: %s", data)
+ return {
+ "username": "linkedin_" + data.get("id", ""),
+ "email": data.get("email-address", ""),
+ "first_name": data.get("firstName", ""),
+ "last_name": data.get("lastName", ""),
+ }
+ # for Google
+ if provider == "google":
+ me = self.oauth_remotes[provider].get("userinfo")
+ data = me.json()
+ log.debug("User info from Google: %s", data)
+ return {
+ "username": "google_" + data.get("id", ""),
+ "first_name": data.get("given_name", ""),
+ "last_name": data.get("family_name", ""),
+ "email": data.get("email", ""),
+ }
+ # for Azure AD Tenant. Azure OAuth response contains
+ # JWT token which has user info.
+ # JWT token needs to be base64 decoded.
+ # https://docs.microsoft.com/en-us/azure/active-directory/develop/
+ # active-directory-protocols-oauth-code
+ if provider == "azure":
+ log.debug("Azure response received : %s", resp)
+ id_token = resp["id_token"]
+ log.debug(str(id_token))
+ me =
FabAirflowSecurityManagerOverrideOauth._azure_jwt_token_parse(id_token)
+ log.debug("Parse JWT token : %s", me)
+ return {
+ "name": me.get("name", ""),
+ "email": me["upn"],
+ "first_name": me.get("given_name", ""),
+ "last_name": me.get("family_name", ""),
+ "id": me["oid"],
+ "username": me["oid"],
+ "role_keys": me.get("roles", []),
+ }
+ # for OpenShift
+ if provider == "openshift":
+ me =
self.oauth_remotes[provider].get("apis/user.openshift.io/v1/users/~")
+ data = me.json()
+ log.debug("User info from OpenShift: %s", data)
+ return {"username": "openshift_" +
data.get("metadata").get("name")}
+ # for Okta
+ if provider == "okta":
+ me = self.oauth_remotes[provider].get("userinfo")
+ data = me.json()
+ log.debug("User info from Okta: %s", data)
+ return {
+ "username": "okta_" + data.get("sub", ""),
+ "first_name": data.get("given_name", ""),
+ "last_name": data.get("family_name", ""),
+ "email": data.get("email", ""),
+ "role_keys": data.get("groups", []),
+ }
+ # for Keycloak
+ if provider in ["keycloak", "keycloak_before_17"]:
+ me = self.oauth_remotes[provider].get("openid-connect/userinfo")
+ me.raise_for_status()
+ data = me.json()
+ log.debug("User info from Keycloak: %s", data)
+ return {
+ "username": data.get("preferred_username", ""),
+ "first_name": data.get("given_name", ""),
+ "last_name": data.get("family_name", ""),
+ "email": data.get("email", ""),
+ }
+ else:
+ return {}
+
+ @staticmethod
+ def oauth_token_getter():
+ """Authentication (OAuth) token getter function."""
+ token = session.get("oauth")
+ log.debug("Token Get: %s", token)
+ return token
+
+ @staticmethod
+ def _azure_parse_jwt(token):
+ """
+ Parse Azure JWT token content.
+
+ :param token: the JWT token
+
+ :meta private:
+ """
+ jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
+ matches = re2.search(jwt_token_parts, token)
+ if not matches or len(matches.groups()) < 3:
+ log.error("Unable to parse token.")
+ return {}
+ return {
+ "header": matches.group(1),
+ "Payload": matches.group(2),
+ "Sig": matches.group(3),
+ }
+
+ @staticmethod
+ def _azure_jwt_token_parse(self, token):
+ """
+ Parse and decode Azure JWT token.
+
+ :param token: the JWT token
+
+ :meta private:
+ """
+ jwt_split_token =
FabAirflowSecurityManagerOverrideOauth._azure_parse_jwt(token)
+ if not jwt_split_token:
+ return
+
+ jwt_payload = jwt_split_token["Payload"]
+ # Prepare for base64 decoding
+ payload_b64_string = jwt_payload
+ payload_b64_string += "=" * (4 - (len(jwt_payload) % 4))
+ decoded_payload =
base64.urlsafe_b64decode(payload_b64_string.encode("ascii"))
+
+ if not decoded_payload:
+ log.error("Payload of id_token could not be base64 url decoded.")
+ return
+
+ jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8"))
+
+ return jwt_decoded_payload
diff --git a/airflow/auth/managers/fab/security_manager/override.py
b/airflow/auth/managers/fab/security_manager/override.py
index 089b449422..bdcc44d8c1 100644
--- a/airflow/auth/managers/fab/security_manager/override.py
+++ b/airflow/auth/managers/fab/security_manager/override.py
@@ -17,10 +17,14 @@
# under the License.
from __future__ import annotations
+import logging
+import warnings
from functools import cached_property
from flask import flash, g
+from flask_appbuilder import const
from flask_appbuilder.const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID,
AUTH_REMOTE_USER
+from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext
from flask_jwt_extended import JWTManager
from flask_login import LoginManager
@@ -28,10 +32,14 @@ from itsdangerous import want_bytes
from markupsafe import Markup
from werkzeug.security import generate_password_hash
-from airflow.auth.managers.fab.models import User
+from airflow.auth.managers.fab.models import Action, Permission, RegisterUser,
Resource, Role, User
from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser
+from airflow.auth.managers.fab.security_manager.modules.db import
FabAirflowSecurityManagerOverrideDb
+from airflow.auth.managers.fab.security_manager.modules.oauth import
FabAirflowSecurityManagerOverrideOauth
from airflow.www.session import AirflowDatabaseSessionInterface
+log = logging.getLogger(__name__)
+
# This is the limit of DB user sessions that we consider as "healthy". If you
have more sessions that this
# number then we will refuse to delete sessions that have expired and old user
sessions when resetting
# user's password, and raise a warning in the UI instead. Usually when you
have that many sessions, it means
@@ -41,7 +49,9 @@ from airflow.www.session import
AirflowDatabaseSessionInterface
MAX_NUM_DATABASE_USER_SESSIONS = 50000
-class FabAirflowSecurityManagerOverride:
+class FabAirflowSecurityManagerOverride(
+ FabAirflowSecurityManagerOverrideDb, FabAirflowSecurityManagerOverrideOauth
+):
"""
This security manager overrides the default AirflowSecurityManager
security manager.
@@ -79,6 +89,15 @@ class FabAirflowSecurityManagerOverride:
auth_view = None
""" The obj instance for user view """
user_view = None
+ """ Models """
+ role_model = Role
+ action_model = Action
+ resource_model = Resource
+ permission_model = Permission
+ registeruser_model = RegisterUser
+
+ """ Initialized (remote_app) providers dict {'provider_name', OBJ } """
+ oauth_allow_list: dict[str, list] = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -108,6 +127,14 @@ class FabAirflowSecurityManagerOverride:
self.userremoteusermodelview = kwargs["userremoteusermodelview"]
self.userstatschartview = kwargs["userstatschartview"]
+ self._init_config()
+ self._init_auth()
+ self._init_data_model()
+
+ self._builtin_roles: dict = self.create_builtin_roles()
+
+ self.create_db()
+
# Setup Flask login
self.lm = self.create_login_manager()
@@ -294,32 +321,152 @@ class FabAirflowSecurityManagerOverride:
g.user = user
return user
- def get_user_by_id(self, pk):
- return self.appbuilder.get_session.get(self.user_model, pk)
-
@property
def auth_user_registration(self):
"""Will user self registration be allowed."""
- return self.appbuilder.get_app.config["AUTH_USER_REGISTRATION"]
+ return self.appbuilder.app.config["AUTH_USER_REGISTRATION"]
@property
def auth_type(self):
"""Get the auth type."""
- return self.appbuilder.get_app.config["AUTH_TYPE"]
+ return self.appbuilder.app.config["AUTH_TYPE"]
@property
def is_auth_limited(self) -> bool:
"""Is the auth rate limited."""
- return self.appbuilder.get_app.config["AUTH_RATE_LIMITED"]
+ return self.appbuilder.app.config["AUTH_RATE_LIMITED"]
@property
def auth_rate_limit(self) -> str:
"""Get the auth rate limit."""
- return self.appbuilder.get_app.config["AUTH_RATE_LIMIT"]
+ return self.appbuilder.app.config["AUTH_RATE_LIMIT"]
@cached_property
def resourcemodelview(self):
"""Return the resource model view."""
- from airflow.www.views import ResourceModelView
+ from airflow.auth.managers.fab.views.permissions import
ResourceModelView
return ResourceModelView
+
+ @property
+ def auth_role_public(self):
+ """Gets the public role."""
+ return self.appbuilder.app.config["AUTH_ROLE_PUBLIC"]
+
+ @property
+ def oauth_providers(self):
+ """Oauth providers."""
+ return self.appbuilder.app.config["OAUTH_PROVIDERS"]
+
+ @property
+ def oauth_whitelists(self):
+ warnings.warn(
+ "The 'oauth_whitelists' property is deprecated. Please use
'oauth_allow_list' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.oauth_allow_list
+
+ def create_builtin_roles(self):
+ """Returns FAB builtin roles."""
+ return self.appbuilder.app.config.get("FAB_ROLES", {})
+
+ def _init_config(self):
+ """
+ Initialize config.
+
+ :meta private:
+ """
+ app = self.appbuilder.get_app
+ # Base Security Config
+ app.config.setdefault("AUTH_ROLE_ADMIN", "Admin")
+ app.config.setdefault("AUTH_ROLE_PUBLIC", "Public")
+ app.config.setdefault("AUTH_TYPE", AUTH_DB)
+ # Self Registration
+ app.config.setdefault("AUTH_USER_REGISTRATION", False)
+ app.config.setdefault("AUTH_USER_REGISTRATION_ROLE",
self.auth_role_public)
+ app.config.setdefault("AUTH_USER_REGISTRATION_ROLE_JMESPATH", None)
+ # Role Mapping
+ app.config.setdefault("AUTH_ROLES_MAPPING", {})
+ app.config.setdefault("AUTH_ROLES_SYNC_AT_LOGIN", False)
+ app.config.setdefault("AUTH_API_LOGIN_ALLOW_MULTIPLE_PROVIDERS", False)
+
+ # LDAP Config
+ if self.auth_type == AUTH_LDAP:
+ if "AUTH_LDAP_SERVER" not in app.config:
+ raise Exception("No AUTH_LDAP_SERVER defined on config with
AUTH_LDAP authentication type.")
+ app.config.setdefault("AUTH_LDAP_SEARCH", "")
+ app.config.setdefault("AUTH_LDAP_SEARCH_FILTER", "")
+ app.config.setdefault("AUTH_LDAP_APPEND_DOMAIN", "")
+ app.config.setdefault("AUTH_LDAP_USERNAME_FORMAT", "")
+ app.config.setdefault("AUTH_LDAP_BIND_USER", "")
+ app.config.setdefault("AUTH_LDAP_BIND_PASSWORD", "")
+ # TLS options
+ app.config.setdefault("AUTH_LDAP_USE_TLS", False)
+ app.config.setdefault("AUTH_LDAP_ALLOW_SELF_SIGNED", False)
+ app.config.setdefault("AUTH_LDAP_TLS_DEMAND", False)
+ app.config.setdefault("AUTH_LDAP_TLS_CACERTDIR", "")
+ app.config.setdefault("AUTH_LDAP_TLS_CACERTFILE", "")
+ app.config.setdefault("AUTH_LDAP_TLS_CERTFILE", "")
+ app.config.setdefault("AUTH_LDAP_TLS_KEYFILE", "")
+ # Mapping options
+ app.config.setdefault("AUTH_LDAP_UID_FIELD", "uid")
+ app.config.setdefault("AUTH_LDAP_GROUP_FIELD", "memberOf")
+ app.config.setdefault("AUTH_LDAP_FIRSTNAME_FIELD", "givenName")
+ app.config.setdefault("AUTH_LDAP_LASTNAME_FIELD", "sn")
+ app.config.setdefault("AUTH_LDAP_EMAIL_FIELD", "mail")
+
+ # Rate limiting
+ app.config.setdefault("AUTH_RATE_LIMITED", True)
+ app.config.setdefault("AUTH_RATE_LIMIT", "5 per 40 second")
+
+ def _init_auth(self):
+ """
+ Initialize authentication configuration.
+
+ :meta private:
+ """
+ app = self.appbuilder.get_app
+ if self.auth_type == AUTH_OID:
+ from flask_openid import OpenID
+
+ self.oid = OpenID(app)
+ if self.auth_type == AUTH_OAUTH:
+ from authlib.integrations.flask_client import OAuth
+
+ self.oauth = OAuth(app)
+ self.oauth_remotes = {}
+ for provider in self.oauth_providers:
+ provider_name = provider["name"]
+ log.debug("OAuth providers init %s", provider_name)
+ obj_provider = self.oauth.register(provider_name,
**provider["remote_app"])
+ obj_provider._tokengetter =
FabAirflowSecurityManagerOverrideOauth.oauth_token_getter
+ if not self.oauth_user_info:
+ self.oauth_user_info = self.get_oauth_user_info
+ # Whitelist only users with matching emails
+ if "whitelist" in provider:
+ self.oauth_allow_list[provider_name] =
provider["whitelist"]
+ self.oauth_remotes[provider_name] = obj_provider
+
+ def _init_data_model(self):
+ user_data_model = SQLAInterface(self.user_model)
+ if self.auth_type == const.AUTH_DB:
+ self.userdbmodelview.datamodel = user_data_model
+ elif self.auth_type == const.AUTH_LDAP:
+ self.userldapmodelview.datamodel = user_data_model
+ elif self.auth_type == const.AUTH_OID:
+ self.useroidmodelview.datamodel = user_data_model
+ elif self.auth_type == const.AUTH_OAUTH:
+ self.useroauthmodelview.datamodel = user_data_model
+ elif self.auth_type == const.AUTH_REMOTE_USER:
+ self.userremoteusermodelview.datamodel = user_data_model
+
+ if self.userstatschartview:
+ self.userstatschartview.datamodel = user_data_model
+ if self.auth_user_registration:
+ self.registerusermodelview.datamodel =
SQLAInterface(self.registeruser_model)
+
+ self.rolemodelview.datamodel = SQLAInterface(self.role_model)
+ self.actionmodelview.datamodel = SQLAInterface(self.action_model)
+ self.resourcemodelview.datamodel = SQLAInterface(self.resource_model)
+ self.permissionmodelview.datamodel =
SQLAInterface(self.permission_model)
diff --git a/airflow/www/fab_security/manager.py
b/airflow/www/fab_security/manager.py
index 223ffbffc5..0e08c93d91 100644
--- a/airflow/www/fab_security/manager.py
+++ b/airflow/www/fab_security/manager.py
@@ -18,9 +18,7 @@
# mypy: disable-error-code=var-annotated
from __future__ import annotations
-import base64
import datetime
-import json
import logging
from typing import Any
from uuid import uuid4
@@ -31,13 +29,10 @@ from flask_appbuilder import AppBuilder
from flask_appbuilder.const import (
AUTH_DB,
AUTH_LDAP,
- AUTH_OAUTH,
- AUTH_OID,
LOGMSG_ERR_SEC_ADD_REGISTER_USER,
LOGMSG_ERR_SEC_AUTH_LDAP,
LOGMSG_ERR_SEC_AUTH_LDAP_TLS,
LOGMSG_WAR_SEC_LOGIN_FAILED,
- LOGMSG_WAR_SEC_NO_USER,
LOGMSG_WAR_SEC_NOLDAP_OBJ,
)
from flask_appbuilder.security.registerviews import (
@@ -56,7 +51,6 @@ from flask_appbuilder.security.views import (
ResetMyPasswordView,
ResetPasswordView,
RoleModelView,
- UserDBModelView,
UserInfoEditView,
UserLDAPModelView,
UserOAuthModelView,
@@ -107,16 +101,6 @@ class BaseSecurityManager:
""" Flask-OAuth """
oauth_remotes: dict[str, Any]
""" OAuth email whitelists """
- oauth_whitelists: dict[str, list] = {}
- """ Initialized (remote_app) providers dict {'provider_name', OBJ } """
-
- @staticmethod
- def oauth_tokengetter(token=None):
- """Authentication (OAuth) token getter function.
-
- Override to implement your own token getter method.
- """
- return _oauth_tokengetter(token)
oauth_user_info = None
@@ -133,8 +117,6 @@ class BaseSecurityManager:
registeruser_model: type[RegisterUser]
""" Override to set your own RegisterUser Model """
- userdbmodelview = UserDBModelView
- """ Override if you want your own user db view """
userldapmodelview = UserLDAPModelView
""" Override if you want your own user ldap view """
useroidmodelview = UserOIDModelView
@@ -178,70 +160,6 @@ class BaseSecurityManager:
def __init__(self, appbuilder):
self.appbuilder = appbuilder
app = self.appbuilder.get_app
- # Base Security Config
- app.config.setdefault("AUTH_ROLE_ADMIN", "Admin")
- app.config.setdefault("AUTH_ROLE_PUBLIC", "Public")
- app.config.setdefault("AUTH_TYPE", AUTH_DB)
- # Self Registration
- app.config.setdefault("AUTH_USER_REGISTRATION", False)
- app.config.setdefault("AUTH_USER_REGISTRATION_ROLE",
self.auth_role_public)
- app.config.setdefault("AUTH_USER_REGISTRATION_ROLE_JMESPATH", None)
- # Role Mapping
- app.config.setdefault("AUTH_ROLES_MAPPING", {})
- app.config.setdefault("AUTH_ROLES_SYNC_AT_LOGIN", False)
- app.config.setdefault("AUTH_API_LOGIN_ALLOW_MULTIPLE_PROVIDERS", False)
-
- # LDAP Config
- if self.auth_type == AUTH_LDAP:
- if "AUTH_LDAP_SERVER" not in app.config:
- raise Exception("No AUTH_LDAP_SERVER defined on config with
AUTH_LDAP authentication type.")
- app.config.setdefault("AUTH_LDAP_SEARCH", "")
- app.config.setdefault("AUTH_LDAP_SEARCH_FILTER", "")
- app.config.setdefault("AUTH_LDAP_APPEND_DOMAIN", "")
- app.config.setdefault("AUTH_LDAP_USERNAME_FORMAT", "")
- app.config.setdefault("AUTH_LDAP_BIND_USER", "")
- app.config.setdefault("AUTH_LDAP_BIND_PASSWORD", "")
- # TLS options
- app.config.setdefault("AUTH_LDAP_USE_TLS", False)
- app.config.setdefault("AUTH_LDAP_ALLOW_SELF_SIGNED", False)
- app.config.setdefault("AUTH_LDAP_TLS_DEMAND", False)
- app.config.setdefault("AUTH_LDAP_TLS_CACERTDIR", "")
- app.config.setdefault("AUTH_LDAP_TLS_CACERTFILE", "")
- app.config.setdefault("AUTH_LDAP_TLS_CERTFILE", "")
- app.config.setdefault("AUTH_LDAP_TLS_KEYFILE", "")
- # Mapping options
- app.config.setdefault("AUTH_LDAP_UID_FIELD", "uid")
- app.config.setdefault("AUTH_LDAP_GROUP_FIELD", "memberOf")
- app.config.setdefault("AUTH_LDAP_FIRSTNAME_FIELD", "givenName")
- app.config.setdefault("AUTH_LDAP_LASTNAME_FIELD", "sn")
- app.config.setdefault("AUTH_LDAP_EMAIL_FIELD", "mail")
-
- # Rate limiting
- app.config.setdefault("AUTH_RATE_LIMITED", True)
- app.config.setdefault("AUTH_RATE_LIMIT", "5 per 40 second")
-
- if self.auth_type == AUTH_OID:
- from flask_openid import OpenID
-
- self.oid = OpenID(app)
- if self.auth_type == AUTH_OAUTH:
- from authlib.integrations.flask_client import OAuth
-
- self.oauth = OAuth(app)
- self.oauth_remotes = {}
- for _provider in self.oauth_providers:
- provider_name = _provider["name"]
- log.debug("OAuth providers init %s", provider_name)
- obj_provider = self.oauth.register(provider_name,
**_provider["remote_app"])
- obj_provider._tokengetter = self.oauth_tokengetter
- if not self.oauth_user_info:
- self.oauth_user_info = self.get_oauth_user_info
- # Whitelist only users with matching emails
- if "whitelist" in _provider:
- self.oauth_whitelists[provider_name] =
_provider["whitelist"]
- self.oauth_remotes[provider_name] = obj_provider
-
- self._builtin_roles = self.create_builtin_roles()
# Setup Flask-Limiter
self.limiter = self.create_limiter(app)
@@ -251,10 +169,6 @@ class BaseSecurityManager:
limiter.init_app(app)
return limiter
- def create_builtin_roles(self):
- """Returns FAB builtin roles."""
- return self.appbuilder.get_app.config.get("FAB_ROLES", {})
-
def get_roles_from_keys(self, role_keys: list[str]) -> set[Role]:
"""
Construct a list of FAB role objects, from a list of keys.
@@ -317,11 +231,6 @@ class BaseSecurityManager:
"""Gets the admin role."""
return self.appbuilder.get_app.config["AUTH_ROLE_ADMIN"]
- @property
- def auth_role_public(self):
- """Gets the public role."""
- return self.appbuilder.get_app.config["AUTH_ROLE_PUBLIC"]
-
@property
def auth_ldap_server(self):
"""Gets the LDAP server object."""
@@ -452,11 +361,6 @@ class BaseSecurityManager:
"""Openid providers."""
return self.appbuilder.get_app.config["OPENID_PROVIDERS"]
- @property
- def oauth_providers(self):
- """Oauth providers."""
- return self.appbuilder.get_app.config["OAUTH_PROVIDERS"]
-
@property
def current_user(self):
"""Current user object."""
@@ -527,144 +431,6 @@ class BaseSecurityManager:
)
session["oauth_provider"] = provider
- def get_oauth_user_info(self, provider, resp):
- """Get the OAuth user information from different OAuth APIs.
-
- All providers have different ways to retrieve user info.
- """
- # for GITHUB
- if provider == "github" or provider == "githublocal":
- me = self.appbuilder.sm.oauth_remotes[provider].get("user")
- data = me.json()
- log.debug("User info from GitHub: %s", data)
- return {"username": "github_" + data.get("login")}
- # for twitter
- if provider == "twitter":
- me =
self.appbuilder.sm.oauth_remotes[provider].get("account/settings.json")
- data = me.json()
- log.debug("User info from Twitter: %s", data)
- return {"username": "twitter_" + data.get("screen_name", "")}
- # for linkedin
- if provider == "linkedin":
- me = self.appbuilder.sm.oauth_remotes[provider].get(
- "people/~:(id,email-address,first-name,last-name)?format=json"
- )
- data = me.json()
- log.debug("User info from LinkedIn: %s", data)
- return {
- "username": "linkedin_" + data.get("id", ""),
- "email": data.get("email-address", ""),
- "first_name": data.get("firstName", ""),
- "last_name": data.get("lastName", ""),
- }
- # for Google
- if provider == "google":
- me = self.appbuilder.sm.oauth_remotes[provider].get("userinfo")
- data = me.json()
- log.debug("User info from Google: %s", data)
- return {
- "username": "google_" + data.get("id", ""),
- "first_name": data.get("given_name", ""),
- "last_name": data.get("family_name", ""),
- "email": data.get("email", ""),
- }
- # for Azure AD Tenant. Azure OAuth response contains
- # JWT token which has user info.
- # JWT token needs to be base64 decoded.
- # https://docs.microsoft.com/en-us/azure/active-directory/develop/
- # active-directory-protocols-oauth-code
- if provider == "azure":
- log.debug("Azure response received : %s", resp)
- id_token = resp["id_token"]
- log.debug(str(id_token))
- me = self._azure_jwt_token_parse(id_token)
- log.debug("Parse JWT token : %s", me)
- return {
- "name": me.get("name", ""),
- "email": me["upn"],
- "first_name": me.get("given_name", ""),
- "last_name": me.get("family_name", ""),
- "id": me["oid"],
- "username": me["oid"],
- "role_keys": me.get("roles", []),
- }
- # for OpenShift
- if provider == "openshift":
- me =
self.appbuilder.sm.oauth_remotes[provider].get("apis/user.openshift.io/v1/users/~")
- data = me.json()
- log.debug("User info from OpenShift: %s", data)
- return {"username": "openshift_" +
data.get("metadata").get("name")}
- # for Okta
- if provider == "okta":
- me = self.appbuilder.sm.oauth_remotes[provider].get("userinfo")
- data = me.json()
- log.debug("User info from Okta: %s", data)
- return {
- "username": "okta_" + data.get("sub", ""),
- "first_name": data.get("given_name", ""),
- "last_name": data.get("family_name", ""),
- "email": data.get("email", ""),
- "role_keys": data.get("groups", []),
- }
- # for Keycloak
- if provider in ["keycloak", "keycloak_before_17"]:
- me =
self.appbuilder.sm.oauth_remotes[provider].get("openid-connect/userinfo")
- me.raise_for_status()
- data = me.json()
- log.debug("User info from Keycloak: %s", data)
- return {
- "username": data.get("preferred_username", ""),
- "first_name": data.get("given_name", ""),
- "last_name": data.get("family_name", ""),
- "email": data.get("email", ""),
- }
- else:
- return {}
-
- def _azure_parse_jwt(self, id_token):
- jwt_token_parts = r"^([^\.\s]*)\.([^\.\s]+)\.([^\.\s]*)$"
- matches = re2.search(jwt_token_parts, id_token)
- if not matches or len(matches.groups()) < 3:
- log.error("Unable to parse token.")
- return {}
- return {
- "header": matches.group(1),
- "Payload": matches.group(2),
- "Sig": matches.group(3),
- }
-
- def _azure_jwt_token_parse(self, id_token):
- jwt_split_token = self._azure_parse_jwt(id_token)
- if not jwt_split_token:
- return
-
- jwt_payload = jwt_split_token["Payload"]
- # Prepare for base64 decoding
- payload_b64_string = jwt_payload
- payload_b64_string += "=" * (4 - (len(jwt_payload) % 4))
- decoded_payload =
base64.urlsafe_b64decode(payload_b64_string.encode("ascii"))
-
- if not decoded_payload:
- log.error("Payload of id_token could not be base64 url decoded.")
- return
-
- jwt_decoded_payload = json.loads(decoded_payload.decode("utf-8"))
-
- return jwt_decoded_payload
-
- def create_db(self):
- """Setups the DB, creates admin and public roles if they don't
exist."""
- roles_mapping =
self.appbuilder.get_app.config.get("FAB_ROLES_MAPPING", {})
- for pk, name in roles_mapping.items():
- self.update_role(pk, name)
- for role_name in self.builtin_roles:
- self.add_role(role_name)
- if self.auth_role_admin not in self.builtin_roles:
- self.add_role(self.auth_role_admin)
- self.add_role(self.auth_role_public)
- if self.count_users() == 0 and self.auth_role_public !=
self.auth_role_admin:
- log.warning(LOGMSG_WAR_SEC_NO_USER)
-
def update_user_auth_stat(self, user, success=True):
"""Update user authentication stats.
@@ -857,7 +623,7 @@ class BaseSecurityManager:
"""
Method for authenticating user with LDAP.
- NOTE: this depends on python-ldap module
+ NOTE: this depends on python-ldap module.
:param username: the username
:param password: the password
@@ -1329,30 +1095,10 @@ class BaseSecurityManager:
self.delete_permission(permission.action.name,
resource.name)
self.delete_resource(resource.name)
- def find_register_user(self, registration_hash):
- """Generic function to return user registration."""
- raise NotImplementedError
-
- def add_register_user(self, username, first_name, last_name, email,
password="", hashed_password=""):
- """Generic function to add user registration."""
- raise NotImplementedError
-
- def del_register_user(self, register_user):
- """Generic function to delete user registration."""
- raise NotImplementedError
-
- def get_user_by_id(self, pk):
- """Generic function to return user by it's id (pk)."""
- raise NotImplementedError
-
def find_user(self, username=None, email=None):
"""Generic function find a user by it's username or email."""
raise NotImplementedError
- def get_all_users(self):
- """Generic function that returns all existing users."""
- raise NotImplementedError
-
def get_role_permissions_from_db(self, role_id: int) -> list[Permission]:
"""Get all DB permissions from a role id."""
raise NotImplementedError
@@ -1369,19 +1115,9 @@ class BaseSecurityManager:
"""
raise NotImplementedError
- def count_users(self):
- """Generic function to count the existing users."""
- raise NotImplementedError
-
def find_role(self, name):
raise NotImplementedError
- def add_role(self, name):
- raise NotImplementedError
-
- def update_role(self, role_id, name):
- raise NotImplementedError
-
def get_all_roles(self):
raise NotImplementedError
@@ -1389,15 +1125,6 @@ class BaseSecurityManager:
"""Returns all permissions from public role."""
raise NotImplementedError
- def get_action(self, name: str) -> Action:
- """
- Gets an existing action record.
-
- :param name: name
- :return: Action record, if it exists
- """
- raise NotImplementedError
-
def filter_roles_by_perm_with_action(self, permission_name: str, role_ids:
list[int]):
raise NotImplementedError
@@ -1407,38 +1134,12 @@ class BaseSecurityManager:
"""Finds and returns permission views for a group of roles."""
raise NotImplementedError
- def create_action(self, name):
- """
- Adds a permission to the backend, model permission.
-
- :param name:
- name of the permission: 'can_add','can_edit' etc...
- """
- raise NotImplementedError
-
- def delete_action(self, name: str) -> bool:
- """
- Deletes a permission action.
-
- :param name: Name of action to delete (e.g. can_read).
- :return: Whether or not delete was successful.
- """
- raise NotImplementedError
-
"""
----------------------
PRIMITIVES VIEW MENU
----------------------
"""
- def get_resource(self, name: str):
- """
- Returns a resource record by name, if it exists.
-
- :param name: Name of resource
- """
- raise NotImplementedError
-
def get_all_resources(self) -> list[Resource]:
"""
Gets all existing resource records.
diff --git a/airflow/www/fab_security/sqla/manager.py
b/airflow/www/fab_security/sqla/manager.py
index 83e2119492..5084a24be8 100644
--- a/airflow/www/fab_security/sqla/manager.py
+++ b/airflow/www/fab_security/sqla/manager.py
@@ -17,14 +17,8 @@
from __future__ import annotations
import logging
-import uuid
-from flask_appbuilder import const as c
-from flask_appbuilder.models.sqla import Base
-from flask_appbuilder.models.sqla.interface import SQLAInterface
-from sqlalchemy import and_, func, inspect, literal
-from sqlalchemy.orm.exc import MultipleResultsFound
-from werkzeug.security import generate_password_hash
+from sqlalchemy import and_, literal
from airflow.auth.managers.fab.models import (
Action,
@@ -64,218 +58,11 @@ class SecurityManager(BaseSecurityManager):
:param appbuilder: F.A.B AppBuilder main object
"""
super().__init__(appbuilder)
- user_datamodel = SQLAInterface(self.user_model)
- if self.auth_type == c.AUTH_DB:
- self.userdbmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_LDAP:
- self.userldapmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_OID:
- self.useroidmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_OAUTH:
- self.useroauthmodelview.datamodel = user_datamodel
- elif self.auth_type == c.AUTH_REMOTE_USER:
- self.userremoteusermodelview.datamodel = user_datamodel
-
- if self.userstatschartview:
- self.userstatschartview.datamodel = user_datamodel
- if self.auth_user_registration:
- self.registerusermodelview.datamodel =
SQLAInterface(self.registeruser_model)
-
- self.rolemodelview.datamodel = SQLAInterface(self.role_model)
- self.actionmodelview.datamodel = SQLAInterface(self.action_model)
- self.resourcemodelview.datamodel = SQLAInterface(self.resource_model)
- self.permissionmodelview.datamodel =
SQLAInterface(self.permission_model)
- self.create_db()
@property
def get_session(self):
return self.appbuilder.get_session
- def register_views(self):
- super().register_views()
-
- def create_db(self):
- try:
- engine = self.get_session.get_bind(mapper=None, clause=None)
- inspector = inspect(engine)
- if "ab_user" not in inspector.get_table_names():
- log.info(c.LOGMSG_INF_SEC_NO_DB)
- Base.metadata.create_all(engine)
- log.info(c.LOGMSG_INF_SEC_ADD_DB)
- super().create_db()
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_CREATE_DB.format(str(e)))
- exit(1)
-
- def find_register_user(self, registration_hash):
- return (
- self.get_session.query(self.registeruser_model)
- .filter(self.registeruser_model.registration_hash ==
registration_hash)
- .scalar()
- )
-
- def add_register_user(self, username, first_name, last_name, email,
password="", hashed_password=""):
- """
- Add a registration request for the user.
-
- :rtype : RegisterUser
- """
- register_user = self.registeruser_model()
- register_user.username = username
- register_user.email = email
- register_user.first_name = first_name
- register_user.last_name = last_name
- if hashed_password:
- register_user.password = hashed_password
- else:
- register_user.password = generate_password_hash(password)
- register_user.registration_hash = str(uuid.uuid1())
- try:
- self.get_session.add(register_user)
- self.get_session.commit()
- return register_user
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_REGISTER_USER.format(str(e)))
- self.appbuilder.get_session.rollback()
- return None
-
- def del_register_user(self, register_user):
- """
- Deletes registration object from database.
-
- :param register_user: RegisterUser object to delete
- """
- try:
- self.get_session.delete(register_user)
- self.get_session.commit()
- return True
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_REGISTER_USER.format(str(e)))
- self.get_session.rollback()
- return False
-
- def find_user(self, username=None, email=None):
- """Finds user by username or email."""
- if username:
- try:
- if self.auth_username_ci:
- return (
- self.get_session.query(self.user_model)
- .filter(func.lower(self.user_model.username) ==
func.lower(username))
- .one_or_none()
- )
- else:
- return (
- self.get_session.query(self.user_model)
- .filter(func.lower(self.user_model.username) ==
func.lower(username))
- .one_or_none()
- )
- except MultipleResultsFound:
- log.error("Multiple results found for user %s", username)
- return None
- elif email:
- try:
- return
self.get_session.query(self.user_model).filter_by(email=email).one_or_none()
- except MultipleResultsFound:
- log.error("Multiple results found for user with email %s",
email)
- return None
-
- def get_all_users(self):
- return self.get_session.query(self.user_model).all()
-
- def add_user(
- self,
- username,
- first_name,
- last_name,
- email,
- role,
- password="",
- hashed_password="",
- ):
- """Generic function to create user."""
- try:
- user = self.user_model()
- user.first_name = first_name
- user.last_name = last_name
- user.username = username
- user.email = email
- user.active = True
- user.roles = role if isinstance(role, list) else [role]
- if hashed_password:
- user.password = hashed_password
- else:
- user.password = generate_password_hash(password)
- self.get_session.add(user)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_USER.format(username))
- return user
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_USER.format(str(e)))
- self.get_session.rollback()
- return False
-
- def count_users(self):
- return self.get_session.query(func.count(self.user_model.id)).scalar()
-
- def update_user(self, user):
- try:
- self.get_session.merge(user)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_UPD_USER.format(user))
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_UPD_USER.format(str(e)))
- self.get_session.rollback()
- return False
-
- def add_role(self, name: str) -> Role:
- role = self.find_role(name)
- if role is None:
- try:
- role = self.role_model()
- role.name = name
- self.get_session.add(role)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_ROLE.format(name))
- return role
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_ROLE.format(str(e)))
- self.get_session.rollback()
- return role
-
- def update_role(self, role_id, name: str) -> Role | None:
- role = self.get_session.get(self.role_model, role_id)
- if not role:
- return None
- try:
- role.name = name
- self.get_session.merge(role)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_UPD_ROLE.format(role))
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_UPD_ROLE.format(str(e)))
- self.get_session.rollback()
- return None
- return role
-
- def find_role(self, name):
- return
self.get_session.query(self.role_model).filter_by(name=name).one_or_none()
-
- def get_all_roles(self):
- return self.get_session.query(self.role_model).all()
-
- def get_public_role(self):
- return
self.get_session.query(self.role_model).filter_by(name=self.auth_role_public).one_or_none()
-
- def get_action(self, name: str) -> Action:
- """
- Gets an existing action record.
-
- :param name: name
- :return: Action record, if it exists
- """
- return
self.get_session.query(self.action_model).filter_by(name=name).one_or_none()
-
def permission_exists_in_one_or_more_roles(
self, resource_name: str, action_name: str, role_ids: list[int]
) -> bool:
@@ -325,254 +112,8 @@ class SecurityManager(BaseSecurityManager):
)
).all()
- def create_action(self, name):
- """
- Adds an action to the backend, model action.
-
- :param name:
- name of the action: 'can_add','can_edit' etc...
- """
- action = self.get_action(name)
- if action is None:
- try:
- action = self.action_model()
- action.name = name
- self.get_session.add(action)
- self.get_session.commit()
- return action
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_PERMISSION.format(str(e)))
- self.get_session.rollback()
- return action
-
- def delete_action(self, name: str) -> bool:
- """
- Deletes a permission action.
-
- :param name: Name of action to delete (e.g. can_read).
- :return: Whether or not delete was successful.
- """
- action = self.get_action(name)
- if not action:
- log.warning(c.LOGMSG_WAR_SEC_DEL_PERMISSION.format(name))
- return False
- try:
- perms = (
- self.get_session.query(self.permission_model)
- .filter(self.permission_model.action == action)
- .all()
- )
- if perms:
- log.warning(c.LOGMSG_WAR_SEC_DEL_PERM_PVM.format(action,
perms))
- return False
- self.get_session.delete(action)
- self.get_session.commit()
- return True
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMISSION.format(str(e)))
- self.get_session.rollback()
- return False
-
- def get_resource(self, name: str) -> Resource:
- """
- Returns a resource record by name, if it exists.
-
- :param name: Name of resource
- :return: Resource record
- """
- return
self.get_session.query(self.resource_model).filter_by(name=name).one_or_none()
-
- def get_all_resources(self) -> list[Resource]:
- """
- Gets all existing resource records.
-
- :return: List of all resources
- """
- return self.get_session.query(self.resource_model).all()
-
- def create_resource(self, name) -> Resource:
- """
- Create a resource with the given name.
-
- :param name: The name of the resource to create created.
- :return: The FAB resource created.
- """
- resource = self.get_resource(name)
- if resource is None:
- try:
- resource = self.resource_model()
- resource.name = name
- self.get_session.add(resource)
- self.get_session.commit()
- return resource
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_VIEWMENU.format(str(e)))
- self.get_session.rollback()
- return resource
-
- def delete_resource(self, name: str) -> bool:
- """
- Deletes a Resource from the backend.
-
- :param name:
- name of the resource
- """
- resource = self.get_resource(name)
- if not resource:
- log.warning(c.LOGMSG_WAR_SEC_DEL_VIEWMENU.format(name))
- return False
- try:
- perms = (
- self.get_session.query(self.permission_model)
- .filter(self.permission_model.resource == resource)
- .all()
- )
- if perms:
- log.warning(c.LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM.format(resource,
perms))
- return False
- self.get_session.delete(resource)
- self.get_session.commit()
- return True
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMISSION.format(str(e)))
- self.get_session.rollback()
- return False
-
- """
- ----------------------
- PERMISSION VIEW MENU
- ----------------------
- """
-
- def get_permission(
- self,
- action_name: str,
- resource_name: str,
- ) -> Permission | None:
- """
- Gets a permission made with the given action->resource pair, if the
permission already exists.
-
- :param action_name: Name of action
- :param resource_name: Name of resource
- :return: The existing permission
- """
- action = self.get_action(action_name)
- resource = self.get_resource(resource_name)
- if action and resource:
- return (
- self.get_session.query(self.permission_model)
- .filter_by(action=action, resource=resource)
- .one_or_none()
- )
- return None
-
- def get_resource_permissions(self, resource: Resource) -> Permission:
- """
- Retrieve permission pairs associated with a specific resource object.
-
- :param resource: Object representing a single resource.
- :return: Action objects representing resource->action pair
- """
- return
self.get_session.query(self.permission_model).filter_by(resource_id=resource.id).all()
-
- def create_permission(self, action_name, resource_name) -> Permission |
None:
- """
- Adds a permission on a resource to the backend.
-
- :param action_name:
- name of the action to add: 'can_add','can_edit' etc...
- :param resource_name:
- name of the resource to add
- """
- if not (action_name and resource_name):
- return None
- perm = self.get_permission(action_name, resource_name)
- if perm:
- return perm
- resource = self.create_resource(resource_name)
- action = self.create_action(action_name)
- perm = self.permission_model()
- perm.resource_id, perm.action_id = resource.id, action.id
- try:
- self.get_session.add(perm)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_PERMVIEW.format(str(perm)))
- return perm
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_PERMVIEW.format(str(e)))
- self.get_session.rollback()
- return None
-
- def delete_permission(self, action_name: str, resource_name: str) -> None:
- """
- Deletes the permission linking an action->resource pair.
-
- Doesn't delete the underlying action or resource.
-
- :param action_name: Name of existing action
- :param resource_name: Name of existing resource
- :return: None
- """
- if not (action_name and resource_name):
- return
- perm = self.get_permission(action_name, resource_name)
- if not perm:
- return
- roles = (
-
self.get_session.query(self.role_model).filter(self.role_model.permissions.contains(perm)).first()
- )
- if roles:
- log.warning(c.LOGMSG_WAR_SEC_DEL_PERMVIEW.format(resource_name,
action_name, roles))
- return
- try:
- # delete permission on resource
- self.get_session.delete(perm)
- self.get_session.commit()
- # if no more permission on permission view, delete permission
- if not
self.get_session.query(self.permission_model).filter_by(action=perm.action).all():
- self.delete_action(perm.action.name)
- log.info(c.LOGMSG_INF_SEC_DEL_PERMVIEW.format(action_name,
resource_name))
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMVIEW.format(str(e)))
- self.get_session.rollback()
-
def perms_include_action(self, perms, action_name):
for perm in perms:
if perm.action and perm.action.name == action_name:
return True
return False
-
- def add_permission_to_role(self, role: Role, permission: Permission |
None) -> None:
- """
- Add an existing permission pair to a role.
-
- :param role: The role about to get a new permission.
- :param permission: The permission pair to add to a role.
- :return: None
- """
- if permission and permission not in role.permissions:
- try:
- role.permissions.append(permission)
- self.get_session.merge(role)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_ADD_PERMROLE.format(str(permission),
role.name))
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_ADD_PERMROLE.format(str(e)))
- self.get_session.rollback()
-
- def remove_permission_from_role(self, role: Role, permission: Permission)
-> None:
- """
- Remove a permission pair from a role.
-
- :param role: User role containing permissions.
- :param permission: Object representing resource-> action pair
- """
- if permission in role.permissions:
- try:
- role.permissions.remove(permission)
- self.get_session.merge(role)
- self.get_session.commit()
- log.info(c.LOGMSG_INF_SEC_DEL_PERMROLE.format(str(permission),
role.name))
- except Exception as e:
- log.error(c.LOGMSG_ERR_SEC_DEL_PERMROLE.format(str(e)))
- self.get_session.rollback()
diff --git a/airflow/www/security.py b/airflow/www/security.py
index ef0411eb73..a988dccd2e 100644
--- a/airflow/www/security.py
+++ b/airflow/www/security.py
@@ -241,12 +241,6 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
view.datamodel = CustomSQLAInterface(view.datamodel.obj)
self.perms = None
- def create_db(self) -> None:
- if not self.appbuilder.update_perms:
- self.log.debug("Skipping db since appbuilder disables
update_perms")
- return
- super().create_db()
-
def _get_root_dag_id(self, dag_id: str) -> str:
if "." in dag_id:
dm = (
@@ -290,21 +284,6 @@ class AirflowSecurityManager(SecurityManagerOverride,
SecurityManager, LoggingMi
if perm not in role.permissions:
self.add_permission_to_role(role, perm)
- def delete_role(self, role_name: str) -> None:
- """
- Delete the given Role.
-
- :param role_name: the name of a role in the ab_role table
- """
- session = self.appbuilder.get_session
- role = session.query(Role).filter(Role.name == role_name).first()
- if role:
- self.log.info("Deleting role '%s'", role_name)
- session.delete(role)
- session.commit()
- else:
- raise AirflowException(f"Role named '{role_name}' does not exist")
-
@staticmethod
def get_user_roles(user=None):
"""
diff --git a/tests/auth/managers/fab/security_manager/test_override.py
b/tests/auth/managers/fab/security_manager/test_override.py
index 6b6965b9a6..9f931e6a2d 100644
--- a/tests/auth/managers/fab/security_manager/test_override.py
+++ b/tests/auth/managers/fab/security_manager/test_override.py
@@ -17,10 +17,11 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock
import pytest
+from airflow.auth.managers.fab.models import User
from airflow.auth.managers.fab.security_manager.override import
FabAirflowSecurityManagerOverride
appbuilder = Mock()
@@ -39,7 +40,7 @@ registeruseroidview = Mock()
resetmypasswordview = Mock()
resetpasswordview = Mock()
rolemodelview = Mock()
-user_model = Mock()
+user_model = User
userinfoeditview = Mock()
userdbmodelview = Mock()
userldapmodelview = Mock()
@@ -88,12 +89,16 @@ def security_manager_override():
"airflow.auth.managers.fab.security_manager.override.LoginManager"
) as mock_login_manager, mock.patch(
"airflow.auth.managers.fab.security_manager.override.JWTManager"
- ) as mock_jwt_manager:
+ ) as mock_jwt_manager, mock.patch.object(
+ FabAirflowSecurityManagerOverride, "create_db"
+ ):
mock_login_manager_instance = Mock()
mock_login_manager.return_value = mock_login_manager_instance
mock_jwt_manager_instance = Mock()
mock_jwt_manager.return_value = mock_jwt_manager_instance
+ appbuilder.app.config = MagicMock()
+
security_manager_override = EmptySecurityManager(appbuilder)
mock_login_manager.assert_called_once_with(appbuilder.app)