This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 153250b58ae AIP-72: Add some basic Task Context keys (#44894)
153250b58ae is described below

commit 153250b58aea2ed46cd456a16aa027f0cfa1d68e
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Dec 13 02:55:12 2024 +0530

    AIP-72: Add some basic Task Context keys (#44894)
    
    part of https://github.com/apache/airflow/issues/44481 . This adds some 
readily available context keys
---
 .../src/airflow/sdk/execution_time/task_runner.py  | 44 +++++++++++++++++++++-
 task_sdk/tests/dags/super_basic_run.py             |  2 +
 2 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index d210e0011fe..c01677ce1a7 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -23,7 +23,7 @@ import os
 import sys
 from datetime import datetime, timezone
 from io import FileIO
-from typing import TYPE_CHECKING, Generic, TextIO, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, TextIO, TypeVar
 
 import attrs
 import structlog
@@ -49,6 +49,45 @@ class RuntimeTaskInstance(TaskInstance):
 
     task: BaseOperator
 
+    def get_template_context(self):
+        context: dict[str, Any] = {
+            "dag": self.task.dag,
+            "inlets": self.task.inlets,
+            "map_index_template": self.task.map_index_template,
+            "outlets": self.task.outlets,
+            "run_id": self.run_id,
+            "task": self.task,
+            "task_instance": self,
+            "ti": self,
+            # "dag_run": dag_run,
+            # "data_interval_end": timezone.coerce_datetime(data_interval.end),
+            # "data_interval_start": 
timezone.coerce_datetime(data_interval.start),
+            # "outlet_events": OutletEventAccessors(),
+            # "ds": ds,
+            # "ds_nodash": ds_nodash,
+            # "expanded_ti_count": expanded_ti_count,
+            # "inlet_events": InletEventsAccessors(task.inlets, 
session=session),
+            # "logical_date": logical_date,
+            # "macros": macros,
+            # "params": validated_params,
+            # "prev_data_interval_start_success": 
get_prev_data_interval_start_success(),
+            # "prev_data_interval_end_success": 
get_prev_data_interval_end_success(),
+            # "prev_start_date_success": get_prev_start_date_success(),
+            # "prev_end_date_success": get_prev_end_date_success(),
+            # "task_instance_key_str": 
f"{task.dag_id}__{task.task_id}__{ds_nodash}",
+            # "test_mode": task_instance.test_mode,
+            # "triggering_asset_events": 
lazy_object_proxy.Proxy(get_triggering_events),
+            # "ts": ts,
+            # "ts_nodash": ts_nodash,
+            # "ts_nodash_with_tz": ts_nodash_with_tz,
+            # "var": {
+            #     "json": VariableAccessor(deserialize_json=True),
+            #     "value": VariableAccessor(deserialize_json=False),
+            # },
+            # "conn": ConnectionAccessor(),
+        }
+        return context
+
 
 def parse(what: StartupDetails) -> RuntimeTaskInstance:
     # TODO: Task-SDK:
@@ -195,7 +234,8 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         # TODO: pre execute etc.
         # TODO next_method to support resuming from deferred
         # TODO: Get a real context object
-        ti.task.execute({"task_instance": ti})  # type: ignore[attr-defined]
+        context = ti.get_template_context()
+        ti.task.execute(context)  # type: ignore[attr-defined]
         msg = TaskState(state=TerminalTIState.SUCCESS, 
end_date=datetime.now(tz=timezone.utc))
     except TaskDeferred as defer:
         classpath, trigger_kwargs = defer.trigger.serialize()
diff --git a/task_sdk/tests/dags/super_basic_run.py 
b/task_sdk/tests/dags/super_basic_run.py
index 2988d85418a..87d2a682022 100644
--- a/task_sdk/tests/dags/super_basic_run.py
+++ b/task_sdk/tests/dags/super_basic_run.py
@@ -25,6 +25,8 @@ class CustomOperator(BaseOperator):
     def execute(self, context):
         task_id = context["task_instance"].task_id
         print(f"Hello World {task_id}!")
+        assert context["task_instance"].try_number == 1
+        assert context["dag"].dag_id == "super_basic_run"
 
 
 @dag()

Reply via email to