This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3df1af4c45 REST API: Fix task instance access issue in the batch 
endpoint (#34315)
3df1af4c45 is described below

commit 3df1af4c45705d67598753a96debf5619bbfee04
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Sep 12 21:40:26 2023 +0100

    REST API: Fix task instance access issue in the batch endpoint (#34315)
    
    Currently, there's no restriction on the task instances a user can access in
    the REST API batch fetch task instances endpoint.
    This PR fixes it
---
 .../endpoints/task_instance_endpoint.py            | 18 ++++++++--
 .../endpoints/test_task_instance_endpoint.py       | 39 ++++++++++++++++++++++
 2 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 451edd968f..f0d530958b 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, Any, Iterable, TypeVar
 
+from flask import g
 from marshmallow import ValidationError
 from sqlalchemy import and_, or_, select
 from sqlalchemy.exc import MultipleResultsFound
@@ -25,7 +26,7 @@ from sqlalchemy.orm import joinedload
 
 from airflow.api_connexion import security
 from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
-from airflow.api_connexion.exceptions import BadRequest, NotFound
+from airflow.api_connexion.exceptions import BadRequest, NotFound, 
PermissionDenied
 from airflow.api_connexion.parameters import format_datetime, format_parameters
 from airflow.api_connexion.schemas.task_instance_schema import (
     TaskInstanceCollection,
@@ -400,10 +401,23 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
         data = task_instance_batch_form.load(body)
     except ValidationError as err:
         raise BadRequest(detail=str(err.messages))
+    dag_ids = data["dag_ids"]
+    if dag_ids:
+        cannot_access_dag_ids = set()
+        for id in dag_ids:
+            if not get_airflow_app().appbuilder.sm.can_read_dag(id, g.user):
+                cannot_access_dag_ids.add(id)
+        if cannot_access_dag_ids:
+            raise PermissionDenied(
+                detail=f"User not allowed to access these DAGs: 
{list(cannot_access_dag_ids)}"
+            )
+    else:
+        dag_ids = 
get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
+
     states = _convert_ti_states(data["state"])
     base_query = select(TI).join(TI.dag_run)
 
-    base_query = _apply_array_filter(base_query, key=TI.dag_id, 
values=data["dag_ids"])
+    base_query = _apply_array_filter(base_query, key=TI.dag_id, values=dag_ids)
     base_query = _apply_array_filter(base_query, key=TI.run_id, 
values=data["dag_run_ids"])
     base_query = _apply_array_filter(base_query, key=TI.task_id, 
values=data["task_ids"])
     base_query = _apply_range_filter(
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index f39fad6d6e..f09b55cf41 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -24,6 +24,7 @@ import pendulum
 import pytest
 from sqlalchemy.orm import contains_eager
 
+from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
 from airflow.jobs.job import Job
 from airflow.jobs.triggerer_job_runner import TriggererJobRunner
 from airflow.models import DagRun, SlaMiss, TaskInstance, Trigger
@@ -82,6 +83,25 @@ def configured_app(minimal_app_for_api):
             (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
         ],
     )
+    create_user(
+        app,  # type: ignore
+        username="test_read_only_one_dag",
+        role_name="TestReadOnlyOneDag",
+        permissions=[
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+            (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE),
+        ],
+    )
+    # For some reason, "DAG:example_python_operator" is not synced when in the 
above list of perms,
+    # so do it manually here:
+    app.appbuilder.sm.bulk_sync_roles(
+        [
+            {
+                "role": "TestReadOnlyOneDag",
+                "perms": [(permissions.ACTION_CAN_READ, 
"DAG:example_python_operator")],
+            }
+        ]
+    )
     create_user(app, username="test_no_permissions", 
role_name="TestNoPermissions")  # type: ignore
 
     yield app
@@ -90,6 +110,7 @@ def configured_app(minimal_app_for_api):
     delete_user(app, username="test_dag_read_only")  # type: ignore
     delete_user(app, username="test_task_read_only")  # type: ignore
     delete_user(app, username="test_no_permissions")  # type: ignore
+    delete_user(app, username="test_read_only_one_dag")  # type: ignore
     delete_roles(app)
 
 
@@ -905,6 +926,24 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         )
         assert response.status_code == 403
 
+    def 
test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, 
session):
+        self.create_task_instances(session=session)
+        self.create_task_instances(session=session, dag_id="example_skip_dag")
+        payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]}
+
+        response = self.client.post(
+            "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+            environ_overrides={"REMOTE_USER": "test_read_only_one_dag"},
+            json=payload,
+        )
+        assert response.status_code == 403
+        assert response.json == {
+            "detail": "User not allowed to access these DAGs: 
['example_skip_dag']",
+            "status": 403,
+            "title": "Forbidden",
+            "type": EXCEPTIONS_LINK_MAP[403],
+        }
+
     def test_should_raise_400_for_no_json(self):
         response = self.client.post(
             "/api/v1/dags/~/dagRuns/~/taskInstances/list",

Reply via email to