This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e59137781b Fix bug in task_sdk's `parse` function (#44056)
6e59137781b is described below

commit 6e59137781b8e1b935c28a003e34c9925d00040a
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Nov 15 12:58:04 2024 +0000

    Fix bug in task_sdk's `parse` function (#44056)
    
    This function was added by not used or tested in the first PR that added the
    task_runner code, and in the final throes of that PR we swapped from msgspec
    to pydantic, and in doing so introduced a runtime error from Pydantic as it
    tried to look at the type hints of the `task: BaseOperator` property
    
    The fix here is to call model_construct to skip pydantic validations, which 
is
    safe here as the TI (which RuntimeTI inherits from) was validated when the
    StartupDetails object was parsed+created.
    
    And this time add tests for the function too.
    
    In order to test this I have created the first very simple test dag in the
    SDK and configured pytest to skip that entire directory when looking for
    tests.
---
 task_sdk/src/airflow/sdk/definitions/dag.py        |  4 +++-
 .../src/airflow/sdk/execution_time/task_runner.py  |  5 ++--
 task_sdk/tests/conftest.py                         |  9 +++++++
 task_sdk/tests/dags/super_basic.py                 | 28 ++++++++++++++++++++++
 task_sdk/tests/execution_time/test_task_runner.py  | 18 +++++++++++++-
 5 files changed, 60 insertions(+), 4 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 745581b45d7..d427ddde798 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -659,7 +659,9 @@ class DAG:
 
     def resolve_template_files(self):
         for t in self.tasks:
-            t.resolve_template_files()
+            # TODO: TaskSDK: move this on to BaseOperator and remove the check?
+            if hasattr(t, "resolve_template_files"):
+                t.resolve_template_files()
 
     def get_template_env(self, *, force_sandboxed: bool = False) -> 
jinja2.Environment:
         """Build a Jinja2 environment."""
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 382e29c59b6..c952207bca5 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -43,7 +43,7 @@ class RuntimeTaskInstance(TaskInstance):
 
 def parse(what: StartupDetails) -> RuntimeTaskInstance:
     # TODO: Task-SDK:
-    # Using DagBag here is aoubt 98% wrong, but it'll do for now
+    # Using DagBag here is about 98% wrong, but it'll do for now
 
     from airflow.models.dagbag import DagBag
 
@@ -64,7 +64,8 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
     task = dag.task_dict[what.ti.task_id]
     if not isinstance(task, BaseOperator):
         raise TypeError(f"task is of the wrong type, got {type(task)}, wanted 
{BaseOperator}")
-    return RuntimeTaskInstance(**what.ti.model_dump(exclude_unset=True), 
task=task)
+
+    return 
RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True), 
task=task)
 
 
 @attrs.define()
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index dffd1370f4e..9e03cb07963 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import os
+from pathlib import Path
 from typing import TYPE_CHECKING, NoReturn
 
 import pytest
@@ -43,6 +44,9 @@ def pytest_addhooks(pluginmanager: 
pytest.PytestPluginManager):
 def pytest_configure(config: pytest.Config) -> None:
     config.inicfg["airflow_deprecations_ignore"] = []
 
+    # Always skip looking for tests in these folders!
+    config.addinivalue_line("norecursedirs", "tests/test_dags")
+
 
 class LogCapture:
     # Like structlog.typing.LogCapture, but that doesn't add log_level in to 
the event dict
@@ -62,6 +66,11 @@ class LogCapture:
         raise DropEvent
 
 
[email protected]
+def test_dags_dir():
+    return Path(__file__).parent.joinpath("dags")
+
+
 @pytest.fixture
 def captured_logs():
     import structlog
diff --git a/task_sdk/tests/dags/super_basic.py 
b/task_sdk/tests/dags/super_basic.py
new file mode 100644
index 00000000000..afd0a9296d5
--- /dev/null
+++ b/task_sdk/tests/dags/super_basic.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from airflow.sdk import BaseOperator, dag
+
+
+@dag()
+def super_basic():
+    BaseOperator(task_id="a")
+
+
+super_basic()
diff --git a/task_sdk/tests/execution_time/test_task_runner.py 
b/task_sdk/tests/execution_time/test_task_runner.py
index 5a90701cb2c..c634ba1255f 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -18,12 +18,15 @@
 from __future__ import annotations
 
 import uuid
+from pathlib import Path
 from socket import socketpair
 
 import pytest
+from uuid6 import uuid7
 
+from airflow.sdk.api.datamodels.ti import TaskInstance
 from airflow.sdk.execution_time.comms import StartupDetails
-from airflow.sdk.execution_time.task_runner import CommsDecoder
+from airflow.sdk.execution_time.task_runner import CommsDecoder, parse
 
 
 class TestCommsDecoder:
@@ -54,3 +57,16 @@ class TestCommsDecoder:
         assert decoder.request_socket is not None
         assert decoder.request_socket.writable()
         assert decoder.request_socket.fileno() == w2.fileno()
+
+
+def test_parse(test_dags_dir: Path):
+    what = StartupDetails(
+        ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic", 
run_id="c", try_number=1),
+        file=str(test_dags_dir / "super_basic.py"),
+        requests_fd=0,
+    )
+
+    ti = parse(what)
+
+    assert ti.task
+    assert ti.task.dag

Reply via email to