This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 579a8b87fc openlineage: extend custom_run_facets to also be executed 
on complete and fail (#40953)
579a8b87fc is described below

commit 579a8b87fc3d4a737bae11049c0607aaf2a8b8fb
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Jul 23 17:44:05 2024 +0200

    openlineage: extend custom_run_facets to also be executed on complete and 
fail (#40953)
    
    Signed-off-by: Kacper Muda <[email protected]>
---
 airflow/providers/openlineage/plugins/adapter.py   | 24 ++++--
 airflow/providers/openlineage/plugins/listener.py  | 10 ++-
 airflow/providers/openlineage/utils/utils.py       | 29 ++++---
 .../guides/developer.rst                           | 81 ++++++++++----------
 .../providers/openlineage/plugins/test_adapter.py  | 12 +++
 .../providers/openlineage/plugins/test_listener.py | 41 +++++++---
 .../openlineage/utils/custom_facet_fixture.py      | 58 +++++---------
 tests/providers/openlineage/utils/test_utils.py    | 88 +++++++---------------
 8 files changed, 169 insertions(+), 174 deletions(-)

diff --git a/airflow/providers/openlineage/plugins/adapter.py 
b/airflow/providers/openlineage/plugins/adapter.py
index 1d0317228b..7405556088 100644
--- a/airflow/providers/openlineage/plugins/adapter.py
+++ b/airflow/providers/openlineage/plugins/adapter.py
@@ -243,6 +243,7 @@ class OpenLineageAdapter(LoggingMixin):
         parent_run_id: str | None,
         end_time: str,
         task: OperatorLineage,
+        run_facets: dict[str, RunFacet] | None = None,  # Custom run facets
     ) -> RunEvent:
         """
         Emit openlineage event of type COMPLETE.
@@ -254,7 +255,11 @@ class OpenLineageAdapter(LoggingMixin):
         :param parent_run_id: identifier of job spawning this task
         :param end_time: time of task completion
         :param task: metadata container with information extracted from 
operator
+        :param run_facets: custom run facets
         """
+        run_facets = run_facets or {}
+        if task:
+            run_facets = {**task.run_facets, **run_facets}
         event = RunEvent(
             eventType=RunState.COMPLETE,
             eventTime=end_time,
@@ -263,7 +268,7 @@ class OpenLineageAdapter(LoggingMixin):
                 job_name=job_name,
                 parent_job_name=parent_job_name,
                 parent_run_id=parent_run_id,
-                run_facets=task.run_facets,
+                run_facets=run_facets,
             ),
             job=self._build_job(job_name, job_type=_JOB_TYPE_TASK, 
job_facets=task.job_facets),
             inputs=task.inputs,
@@ -280,6 +285,7 @@ class OpenLineageAdapter(LoggingMixin):
         parent_run_id: str | None,
         end_time: str,
         task: OperatorLineage,
+        run_facets: dict[str, RunFacet] | None = None,  # Custom run facets
         error: str | BaseException | None = None,
     ) -> RunEvent:
         """
@@ -292,20 +298,22 @@ class OpenLineageAdapter(LoggingMixin):
         :param parent_run_id: identifier of job spawning this task
         :param end_time: time of task completion
         :param task: metadata container with information extracted from 
operator
+        :param run_facets: custom run facets
         :param error: error
         """
-        error_facet = {}
+        run_facets = run_facets or {}
+        if task:
+            run_facets = {**task.run_facets, **run_facets}
+
         if error:
             stack_trace = None
             if isinstance(error, BaseException) and error.__traceback__:
                 import traceback
 
                 stack_trace = 
"\\n".join(traceback.format_exception(type(error), error, error.__traceback__))
-            error_facet = {
-                "errorMessage": error_message_run.ErrorMessageRunFacet(
-                    message=str(error), programmingLanguage="python", 
stackTrace=stack_trace
-                )
-            }
+            run_facets["errorMessage"] = 
error_message_run.ErrorMessageRunFacet(
+                message=str(error), programmingLanguage="python", 
stackTrace=stack_trace
+            )
 
         event = RunEvent(
             eventType=RunState.FAIL,
@@ -315,7 +323,7 @@ class OpenLineageAdapter(LoggingMixin):
                 job_name=job_name,
                 parent_job_name=parent_job_name,
                 parent_run_id=parent_run_id,
-                run_facets={**task.run_facets, **error_facet},
+                run_facets=run_facets,
             ),
             job=self._build_job(job_name, job_type=_JOB_TYPE_TASK, 
job_facets=task.job_facets),
             inputs=task.inputs,
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index a552cb283b..58ccdcad24 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -33,9 +33,10 @@ from airflow.providers.openlineage.extractors import 
ExtractorManager
 from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter, 
