This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fe4a6c843a Add a check for trailing slash in webserver base_url
(#31833)
fe4a6c843a is described below
commit fe4a6c843acd97c776d5890116bfa85356a54eee
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Jun 19 09:29:11 2023 +0200
Add a check for trailing slash in webserver base_url (#31833)
* Remove right trailing / from webserver base_url
Signed-off-by: Hussein Awala <[email protected]>
* use url join instead of removing trailing slash
Signed-off-by: Hussein Awala <[email protected]>
* raise an exception when base_url contains a trailing slash
Signed-off-by: Hussein Awala <[email protected]>
* Update airflow/www/extensions/init_wsgi_middlewares.py
Co-authored-by: Tzu-ping Chung <[email protected]>
---------
Signed-off-by: Hussein Awala <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
---
airflow/models/taskinstance.py | 18 ++++-----
airflow/www/extensions/init_wsgi_middlewares.py | 6 ++-
tests/www/test_app.py | 51 +++++++++++++++++++------
3 files changed, 54 insertions(+), 21 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 28dc168ec3..f429e27e15 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -33,7 +33,7 @@ from functools import partial
from pathlib import PurePath
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator,
Iterable, Tuple
-from urllib.parse import quote
+from urllib.parse import quote, urljoin
import dill
import jinja2
@@ -759,26 +759,26 @@ class TaskInstance(Base, LoggingMixin):
"""Log URL for TaskInstance."""
iso = quote(self.execution_date.isoformat())
base_url = conf.get_mandatory_value("webserver", "BASE_URL")
- return (
- f"{base_url}/log"
- f"?execution_date={iso}"
+ return urljoin(
+ base_url,
+ f"log?execution_date={iso}"
f"&task_id={self.task_id}"
f"&dag_id={self.dag_id}"
- f"&map_index={self.map_index}"
+ f"&map_index={self.map_index}",
)
@property
def mark_success_url(self) -> str:
"""URL to mark TI success."""
base_url = conf.get_mandatory_value("webserver", "BASE_URL")
- return (
- f"{base_url}/confirm"
- f"?task_id={self.task_id}"
+ return urljoin(
+ base_url,
+ f"confirm?task_id={self.task_id}"
f"&dag_id={self.dag_id}"
f"&dag_run_id={quote(self.run_id)}"
"&upstream=false"
"&downstream=false"
- "&state=success"
+ "&state=success",
)
@provide_session
diff --git a/airflow/www/extensions/init_wsgi_middlewares.py
b/airflow/www/extensions/init_wsgi_middlewares.py
index 3ea47e92c8..37d4a074b4 100644
--- a/airflow/www/extensions/init_wsgi_middlewares.py
+++ b/airflow/www/extensions/init_wsgi_middlewares.py
@@ -25,6 +25,7 @@ from werkzeug.middleware.dispatcher import
DispatcherMiddleware
from werkzeug.middleware.proxy_fix import ProxyFix
from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException
if TYPE_CHECKING:
from _typeshed.wsgi import StartResponse, WSGIEnvironment
@@ -37,8 +38,11 @@ def _root_app(env: WSGIEnvironment, resp: StartResponse) ->
Iterable[bytes]:
def init_wsgi_middleware(flask_app: Flask) -> None:
"""Handle X-Forwarded-* headers and base_url support."""
+ webserver_base_url = conf.get_mandatory_value("webserver", "BASE_URL",
fallback="")
+ if webserver_base_url.endswith("/"):
+ raise AirflowConfigException("webserver.base_url conf cannot have a
trailing slash.")
# Apply DispatcherMiddleware
- base_url = urlsplit(conf.get("webserver", "base_url"))[2]
+ base_url = urlsplit(webserver_base_url)[2]
if not base_url or base_url == "/":
base_url = ""
if base_url:
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index 61bfae3e12..8dd3b57b2e 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -18,6 +18,7 @@
from __future__ import annotations
import hashlib
+import re
import runpy
import sys
from datetime import timedelta
@@ -86,15 +87,26 @@ class TestApp:
assert b"success" == response.get_data()
assert response.status_code == 200
- @conf_vars(
- {
- ("webserver", "base_url"): "http://localhost:8080/internal-client",
- }
+ @pytest.mark.parametrize(
+ "base_url, expected_exception",
+ [
+ ("http://localhost:8080/internal-client", None),
+ (
+ "http://localhost:8080/internal-client/",
+ AirflowConfigException("webserver.base_url conf cannot have a
trailing slash."),
+ ),
+ ],
)
@dont_initialize_flask_app_submodules
- def test_should_respect_base_url_ignore_proxy_headers(self):
- app = application.cached_app(testing=True)
- app.url_map.add(Rule("/debug", endpoint="debug"))
+ def test_should_respect_base_url_ignore_proxy_headers(self, base_url,
expected_exception):
+ with conf_vars({("webserver", "base_url"): base_url}):
+ if expected_exception:
+ with pytest.raises(expected_exception.__class__,
match=re.escape(str(expected_exception))):
+ app = application.cached_app(testing=True)
+ app.url_map.add(Rule("/debug", endpoint="debug"))
+ return
+ app = application.cached_app(testing=True)
+ app.url_map.add(Rule("/debug", endpoint="debug"))
def debug_view():
from flask import request
@@ -126,9 +138,18 @@ class TestApp:
assert b"success" == response.get_data()
assert response.status_code == 200
+ @pytest.mark.parametrize(
+ "base_url, expected_exception",
+ [
+ ("http://localhost:8080/internal-client", None),
+ (
+ "http://localhost:8080/internal-client/",
+ AirflowConfigException("webserver.base_url conf cannot have a
trailing slash."),
+ ),
+ ],
+ )
@conf_vars(
{
- ("webserver", "base_url"): "http://localhost:8080/internal-client",
("webserver", "enable_proxy_fix"): "True",
("webserver", "proxy_fix_x_for"): "1",
("webserver", "proxy_fix_x_proto"): "1",
@@ -138,9 +159,17 @@ class TestApp:
}
)
@dont_initialize_flask_app_submodules
- def
test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self):
- app = application.cached_app(testing=True)
- app.url_map.add(Rule("/debug", endpoint="debug"))
+ def
test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(
+ self, base_url, expected_exception
+ ):
+ with conf_vars({("webserver", "base_url"): base_url}):
+ if expected_exception:
+ with pytest.raises(expected_exception.__class__,
match=re.escape(str(expected_exception))):
+ app = application.cached_app(testing=True)
+ app.url_map.add(Rule("/debug", endpoint="debug"))
+ return
+ app = application.cached_app(testing=True)
+ app.url_map.add(Rule("/debug", endpoint="debug"))
def debug_view():
from flask import request