mobuchowski commented on code in PR #45294:
URL: https://github.com/apache/airflow/pull/45294#discussion_r1910383560


##########
providers/src/airflow/providers/openlineage/plugins/listener.py:
##########
@@ -87,28 +88,58 @@ def __init__(self):
         self.extractor_manager = ExtractorManager()
         self.adapter = OpenLineageAdapter()
 
-    @hookimpl
-    def on_task_instance_running(
-        self,
-        previous_state: TaskInstanceState,
-        task_instance: TaskInstance,
-        session: Session,  # This will always be QUEUED
-    ) -> None:
-        if not getattr(task_instance, "task", None) is not None:
-            self.log.warning(
-                "No task set for TI object task_id: %s - dag_id: %s - run_id 
%s",
-                task_instance.task_id,
-                task_instance.dag_id,
-                task_instance.run_id,
-            )
-            return
+    if AIRFLOW_V_3_0_PLUS:
 
-        self.log.debug("OpenLineage listener got notification about task 
instance start")
-        dagrun = task_instance.dag_run
-        task = task_instance.task
-        if TYPE_CHECKING:
-            assert task
-        dag = task.dag
+        @hookimpl
+        def on_task_instance_running(
+            self,
+            previous_state: TaskInstanceState,
+            task_instance: RuntimeTaskInstance,
+        ):
+            if not getattr(task_instance, "task", None) is not None:
+                self.log.warning(
+                    "No task set for TI object task_id: %s - dag_id: %s - 
run_id %s",
+                    task_instance.task_id,
+                    task_instance.dag_id,
+                    task_instance.run_id,
+                )
+                return

Review Comment:
   Actually nope - `RuntimeTaskInstance` prevents that. So, removed that check.



##########
providers/src/airflow/providers/openlineage/plugins/listener.py:
##########
@@ -127,35 +158,34 @@ def on_task_instance_running(
             return
 
         # Needs to be calculated outside of inner method so that it gets 
cached for usage in fork processes
+        data_interval_start = dagrun.data_interval_start
+        if isinstance(data_interval_start, datetime):
+            data_interval_start = data_interval_start.isoformat()
+        data_interval_end = dagrun.data_interval_end
+        if isinstance(data_interval_end, datetime):
+            data_interval_end = data_interval_end.isoformat()
+
         debug_facet = get_airflow_debug_facet()
 
         @print_warning(self.log)
         def on_running():
-            # that's a workaround to detect task running from deferred state
-            # we return here because Airflow 2.3 needs task from deferred state
-            if task_instance.next_method is not None:
-                return
-
-            if is_ti_rescheduled_already(task_instance):
+            context = task_instance.get_template_context()
+            if hasattr(context, "task_reschedule_count") and 
context["task_reschedule_count"] > 0:
                 self.log.debug("Skipping this instance of rescheduled task - 
START event was emitted already")
                 return
 
             parent_run_id = self.adapter.build_dag_run_id(
                 dag_id=dag.dag_id,
                 logical_date=dagrun.logical_date,
-                clear_number=dagrun.clear_number,
+                clear_number=0,

Review Comment:
   Fixed.



##########
providers/tests/openlineage/extractors/test_manager.py:
##########
@@ -324,3 +347,123 @@ def use_read():
 
     assert len(datasets.outputs) == 1
     assert datasets.outputs[0].asset == Asset(uri=path)
+
+
[email protected]
+def mock_supervisor_comms():
+    with mock.patch(
+        "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
+    ) as supervisor_comms:
+        yield supervisor_comms
+
+
[email protected]
+def mocked_parse(spy_agency):
+    """
+    Fixture to set up an inline DAG and use it in a stubbed `parse` function. 
Use this fixture if you
+    want to isolate and test `parse` or `run` logic without having to define a 
DAG file.
+
+    This fixture returns a helper function `set_dag` that:
+    1. Creates an in line DAG with the given `dag_id` and `task` (limited to 
one task)
+    2. Constructs a `RuntimeTaskInstance` based on the provided 
`StartupDetails` and task.
+    3. Stubs the `parse` function using `spy_agency`, to return the mocked 
`RuntimeTaskInstance`.
+
+    After adding the fixture in your test function signature, you can use it 
like this ::
+
+            mocked_parse(
+                StartupDetails(
+                    ti=TaskInstance(id=uuid7(), task_id="hello", 
dag_id="super_basic_run", run_id="c", try_number=1),
+                    file="",
+                    requests_fd=0,
+                ),
+                "example_dag_id",
+                CustomOperator(task_id="hello"),
+            )
+    """
+
+    def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> 
RuntimeTaskInstance:
+        from task_sdk.tests.execution_time.test_task_runner import 
get_inline_dag
+
+        dag = get_inline_dag(dag_id, task)
+        t = dag.task_dict[task.task_id]
+        ti = RuntimeTaskInstance.model_construct(
+            **what.ti.model_dump(exclude_unset=True), task=t, 
_ti_context_from_server=what.ti_context
+        )
+        spy_agency.spy_on(parse, call_fake=lambda _: ti)
+        return ti
+
+    return set_dag
+
+
[email protected]
+def make_ti_context() -> MakeTIContextCallable:
+    """Factory for creating TIRunContext objects."""
+    from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+
+    def _make_context(
+        dag_id: str = "test_dag",
+        run_id: str = "test_run",
+        logical_date: str | datetime = "2024-12-01T01:00:00Z",
+        data_interval_start: str | datetime = "2024-12-01T00:00:00Z",
+        data_interval_end: str | datetime = "2024-12-01T01:00:00Z",
+        clear_number: int = 0,
+        start_date: str | datetime = "2024-12-01T01:00:00Z",
+        run_type: str = "manual",
+        task_reschedule_count: int = 0,
+    ) -> TIRunContext:
+        return TIRunContext(
+            dag_run=DagRun(
+                dag_id=dag_id,
+                run_id=run_id,
+                logical_date=logical_date,  # type: ignore
+                data_interval_start=data_interval_start,  # type: ignore
+                data_interval_end=data_interval_end,  # type: ignore
+                clear_number=clear_number,  # type: ignore
+                start_date=start_date,  # type: ignore
+                run_type=run_type,  # type: ignore
+            ),
+            task_reschedule_count=task_reschedule_count,
+        )
+
+    return _make_context
+
+
[email protected]_test
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Task SDK related test")
+def test_extractor_manager_gets_data_from_pythonoperator_tasksdk(
+    session, hook_lineage_collector, mocked_parse, make_ti_context, 
mock_supervisor_comms
+):
+    path = None
+    with tempfile.NamedTemporaryFile() as f:
+        path = f.name
+
+        def use_read():
+            storage_path = ObjectStoragePath(path)
+            with storage_path.open("w") as out:
+                out.write("test")
+
+    task = PythonOperator(task_id="test_task_extractor_pythonoperator", 
python_callable=use_read)
+
+    what = StartupDetails(
+        ti=SDKTaskInstance(
+            id=uuid7(),
+            task_id="test_task_extractor_pythonoperator",
+            dag_id="test_hookcollector_dag",
+            run_id="c",
+            try_number=1,
+            start_date=timezone.utcnow(),
+        ),
+        file="",
+        requests_fd=0,
+        ti_context=make_ti_context(),
+    )
+    ti = mocked_parse(what, "test_hookcollector_dag", task)
+
+    print(ti.__dict__)

Review Comment:
   Removed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to