This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 71f422c7bc Sanitize the conn_id to disallow potential script execution
(#32867)
71f422c7bc is described below
commit 71f422c7bc217129ac8614f14e6aeb586c6c88da
Author: andylamp <[email protected]>
AuthorDate: Wed Jan 17 00:17:45 2024 +0200
Sanitize the conn_id to disallow potential script execution (#32867)
---
airflow/models/connection.py | 37 ++++++++++++++++++++++--
airflow/www/forms.py | 4 +--
airflow/www/validators.py | 28 +++++++++++++++++-
tests/models/test_connection.py | 64 +++++++++++++++++++++++++++++++++++++++++
4 files changed, 128 insertions(+), 5 deletions(-)
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 521a5d880e..6e43751b9d 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -24,6 +24,7 @@ from json import JSONDecodeError
from typing import Any
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
+import re2
from sqlalchemy import Boolean, Column, Integer, String, Text
from sqlalchemy.orm import declared_attr, reconstructor, synonym
@@ -38,6 +39,13 @@ from airflow.utils.log.secrets_masker import mask_secret
from airflow.utils.module_loading import import_string
log = logging.getLogger(__name__)
+# sanitize the `conn_id` pattern by allowing alphanumeric characters plus
+# the symbols #,!,-,_,.,:,\,/ and () requiring at least one match.
+#
+# You can try the regex here: https://regex101.com/r/69033B/1
+RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$")
+# the conn ID max len should be 250
+CONN_ID_MAX_LEN: int = 250
def parse_netloc_to_hostname(*args, **kwargs):
@@ -46,10 +54,35 @@ def parse_netloc_to_hostname(*args, **kwargs):
return _parse_netloc_to_hostname(*args, **kwargs)
+def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str |
None:
+ r"""Sanitizes the connection id and allows only specific characters to be
within.
+
+ Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/
and () from 1 and up to
+ 250 consecutive matches. If desired, the max length can be adjusted by
setting `max_length`.
+
+ You can try to play with the regex here: https://regex101.com/r/69033B/1
+
+ The character selection is such that it prevents the injection of
javascript or
+ executable bits to avoid any awkward behaviour in the front-end.
+
+ :param conn_id: The connection id to sanitize.
+ :param max_length: The max length of the connection ID, by default it is
250.
+ :return: the sanitized string, `None` otherwise.
+ """
+ # check if `conn_id` or our match group is `None` and the `conn_id` is
within the specified length.
+ if (not isinstance(conn_id, str) or len(conn_id) > max_length) or (
+ res := re2.match(RE_SANITIZE_CONN_ID, conn_id)
+ ) is None:
+ return None
+
+ # if we reach here, then we matched something, return the first match
+ return res.group(0)
+
+
# Python automatically converts all letters to lowercase in hostname
# See: https://issues.apache.org/jira/browse/AIRFLOW-3615
def _parse_netloc_to_hostname(uri_parts):
- """Parse a URI string to get correct Hostname."""
+ """Parse a URI string to get the correct Hostname."""
hostname = unquote(uri_parts.hostname or "")
if "/" in hostname:
hostname = uri_parts.netloc
@@ -115,7 +148,7 @@ class Connection(Base, LoggingMixin):
uri: str | None = None,
):
super().__init__()
- self.conn_id = conn_id
+ self.conn_id = sanitize_conn_id(conn_id)
self.description = description
if extra and not isinstance(extra, str):
extra = json.dumps(extra)
diff --git a/airflow/www/forms.py b/airflow/www/forms.py
index 8a8f69cf44..aa5d3a6249 100644
--- a/airflow/www/forms.py
+++ b/airflow/www/forms.py
@@ -41,7 +41,7 @@ from airflow.configuration import conf
from airflow.providers_manager import ProvidersManager
from airflow.utils import timezone
from airflow.utils.types import DagRunType
-from airflow.www.validators import ReadOnly, ValidKey
+from airflow.www.validators import ReadOnly, ValidConnID
from airflow.www.widgets import (
AirflowDateTimePickerROWidget,
AirflowDateTimePickerWidget,
@@ -221,7 +221,7 @@ def create_connection_form_class() -> type[DynamicForm]:
conn_id = StringField(
lazy_gettext("Connection Id"),
- validators=[InputRequired(), ValidKey()],
+ validators=[InputRequired(), ValidConnID()],
widget=BS3TextFieldWidget(),
)
conn_type = SelectField(
diff --git a/airflow/www/validators.py b/airflow/www/validators.py
index ce273308df..8deacb8763 100644
--- a/airflow/www/validators.py
+++ b/airflow/www/validators.py
@@ -22,6 +22,7 @@ from json import JSONDecodeError
from wtforms.validators import EqualTo, ValidationError
+from airflow.models.connection import CONN_ID_MAX_LEN, sanitize_conn_id
from airflow.utils import helpers
@@ -85,7 +86,7 @@ class ValidKey:
Validates values that will be used as keys.
:param max_length:
- The maximum length of the given key
+ The maximum allowed length of the given key
"""
def __init__(self, max_length=200):
@@ -108,3 +109,28 @@ class ReadOnly:
def __call__(self, form, field):
field.flags.readonly = True
+
+
+class ValidConnID:
+ """
+ Validates the connection ID adheres to the desired format.
+
+ :param max_length:
+ The maximum allowed length of the given Connection ID.
+ """
+
+ message = (
+ "Connection ID must be alphanumeric characters plus dashes, dots,
hashes, colons, semicolons, "
+ "underscores, exclamation marks, and parentheses"
+ )
+
+ def __init__(
+ self,
+ max_length: int = CONN_ID_MAX_LEN,
+ ):
+ self.max_length = max_length
+
+ def __call__(self, form, field):
+ if field.data:
+ if sanitize_conn_id(field.data, self.max_length) is None:
+ raise ValidationError(f"{self.message} for 1 and up to
{self.max_length} matches")
diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py
index cac38c1451..21e5682c8d 100644
--- a/tests/models/test_connection.py
+++ b/tests/models/test_connection.py
@@ -186,3 +186,67 @@ class TestConnection:
)
def test_get_uri(self, connection, expected_uri):
assert connection.get_uri() == expected_uri
+
+ @pytest.mark.parametrize(
+ "connection, expected_conn_id",
+ [
+ # a valid example of connection id
+ (
+ Connection(
+ conn_id="12312312312213___12312321",
+ conn_type="type",
+ login="user",
+ password="pass",
+ host="host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ "12312312312213___12312321",
+ ),
+ # an invalid example of connection id, which allows potential code
execution
+ (
+ Connection(
+ conn_id="<script>alert(1)</script>",
+ conn_type="type",
+ host="protocol://host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ None,
+ ),
+ # a valid connection as well
+ (
+ Connection(
+ conn_id="a_valid_conn_id_!!##",
+ conn_type="type",
+ login="user",
+ password="pass",
+ host="protocol://host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ "a_valid_conn_id_!!##",
+ ),
+ # a valid connection as well testing dashes
+ (
+ Connection(
+ conn_id="a_-.11",
+ conn_type="type",
+ login="user",
+ password="pass",
+ host="protocol://host",
+ port=100,
+ schema="schema",
+ extra={"param1": "val1", "param2": "val2"},
+ ),
+ "a_-.11",
+ ),
+ ],
+ )
+ # Responsible for ensuring that the sanitized connection id
+ # string works as expected.
+ def test_sanitize_conn_id(self, connection, expected_conn_id):
+ assert connection.conn_id == expected_conn_id