ashb commented on code in PR #60154:
URL: https://github.com/apache/airflow/pull/60154#discussion_r2671782909


##########
airflow-core/run_access_control_tests.py:
##########


Review Comment:
   This is not how you should write any kind of test.



##########
airflow-core/src/airflow/api_fastapi/execution_api/deps.py:
##########
@@ -103,7 +103,7 @@ async def __call__(  # type: ignore[override]
 JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
 
 
-async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> 
str | None:
+async def get_team_name_dep(session: AsyncSessionDep, token: TIToken = 
JWTBearerDep) -> str | None:

Review Comment:
   This type is correct, but an unrelated change to the PR so please avoid 
making changes like this.



##########
airflow.db-shm:
##########


Review Comment:
   You didn't want to commit these.



##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py:
##########
@@ -22,42 +22,83 @@
 
 from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, 
Request, Response, status
 from pydantic import JsonValue
-from sqlalchemy import delete
+from sqlalchemy import delete, select
 from sqlalchemy.sql.selectable import Select
 
-from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.common.db.common import AsyncSessionDep, SessionDep
 from airflow.api_fastapi.core_api.base import BaseModel
 from airflow.api_fastapi.execution_api.datamodels.xcom import (
     XComResponse,
     XComSequenceIndexResponse,
     XComSequenceSliceResponse,
 )
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
 from airflow.api_fastapi.execution_api.deps import JWTBearerDep
+from airflow.models import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XComModel
 from airflow.utils.db import get_query_count
 
+log = logging.getLogger(__name__)
+
 
 async def has_xcom_access(
+    session: AsyncSessionDep,
     dag_id: str,
     run_id: str,
     task_id: str,
     xcom_key: Annotated[str, Path(alias="key", min_length=1)],
     request: Request,
-    token=JWTBearerDep,
-) -> bool:
+    token: TIToken = JWTBearerDep,
+) -> None:
     """Check if the task has access to the XCom."""
-    # TODO: Placeholder for actual implementation
+    # We want to ensure that the task instance identified by the token
+    # is only accessing XComs from its own DAG and Run.
+    # Note: task_id might be different if pulling from an upstream task.
 
-    write = request.method not in {"GET", "HEAD", "OPTIONS"}
-
-    log.debug(
-        "Checking %s XCom access for xcom from TaskInstance with key '%s' to 
XCom '%s'",
-        "write" if write else "read",
-        token.id,
-        xcom_key,
+    stmt = select(TaskInstance.dag_id, TaskInstance.run_id, 
TaskInstance.task_id).where(
+        TaskInstance.id == str(token.id)
     )
-    return True
+    ti_context = await session.execute(stmt)
+    ti_row = ti_context.first()
+
+    if not ti_row:
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail="Task instance not found",
+        )
+
+    ti_dag_id, ti_run_id, ti_task_id = ti_row
+
+    if ti_dag_id != dag_id or ti_run_id != run_id:

Review Comment:
   This will break a huge number of dags. Right now it is possible to access 
the XCom for any task. We can't "just" introduce this change without 
considering the implications and letting users control this



##########
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_access_control.py:
##########
@@ -0,0 +1,199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import uuid
+from unittest import mock
+
+import pytest
+from sqlalchemy import select
+
+from airflow.models import DagModel, TaskInstance
+from airflow.models.connection import Connection
+from airflow.models.dagbundle import DagBundleModel
+from airflow.models.team import Team
+from airflow.models.variable import Variable
+from airflow.models.xcom import XComModel
+from airflow.utils.state import TaskInstanceState
+from airflow.utils.types import DagRunType
+
+from tests_common.test_utils.db import clear_db_connections, clear_db_dags, 
clear_db_variables, clear_db_runs, clear_db_teams, clear_db_dag_bundles
+
+pytestmark = pytest.mark.db_test
+
+
[email protected](autouse=True)
+def setup_method():
+    clear_db_variables()
+    clear_db_connections()
+    clear_db_dags()
+    clear_db_runs()
+    clear_db_teams()
+    clear_db_dag_bundles()
+    yield
+    clear_db_variables()
+    clear_db_connections()
+    clear_db_dags()
+    clear_db_runs()
+    clear_db_teams()
+    clear_db_dag_bundles()
+
+
+def setup_dag_run(session, dag_id, run_id, bundle_name="test_bundle"):
+    from airflow.utils import timezone
+    from airflow.models.dagrun import DagRun
+    from airflow.models.dag import DagModel
+    from airflow.models.dagbundle import DagBundleModel
+    from airflow.models.dag_version import DagVersion
+
+    bundle = session.get(DagBundleModel, bundle_name)
+    if not bundle:
+        bundle = DagBundleModel(name=bundle_name)
+        session.add(bundle)
+        session.flush()
+
+    dag = session.get(DagModel, dag_id)
+    if not dag:
+        dag = DagModel(dag_id=dag_id, bundle_name=bundle_name)
+        session.add(dag)
+        session.flush()
+
+    dv = session.scalar(select(DagVersion).where(DagVersion.dag_id == dag_id))
+    if not dv:
+        dv = DagVersion(dag_id=dag_id, bundle_name=bundle_name)
+        session.add(dv)
+        session.flush()
+
+    dr = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, 
DagRun.run_id == run_id))
+    if not dr:
+        dr = DagRun(
+            dag_id=dag_id,
+            run_id=run_id,
+            run_type=DagRunType.MANUAL,
+            logical_date=timezone.utcnow(),
+        )
+        session.add(dr)
+        session.flush()
+    return dr, dv
+
+
+def create_task_instance(session, dag_id, task_id, run_id, ti_id=None, 
bundle_name="test_bundle"):

Review Comment:
   There are already helpers that will create a task instance in tests. Use one 
of those instead please.



##########
airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_access_control.py:
##########


Review Comment:
   This shouldn't be written in a new test file, but methods added to the 
existing test_conn/test_variable/test_xcom.py etc.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to