RunState
 from airflow.providers.openlineage.utils.utils import (
     get_airflow_job_facet,
+    get_airflow_mapped_task_facet,
     get_airflow_run_facet,
-    get_custom_facets,
     get_job_name,
+    get_user_provided_run_facets,
     is_operator_disabled,
     is_selective_lineage_enabled,
     print_warning,
@@ -43,13 +44,13 @@ from airflow.providers.openlineage.utils.utils import (
 from airflow.settings import configure_orm
 from airflow.stats import Stats
 from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
 from airflow.utils.timeout import timeout
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models import DagRun, TaskInstance
-    from airflow.utils.state import TaskInstanceState
 
 _openlineage_listener: OpenLineageListener | None = None
 _IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= 
Version("2.10.0")
@@ -163,7 +164,8 @@ class OpenLineageListener:
                 owners=dag.owner.split(", "),
                 task=task_metadata,
                 run_facets={
-                    **get_custom_facets(task_instance),
+                    **get_user_provided_run_facets(task_instance, 
TaskInstanceState.RUNNING),
+                    **get_airflow_mapped_task_facet(task_instance),
                     **get_airflow_run_facet(dagrun, dag, task_instance, task, 
task_uuid),
                 },
             )
@@ -233,6 +235,7 @@ class OpenLineageListener:
                 parent_run_id=parent_run_id,
                 end_time=end_date.isoformat(),
                 task=task_metadata,
+                run_facets=get_user_provided_run_facets(task_instance, 
TaskInstanceState.SUCCESS),
             )
             Stats.gauge(
                 f"ol.event.size.{event_type}.{operator_name}",
@@ -327,6 +330,7 @@ class OpenLineageListener:
                 parent_run_id=parent_run_id,
                 end_time=end_date.isoformat(),
                 task=task_metadata,
+                run_facets=get_user_provided_run_facets(task_instance, 
TaskInstanceState.FAILED),
                 error=error,
             )
             Stats.gauge(
diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index 195d14e4e7..2e995fb7a8 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -55,10 +55,11 @@ from airflow.utils.log.secrets_masker import Redactable, 
Redacted, SecretsMasker
 from airflow.utils.module_loading import import_string
 
 if TYPE_CHECKING:
+    from openlineage.client.event_v2 import Dataset as OpenLineageDataset
     from openlineage.client.facet_v2 import RunFacet
-    from openlineage.client.run import Dataset as OpenLineageDataset
 
     from airflow.models import DagRun, TaskInstance
+    from airflow.utils.state import TaskInstanceState
 
 
 log = logging.getLogger(__name__)
@@ -81,28 +82,32 @@ def get_job_name(task: TaskInstance) -> str:
     return f"{task.dag_id}.{task.task_id}"
 
 
-def get_custom_facets(task_instance: TaskInstance | None = None) -> dict[str, 
Any]:
-    from airflow.providers.openlineage.extractors.manager import 
try_import_from_string
-
-    custom_facets = {}
+def get_airflow_mapped_task_facet(task_instance: TaskInstance) -> dict[str, 
Any]:
     # check for -1 comes from SmartSensor compatibility with dynamic task 
mapping
     # this comes from Airflow code
     if hasattr(task_instance, "map_index") and getattr(task_instance, 
"map_index") != -1:
-        custom_facets["airflow_mappedTask"] = 
AirflowMappedTaskRunFacet.from_task_instance(task_instance)
+        return {"airflow_mappedTask": 
AirflowMappedTaskRunFacet.from_task_instance(task_instance)}
+    return {}
+
+
+def get_user_provided_run_facets(ti: TaskInstance, ti_state: 
TaskInstanceState) -> dict[str, RunFacet]:
+    custom_facets = {}
 
     # Append custom run facets by executing the custom_run_facet functions.
     for custom_facet_func in conf.custom_run_facets():
         try:
-            func: Callable[[Any], dict] | None = 
try_import_from_string(custom_facet_func)
+            func: Callable[[TaskInstance, TaskInstanceState], dict[str, 
RunFacet]] | None = (
+                try_import_from_string(custom_facet_func)
+            )
             if not func:
                 log.warning(
                     "OpenLineage is unable to import custom facet function 
`%s`; will ignore it.",
                     custom_facet_func,
                 )
                 continue
-            facet: dict[str, dict[Any, Any]] | None = func(task_instance)
-            if facet and isinstance(facet, dict):
-                duplicate_facet_keys = [facet_key for facet_key in 
facet.keys() if facet_key in custom_facets]
+            facets: dict[str, RunFacet] | None = func(ti, ti_state)
+            if facets and isinstance(facets, dict):
+                duplicate_facet_keys = [facet_key for facet_key in facets if 
facet_key in custom_facets]
                 if duplicate_facet_keys:
                     log.warning(
                         "Duplicate OpenLineage custom facets key(s) found: 
`%s` from function `%s`; "
@@ -112,10 +117,10 @@ def get_custom_facets(task_instance: TaskInstance | None 
= None) -> dict[str, An
                     )
                 log.debug(
                     "Adding OpenLineage custom facet with key(s): `%s` from 
function `%s`.",
-                    tuple(facet),
+                    tuple(facets),
                     custom_facet_func,
                 )
-                custom_facets.update(facet)
+                custom_facets.update(facets)
         except Exception as exc:
             log.warning(
                 "Error processing custom facet function `%s`; will ignore it. 
Error was: %s: %s",
diff --git a/docs/apache-airflow-providers-openlineage/guides/developer.rst 
b/docs/apache-airflow-providers-openlineage/guides/developer.rst
index 8d66780190..4e9ada44c2 100644
--- a/docs/apache-airflow-providers-openlineage/guides/developer.rst
+++ b/docs/apache-airflow-providers-openlineage/guides/developer.rst
@@ -85,7 +85,7 @@ Instead of returning complete OpenLineage event, the provider 
defines ``Operator
   class OperatorLineage:
       inputs: list[Dataset] = Factory(list)
       outputs: list[Dataset] = Factory(list)
-      run_facets: dict[str, BaseFacet] = Factory(dict)
+      run_facets: dict[str, RunFacet] = Factory(dict)
       job_facets: dict[str, BaseFacet] = Factory(dict)
 
 OpenLineage integration itself takes care to enrich it with things like 
general Airflow facets, proper event time and type, creating proper OpenLineage 
RunEvent.
@@ -214,11 +214,11 @@ Both methods return ``OperatorLineage`` structure:
 
         inputs: list[Dataset] = Factory(list)
         outputs: list[Dataset] = Factory(list)
-        run_facets: dict[str, BaseFacet] = Factory(dict)
+        run_facets: dict[str, RunFacet] = Factory(dict)
         job_facets: dict[str, BaseFacet] = Factory(dict)
 
 
-Inputs and outputs are lists of plain OpenLineage datasets 
(`openlineage.client.run.Dataset`).
+Inputs and outputs are lists of plain OpenLineage datasets 
(`openlineage.client.event_v2.Dataset`).
 
 ``run_facets`` and ``job_facets`` are dictionaries of optional RunFacets and 
JobFacets that would be attached to the job - for example,
 you might want to attach ``SqlJobFacet`` if your Operator is executing SQL.
@@ -303,23 +303,20 @@ like extracting column level lineage and inputs/outputs 
from SQL query with SQL
 
 .. code-block:: python
 
+    from airflow.models.baseoperator import BaseOperator
+    from airflow.providers.openlineage.extractors.base import BaseExtractor, 
OperatorLineage
     from airflow.providers.common.compat.openlineage.facet import (
-        BaseFacet,
         Dataset,
         ExternalQueryRunFacet,
         SQLJobFacet,
     )
 
-    from airflow.models.baseoperator import BaseOperator
-    from airflow.providers.openlineage.extractors.base import BaseExtractor
-
 
     class ExampleOperator(BaseOperator):
         def __init__(self, query, bq_table_reference, s3_path) -> None:
             self.bq_table_reference = bq_table_reference
             self.s3_path = s3_path
             self.s3_file_name = s3_file_name
-            self.query = query
             self._job_id = None
 
         def execute(self, context) -> Any:
@@ -334,8 +331,8 @@ like extracting column level lineage and inputs/outputs 
from SQL query with SQL
         def _execute_extraction(self) -> OperatorLineage:
             """Define what we know before Operator's extract is called."""
             return OperatorLineage(
-                inputs=[Dataset(namespace="bigquery", 
name=self.bq_table_reference)],
-                outputs=[Dataset(namespace=self.s3_path, 
name=self.s3_file_name)],
+                inputs=[Dataset(namespace="bigquery", 
name=self.operator.bq_table_reference)],
+                outputs=[Dataset(namespace=self.operator.s3_path, 
name=self.operator.s3_file_name)],
                 job_facets={
                     "sql": SQLJobFacet(
                         query="EXPORT INTO ... OPTIONS(FORMAT=csv, SEP=';' 
...) AS SELECT * FROM ... "
@@ -343,11 +340,11 @@ like extracting column level lineage and inputs/outputs 
from SQL query with SQL
                 },
             )
 
-        def extract_on_complete(self) -> OperatorLineage:
+        def extract_on_complete(self, task_instance) -> OperatorLineage:
             """Add what we received after Operator's extract call."""
             lineage_metadata = self.extract()
             lineage_metadata.run_facets = {
-                "parent": ExternalQueryRunFacet(externalQueryId=self._job_id, 
source="bigquery")
+                "parent": 
ExternalQueryRunFacet(externalQueryId=task_instance.task._job_id, 
source="bigquery")
             }
             return lineage_metadata
 
@@ -454,42 +451,38 @@ Custom Facets
 =============
 To learn more about facets in OpenLineage, please refer to `facet 
documentation <https://openlineage.io/docs/spec/facets/>`_.
 Also check out `available facets 
<https://github.com/OpenLineage/OpenLineage/blob/main/client/python/openlineage/client/facet.py>`_
+and a blog post about `extending with facets 
<https://openlineage.io/blog/extending-with-facets/>`_.
 
 The OpenLineage spec might not contain all the facets you need to write your 
extractor,
 in which case you will have to make your own `custom facets 
<https://openlineage.io/docs/spec/facets/custom-facets>`_.
-More on creating custom facets can be found `here 
<https://openlineage.io/blog/extending-with-facets/>`_.
-
-Custom Run Facets
-=================
 
-You can inject your own custom facets in the lineage event's run facet using 
the ``custom_run_facets`` Airflow configuration.
+You can also inject your own custom facets in the lineage event's run facet 
using the ``custom_run_facets`` Airflow configuration.
 
 Steps to be taken,
 
-1. Write a function that returns the custom facet. You can write as many 
custom facet functions as needed.
+1. Write a function that returns the custom facets. You can write as many 
custom facet functions as needed.
 2. Register the functions using the ``custom_run_facets`` Airflow 
configuration.
 
-Once done, Airflow OpenLineage listener will automatically execute these 
functions during the lineage event generation
-and append their return values to the run facet in the lineage event.
+Airflow OpenLineage listener will automatically execute these functions during 
the lineage event generation and append their return values to the run facet in 
the lineage event.
 
 Writing a custom facet function
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-- **Input arguments:** The function should accept the ``TaskInstance`` as an 
input argument.
-- **Function body:** Perform the logic needed to generate the custom facet. 
The custom facet should inherit from the ``BaseFacet`` for the ``_producer`` 
and ``_schemaURL`` to be automatically added for the facet.
-- **Return value:** The custom facet to be added to the lineage event. Return 
type should be ``dict[str, dict]`` or ``None``. You may choose to return 
``None``, if you do not want to add custom facets for certain criteria.
+- **Input arguments:** The function should accept two input arguments: 
``TaskInstance`` and ``TaskInstanceState``.
+- **Function body:** Perform the logic needed to generate the custom facets. 
The custom facets must inherit from the ``RunFacet`` for the ``_producer`` and 
``_schemaURL`` to be automatically added for the facet.
+- **Return value:** The custom facets to be added to the lineage event. Return 
type should be ``dict[str, RunFacet]`` or ``None``. You may choose to return 
``None``, if you do not want to add custom facets for certain criteria.
 
 **Example custom facet function**
 
 .. code-block:: python
 
     import attrs
-    from airflow.models import TaskInstance
-    from airflow.providers.common.compat.openlineage.facet import BaseFacet
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceState
+    from airflow.providers.common.compat.openlineage.facet import RunFacet
 
 
     @attrs.define(slots=False)
-    class MyCustomRunFacet(BaseFacet):
+    class MyCustomRunFacet(RunFacet):
         """Define a custom facet."""
 
         name: str
@@ -499,24 +492,29 @@ Writing a custom facet function
         dagId: str
         taskId: str
         cluster: str
+        custom_metadata: dict
 
 
-    def get_my_custom_facet(task_instance: TaskInstance) -> dict[str, dict] | 
None:
+    def get_my_custom_facet(
+        task_instance: TaskInstance, ti_state: TaskInstanceState
+    ) -> dict[str, RunFacet] | None:
         operator_name = task_instance.task.operator_name
+        custom_metadata = {}
         if operator_name == "BashOperator":
-            return
+            return None
+        if ti_state == TaskInstanceState.FAILED:
+            custom_metadata["custom_key_failed"] = "custom_value"
         job_unique_name = 
f"TEST.{task_instance.dag_id}.{task_instance.task_id}"
         return {
-            "additional_run_facet": attrs.asdict(
-                MyCustomRunFacet(
-                    name="test-lineage-namespace",
-                    jobState=task_instance.state,
-                    uniqueName=job_unique_name,
-                    
displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
-                    dagId=task_instance.dag_id,
-                    taskId=task_instance.task_id,
-                    cluster="TEST",
-                )
+            "additional_run_facet": MyCustomRunFacet(
+                name="test-lineage-namespace",
+                jobState=task_instance.state,
+                uniqueName=job_unique_name,
+                displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
+                dagId=task_instance.dag_id,
+                taskId=task_instance.task_id,
+                cluster="TEST",
+                custom_metadata=custom_metadata,
             )
         }
 
@@ -540,9 +538,10 @@ a string of semicolon separated full import path to the 
functions.
 
 .. note::
 
-    - The custom facet functions are only executed at the start of the 
TaskInstance and added to the OpenLineage START event.
-    - Duplicate functions if registered, will be executed only once.
-    - When duplicate custom facet keys are returned by different functions, 
the last processed function will be added to the lineage event.
+    - The custom facet functions are executed both at the START and 
COMPLETE/FAIL of the TaskInstance and added to the corresponding OpenLineage 
event.
+    - When creating conditions on TaskInstance state, you should use second 
argument provided (``TaskInstanceState``) that will contain the state the task 
should be in. This may vary from ti.current_state() as the OpenLineage listener 
may get called before the TaskInstance's state is updated in Airflow database.
+    - When path to a single function is registered more than once, it will 
still be executed only once.
+    - When duplicate custom facet keys are returned by multiple functions 
registered, the result of random function result will be added to the lineage 
event. Please avoid using duplicate facet keys as it can produce unexpected 
behaviour.
 
 .. _job_hierarchy:openlineage:
 
diff --git a/tests/providers/openlineage/plugins/test_adapter.py 
b/tests/providers/openlineage/plugins/test_adapter.py
index 6ba5d9d4a3..b648bb51d3 100644
--- a/tests/providers/openlineage/plugins/test_adapter.py
+++ b/tests/providers/openlineage/plugins/test_adapter.py
@@ -354,6 +354,9 @@ def 
test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s
                 )
             },
         ),
+        run_facets={
+            "externalQuery2": 
external_query_run.ExternalQueryRunFacet(externalQueryId="999", source="source")
+        },
     )
 
     assert (
@@ -371,6 +374,9 @@ def 
test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s
                         "externalQuery": 
external_query_run.ExternalQueryRunFacet(
                             externalQueryId="123", source="source"
                         ),
+                        "externalQuery2": 
external_query_run.ExternalQueryRunFacet(
+                            externalQueryId="999", source="source"
+                        ),
                     },
                 ),
                 job=Job(
@@ -467,6 +473,9 @@ def 
test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta
             },
             job_facets={"sql": sql_job.SQLJobFacet(query="SELECT 1;")},
         ),
