This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 33ec72948f Return only the TIs of the readable dags when ~ is provided
as a dag_id (#34939)
33ec72948f is described below
commit 33ec72948f74f56f2adb5e2d388e60e88e8a3fa3
Author: Hussein Awala <[email protected]>
AuthorDate: Sun Oct 15 00:01:00 2023 +0200
Return only the TIs of the readable dags when ~ is provided as a dag_id
(#34939)
---
.../endpoints/task_instance_endpoint.py | 3 ++
airflow/api_connexion/security.py | 10 ++++-
.../endpoints/test_task_instance_endpoint.py | 46 ++++++++++++++++++++++
3 files changed, 58 insertions(+), 1 deletion(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index f0d530958b..62d6cd4323 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -41,6 +41,7 @@ from airflow.api_connexion.schemas.task_instance_schema
import (
task_instance_reference_schema,
task_instance_schema,
)
+from airflow.api_connexion.security import get_readable_dags
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
from airflow.models.operator import needs_expansion
@@ -342,6 +343,8 @@ def get_task_instances(
if dag_id != "~":
base_query = base_query.where(TI.dag_id == dag_id)
+ else:
+ base_query = base_query.where(TI.dag_id.in_(get_readable_dags()))
if dag_run_id != "~":
base_query = base_query.where(TI.run_id == dag_run_id)
base_query = _apply_range_filter(
diff --git a/airflow/api_connexion/security.py
b/airflow/api_connexion/security.py
index b108adc2c3..b19f15257c 100644
--- a/airflow/api_connexion/security.py
+++ b/airflow/api_connexion/security.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from functools import wraps
from typing import Callable, Sequence, TypeVar, cast
-from flask import Response
+from flask import Response, g
from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated
from airflow.utils.airflow_flask_app import get_airflow_app
@@ -55,3 +55,11 @@ def requires_access(permissions: Sequence[tuple[str, str]] |
None = None) -> Cal
return cast(T, decorated)
return requires_access_decorator
+
+
+def get_readable_dags() -> list[str]:
+ return get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
+
+
+def can_read_dag(dag_id: str) -> bool:
+ return get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user)
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 5056f7736d..676722a237 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -658,6 +658,52 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
assert response.json["total_entries"] == expected_ti
assert len(response.json["task_instances"]) == expected_ti
+ @pytest.mark.parametrize(
+ "task_instances, user, expected_ti",
+ [
+ pytest.param(
+ {
+ "example_python_operator": 2,
+ "example_skip_dag": 1,
+ },
+ "test_read_only_one_dag",
+ 2,
+ ),
+ pytest.param(
+ {
+ "example_python_operator": 1,
+ "example_skip_dag": 2,
+ },
+ "test_read_only_one_dag",
+ 1,
+ ),
+ pytest.param(
+ {
+ "example_python_operator": 1,
+ "example_skip_dag": 2,
+ },
+ "test",
+ 3,
+ ),
+ ],
+ )
+ def test_return_TI_only_from_readable_dags(self, task_instances, user,
expected_ti, session):
+ for dag_id in task_instances:
+ self.create_task_instances(
+ session,
+ task_instances=[
+ {"execution_date": DEFAULT_DATETIME_1 +
dt.timedelta(days=i)}
+ for i in range(task_instances[dag_id])
+ ],
+ dag_id=dag_id,
+ )
+ response = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances",
environ_overrides={"REMOTE_USER": user}
+ )
+ assert response.status_code == 200
+ assert response.json["total_entries"] == expected_ti
+ assert len(response.json["task_instances"]) == expected_ti
+
def test_should_respond_200_for_dag_id_filter(self, session):
self.create_task_instances(session)
self.create_task_instances(session, dag_id="example_skip_dag")