This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit dbdec0bf7a8c1072051a546677701bbeb3a81347 Author: andylamp <[email protected]> AuthorDate: Wed Jan 17 00:17:45 2024 +0200 Sanitize the conn_id to disallow potential script execution (#32867) (cherry picked from commit 71f422c7bc217129ac8614f14e6aeb586c6c88da) --- airflow/models/connection.py | 37 ++++++++++++++++++++++-- airflow/www/forms.py | 4 +-- airflow/www/validators.py | 28 +++++++++++++++++- tests/models/test_connection.py | 64 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+), 5 deletions(-) diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 4e8e3c7aaf..0af3f13768 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -24,6 +24,7 @@ from json import JSONDecodeError from typing import Any from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit +import re2 from sqlalchemy import Boolean, Column, Integer, String, Text from sqlalchemy.orm import declared_attr, reconstructor, synonym @@ -38,6 +39,13 @@ from airflow.utils.log.secrets_masker import mask_secret from airflow.utils.module_loading import import_string log = logging.getLogger(__name__) +# sanitize the `conn_id` pattern by allowing alphanumeric characters plus +# the symbols #,!,-,_,.,:,\,/ and () requiring at least one match. +# +# You can try the regex here: https://regex101.com/r/69033B/1 +RE_SANITIZE_CONN_ID = re2.compile(r"^[\w\#\!\(\)\-\.\:\/\\]{1,}$") +# the conn ID max len should be 250 +CONN_ID_MAX_LEN: int = 250 def parse_netloc_to_hostname(*args, **kwargs): @@ -46,10 +54,35 @@ def parse_netloc_to_hostname(*args, **kwargs): return _parse_netloc_to_hostname(*args, **kwargs) +def sanitize_conn_id(conn_id: str | None, max_length=CONN_ID_MAX_LEN) -> str | None: + r"""Sanitizes the connection id and allows only specific characters to be within. + + Namely, it allows alphanumeric characters plus the symbols #,!,-,_,.,:,\,/ and () from 1 and up to + 250 consecutive matches. If desired, the max length can be adjusted by setting `max_length`. + + You can try to play with the regex here: https://regex101.com/r/69033B/1 + + The character selection is such that it prevents the injection of javascript or + executable bits to avoid any awkward behaviour in the front-end. + + :param conn_id: The connection id to sanitize. + :param max_length: The max length of the connection ID, by default it is 250. + :return: the sanitized string, `None` otherwise. + """ + # check if `conn_id` or our match group is `None` and the `conn_id` is within the specified length. + if (not isinstance(conn_id, str) or len(conn_id) > max_length) or ( + res := re2.match(RE_SANITIZE_CONN_ID, conn_id) + ) is None: + return None + + # if we reach here, then we matched something, return the first match + return res.group(0) + + # Python automatically converts all letters to lowercase in hostname # See: https://issues.apache.org/jira/browse/AIRFLOW-3615 def _parse_netloc_to_hostname(uri_parts): - """Parse a URI string to get correct Hostname.""" + """Parse a URI string to get the correct Hostname.""" hostname = unquote(uri_parts.hostname or "") if "/" in hostname: hostname = uri_parts.netloc @@ -115,7 +148,7 @@ class Connection(Base, LoggingMixin): uri: str | None = None, ): super().__init__() - self.conn_id = conn_id + self.conn_id = sanitize_conn_id(conn_id) self.description = description if extra and not isinstance(extra, str): extra = json.dumps(extra) diff --git a/airflow/www/forms.py b/airflow/www/forms.py index 8a8f69cf44..aa5d3a6249 100644 --- a/airflow/www/forms.py +++ b/airflow/www/forms.py @@ -41,7 +41,7 @@ from airflow.configuration import conf from airflow.providers_manager import ProvidersManager from airflow.utils import timezone from airflow.utils.types import DagRunType -from airflow.www.validators import ReadOnly, ValidKey +from airflow.www.validators import ReadOnly, ValidConnID from airflow.www.widgets import ( AirflowDateTimePickerROWidget, AirflowDateTimePickerWidget, @@ -221,7 +221,7 @@ def create_connection_form_class() -> type[DynamicForm]: conn_id = StringField( lazy_gettext("Connection Id"), - validators=[InputRequired(), ValidKey()], + validators=[InputRequired(), ValidConnID()], widget=BS3TextFieldWidget(), ) conn_type = SelectField( diff --git a/airflow/www/validators.py b/airflow/www/validators.py index ce273308df..8deacb8763 100644 --- a/airflow/www/validators.py +++ b/airflow/www/validators.py @@ -22,6 +22,7 @@ from json import JSONDecodeError from wtforms.validators import EqualTo, ValidationError +from airflow.models.connection import CONN_ID_MAX_LEN, sanitize_conn_id from airflow.utils import helpers @@ -85,7 +86,7 @@ class ValidKey: Validates values that will be used as keys. :param max_length: - The maximum length of the given key + The maximum allowed length of the given key """ def __init__(self, max_length=200): @@ -108,3 +109,28 @@ class ReadOnly: def __call__(self, form, field): field.flags.readonly = True + + +class ValidConnID: + """ + Validates the connection ID adheres to the desired format. + + :param max_length: + The maximum allowed length of the given Connection ID. + """ + + message = ( + "Connection ID must be alphanumeric characters plus dashes, dots, hashes, colons, semicolons, " + "underscores, exclamation marks, and parentheses" + ) + + def __init__( + self, + max_length: int = CONN_ID_MAX_LEN, + ): + self.max_length = max_length + + def __call__(self, form, field): + if field.data: + if sanitize_conn_id(field.data, self.max_length) is None: + raise ValidationError(f"{self.message} for 1 and up to {self.max_length} matches") diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py index cac38c1451..21e5682c8d 100644 --- a/tests/models/test_connection.py +++ b/tests/models/test_connection.py @@ -186,3 +186,67 @@ class TestConnection: ) def test_get_uri(self, connection, expected_uri): assert connection.get_uri() == expected_uri + + @pytest.mark.parametrize( + "connection, expected_conn_id", + [ + # a valid example of connection id + ( + Connection( + conn_id="12312312312213___12312321", + conn_type="type", + login="user", + password="pass", + host="host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + "12312312312213___12312321", + ), + # an invalid example of connection id, which allows potential code execution + ( + Connection( + conn_id="<script>alert(1)</script>", + conn_type="type", + host="protocol://host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + None, + ), + # a valid connection as well + ( + Connection( + conn_id="a_valid_conn_id_!!##", + conn_type="type", + login="user", + password="pass", + host="protocol://host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + "a_valid_conn_id_!!##", + ), + # a valid connection as well testing dashes + ( + Connection( + conn_id="a_-.11", + conn_type="type", + login="user", + password="pass", + host="protocol://host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + "a_-.11", + ), + ], + ) + # Responsible for ensuring that the sanitized connection id + # string works as expected. + def test_sanitize_conn_id(self, connection, expected_conn_id): + assert connection.conn_id == expected_conn_id
