This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit dbdec0bf7a8c1072051a546677701bbeb3a81347
Author: andylamp <[email protected]>
AuthorDate: Wed Jan 17 00:17:45 2024 +0200

    Sanitize the conn_id to disallow potential script execution (#32867)
    
    (cherry picked from commit 71f422c7bc217129ac8614f14e6aeb586c6c88da)
---
 airflow/models/connection.py    | 37 ++++++++++++++++++++++--
 airflow/www/forms.py            |  4 +--
 airflow/www/validators.py       | 28 +++++++++++++++++-
 tests/models/test_connection.py | 64 +++++++++++++++++++++++++++++++++++++++++
 4 files changed, 128 insertions(+), 5 deletions(-)

diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 4e8e3c7aaf..0af3f13768 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -24,6 +24,7 @@ from json import JSONDecodeError
 from typing import Any
 from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
 
+import re2
 from sqlalchemy import Boolean, Column, Integer, String, Text
 from sqlalchemy.orm import declared_attr, reconstructor, synonym
 
@@ -38,6 +39,13 @@ from airflow.utils.log.secrets_masker import mask_secret
 from airflow.utils.module_loading import import_string
 
 log = logging.getLogger(__name__)
+# sanitize the `conn_id` pattern by allowing alphanumeric characters plus
+# the symbols #,!,-,_,.,:,\,/ and () requiring at least one match.
+#
+# You can try the regex here: https://regex101.com/r/69033B/1
+RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$")
+# the conn ID max len should be 250
+CONN_ID_MAX_LEN: int = 250
 
 
 def parse_netloc_to_hostname(*args, **kwargs):
@@ -46,10 +54,35 @@ def parse_netloc_to_hostname(*args, **kwargs):
     return _parse_netloc_to_hostname(*args, **kwargs)
 
 
+def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str | 
None:
+    r"""Sanitizes the connection id and allows only specific characters to be 
within.
+
+    Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/ 
and () from 1 and up to
+    250 consecutive matches. If desired, the max length can be adjusted by 
setting `max_length`.
+
+    You can try to play with the regex here: https://regex101.com/r/69033B/1
+
+    The character selection is such that it prevents the injection of 
javascript or
+    executable bits to avoid any awkward behaviour in the front-end.
+
+    :param conn_id: The connection id to sanitize.
+    :param max_length: The max length of the connection ID, by default it is 
250.
+    :return: the sanitized string, `None` otherwise.
+    """
+    # check if `conn_id` or our match group is `None` and the `conn_id` is 
within the specified length.
+    if (not isinstance(conn_id, str) or len(conn_id) > max_length) or (
+        res := re2.match(RE_SANITIZE_CONN_ID, conn_id)
+    ) is None:
+        return None
+
+    # if we reach here, then we matched something, return the first match
+    return res.group(0)
+
+
 # Python automatically converts all letters to lowercase in hostname
 # See: https://issues.apache.org/jira/browse/AIRFLOW-3615
 def _parse_netloc_to_hostname(uri_parts):
-    """Parse a URI string to get correct Hostname."""
+    """Parse a URI string to get the correct Hostname."""
     hostname = unquote(uri_parts.hostname or "")
     if "/" in hostname:
         hostname = uri_parts.netloc
@@ -115,7 +148,7 @@ class Connection(Base, LoggingMixin):
         uri: str | None = None,
     ):
         super().__init__()
-        self.conn_id = conn_id
+        self.conn_id = sanitize_conn_id(conn_id)
         self.description = description
         if extra and not isinstance(extra, str):
             extra = json.dumps(extra)
diff --git a/airflow/www/forms.py b/airflow/www/forms.py
index 8a8f69cf44..aa5d3a6249 100644
--- a/airflow/www/forms.py
+++ b/airflow/www/forms.py
@@ -41,7 +41,7 @@ from airflow.configuration import conf
 from airflow.providers_manager import ProvidersManager
 from airflow.utils import timezone
 from airflow.utils.types import DagRunType
