This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4f1b500c47 Check that dag_ids passed in request are consistent (#34366)
4f1b500c47 is described below
commit 4f1b500c47813c54349b7d3e48df0a444fb4826c
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Sep 14 13:03:55 2023 +0200
Check that dag_ids passed in request are consistent (#34366)
There are several ways to pass dag_ids in the request - via args
via kwargs, or via form requests or via json. If you pass several
of those, they should all be the same.
---
airflow/www/auth.py | 37 ++++++++++++++++----
tests/www/test_auth.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 123 insertions(+), 7 deletions(-)
diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index f4cb53886e..d7821f14f2 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import logging
from functools import wraps
from typing import Callable, Sequence, TypeVar, cast
@@ -27,6 +28,8 @@ from airflow.www.extensions.init_auth_manager import
get_auth_manager
T = TypeVar("T", bound=Callable)
+log = logging.getLogger(__name__)
+
def get_access_denied_message():
return conf.get("webserver", "access_denied_message")
@@ -42,13 +45,33 @@ def has_access(permissions: Sequence[tuple[str, str]] |
None = None) -> Callable
appbuilder = current_app.appbuilder
- dag_id = (
- kwargs.get("dag_id")
- or request.args.get("dag_id")
- or request.form.get("dag_id")
- or (request.is_json and request.json.get("dag_id"))
- or None
- )
+ dag_id_kwargs = kwargs.get("dag_id")
+ dag_id_args = request.args.get("dag_id")
+ dag_id_form = request.form.get("dag_id")
+ dag_id_json = request.json.get("dag_id") if request.is_json else
None
+ all_dag_ids = [dag_id_kwargs, dag_id_args, dag_id_form,
dag_id_json]
+ unique_dag_ids = set(dag_id for dag_id in all_dag_ids if dag_id is
not None)
+
+ if len(unique_dag_ids) > 1:
+ log.warning(
+ f"There are different dag_ids passed in the request:
{unique_dag_ids}. Returning 403."
+ )
+ log.warning(
+ f"kwargs: {dag_id_kwargs}, args: {dag_id_args}, "
+ f"form: {dag_id_form}, json: {dag_id_json}"
+ )
+ return (
+ render_template(
+ "airflow/no_roles_permissions.html",
+ hostname=get_hostname()
+ if conf.getboolean("webserver", "EXPOSE_HOSTNAME")
+ else "redact",
+ logout_url=get_auth_manager().get_url_logout(),
+ ),
+ 403,
+ )
+ dag_id = unique_dag_ids.pop() if unique_dag_ids else None
+
if appbuilder.sm.check_authorization(permissions, dag_id):
return func(*args, **kwargs)
elif get_auth_manager().is_logged_in() and not g.user.perms:
diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py
new file mode 100644
index 0000000000..5ff6768e65
--- /dev/null
+++ b/tests/www/test_auth.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from airflow.security import permissions
+from airflow.settings import json
+from tests.test_utils.api_connexion_utils import create_user_scope
+from tests.www.test_security import SomeBaseView, SomeModelView
+
+
[email protected](scope="module")
+def app_builder(app):
+ app_builder = app.appbuilder
+ app_builder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
+ app_builder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
+ return app.appbuilder
+
+
[email protected](
+ "dag_id_args, dag_id_kwargs, dag_id_form, dag_id_json, fail",
+ [
+ ("a", None, None, None, False),
+ (None, "b", None, None, False),
+ (None, None, "c", None, False),
+ (None, None, None, "d", False),
+ ("a", "a", None, None, False),
+ ("a", "a", "a", None, False),
+ ("a", "a", "a", "a", False),
+ (None, "a", "a", "a", False),
+ (None, None, "a", "a", False),
+ ("a", None, None, "a", False),
+ ("a", None, "a", None, False),
+ ("a", None, "c", None, True),
+ (None, "b", "c", None, True),
+ (None, None, "c", "d", True),
+ ("a", "b", "c", "d", True),
+ ],
+)
+def test_dag_id_consistency(
+ app,
+ dag_id_args: str | None,
+ dag_id_kwargs: str | None,
+ dag_id_form: str | None,
+ dag_id_json: str | None,
+ fail: bool,
+):
+ with app.test_request_context() as mock_context:
+ from airflow.www.auth import has_access
+
+ mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args
else {}
+ kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {}
+ mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form
else {}
+ if dag_id_json:
+ mock_context.request._cached_data = json.dumps({"dag_id":
dag_id_json})
+ mock_context.request._parsed_content_type = ["application/json"]
+
+ with create_user_scope(
+ app,
+ username="test-user",
+ role_name="limited-role",
+ permissions=[(permissions.ACTION_CAN_READ,
permissions.RESOURCE_DAG)],
+ ) as user:
+ with patch("airflow.www.security_manager.g") as mock_g:
+ mock_g.user = user
+
+ @has_access(permissions=[(permissions.ACTION_CAN_READ,
permissions.RESOURCE_DAG)])
+ def test_func(**kwargs):
+ return True
+
+ result = test_func(**kwargs)
+ if fail:
+ assert result[1] == 403
+ else:
+ assert result is True