This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a6064f3a3c7 Handle the next URL after logging in in the simple auth
manager (#44856)
a6064f3a3c7 is described below
commit a6064f3a3c74cd1c1de7a044672d60b84da4252b
Author: Vincent <[email protected]>
AuthorDate: Thu Dec 12 09:59:57 2024 -0500
Handle the next URL after logging in in the simple auth manager (#44856)
---
.../auth/managers/simple/simple_auth_manager.py | 3 +-
airflow/auth/managers/simple/views/auth.py | 39 ++++++++++++++++++++--
airflow/www/templates/airflow/login.html | 2 +-
.../managers/simple/test_simple_auth_manager.py | 2 +-
tests/auth/managers/simple/views/test_auth.py | 21 +++++++++---
5 files changed, 56 insertions(+), 11 deletions(-)
diff --git a/airflow/auth/managers/simple/simple_auth_manager.py
b/airflow/auth/managers/simple/simple_auth_manager.py
index 48baa02e7c7..d63aa480c9d 100644
--- a/airflow/auth/managers/simple/simple_auth_manager.py
+++ b/airflow/auth/managers/simple/simple_auth_manager.py
@@ -122,7 +122,8 @@ class
SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]):
)
def get_url_login(self, **kwargs) -> str:
- return url_for("SimpleAuthManagerAuthenticationViews.login")
+ """Return the login page url."""
+ return url_for("SimpleAuthManagerAuthenticationViews.login",
next=kwargs.get("next_url"))
def get_url_logout(self) -> str:
return url_for("SimpleAuthManagerAuthenticationViews.logout")
diff --git a/airflow/auth/managers/simple/views/auth.py
b/airflow/auth/managers/simple/views/auth.py
index 6e4cf0c3994..bd06661d833 100644
--- a/airflow/auth/managers/simple/views/auth.py
+++ b/airflow/auth/managers/simple/views/auth.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import logging
+from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
from flask import redirect, request, session, url_for
from flask_appbuilder import expose
@@ -54,7 +55,9 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
return self.render_template(
"airflow/login.html",
disable_nav_bar=True,
-
login_submit_url=url_for("SimpleAuthManagerAuthenticationViews.login_submit"),
+ login_submit_url=url_for(
+ "SimpleAuthManagerAuthenticationViews.login_submit",
next=request.args.get("next")
+ ),
auto_refresh_interval=conf.getint("webserver",
"auto_refresh_interval"),
state_color_mapping=state_color_mapping,
standalone_dag_processor=standalone_dag_processor,
@@ -72,6 +75,7 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
"""Redirect the user to this callback after login attempt."""
username = request.form.get("username")
password = request.form.get("password")
+ next_url = request.args.get("next")
found_users = [
user
@@ -80,7 +84,7 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
]
if not username or not password or len(found_users) == 0:
- return
redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"]))
+ return
redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"],
next=next_url))
user = SimpleAuthManagerUser(
username=username,
@@ -96,4 +100,33 @@ class SimpleAuthManagerAuthenticationViews(AirflowBaseView):
)
token =
signer.generate_signed_token(get_auth_manager().serialize_user(user))
- return redirect(url_for("Airflow.index", token=token))
+ if next_url:
+ return redirect(self._get_redirect_url(next_url, token))
+ else:
+ return redirect(url_for("Airflow.index", token=token))
+
+ def _get_redirect_url(self, next_url: str, token: str) -> str:
+ if self._is_same_domain(next_url, request.url):
+ return self._add_query_params(next_url, {"token": token})
+ else:
+ return url_for("Airflow.index", token=token)
+
+ @staticmethod
+ def _is_same_domain(next_url: str, current_url: str) -> bool:
+ next_url_infos = urlsplit(next_url)
+ current_url_infos = urlsplit(current_url)
+ return (
+ current_url_infos.netloc.startswith("localhost:")
+ or (not next_url_infos.scheme or next_url_infos.scheme ==
current_url_infos.scheme)
+ and (not next_url_infos.netloc or next_url_infos.netloc ==
current_url_infos.netloc)
+ )
+
+ @staticmethod
+ def _add_query_params(url: str, params: dict) -> str:
+ url_infos = urlsplit(url)
+ existing_query = dict(parse_qsl(url_infos.query))
+ existing_query.update(params)
+ updated_query = urlencode(existing_query, doseq=True)
+ return urlunsplit(
+ (url_infos.scheme, url_infos.netloc, url_infos.path,
updated_query, url_infos.fragment)
+ )
diff --git a/airflow/www/templates/airflow/login.html
b/airflow/www/templates/airflow/login.html
index 5a25fb3b5f2..afeac1104a7 100644
--- a/airflow/www/templates/airflow/login.html
+++ b/airflow/www/templates/airflow/login.html
@@ -21,7 +21,7 @@
{% block head_meta %}
{{ super() }}
- <meta name="login_submit_url" content="{{
url_for('SimpleAuthManagerAuthenticationViews.login_submit') }}">
+ <meta name="login_submit_url" content="{{ login_submit_url }}">
{% endblock %}
{% block messages %}
diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py
b/tests/auth/managers/simple/test_simple_auth_manager.py
index 07289f6f002..0cc553ad420 100644
--- a/tests/auth/managers/simple/test_simple_auth_manager.py
+++ b/tests/auth/managers/simple/test_simple_auth_manager.py
@@ -95,7 +95,7 @@ class TestSimpleAuthManager:
@patch("airflow.auth.managers.simple.simple_auth_manager.url_for")
def test_get_url_login(self, mock_url_for, auth_manager):
auth_manager.get_url_login()
-
mock_url_for.assert_called_once_with("SimpleAuthManagerAuthenticationViews.login")
+
mock_url_for.assert_called_once_with("SimpleAuthManagerAuthenticationViews.login",
next=None)
@patch("airflow.auth.managers.simple.simple_auth_manager.url_for")
def test_get_url_logout(self, mock_url_for, auth_manager):
diff --git a/tests/auth/managers/simple/views/test_auth.py
b/tests/auth/managers/simple/views/test_auth.py
index e3e7b29e2cc..86a8be4f444 100644
--- a/tests/auth/managers/simple/views/test_auth.py
+++ b/tests/auth/managers/simple/views/test_auth.py
@@ -65,18 +65,29 @@ class TestSimpleAuthManagerAuthenticationViews:
assert session.get("user") is None
@pytest.mark.parametrize(
- "username, password, is_successful",
- [("test", "test", True), ("test", "test2", False), ("", "", False)],
+ "username, password, is_successful, query_params, expected_redirect",
+ [
+ ("test", "test", True, {}, None),
+ ("test", "test2", False, {}, None),
+ ("", "", False, {}, None),
+ ("test", "test", True, {"next": "next_url"},
"next_url?token=token"),
+ ],
)
@patch("airflow.auth.managers.simple.views.auth.JWTSigner")
- def test_login_submit(self, mock_jwt_signer, simple_app, username,
password, is_successful):
+ def test_login_submit(
+ self, mock_jwt_signer, simple_app, username, password, is_successful,
query_params, expected_redirect
+ ):
signer = Mock()
signer.generate_signed_token.return_value = "token"
mock_jwt_signer.return_value = signer
with simple_app.test_client() as client:
- response = client.post("/login_submit", data={"username":
username, "password": password})
+ response = client.post(
+ "/login_submit", query_string=query_params, data={"username":
username, "password": password}
+ )
assert response.status_code == 302
if is_successful:
- assert response.location == url_for("Airflow.index",
token="token")
+ if not expected_redirect:
+ expected_redirect = url_for("Airflow.index", token="token")
+ assert response.location == expected_redirect
else:
assert response.location ==
url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"])