kaxil commented on code in PR #44899:
URL: https://github.com/apache/airflow/pull/44899#discussion_r1887070405
##########
task_sdk/tests/execution_time/test_task_runner.py:
##########
@@ -318,3 +318,83 @@ def __init__(self, *args, **kwargs):
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
)
+
+
+class TestRuntimeTaskInstance:
+ def test_get_context_without_ti_context_from_server(self, mocked_parse):
+ """Test get_template_context without ti_context_from_server."""
+ from airflow.providers.standard.operators.python import PythonOperator
+
+ task = PythonOperator(
+ task_id="hello",
+ python_callable=lambda: print("hello"),
+ )
+
+ what = StartupDetails(
+ ti=TaskInstance(
+ id=uuid7(), task_id="hello", dag_id="basic_task",
run_id="test_run", try_number=1
+ ),
+ file="",
+ requests_fd=0,
+ )
+
+ runtime_ti = mocked_parse(what, "basic_skipped", task)
+ context = runtime_ti.get_template_context()
+
+ # Verify the context keys and values
+ assert context == {
+ "dag": runtime_ti.task.dag,
+ "inlets": task.inlets,
+ "map_index_template": task.map_index_template,
+ "outlets": task.outlets,
+ "run_id": "test_run",
+ "task": task,
+ "task_instance": runtime_ti,
+ "ti": runtime_ti,
+ }
+
+ def test_get_context_with_ti_context_from_server(self, mocked_parse,
make_ti_context):
+ """Test the context keys are added when sent from API server
(mocked)"""
+ from airflow.providers.standard.operators.python import PythonOperator
+ from airflow.utils import timezone
+
+ task = PythonOperator(task_id="hello", python_callable=lambda:
print("hello"))
Review Comment:
Updated
##########
task_sdk/tests/execution_time/test_task_runner.py:
##########
@@ -318,3 +318,83 @@ def __init__(self, *args, **kwargs):
msg=SetRenderedFields(rendered_fields=expected_rendered_fields),
log=mock.ANY,
)
+
+
+class TestRuntimeTaskInstance:
+ def test_get_context_without_ti_context_from_server(self, mocked_parse):
+ """Test get_template_context without ti_context_from_server."""
+ from airflow.providers.standard.operators.python import PythonOperator
+
+ task = PythonOperator(
+ task_id="hello",
+ python_callable=lambda: print("hello"),
+ )
Review Comment:
Done, updated
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]