This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 153250b58ae AIP-72: Add some basic Task Context keys (#44894)
153250b58ae is described below
commit 153250b58aea2ed46cd456a16aa027f0cfa1d68e
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Dec 13 02:55:12 2024 +0530
AIP-72: Add some basic Task Context keys (#44894)
part of https://github.com/apache/airflow/issues/44481 . This adds some
readily available context keys
---
.../src/airflow/sdk/execution_time/task_runner.py | 44 +++++++++++++++++++++-
task_sdk/tests/dags/super_basic_run.py | 2 +
2 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d210e0011fe..c01677ce1a7 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -23,7 +23,7 @@ import os
import sys
from datetime import datetime, timezone
from io import FileIO
-from typing import TYPE_CHECKING, Generic, TextIO, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, TextIO, TypeVar
import attrs
import structlog
@@ -49,6 +49,45 @@ class RuntimeTaskInstance(TaskInstance):
task: BaseOperator
+ def get_template_context(self):
+ context: dict[str, Any] = {
+ "dag": self.task.dag,
+ "inlets": self.task.inlets,
+ "map_index_template": self.task.map_index_template,
+ "outlets": self.task.outlets,
+ "run_id": self.run_id,
+ "task": self.task,
+ "task_instance": self,
+ "ti": self,
+ # "dag_run": dag_run,
+ # "data_interval_end": timezone.coerce_datetime(data_interval.end),
+ # "data_interval_start":
timezone.coerce_datetime(data_interval.start),
+ # "outlet_events": OutletEventAccessors(),
+ # "ds": ds,
+ # "ds_nodash": ds_nodash,
+ # "expanded_ti_count": expanded_ti_count,
+ # "inlet_events": InletEventsAccessors(task.inlets,
session=session),
+ # "logical_date": logical_date,
+ # "macros": macros,
+ # "params": validated_params,
+ # "prev_data_interval_start_success":
get_prev_data_interval_start_success(),
+ # "prev_data_interval_end_success":
get_prev_data_interval_end_success(),
+ # "prev_start_date_success": get_prev_start_date_success(),
+ # "prev_end_date_success": get_prev_end_date_success(),
+ # "task_instance_key_str":
f"{task.dag_id}__{task.task_id}__{ds_nodash}",
+ # "test_mode": task_instance.test_mode,
+ # "triggering_asset_events":
lazy_object_proxy.Proxy(get_triggering_events),
+ # "ts": ts,
+ # "ts_nodash": ts_nodash,
+ # "ts_nodash_with_tz": ts_nodash_with_tz,
+ # "var": {
+ # "json": VariableAccessor(deserialize_json=True),
+ # "value": VariableAccessor(deserialize_json=False),
+ # },
+ # "conn": ConnectionAccessor(),
+ }
+ return context
+
def parse(what: StartupDetails) -> RuntimeTaskInstance:
# TODO: Task-SDK:
@@ -195,7 +234,8 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: pre execute etc.
# TODO next_method to support resuming from deferred
# TODO: Get a real context object
- ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined]
+ context = ti.get_template_context()
+ ti.task.execute(context) # type: ignore[attr-defined]
msg = TaskState(state=TerminalTIState.SUCCESS,
end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
diff --git a/task_sdk/tests/dags/super_basic_run.py
b/task_sdk/tests/dags/super_basic_run.py
index 2988d85418a..87d2a682022 100644
--- a/task_sdk/tests/dags/super_basic_run.py
+++ b/task_sdk/tests/dags/super_basic_run.py
@@ -25,6 +25,8 @@ class CustomOperator(BaseOperator):
def execute(self, context):
task_id = context["task_instance"].task_id
print(f"Hello World {task_id}!")
+ assert context["task_instance"].try_number == 1
+ assert context["dag"].dag_id == "super_basic_run"
@dag()