-from airflow.www.validators import ReadOnly, ValidKey
+from airflow.www.validators import ReadOnly, ValidConnID
 from airflow.www.widgets import (
     AirflowDateTimePickerROWidget,
     AirflowDateTimePickerWidget,
@@ -221,7 +221,7 @@ def create_connection_form_class() -> type[DynamicForm]:
 
         conn_id = StringField(
             lazy_gettext("Connection Id"),
-            validators=[InputRequired(), ValidKey()],
+            validators=[InputRequired(), ValidConnID()],
             widget=BS3TextFieldWidget(),
         )
         conn_type = SelectField(
diff --git a/airflow/www/validators.py b/airflow/www/validators.py
index ce273308df..8deacb8763 100644
--- a/airflow/www/validators.py
+++ b/airflow/www/validators.py
@@ -22,6 +22,7 @@ from json import JSONDecodeError
 
 from wtforms.validators import EqualTo, ValidationError
 
+from airflow.models.connection import CONN_ID_MAX_LEN, sanitize_conn_id
 from airflow.utils import helpers
 
 
@@ -85,7 +86,7 @@ class ValidKey:
     Validates values that will be used as keys.
 
     :param max_length:
-        The maximum length of the given key
+        The maximum allowed length of the given key
     """
 
     def __init__(self, max_length=200):
@@ -108,3 +109,28 @@ class ReadOnly:
 
     def __call__(self, form, field):
         field.flags.readonly = True
+
+
+class ValidConnID:
+    """
+    Validates the connection ID adheres to the desired format.
+
+    :param max_length:
+        The maximum allowed length of the given Connection ID.
+    """
+
+    message = (
+        "Connection ID must be alphanumeric characters plus dashes, dots, 
hashes, colons, semicolons, "
+        "underscores, exclamation marks, and parentheses"
+    )
+
+    def __init__(
+        self,
+        max_length: int = CONN_ID_MAX_LEN,
+    ):
+        self.max_length = max_length
+
+    def __call__(self, form, field):
+        if field.data:
+            if sanitize_conn_id(field.data, self.max_length) is None:
+                raise ValidationError(f"{self.message} for 1 and up to 
{self.max_length} matches")
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
index cac38c1451..21e5682c8d 100644
--- a/tests/models/test_connection.py
+++ b/tests/models/test_connection.py
@@ -186,3 +186,67 @@ class TestConnection:
     )
     def test_get_uri(self, connection, expected_uri):
         assert connection.get_uri() == expected_uri
+
+    @pytest.mark.parametrize(
+        "connection, expected_conn_id",
+        [
+            # a valid example of connection id
+            (
+                Connection(
+                    conn_id="12312312312213___12312321",
+                    conn_type="type",
+                    login="user",
+                    password="pass",
+                    host="host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                "12312312312213___12312321",
+            ),
+            # an invalid example of connection id, which allows potential code 
execution
+            (
+                Connection(
+                    conn_id="<script>alert(1)</script>",
+                    conn_type="type",
+                    host="protocol://host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                None,
+            ),
+            # a valid connection as well
+            (
+                Connection(
+                    conn_id="a_valid_conn_id_!!##",
+                    conn_type="type",
+                    login="user",
+                    password="pass",
+                    host="protocol://host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                "a_valid_conn_id_!!##",
+            ),
+            # a valid connection as well testing dashes
+            (
+                Connection(
+                    conn_id="a_-.11",
+                    conn_type="type",
+                    login="user",
+                    password="pass",
+                    host="protocol://host",
+                    port=100,
+                    schema="schema",
+                    extra={"param1": "val1", "param2": "val2"},
+                ),
+                "a_-.11",
+            ),
+        ],
+    )
+    # Responsible for ensuring that the sanitized connection id
+    # string works as expected.
+    def test_sanitize_conn_id(self, connection, expected_conn_id):
+        assert connection.conn_id == expected_conn_id

Reply via email to