+        run_facets={
+            "externalQuery2": 
external_query_run.ExternalQueryRunFacet(externalQueryId="999", source="source")
+        },
         error=ValueError("Error message"),
     )
 
@@ -487,6 +496,9 @@ def 
test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta
                     "externalQuery": external_query_run.ExternalQueryRunFacet(
                         externalQueryId="123", source="source"
                     ),
+                    "externalQuery2": external_query_run.ExternalQueryRunFacet(
+                        externalQueryId="999", source="source"
+                    ),
                 },
             ),
             job=Job(
diff --git a/tests/providers/openlineage/plugins/test_listener.py 
b/tests/providers/openlineage/plugins/test_listener.py
index 2fa8216bb4..b05a934e02 100644
--- a/tests/providers/openlineage/plugins/test_listener.py
+++ b/tests/providers/openlineage/plugins/test_listener.py
@@ -215,11 +215,16 @@ def _create_listener_and_task_instance() -> 
tuple[OpenLineageListener, TaskInsta
 
 
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_custom_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
 @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
 def test_adapter_start_task_is_called_with_proper_arguments(
-    mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet, 
mock_disabled
+    mock_get_job_name,
+    mock_get_airflow_mapped_task_facet,
+    mock_get_user_provided_run_facets,
+    mock_get_airflow_run_facet,
+    mock_disabled,
 ):
     """Tests that the 'start_task' method of the OpenLineageAdapter is invoked 
with the correct arguments.
 
@@ -231,7 +236,8 @@ def test_adapter_start_task_is_called_with_proper_arguments(
     """
     listener, task_instance = _create_listener_and_task_instance()
     mock_get_job_name.return_value = "job_name"
-    mock_get_custom_facets.return_value = {"custom_facet": 2}
+    mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1}
+    mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
     mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
     mock_disabled.return_value = False
 
@@ -249,7 +255,8 @@ def test_adapter_start_task_is_called_with_proper_arguments(
         owners=["Test Owner"],
         task=listener.extractor_manager.extract_metadata(),
         run_facets={
-            "custom_facet": 2,
+            "mapped_facet": 1,
+            "custom_user_facet": 2,
             "airflow_run_facet": 3,
         },
     )
@@ -257,9 +264,12 @@ def 
test_adapter_start_task_is_called_with_proper_arguments(
 
 
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
 @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
-def test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, 
mocked_adapter, mock_disabled):
+def test_adapter_fail_task_is_called_with_proper_arguments(
+    mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter, 
mock_disabled
+):
     """Tests that the 'fail_task' method of the OpenLineageAdapter is invoked 
with the correct arguments.
 
     This test ensures that the job name is accurately retrieved and included, 
along with the generated
@@ -278,6 +288,7 @@ def 
test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mo
     mock_get_job_name.return_value = "job_name"
     mocked_adapter.build_dag_run_id.side_effect = mock_dag_id
     mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id
+    mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
     mock_disabled.return_value = False
 
     err = ValueError("test")
@@ -294,16 +305,18 @@ def 
test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mo
         parent_run_id="execution_date.dag_id",
         run_id="execution_date.dag_id.task_id.1",
         task=listener.extractor_manager.extract_metadata(),
+        run_facets={"custom_user_facet": 2},
         **expected_err_kwargs,
     )
 
 
 
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
 @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
 new=regular_call)
 def test_adapter_complete_task_is_called_with_proper_arguments(
-    mock_get_job_name, mocked_adapter, mock_disabled
+    mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter, 
mock_disabled
 ):
     """Tests that the 'complete_task' method of the OpenLineageAdapter is 
called with the correct arguments.
 
@@ -324,6 +337,7 @@ def 
test_adapter_complete_task_is_called_with_proper_arguments(
     mock_get_job_name.return_value = "job_name"
     mocked_adapter.build_dag_run_id.side_effect = mock_dag_id
     mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id
+    mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
     mock_disabled.return_value = False
 
     listener.on_task_instance_success(None, task_instance, None)
@@ -338,6 +352,7 @@ def 
test_adapter_complete_task_is_called_with_proper_arguments(
         parent_run_id="execution_date.dag_id",
         run_id=f"execution_date.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}",
         task=listener.extractor_manager.extract_metadata(),
+        run_facets={"custom_user_facet": 2},
     )
 
 
@@ -464,14 +479,14 @@ def 
test_listener_on_task_instance_success_is_called_after_try_number_increment(
 
 
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_custom_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
 @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
 def 
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator(
-    mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet, 
mock_disabled
+    mock_get_job_name, mock_get_user_provided_run_facets, 
mock_get_airflow_run_facet, mock_disabled
 ):
     listener, task_instance = _create_listener_and_task_instance()
     mock_get_job_name.return_value = "job_name"
-    mock_get_custom_facets.return_value = {"custom_facet": 2}
+    mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
     mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
     mock_disabled.return_value = True
 
@@ -485,11 +500,13 @@ def 
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_ope
 
 
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
 @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
 def 
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator(
-    mock_get_job_name, mocked_adapter, mock_disabled
+    mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter, 
mock_disabled
 ):
     listener, task_instance = _create_listener_and_task_instance()
+    mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
     mock_disabled.return_value = True
 
     on_task_failed_kwargs = {"error": ValueError("test")} if 
AIRFLOW_V_2_10_PLUS else {}
@@ -506,11 +523,13 @@ def 
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_oper
 
 
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
 
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
 @mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
 def 
test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator(
-    mock_get_job_name, mocked_adapter, mock_disabled
+    mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter, 
mock_disabled
 ):
     listener, task_instance = _create_listener_and_task_instance()
+    mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
     mock_disabled.return_value = True
 
     listener.on_task_instance_success(None, task_instance, None)
diff --git a/tests/providers/openlineage/utils/custom_facet_fixture.py 
b/tests/providers/openlineage/utils/custom_facet_fixture.py
index f2504888b4..6b9d0edcce 100644
--- a/tests/providers/openlineage/utils/custom_facet_fixture.py
+++ b/tests/providers/openlineage/utils/custom_facet_fixture.py
@@ -20,69 +20,47 @@ from typing import TYPE_CHECKING
 
 import attrs
 
-from airflow.providers.common.compat.openlineage.facet import BaseFacet
+from airflow.providers.common.compat.openlineage.facet import RunFacet
 
 if TYPE_CHECKING:
-    from airflow.models import TaskInstance
+    from airflow.models.taskinstance import TaskInstance, TaskInstanceState
 
 
 @attrs.define(slots=False)
-class MyCustomRunFacet(BaseFacet):
+class MyCustomRunFacet(RunFacet):
     """Define a custom run facet."""
 
     name: str
-    jobState: str
-    uniqueName: str
-    displayName: str
-    dagId: str
-    taskId: str
     cluster: str
 
 
-def get_additional_test_facet(task_instance: TaskInstance) -> dict[str, dict] 
| None:
-    operator_name = task_instance.task.operator_name if task_instance.task 
else None
+def get_additional_test_facet(
+    task_instance: TaskInstance, ti_state: TaskInstanceState
+) -> dict[str, RunFacet] | None:
+    operator_name = task_instance.task.operator_name if task_instance.task 
else ""
     if operator_name == "BashOperator":
         return None
-    job_unique_name = f"TEST.{task_instance.dag_id}.{task_instance.task_id}"
     return {
-        "additional_run_facet": attrs.asdict(
-            MyCustomRunFacet(
-                name="test-lineage-namespace",
-                jobState=task_instance.state,
-                uniqueName=job_unique_name,
-                displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
-                dagId=task_instance.dag_id,
-                taskId=task_instance.task_id,
-                cluster="TEST",
-            )
+        "additional_run_facet": MyCustomRunFacet(
+            name=f"test-lineage-namespace-{ti_state}",
+            cluster=f"TEST_{task_instance.dag_id}.{task_instance.task_id}",
         )
     }
 
 
-def get_duplicate_test_facet_key(task_instance: TaskInstance):
-    job_unique_name = f"TEST.{task_instance.dag_id}.{task_instance.task_id}"
-    return {
-        "additional_run_facet": attrs.asdict(
-            MyCustomRunFacet(
-                name="test-lineage-namespace",
-                jobState=task_instance.state,
-                uniqueName=job_unique_name,
-                displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
-                dagId=task_instance.dag_id,
-                taskId=task_instance.task_id,
-                cluster="TEST",
-            )
-        )
-    }
+def get_duplicate_test_facet_key(
+    task_instance: TaskInstance, ti_state: TaskInstanceState
+) -> dict[str, RunFacet] | None:
+    return get_additional_test_facet(task_instance, ti_state)
 
 
-def get_another_test_facet(task_instance: TaskInstance):
+def get_another_test_facet(task_instance, ti_state):
     return {"another_run_facet": {"name": "another-lineage-namespace"}}
 
 
-def return_type_is_not_dict(task_instance: TaskInstance):
+def return_type_is_not_dict(task_instance, ti_state):
     return "return type is not dict"
 
 
-def get_custom_facet_throws_exception(task_instance: TaskInstance):
-    raise Exception("fake exception from custom fcet function")
+def get_custom_facet_throws_exception(task_instance, ti_state):
+    raise Exception("fake exception from custom facet function")
diff --git a/tests/providers/openlineage/utils/test_utils.py 
b/tests/providers/openlineage/utils/test_utils.py
index 6f6fc104b3..d88d596d7e 100644
--- a/tests/providers/openlineage/utils/test_utils.py
+++ b/tests/providers/openlineage/utils/test_utils.py
@@ -18,14 +18,14 @@
 from __future__ import annotations
 
 import datetime
-from unittest.mock import ANY, MagicMock, patch
+from unittest.mock import MagicMock, patch
 
 from airflow import DAG
 from airflow.decorators import task
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dagrun import DagRun
 from airflow.models.mappedoperator import MappedOperator
-from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstance import TaskInstance, TaskInstanceState
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import PythonOperator
@@ -37,10 +37,10 @@ from airflow.providers.openlineage.utils.utils import (
     _safe_get_dag_tree_view,
     get_airflow_dag_run_facet,
     get_airflow_job_facet,
-    get_custom_facets,
     get_fully_qualified_class_name,
     get_job_name,
     get_operator_class,
+    get_user_provided_run_facets,
 )
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.utils.task_group import TaskGroup
@@ -541,14 +541,14 @@ def test_get_task_groups_details_no_task_groups():
 
 
 @patch("airflow.providers.openlineage.conf.custom_run_facets", 
return_value=set())
-def 
test_get_custom_facets_with_no_function_definition(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_no_function_definition(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
     assert result == {}
 
 
@@ -556,27 +556,17 @@ def 
test_get_custom_facets_with_no_function_definition(mock_custom_facet_funcs):
     "airflow.providers.openlineage.conf.custom_run_facets",
     
return_value={"tests.providers.openlineage.utils.custom_facet_fixture.get_additional_test_facet"},
 )
-def test_get_custom_facets_with_function_definition(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_function_definition(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
-    assert result == {
-        "additional_run_facet": {
-            "_producer": ANY,
-            "_schemaURL": ANY,
-            "name": "test-lineage-namespace",
-            "jobState": "running",
-            "uniqueName": "TEST.test-dag.test-task",
-            "displayName": "test-dag.test-task",
-            "dagId": "test-dag",
-            "taskId": "test-task",
-            "cluster": "TEST",
-        }
-    }
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
+    assert len(result) == 1
+    assert result["additional_run_facet"].name == 
f"test-lineage-namespace-{TaskInstanceState.RUNNING}"
+    assert result["additional_run_facet"].cluster == "TEST_test-dag.test-task"
 
 
 @patch(
@@ -585,7 +575,7 @@ def 
test_get_custom_facets_with_function_definition(mock_custom_facet_funcs):
         
"tests.providers.openlineage.utils.custom_facet_fixture.get_additional_test_facet",
     },
 )
-def test_get_custom_facets_with_return_value_as_none(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_return_value_as_none(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=BashOperator(
             task_id="test-task",
@@ -594,7 +584,7 @@ def 
test_get_custom_facets_with_return_value_as_none(mock_custom_facet_funcs):
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
     assert result == {}
 
 
@@ -607,28 +597,18 @@ def 
test_get_custom_facets_with_return_value_as_none(mock_custom_facet_funcs):
         
"tests.providers.openlineage.utils.custom_facet_fixture.get_another_test_facet",
     },
 )
-def 
test_get_custom_facets_with_multiple_function_definition(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_multiple_function_definition(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
-    assert result == {
-        "additional_run_facet": {
-            "_producer": ANY,
-            "_schemaURL": ANY,
-            "name": "test-lineage-namespace",
-            "jobState": "running",
-            "uniqueName": "TEST.test-dag.test-task",
-            "displayName": "test-dag.test-task",
-            "dagId": "test-dag",
-            "taskId": "test-task",
-            "cluster": "TEST",
-        },
-        "another_run_facet": {"name": "another-lineage-namespace"},
-    }
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
+    assert len(result) == 2
+    assert result["additional_run_facet"].name == 
f"test-lineage-namespace-{TaskInstanceState.RUNNING}"
+    assert result["additional_run_facet"].cluster == "TEST_test-dag.test-task"
+    assert result["another_run_facet"] == {"name": "another-lineage-namespace"}
 
 
 @patch(
@@ -638,41 +618,31 @@ def 
test_get_custom_facets_with_multiple_function_definition(mock_custom_facet_f
         
"tests.providers.openlineage.utils.custom_facet_fixture.get_duplicate_test_facet_key",
     },
 )
-def test_get_custom_facets_with_duplicate_facet_keys(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_duplicate_facet_keys(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
-    assert result == {
-        "additional_run_facet": {
-            "_producer": ANY,
-            "_schemaURL": ANY,
-            "name": "test-lineage-namespace",
-            "jobState": "running",
-            "uniqueName": "TEST.test-dag.test-task",
-            "displayName": "test-dag.test-task",
-            "dagId": "test-dag",
-            "taskId": "test-task",
-            "cluster": "TEST",
-        }
-    }
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
+    assert len(result) == 1
+    assert result["additional_run_facet"].name == 
f"test-lineage-namespace-{TaskInstanceState.RUNNING}"
+    assert result["additional_run_facet"].cluster == "TEST_test-dag.test-task"
 
 
 @patch(
     "airflow.providers.openlineage.conf.custom_run_facets",
     return_value={"invalid_function"},
 )
-def 
test_get_custom_facets_with_invalid_function_definition(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_invalid_function_definition(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
     assert result == {}
 
 
@@ -680,14 +650,14 @@ def 
test_get_custom_facets_with_invalid_function_definition(mock_custom_facet_fu
     "airflow.providers.openlineage.conf.custom_run_facets",
     
return_value={"tests.providers.openlineage.utils.custom_facet_fixture.return_type_is_not_dict"},
 )
-def 
test_get_custom_facets_with_wrong_return_type_function(mock_custom_facet_funcs):
+def 
test_get_user_provided_run_facets_with_wrong_return_type_function(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
     assert result == {}
 
 
@@ -695,12 +665,12 @@ def 
test_get_custom_facets_with_wrong_return_type_function(mock_custom_facet_fun
     "airflow.providers.openlineage.conf.custom_run_facets",
     
return_value={"tests.providers.openlineage.utils.custom_facet_fixture.get_custom_facet_throws_exception"},
 )
-def test_get_custom_facets_with_exception(mock_custom_facet_funcs):
+def test_get_user_provided_run_facets_with_exception(mock_custom_facet_funcs):
     sample_ti = TaskInstance(
         task=EmptyOperator(
             task_id="test-task", dag=DAG("test-dag", 
start_date=datetime.datetime(2024, 7, 1))
         ),
         state="running",
     )
-    result = get_custom_facets(sample_ti)
+    result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
     assert result == {}


Reply via email to