This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fe4a6c843a Add a check for trailing slash in webserver base_url 
(#31833)
fe4a6c843a is described below

commit fe4a6c843acd97c776d5890116bfa85356a54eee
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Jun 19 09:29:11 2023 +0200

    Add a check for trailing slash in webserver base_url (#31833)
    
    * Remove right trailing / from webserver base_url
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * use url join instead of removing trailing slash
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * raise an exception when base_url contains a trailing slash
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * Update airflow/www/extensions/init_wsgi_middlewares.py
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    ---------
    
    Signed-off-by: Hussein Awala <[email protected]>
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 airflow/models/taskinstance.py                  | 18 ++++-----
 airflow/www/extensions/init_wsgi_middlewares.py |  6 ++-
 tests/www/test_app.py                           | 51 +++++++++++++++++++------
 3 files changed, 54 insertions(+), 21 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 28dc168ec3..f429e27e15 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -33,7 +33,7 @@ from functools import partial
 from pathlib import PurePath
 from types import TracebackType
 from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, 
Iterable, Tuple
-from urllib.parse import quote
+from urllib.parse import quote, urljoin
 
 import dill
 import jinja2
@@ -759,26 +759,26 @@ class TaskInstance(Base, LoggingMixin):
         """Log URL for TaskInstance."""
         iso = quote(self.execution_date.isoformat())
         base_url = conf.get_mandatory_value("webserver", "BASE_URL")
-        return (
-            f"{base_url}/log"
-            f"?execution_date={iso}"
+        return urljoin(
+            base_url,
+            f"log?execution_date={iso}"
             f"&task_id={self.task_id}"
             f"&dag_id={self.dag_id}"
-            f"&map_index={self.map_index}"
+            f"&map_index={self.map_index}",
         )
 
     @property
     def mark_success_url(self) -> str:
         """URL to mark TI success."""
         base_url = conf.get_mandatory_value("webserver", "BASE_URL")
-        return (
-            f"{base_url}/confirm"
-            f"?task_id={self.task_id}"
+        return urljoin(
+            base_url,
+            f"confirm?task_id={self.task_id}"
             f"&dag_id={self.dag_id}"
             f"&dag_run_id={quote(self.run_id)}"
             "&upstream=false"
             "&downstream=false"
-            "&state=success"
+            "&state=success",
         )
 
     @provide_session
diff --git a/airflow/www/extensions/init_wsgi_middlewares.py 
b/airflow/www/extensions/init_wsgi_middlewares.py
index 3ea47e92c8..37d4a074b4 100644
--- a/airflow/www/extensions/init_wsgi_middlewares.py
+++ b/airflow/www/extensions/init_wsgi_middlewares.py
@@ -25,6 +25,7 @@ from werkzeug.middleware.dispatcher import 
DispatcherMiddleware
 from werkzeug.middleware.proxy_fix import ProxyFix
 
 from airflow.configuration import conf
+from airflow.exceptions import AirflowConfigException
 
 if TYPE_CHECKING:
     from _typeshed.wsgi import StartResponse, WSGIEnvironment
@@ -37,8 +38,11 @@ def _root_app(env: WSGIEnvironment, resp: StartResponse) -> 
Iterable[bytes]:
 
 def init_wsgi_middleware(flask_app: Flask) -> None:
     """Handle X-Forwarded-* headers and base_url support."""
+    webserver_base_url = conf.get_mandatory_value("webserver", "BASE_URL", 
fallback="")
+    if webserver_base_url.endswith("/"):
+        raise AirflowConfigException("webserver.base_url conf cannot have a 
trailing slash.")
     # Apply DispatcherMiddleware
-    base_url = urlsplit(conf.get("webserver", "base_url"))[2]
+    base_url = urlsplit(webserver_base_url)[2]
     if not base_url or base_url == "/":
         base_url = ""
     if base_url:
diff --git a/tests/www/test_app.py b/tests/www/test_app.py
index 61bfae3e12..8dd3b57b2e 100644
--- a/tests/www/test_app.py
+++ b/tests/www/test_app.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import hashlib
+import re
 import runpy
 import sys
 from datetime import timedelta
@@ -86,15 +87,26 @@ class TestApp:
         assert b"success" == response.get_data()
         assert response.status_code == 200
 
-    @conf_vars(
-        {
-            ("webserver", "base_url"): "http://localhost:8080/internal-client";,
-        }
+    @pytest.mark.parametrize(
+        "base_url, expected_exception",
+        [
+            ("http://localhost:8080/internal-client";, None),
+            (
+                "http://localhost:8080/internal-client/";,
+                AirflowConfigException("webserver.base_url conf cannot have a 
trailing slash."),
+            ),
+        ],
     )
     @dont_initialize_flask_app_submodules
-    def test_should_respect_base_url_ignore_proxy_headers(self):
-        app = application.cached_app(testing=True)
-        app.url_map.add(Rule("/debug", endpoint="debug"))
+    def test_should_respect_base_url_ignore_proxy_headers(self, base_url, 
expected_exception):
+        with conf_vars({("webserver", "base_url"): base_url}):
+            if expected_exception:
+                with pytest.raises(expected_exception.__class__, 
match=re.escape(str(expected_exception))):
+                    app = application.cached_app(testing=True)
+                    app.url_map.add(Rule("/debug", endpoint="debug"))
+                return
+            app = application.cached_app(testing=True)
+            app.url_map.add(Rule("/debug", endpoint="debug"))
 
         def debug_view():
             from flask import request
@@ -126,9 +138,18 @@ class TestApp:
         assert b"success" == response.get_data()
         assert response.status_code == 200
 
+    @pytest.mark.parametrize(
+        "base_url, expected_exception",
+        [
+            ("http://localhost:8080/internal-client";, None),
+            (
+                "http://localhost:8080/internal-client/";,
+                AirflowConfigException("webserver.base_url conf cannot have a 
trailing slash."),
+            ),
+        ],
+    )
     @conf_vars(
         {
-            ("webserver", "base_url"): "http://localhost:8080/internal-client";,
             ("webserver", "enable_proxy_fix"): "True",
             ("webserver", "proxy_fix_x_for"): "1",
             ("webserver", "proxy_fix_x_proto"): "1",
@@ -138,9 +159,17 @@ class TestApp:
         }
     )
     @dont_initialize_flask_app_submodules
-    def 
test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(self):
-        app = application.cached_app(testing=True)
-        app.url_map.add(Rule("/debug", endpoint="debug"))
+    def 
test_should_respect_base_url_when_proxy_fix_and_base_url_is_set_up_but_headers_missing(
+        self, base_url, expected_exception
+    ):
+        with conf_vars({("webserver", "base_url"): base_url}):
+            if expected_exception:
+                with pytest.raises(expected_exception.__class__, 
match=re.escape(str(expected_exception))):
+                    app = application.cached_app(testing=True)
+                    app.url_map.add(Rule("/debug", endpoint="debug"))
+                return
+            app = application.cached_app(testing=True)
+            app.url_map.add(Rule("/debug", endpoint="debug"))
 
         def debug_view():
             from flask import request

Reply via email to