This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4f1b500c47 Check that dag_ids passed in request are consistent (#34366)
4f1b500c47 is described below

commit 4f1b500c47813c54349b7d3e48df0a444fb4826c
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Sep 14 13:03:55 2023 +0200

    Check that dag_ids passed in request are consistent (#34366)
    
    There are several ways to pass dag_ids in the request - via args
    via kwargs, or via form requests or via json. If you pass several
    of those, they should all be the same.
---
 airflow/www/auth.py    | 37 ++++++++++++++++----
 tests/www/test_auth.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 123 insertions(+), 7 deletions(-)

diff --git a/airflow/www/auth.py b/airflow/www/auth.py
index f4cb53886e..d7821f14f2 100644
--- a/airflow/www/auth.py
+++ b/airflow/www/auth.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from functools import wraps
 from typing import Callable, Sequence, TypeVar, cast
 
@@ -27,6 +28,8 @@ from airflow.www.extensions.init_auth_manager import 
get_auth_manager
 
 T = TypeVar("T", bound=Callable)
 
+log = logging.getLogger(__name__)
+
 
 def get_access_denied_message():
     return conf.get("webserver", "access_denied_message")
@@ -42,13 +45,33 @@ def has_access(permissions: Sequence[tuple[str, str]] | 
None = None) -> Callable
 
             appbuilder = current_app.appbuilder
 
-            dag_id = (
-                kwargs.get("dag_id")
-                or request.args.get("dag_id")
-                or request.form.get("dag_id")
-                or (request.is_json and request.json.get("dag_id"))
-                or None
-            )
+            dag_id_kwargs = kwargs.get("dag_id")
+            dag_id_args = request.args.get("dag_id")
+            dag_id_form = request.form.get("dag_id")
+            dag_id_json = request.json.get("dag_id") if request.is_json else 
None
+            all_dag_ids = [dag_id_kwargs, dag_id_args, dag_id_form, 
dag_id_json]
+            unique_dag_ids = set(dag_id for dag_id in all_dag_ids if dag_id is 
not None)
+
+            if len(unique_dag_ids) > 1:
+                log.warning(
+                    f"There are different dag_ids passed in the request: 
{unique_dag_ids}. Returning 403."
+                )
+                log.warning(
+                    f"kwargs: {dag_id_kwargs}, args: {dag_id_args}, "
+                    f"form: {dag_id_form}, json: {dag_id_json}"
+                )
+                return (
+                    render_template(
+                        "airflow/no_roles_permissions.html",
+                        hostname=get_hostname()
+                        if conf.getboolean("webserver", "EXPOSE_HOSTNAME")
+                        else "redact",
+                        logout_url=get_auth_manager().get_url_logout(),
+                    ),
+                    403,
+                )
+            dag_id = unique_dag_ids.pop() if unique_dag_ids else None
+
             if appbuilder.sm.check_authorization(permissions, dag_id):
                 return func(*args, **kwargs)
             elif get_auth_manager().is_logged_in() and not g.user.perms:
diff --git a/tests/www/test_auth.py b/tests/www/test_auth.py
new file mode 100644
index 0000000000..5ff6768e65
--- /dev/null
+++ b/tests/www/test_auth.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from airflow.security import permissions
+from airflow.settings import json
+from tests.test_utils.api_connexion_utils import create_user_scope
+from tests.www.test_security import SomeBaseView, SomeModelView
+
+
[email protected](scope="module")
+def app_builder(app):
+    app_builder = app.appbuilder
+    app_builder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
+    app_builder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
+    return app.appbuilder
+
+
[email protected](
+    "dag_id_args, dag_id_kwargs, dag_id_form, dag_id_json, fail",
+    [
+        ("a", None, None, None, False),
+        (None, "b", None, None, False),
+        (None, None, "c", None, False),
+        (None, None, None, "d", False),
+        ("a", "a", None, None, False),
+        ("a", "a", "a", None, False),
+        ("a", "a", "a", "a", False),
+        (None, "a", "a", "a", False),
+        (None, None, "a", "a", False),
+        ("a", None, None, "a", False),
+        ("a", None, "a", None, False),
+        ("a", None, "c", None, True),
+        (None, "b", "c", None, True),
+        (None, None, "c", "d", True),
+        ("a", "b", "c", "d", True),
+    ],
+)
+def test_dag_id_consistency(
+    app,
+    dag_id_args: str | None,
+    dag_id_kwargs: str | None,
+    dag_id_form: str | None,
+    dag_id_json: str | None,
+    fail: bool,
+):
+    with app.test_request_context() as mock_context:
+        from airflow.www.auth import has_access
+
+        mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args 
else {}
+        kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {}
+        mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form 
else {}
+        if dag_id_json:
+            mock_context.request._cached_data = json.dumps({"dag_id": 
dag_id_json})
+            mock_context.request._parsed_content_type = ["application/json"]
+
+        with create_user_scope(
+            app,
+            username="test-user",
+            role_name="limited-role",
+            permissions=[(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)],
+        ) as user:
+            with patch("airflow.www.security_manager.g") as mock_g:
+                mock_g.user = user
+
+                @has_access(permissions=[(permissions.ACTION_CAN_READ, 
permissions.RESOURCE_DAG)])
+                def test_func(**kwargs):
+                    return True
+
+                result = test_func(**kwargs)
+                if fail:
+                    assert result[1] == 403
+                else:
+                    assert result is True

Reply via email to