This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-4-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 85895b567f5f70ffc497d84570223c5fb80f7de4 Author: Jed Cunningham <[email protected]> AuthorDate: Fri Sep 23 13:28:33 2022 -0700 Check user is active (#26635) (cherry picked from commit 59707cdf7eacb698ca375b5220af30a39ca1018c) --- airflow/www/app.py | 7 ++++++- airflow/www/extensions/init_security.py | 11 +++++++++++ tests/test_utils/decorators.py | 1 + tests/www/views/conftest.py | 1 + tests/www/views/test_session.py | 14 ++++++++++++++ tests/www/views/test_views_base.py | 13 +++++++++++-- 6 files changed, 44 insertions(+), 3 deletions(-) diff --git a/airflow/www/app.py b/airflow/www/app.py index b67314c99a..d0c38b2936 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -39,7 +39,11 @@ from airflow.www.extensions.init_dagbag import init_dagbag from airflow.www.extensions.init_jinja_globals import init_jinja_globals from airflow.www.extensions.init_manifest_files import configure_manifest_files from airflow.www.extensions.init_robots import init_robots -from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection +from airflow.www.extensions.init_security import ( + init_api_experimental_auth, + init_check_user_active, + init_xframe_protection, +) from airflow.www.extensions.init_session import init_airflow_session_interface from airflow.www.extensions.init_views import ( init_api_connexion, @@ -152,6 +156,7 @@ def create_app(config=None, testing=False): init_jinja_globals(flask_app) init_xframe_protection(flask_app) init_airflow_session_interface(flask_app) + init_check_user_active(flask_app) return flask_app diff --git a/airflow/www/extensions/init_security.py b/airflow/www/extensions/init_security.py index 1d96e351df..b967b74084 100644 --- a/airflow/www/extensions/init_security.py +++ b/airflow/www/extensions/init_security.py @@ -19,6 +19,9 @@ from __future__ import annotations import logging from importlib import import_module +from flask import g, redirect, url_for +from flask_login import logout_user + from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException @@ -60,3 +63,11 @@ def init_api_experimental_auth(app): except ImportError as err: log.critical("Cannot import %s for API authentication due to: %s", backend, err) raise AirflowException(err) + + +def init_check_user_active(app): + @app.before_request + def check_user_active(): + if g.user is not None and not g.user.is_anonymous and not g.user.is_active: + logout_user() + return redirect(url_for(app.appbuilder.sm.auth_view.endpoint + ".login")) diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py index bdb8d67807..d0b71b502c 100644 --- a/tests/test_utils/decorators.py +++ b/tests/test_utils/decorators.py @@ -45,6 +45,7 @@ def dont_initialize_flask_app_submodules(_func=None, *, skip_all_except=None): "init_xframe_protection", "init_airflow_session_interface", "init_appbuilder", + "init_check_user_active", ] @functools.wraps(f) diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 02c857180f..ad562385bc 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -58,6 +58,7 @@ def app(examples_dag_bag): "init_jinja_globals", "init_plugins", "init_airflow_session_interface", + "init_check_user_active", ] ) def factory(): diff --git a/tests/www/views/test_session.py b/tests/www/views/test_session.py index 090bc503a8..3802399264 100644 --- a/tests/www/views/test_session.py +++ b/tests/www/views/test_session.py @@ -88,3 +88,17 @@ def test_session_id_rotates(app, user_client): new_session_cookie = get_session_cookie(user_client) assert new_session_cookie is not None assert old_session_cookie.value != new_session_cookie.value + + +def test_check_active_user(app, user_client): + user = app.appbuilder.sm.find_user(username="test_user") + user.active = False + resp = user_client.get("/home") + assert resp.status_code == 302 + assert "/login" in resp.headers.get("Location") + + # And they were logged out + user.active = True + resp = user_client.get("/home") + assert resp.status_code == 302 + assert "/login" in resp.headers.get("Location") diff --git a/tests/www/views/test_views_base.py b/tests/www/views/test_views_base.py index d0acc4df27..9c9c4f0aba 100644 --- a/tests/www/views/test_views_base.py +++ b/tests/www/views/test_views_base.py @@ -30,9 +30,18 @@ from tests.test_utils.config import conf_vars from tests.test_utils.www import check_content_in_response, check_content_not_in_response -def test_index(admin_client): +def test_index_redirect(admin_client): + resp = admin_client.get('/') + assert resp.status_code == 302 + assert '/home' in resp.headers.get("Location") + + resp = admin_client.get('/', follow_redirects=True) + check_content_in_response('DAGs', resp) + + +def test_homepage_query_count(admin_client): with assert_queries_count(16): - resp = admin_client.get('/', follow_redirects=True) + resp = admin_client.get('/home') check_content_in_response('DAGs', resp)
