This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6e59137781b Fix bug in task_sdk's `parse` function (#44056)
6e59137781b is described below
commit 6e59137781b8e1b935c28a003e34c9925d00040a
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Nov 15 12:58:04 2024 +0000
Fix bug in task_sdk's `parse` function (#44056)
This function was added by not used or tested in the first PR that added the
task_runner code, and in the final throes of that PR we swapped from msgspec
to pydantic, and in doing so introduced a runtime error from Pydantic as it
tried to look at the type hints of the `task: BaseOperator` property
The fix here is to call model_construct to skip pydantic validations, which
is
safe here as the TI (which RuntimeTI inherits from) was validated when the
StartupDetails object was parsed+created.
And this time add tests for the function too.
In order to test this I have created the first very simple test dag in the
SDK and configured pytest to skip that entire directory when looking for
tests.
---
task_sdk/src/airflow/sdk/definitions/dag.py | 4 +++-
.../src/airflow/sdk/execution_time/task_runner.py | 5 ++--
task_sdk/tests/conftest.py | 9 +++++++
task_sdk/tests/dags/super_basic.py | 28 ++++++++++++++++++++++
task_sdk/tests/execution_time/test_task_runner.py | 18 +++++++++++++-
5 files changed, 60 insertions(+), 4 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 745581b45d7..d427ddde798 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -659,7 +659,9 @@ class DAG:
def resolve_template_files(self):
for t in self.tasks:
- t.resolve_template_files()
+ # TODO: TaskSDK: move this on to BaseOperator and remove the check?
+ if hasattr(t, "resolve_template_files"):
+ t.resolve_template_files()
def get_template_env(self, *, force_sandboxed: bool = False) ->
jinja2.Environment:
"""Build a Jinja2 environment."""
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index 382e29c59b6..c952207bca5 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -43,7 +43,7 @@ class RuntimeTaskInstance(TaskInstance):
def parse(what: StartupDetails) -> RuntimeTaskInstance:
# TODO: Task-SDK:
- # Using DagBag here is aoubt 98% wrong, but it'll do for now
+ # Using DagBag here is about 98% wrong, but it'll do for now
from airflow.models.dagbag import DagBag
@@ -64,7 +64,8 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance:
task = dag.task_dict[what.ti.task_id]
if not isinstance(task, BaseOperator):
raise TypeError(f"task is of the wrong type, got {type(task)}, wanted
{BaseOperator}")
- return RuntimeTaskInstance(**what.ti.model_dump(exclude_unset=True),
task=task)
+
+ return
RuntimeTaskInstance.model_construct(**what.ti.model_dump(exclude_unset=True),
task=task)
@attrs.define()
diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py
index dffd1370f4e..9e03cb07963 100644
--- a/task_sdk/tests/conftest.py
+++ b/task_sdk/tests/conftest.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import os
+from pathlib import Path
from typing import TYPE_CHECKING, NoReturn
import pytest
@@ -43,6 +44,9 @@ def pytest_addhooks(pluginmanager:
pytest.PytestPluginManager):
def pytest_configure(config: pytest.Config) -> None:
config.inicfg["airflow_deprecations_ignore"] = []
+ # Always skip looking for tests in these folders!
+ config.addinivalue_line("norecursedirs", "tests/test_dags")
+
class LogCapture:
# Like structlog.typing.LogCapture, but that doesn't add log_level in to
the event dict
@@ -62,6 +66,11 @@ class LogCapture:
raise DropEvent
[email protected]
+def test_dags_dir():
+ return Path(__file__).parent.joinpath("dags")
+
+
@pytest.fixture
def captured_logs():
import structlog
diff --git a/task_sdk/tests/dags/super_basic.py
b/task_sdk/tests/dags/super_basic.py
new file mode 100644
index 00000000000..afd0a9296d5
--- /dev/null
+++ b/task_sdk/tests/dags/super_basic.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from airflow.sdk import BaseOperator, dag
+
+
+@dag()
+def super_basic():
+ BaseOperator(task_id="a")
+
+
+super_basic()
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index 5a90701cb2c..c634ba1255f 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -18,12 +18,15 @@
from __future__ import annotations
import uuid
+from pathlib import Path
from socket import socketpair
import pytest
+from uuid6 import uuid7
+from airflow.sdk.api.datamodels.ti import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails
-from airflow.sdk.execution_time.task_runner import CommsDecoder
+from airflow.sdk.execution_time.task_runner import CommsDecoder, parse
class TestCommsDecoder:
@@ -54,3 +57,16 @@ class TestCommsDecoder:
assert decoder.request_socket is not None
assert decoder.request_socket.writable()
assert decoder.request_socket.fileno() == w2.fileno()
+
+
+def test_parse(test_dags_dir: Path):
+ what = StartupDetails(
+ ti=TaskInstance(id=uuid7(), task_id="a", dag_id="super_basic",
run_id="c", try_number=1),
+ file=str(test_dags_dir / "super_basic.py"),
+ requests_fd=0,
+ )
+
+ ti = parse(what)
+
+ assert ti.task
+ assert ti.task.dag