This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 579a8b87fc openlineage: extend custom_run_facets to also be executed
on complete and fail (#40953)
579a8b87fc is described below
commit 579a8b87fc3d4a737bae11049c0607aaf2a8b8fb
Author: Kacper Muda <[email protected]>
AuthorDate: Tue Jul 23 17:44:05 2024 +0200
openlineage: extend custom_run_facets to also be executed on complete and
fail (#40953)
Signed-off-by: Kacper Muda <[email protected]>
---
airflow/providers/openlineage/plugins/adapter.py | 24 ++++--
airflow/providers/openlineage/plugins/listener.py | 10 ++-
airflow/providers/openlineage/utils/utils.py | 29 ++++---
.../guides/developer.rst | 81 ++++++++++----------
.../providers/openlineage/plugins/test_adapter.py | 12 +++
.../providers/openlineage/plugins/test_listener.py | 41 +++++++---
.../openlineage/utils/custom_facet_fixture.py | 58 +++++---------
tests/providers/openlineage/utils/test_utils.py | 88 +++++++---------------
8 files changed, 169 insertions(+), 174 deletions(-)
diff --git a/airflow/providers/openlineage/plugins/adapter.py
b/airflow/providers/openlineage/plugins/adapter.py
index 1d0317228b..7405556088 100644
--- a/airflow/providers/openlineage/plugins/adapter.py
+++ b/airflow/providers/openlineage/plugins/adapter.py
@@ -243,6 +243,7 @@ class OpenLineageAdapter(LoggingMixin):
parent_run_id: str | None,
end_time: str,
task: OperatorLineage,
+ run_facets: dict[str, RunFacet] | None = None, # Custom run facets
) -> RunEvent:
"""
Emit openlineage event of type COMPLETE.
@@ -254,7 +255,11 @@ class OpenLineageAdapter(LoggingMixin):
:param parent_run_id: identifier of job spawning this task
:param end_time: time of task completion
:param task: metadata container with information extracted from
operator
+ :param run_facets: custom run facets
"""
+ run_facets = run_facets or {}
+ if task:
+ run_facets = {**task.run_facets, **run_facets}
event = RunEvent(
eventType=RunState.COMPLETE,
eventTime=end_time,
@@ -263,7 +268,7 @@ class OpenLineageAdapter(LoggingMixin):
job_name=job_name,
parent_job_name=parent_job_name,
parent_run_id=parent_run_id,
- run_facets=task.run_facets,
+ run_facets=run_facets,
),
job=self._build_job(job_name, job_type=_JOB_TYPE_TASK,
job_facets=task.job_facets),
inputs=task.inputs,
@@ -280,6 +285,7 @@ class OpenLineageAdapter(LoggingMixin):
parent_run_id: str | None,
end_time: str,
task: OperatorLineage,
+ run_facets: dict[str, RunFacet] | None = None, # Custom run facets
error: str | BaseException | None = None,
) -> RunEvent:
"""
@@ -292,20 +298,22 @@ class OpenLineageAdapter(LoggingMixin):
:param parent_run_id: identifier of job spawning this task
:param end_time: time of task completion
:param task: metadata container with information extracted from
operator
+ :param run_facets: custom run facets
:param error: error
"""
- error_facet = {}
+ run_facets = run_facets or {}
+ if task:
+ run_facets = {**task.run_facets, **run_facets}
+
if error:
stack_trace = None
if isinstance(error, BaseException) and error.__traceback__:
import traceback
stack_trace =
"\\n".join(traceback.format_exception(type(error), error, error.__traceback__))
- error_facet = {
- "errorMessage": error_message_run.ErrorMessageRunFacet(
- message=str(error), programmingLanguage="python",
stackTrace=stack_trace
- )
- }
+ run_facets["errorMessage"] =
error_message_run.ErrorMessageRunFacet(
+ message=str(error), programmingLanguage="python",
stackTrace=stack_trace
+ )
event = RunEvent(
eventType=RunState.FAIL,
@@ -315,7 +323,7 @@ class OpenLineageAdapter(LoggingMixin):
job_name=job_name,
parent_job_name=parent_job_name,
parent_run_id=parent_run_id,
- run_facets={**task.run_facets, **error_facet},
+ run_facets=run_facets,
),
job=self._build_job(job_name, job_type=_JOB_TYPE_TASK,
job_facets=task.job_facets),
inputs=task.inputs,
diff --git a/airflow/providers/openlineage/plugins/listener.py
b/airflow/providers/openlineage/plugins/listener.py
index a552cb283b..58ccdcad24 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -33,9 +33,10 @@ from airflow.providers.openlineage.extractors import
ExtractorManager
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter,
RunState
from airflow.providers.openlineage.utils.utils import (
get_airflow_job_facet,
+ get_airflow_mapped_task_facet,
get_airflow_run_facet,
- get_custom_facets,
get_job_name,
+ get_user_provided_run_facets,
is_operator_disabled,
is_selective_lineage_enabled,
print_warning,
@@ -43,13 +44,13 @@ from airflow.providers.openlineage.utils.utils import (
from airflow.settings import configure_orm
from airflow.stats import Stats
from airflow.utils import timezone
+from airflow.utils.state import TaskInstanceState
from airflow.utils.timeout import timeout
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models import DagRun, TaskInstance
- from airflow.utils.state import TaskInstanceState
_openlineage_listener: OpenLineageListener | None = None
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >=
Version("2.10.0")
@@ -163,7 +164,8 @@ class OpenLineageListener:
owners=dag.owner.split(", "),
task=task_metadata,
run_facets={
- **get_custom_facets(task_instance),
+ **get_user_provided_run_facets(task_instance,
TaskInstanceState.RUNNING),
+ **get_airflow_mapped_task_facet(task_instance),
**get_airflow_run_facet(dagrun, dag, task_instance, task,
task_uuid),
},
)
@@ -233,6 +235,7 @@ class OpenLineageListener:
parent_run_id=parent_run_id,
end_time=end_date.isoformat(),
task=task_metadata,
+ run_facets=get_user_provided_run_facets(task_instance,
TaskInstanceState.SUCCESS),
)
Stats.gauge(
f"ol.event.size.{event_type}.{operator_name}",
@@ -327,6 +330,7 @@ class OpenLineageListener:
parent_run_id=parent_run_id,
end_time=end_date.isoformat(),
task=task_metadata,
+ run_facets=get_user_provided_run_facets(task_instance,
TaskInstanceState.FAILED),
error=error,
)
Stats.gauge(
diff --git a/airflow/providers/openlineage/utils/utils.py
b/airflow/providers/openlineage/utils/utils.py
index 195d14e4e7..2e995fb7a8 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -55,10 +55,11 @@ from airflow.utils.log.secrets_masker import Redactable,
Redacted, SecretsMasker
from airflow.utils.module_loading import import_string
if TYPE_CHECKING:
+ from openlineage.client.event_v2 import Dataset as OpenLineageDataset
from openlineage.client.facet_v2 import RunFacet
- from openlineage.client.run import Dataset as OpenLineageDataset
from airflow.models import DagRun, TaskInstance
+ from airflow.utils.state import TaskInstanceState
log = logging.getLogger(__name__)
@@ -81,28 +82,32 @@ def get_job_name(task: TaskInstance) -> str:
return f"{task.dag_id}.{task.task_id}"
-def get_custom_facets(task_instance: TaskInstance | None = None) -> dict[str,
Any]:
- from airflow.providers.openlineage.extractors.manager import
try_import_from_string
-
- custom_facets = {}
+def get_airflow_mapped_task_facet(task_instance: TaskInstance) -> dict[str,
Any]:
# check for -1 comes from SmartSensor compatibility with dynamic task
mapping
# this comes from Airflow code
if hasattr(task_instance, "map_index") and getattr(task_instance,
"map_index") != -1:
- custom_facets["airflow_mappedTask"] =
AirflowMappedTaskRunFacet.from_task_instance(task_instance)
+ return {"airflow_mappedTask":
AirflowMappedTaskRunFacet.from_task_instance(task_instance)}
+ return {}
+
+
+def get_user_provided_run_facets(ti: TaskInstance, ti_state:
TaskInstanceState) -> dict[str, RunFacet]:
+ custom_facets = {}
# Append custom run facets by executing the custom_run_facet functions.
for custom_facet_func in conf.custom_run_facets():
try:
- func: Callable[[Any], dict] | None =
try_import_from_string(custom_facet_func)
+ func: Callable[[TaskInstance, TaskInstanceState], dict[str,
RunFacet]] | None = (
+ try_import_from_string(custom_facet_func)
+ )
if not func:
log.warning(
"OpenLineage is unable to import custom facet function
`%s`; will ignore it.",
custom_facet_func,
)
continue
- facet: dict[str, dict[Any, Any]] | None = func(task_instance)
- if facet and isinstance(facet, dict):
- duplicate_facet_keys = [facet_key for facet_key in
facet.keys() if facet_key in custom_facets]
+ facets: dict[str, RunFacet] | None = func(ti, ti_state)
+ if facets and isinstance(facets, dict):
+ duplicate_facet_keys = [facet_key for facet_key in facets if
facet_key in custom_facets]
if duplicate_facet_keys:
log.warning(
"Duplicate OpenLineage custom facets key(s) found:
`%s` from function `%s`; "
@@ -112,10 +117,10 @@ def get_custom_facets(task_instance: TaskInstance | None
= None) -> dict[str, An
)
log.debug(
"Adding OpenLineage custom facet with key(s): `%s` from
function `%s`.",
- tuple(facet),
+ tuple(facets),
custom_facet_func,
)
- custom_facets.update(facet)
+ custom_facets.update(facets)
except Exception as exc:
log.warning(
"Error processing custom facet function `%s`; will ignore it.
Error was: %s: %s",
diff --git a/docs/apache-airflow-providers-openlineage/guides/developer.rst
b/docs/apache-airflow-providers-openlineage/guides/developer.rst
index 8d66780190..4e9ada44c2 100644
--- a/docs/apache-airflow-providers-openlineage/guides/developer.rst
+++ b/docs/apache-airflow-providers-openlineage/guides/developer.rst
@@ -85,7 +85,7 @@ Instead of returning complete OpenLineage event, the provider
defines ``Operator
class OperatorLineage:
inputs: list[Dataset] = Factory(list)
outputs: list[Dataset] = Factory(list)
- run_facets: dict[str, BaseFacet] = Factory(dict)
+ run_facets: dict[str, RunFacet] = Factory(dict)
job_facets: dict[str, BaseFacet] = Factory(dict)
OpenLineage integration itself takes care to enrich it with things like
general Airflow facets, proper event time and type, creating proper OpenLineage
RunEvent.
@@ -214,11 +214,11 @@ Both methods return ``OperatorLineage`` structure:
inputs: list[Dataset] = Factory(list)
outputs: list[Dataset] = Factory(list)
- run_facets: dict[str, BaseFacet] = Factory(dict)
+ run_facets: dict[str, RunFacet] = Factory(dict)
job_facets: dict[str, BaseFacet] = Factory(dict)
-Inputs and outputs are lists of plain OpenLineage datasets
(`openlineage.client.run.Dataset`).
+Inputs and outputs are lists of plain OpenLineage datasets
(`openlineage.client.event_v2.Dataset`).
``run_facets`` and ``job_facets`` are dictionaries of optional RunFacets and
JobFacets that would be attached to the job - for example,
you might want to attach ``SqlJobFacet`` if your Operator is executing SQL.
@@ -303,23 +303,20 @@ like extracting column level lineage and inputs/outputs
from SQL query with SQL
.. code-block:: python
+ from airflow.models.baseoperator import BaseOperator
+ from airflow.providers.openlineage.extractors.base import BaseExtractor,
OperatorLineage
from airflow.providers.common.compat.openlineage.facet import (
- BaseFacet,
Dataset,
ExternalQueryRunFacet,
SQLJobFacet,
)
- from airflow.models.baseoperator import BaseOperator
- from airflow.providers.openlineage.extractors.base import BaseExtractor
-
class ExampleOperator(BaseOperator):
def __init__(self, query, bq_table_reference, s3_path) -> None:
self.bq_table_reference = bq_table_reference
self.s3_path = s3_path
self.s3_file_name = s3_file_name
- self.query = query
self._job_id = None
def execute(self, context) -> Any:
@@ -334,8 +331,8 @@ like extracting column level lineage and inputs/outputs
from SQL query with SQL
def _execute_extraction(self) -> OperatorLineage:
"""Define what we know before Operator's extract is called."""
return OperatorLineage(
- inputs=[Dataset(namespace="bigquery",
name=self.bq_table_reference)],
- outputs=[Dataset(namespace=self.s3_path,
name=self.s3_file_name)],
+ inputs=[Dataset(namespace="bigquery",
name=self.operator.bq_table_reference)],
+ outputs=[Dataset(namespace=self.operator.s3_path,
name=self.operator.s3_file_name)],
job_facets={
"sql": SQLJobFacet(
query="EXPORT INTO ... OPTIONS(FORMAT=csv, SEP=';'
...) AS SELECT * FROM ... "
@@ -343,11 +340,11 @@ like extracting column level lineage and inputs/outputs
from SQL query with SQL
},
)
- def extract_on_complete(self) -> OperatorLineage:
+ def extract_on_complete(self, task_instance) -> OperatorLineage:
"""Add what we received after Operator's extract call."""
lineage_metadata = self.extract()
lineage_metadata.run_facets = {
- "parent": ExternalQueryRunFacet(externalQueryId=self._job_id,
source="bigquery")
+ "parent":
ExternalQueryRunFacet(externalQueryId=task_instance.task._job_id,
source="bigquery")
}
return lineage_metadata
@@ -454,42 +451,38 @@ Custom Facets
=============
To learn more about facets in OpenLineage, please refer to `facet
documentation <https://openlineage.io/docs/spec/facets/>`_.
Also check out `available facets
<https://github.com/OpenLineage/OpenLineage/blob/main/client/python/openlineage/client/facet.py>`_
+and a blog post about `extending with facets
<https://openlineage.io/blog/extending-with-facets/>`_.
The OpenLineage spec might not contain all the facets you need to write your
extractor,
in which case you will have to make your own `custom facets
<https://openlineage.io/docs/spec/facets/custom-facets>`_.
-More on creating custom facets can be found `here
<https://openlineage.io/blog/extending-with-facets/>`_.
-
-Custom Run Facets
-=================
-You can inject your own custom facets in the lineage event's run facet using
the ``custom_run_facets`` Airflow configuration.
+You can also inject your own custom facets in the lineage event's run facet
using the ``custom_run_facets`` Airflow configuration.
Steps to be taken,
-1. Write a function that returns the custom facet. You can write as many
custom facet functions as needed.
+1. Write a function that returns the custom facets. You can write as many
custom facet functions as needed.
2. Register the functions using the ``custom_run_facets`` Airflow
configuration.
-Once done, Airflow OpenLineage listener will automatically execute these
functions during the lineage event generation
-and append their return values to the run facet in the lineage event.
+Airflow OpenLineage listener will automatically execute these functions during
the lineage event generation and append their return values to the run facet in
the lineage event.
Writing a custom facet function
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-- **Input arguments:** The function should accept the ``TaskInstance`` as an
input argument.
-- **Function body:** Perform the logic needed to generate the custom facet.
The custom facet should inherit from the ``BaseFacet`` for the ``_producer``
and ``_schemaURL`` to be automatically added for the facet.
-- **Return value:** The custom facet to be added to the lineage event. Return
type should be ``dict[str, dict]`` or ``None``. You may choose to return
``None``, if you do not want to add custom facets for certain criteria.
+- **Input arguments:** The function should accept two input arguments:
``TaskInstance`` and ``TaskInstanceState``.
+- **Function body:** Perform the logic needed to generate the custom facets.
The custom facets must inherit from the ``RunFacet`` for the ``_producer`` and
``_schemaURL`` to be automatically added for the facet.
+- **Return value:** The custom facets to be added to the lineage event. Return
type should be ``dict[str, RunFacet]`` or ``None``. You may choose to return
``None``, if you do not want to add custom facets for certain criteria.
**Example custom facet function**
.. code-block:: python
import attrs
- from airflow.models import TaskInstance
- from airflow.providers.common.compat.openlineage.facet import BaseFacet
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceState
+ from airflow.providers.common.compat.openlineage.facet import RunFacet
@attrs.define(slots=False)
- class MyCustomRunFacet(BaseFacet):
+ class MyCustomRunFacet(RunFacet):
"""Define a custom facet."""
name: str
@@ -499,24 +492,29 @@ Writing a custom facet function
dagId: str
taskId: str
cluster: str
+ custom_metadata: dict
- def get_my_custom_facet(task_instance: TaskInstance) -> dict[str, dict] |
None:
+ def get_my_custom_facet(
+ task_instance: TaskInstance, ti_state: TaskInstanceState
+ ) -> dict[str, RunFacet] | None:
operator_name = task_instance.task.operator_name
+ custom_metadata = {}
if operator_name == "BashOperator":
- return
+ return None
+ if ti_state == TaskInstanceState.FAILED:
+ custom_metadata["custom_key_failed"] = "custom_value"
job_unique_name =
f"TEST.{task_instance.dag_id}.{task_instance.task_id}"
return {
- "additional_run_facet": attrs.asdict(
- MyCustomRunFacet(
- name="test-lineage-namespace",
- jobState=task_instance.state,
- uniqueName=job_unique_name,
-
displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
- dagId=task_instance.dag_id,
- taskId=task_instance.task_id,
- cluster="TEST",
- )
+ "additional_run_facet": MyCustomRunFacet(
+ name="test-lineage-namespace",
+ jobState=task_instance.state,
+ uniqueName=job_unique_name,
+ displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
+ dagId=task_instance.dag_id,
+ taskId=task_instance.task_id,
+ cluster="TEST",
+ custom_metadata=custom_metadata,
)
}
@@ -540,9 +538,10 @@ a string of semicolon separated full import path to the
functions.
.. note::
- - The custom facet functions are only executed at the start of the
TaskInstance and added to the OpenLineage START event.
- - Duplicate functions if registered, will be executed only once.
- - When duplicate custom facet keys are returned by different functions,
the last processed function will be added to the lineage event.
+ - The custom facet functions are executed both at the START and
COMPLETE/FAIL of the TaskInstance and added to the corresponding OpenLineage
event.
+ - When creating conditions on TaskInstance state, you should use second
argument provided (``TaskInstanceState``) that will contain the state the task
should be in. This may vary from ti.current_state() as the OpenLineage listener
may get called before the TaskInstance's state is updated in Airflow database.
+ - When path to a single function is registered more than once, it will
still be executed only once.
+ - When duplicate custom facet keys are returned by multiple functions
registered, the result of random function result will be added to the lineage
event. Please avoid using duplicate facet keys as it can produce unexpected
behaviour.
.. _job_hierarchy:openlineage:
diff --git a/tests/providers/openlineage/plugins/test_adapter.py
b/tests/providers/openlineage/plugins/test_adapter.py
index 6ba5d9d4a3..b648bb51d3 100644
--- a/tests/providers/openlineage/plugins/test_adapter.py
+++ b/tests/providers/openlineage/plugins/test_adapter.py
@@ -354,6 +354,9 @@ def
test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s
)
},
),
+ run_facets={
+ "externalQuery2":
external_query_run.ExternalQueryRunFacet(externalQueryId="999", source="source")
+ },
)
assert (
@@ -371,6 +374,9 @@ def
test_emit_complete_event_with_additional_information(mock_stats_incr, mock_s
"externalQuery":
external_query_run.ExternalQueryRunFacet(
externalQueryId="123", source="source"
),
+ "externalQuery2":
external_query_run.ExternalQueryRunFacet(
+ externalQueryId="999", source="source"
+ ),
},
),
job=Job(
@@ -467,6 +473,9 @@ def
test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta
},
job_facets={"sql": sql_job.SQLJobFacet(query="SELECT 1;")},
),
+ run_facets={
+ "externalQuery2":
external_query_run.ExternalQueryRunFacet(externalQueryId="999", source="source")
+ },
error=ValueError("Error message"),
)
@@ -487,6 +496,9 @@ def
test_emit_failed_event_with_additional_information(mock_stats_incr, mock_sta
"externalQuery": external_query_run.ExternalQueryRunFacet(
externalQueryId="123", source="source"
),
+ "externalQuery2": external_query_run.ExternalQueryRunFacet(
+ externalQueryId="999", source="source"
+ ),
},
),
job=Job(
diff --git a/tests/providers/openlineage/plugins/test_listener.py
b/tests/providers/openlineage/plugins/test_listener.py
index 2fa8216bb4..b05a934e02 100644
--- a/tests/providers/openlineage/plugins/test_listener.py
+++ b/tests/providers/openlineage/plugins/test_listener.py
@@ -215,11 +215,16 @@ def _create_listener_and_task_instance() ->
tuple[OpenLineageListener, TaskInsta
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_custom_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_airflow_mapped_task_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
new=regular_call)
def test_adapter_start_task_is_called_with_proper_arguments(
- mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet,
mock_disabled
+ mock_get_job_name,
+ mock_get_airflow_mapped_task_facet,
+ mock_get_user_provided_run_facets,
+ mock_get_airflow_run_facet,
+ mock_disabled,
):
"""Tests that the 'start_task' method of the OpenLineageAdapter is invoked
with the correct arguments.
@@ -231,7 +236,8 @@ def test_adapter_start_task_is_called_with_proper_arguments(
"""
listener, task_instance = _create_listener_and_task_instance()
mock_get_job_name.return_value = "job_name"
- mock_get_custom_facets.return_value = {"custom_facet": 2}
+ mock_get_airflow_mapped_task_facet.return_value = {"mapped_facet": 1}
+ mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
mock_disabled.return_value = False
@@ -249,7 +255,8 @@ def test_adapter_start_task_is_called_with_proper_arguments(
owners=["Test Owner"],
task=listener.extractor_manager.extract_metadata(),
run_facets={
- "custom_facet": 2,
+ "mapped_facet": 1,
+ "custom_user_facet": 2,
"airflow_run_facet": 3,
},
)
@@ -257,9 +264,12 @@ def
test_adapter_start_task_is_called_with_proper_arguments(
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
new=regular_call)
-def test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name,
mocked_adapter, mock_disabled):
+def test_adapter_fail_task_is_called_with_proper_arguments(
+ mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter,
mock_disabled
+):
"""Tests that the 'fail_task' method of the OpenLineageAdapter is invoked
with the correct arguments.
This test ensures that the job name is accurately retrieved and included,
along with the generated
@@ -278,6 +288,7 @@ def
test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mo
mock_get_job_name.return_value = "job_name"
mocked_adapter.build_dag_run_id.side_effect = mock_dag_id
mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id
+ mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
mock_disabled.return_value = False
err = ValueError("test")
@@ -294,16 +305,18 @@ def
test_adapter_fail_task_is_called_with_proper_arguments(mock_get_job_name, mo
parent_run_id="execution_date.dag_id",
run_id="execution_date.dag_id.task_id.1",
task=listener.extractor_manager.extract_metadata(),
+ run_facets={"custom_user_facet": 2},
**expected_err_kwargs,
)
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageListener._execute",
new=regular_call)
def test_adapter_complete_task_is_called_with_proper_arguments(
- mock_get_job_name, mocked_adapter, mock_disabled
+ mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter,
mock_disabled
):
"""Tests that the 'complete_task' method of the OpenLineageAdapter is
called with the correct arguments.
@@ -324,6 +337,7 @@ def
test_adapter_complete_task_is_called_with_proper_arguments(
mock_get_job_name.return_value = "job_name"
mocked_adapter.build_dag_run_id.side_effect = mock_dag_id
mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id
+ mock_get_user_provided_run_facets.return_value = {"custom_user_facet": 2}
mock_disabled.return_value = False
listener.on_task_instance_success(None, task_instance, None)
@@ -338,6 +352,7 @@ def
test_adapter_complete_task_is_called_with_proper_arguments(
parent_run_id="execution_date.dag_id",
run_id=f"execution_date.dag_id.task_id.{EXPECTED_TRY_NUMBER_1}",
task=listener.extractor_manager.extract_metadata(),
+ run_facets={"custom_user_facet": 2},
)
@@ -464,14 +479,14 @@ def
test_listener_on_task_instance_success_is_called_after_try_number_increment(
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_airflow_run_facet")
[email protected]("airflow.providers.openlineage.plugins.listener.get_custom_facets")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
def
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_operator(
- mock_get_job_name, mock_get_custom_facets, mock_get_airflow_run_facet,
mock_disabled
+ mock_get_job_name, mock_get_user_provided_run_facets,
mock_get_airflow_run_facet, mock_disabled
):
listener, task_instance = _create_listener_and_task_instance()
mock_get_job_name.return_value = "job_name"
- mock_get_custom_facets.return_value = {"custom_facet": 2}
+ mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
mock_get_airflow_run_facet.return_value = {"airflow_run_facet": 3}
mock_disabled.return_value = True
@@ -485,11 +500,13 @@ def
test_listener_on_task_instance_running_do_not_call_adapter_when_disabled_ope
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
def
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_operator(
- mock_get_job_name, mocked_adapter, mock_disabled
+ mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter,
mock_disabled
):
listener, task_instance = _create_listener_and_task_instance()
+ mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
mock_disabled.return_value = True
on_task_failed_kwargs = {"error": ValueError("test")} if
AIRFLOW_V_2_10_PLUS else {}
@@ -506,11 +523,13 @@ def
test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_oper
@mock.patch("airflow.providers.openlineage.plugins.listener.is_operator_disabled")
@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
[email protected]("airflow.providers.openlineage.plugins.listener.get_user_provided_run_facets")
@mock.patch("airflow.providers.openlineage.plugins.listener.get_job_name")
def
test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_operator(
- mock_get_job_name, mocked_adapter, mock_disabled
+ mock_get_job_name, mock_get_user_provided_run_facets, mocked_adapter,
mock_disabled
):
listener, task_instance = _create_listener_and_task_instance()
+ mock_get_user_provided_run_facets.return_value = {"custom_facet": 2}
mock_disabled.return_value = True
listener.on_task_instance_success(None, task_instance, None)
diff --git a/tests/providers/openlineage/utils/custom_facet_fixture.py
b/tests/providers/openlineage/utils/custom_facet_fixture.py
index f2504888b4..6b9d0edcce 100644
--- a/tests/providers/openlineage/utils/custom_facet_fixture.py
+++ b/tests/providers/openlineage/utils/custom_facet_fixture.py
@@ -20,69 +20,47 @@ from typing import TYPE_CHECKING
import attrs
-from airflow.providers.common.compat.openlineage.facet import BaseFacet
+from airflow.providers.common.compat.openlineage.facet import RunFacet
if TYPE_CHECKING:
- from airflow.models import TaskInstance
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceState
@attrs.define(slots=False)
-class MyCustomRunFacet(BaseFacet):
+class MyCustomRunFacet(RunFacet):
"""Define a custom run facet."""
name: str
- jobState: str
- uniqueName: str
- displayName: str
- dagId: str
- taskId: str
cluster: str
-def get_additional_test_facet(task_instance: TaskInstance) -> dict[str, dict]
| None:
- operator_name = task_instance.task.operator_name if task_instance.task
else None
+def get_additional_test_facet(
+ task_instance: TaskInstance, ti_state: TaskInstanceState
+) -> dict[str, RunFacet] | None:
+ operator_name = task_instance.task.operator_name if task_instance.task
else ""
if operator_name == "BashOperator":
return None
- job_unique_name = f"TEST.{task_instance.dag_id}.{task_instance.task_id}"
return {
- "additional_run_facet": attrs.asdict(
- MyCustomRunFacet(
- name="test-lineage-namespace",
- jobState=task_instance.state,
- uniqueName=job_unique_name,
- displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
- dagId=task_instance.dag_id,
- taskId=task_instance.task_id,
- cluster="TEST",
- )
+ "additional_run_facet": MyCustomRunFacet(
+ name=f"test-lineage-namespace-{ti_state}",
+ cluster=f"TEST_{task_instance.dag_id}.{task_instance.task_id}",
)
}
-def get_duplicate_test_facet_key(task_instance: TaskInstance):
- job_unique_name = f"TEST.{task_instance.dag_id}.{task_instance.task_id}"
- return {
- "additional_run_facet": attrs.asdict(
- MyCustomRunFacet(
- name="test-lineage-namespace",
- jobState=task_instance.state,
- uniqueName=job_unique_name,
- displayName=f"{task_instance.dag_id}.{task_instance.task_id}",
- dagId=task_instance.dag_id,
- taskId=task_instance.task_id,
- cluster="TEST",
- )
- )
- }
+def get_duplicate_test_facet_key(
+ task_instance: TaskInstance, ti_state: TaskInstanceState
+) -> dict[str, RunFacet] | None:
+ return get_additional_test_facet(task_instance, ti_state)
-def get_another_test_facet(task_instance: TaskInstance):
+def get_another_test_facet(task_instance, ti_state):
return {"another_run_facet": {"name": "another-lineage-namespace"}}
-def return_type_is_not_dict(task_instance: TaskInstance):
+def return_type_is_not_dict(task_instance, ti_state):
return "return type is not dict"
-def get_custom_facet_throws_exception(task_instance: TaskInstance):
- raise Exception("fake exception from custom fcet function")
+def get_custom_facet_throws_exception(task_instance, ti_state):
+ raise Exception("fake exception from custom facet function")
diff --git a/tests/providers/openlineage/utils/test_utils.py
b/tests/providers/openlineage/utils/test_utils.py
index 6f6fc104b3..d88d596d7e 100644
--- a/tests/providers/openlineage/utils/test_utils.py
+++ b/tests/providers/openlineage/utils/test_utils.py
@@ -18,14 +18,14 @@
from __future__ import annotations
import datetime
-from unittest.mock import ANY, MagicMock, patch
+from unittest.mock import MagicMock, patch
from airflow import DAG
from airflow.decorators import task
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
from airflow.models.mappedoperator import MappedOperator
-from airflow.models.taskinstance import TaskInstance
+from airflow.models.taskinstance import TaskInstance, TaskInstanceState
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
@@ -37,10 +37,10 @@ from airflow.providers.openlineage.utils.utils import (
_safe_get_dag_tree_view,
get_airflow_dag_run_facet,
get_airflow_job_facet,
- get_custom_facets,
get_fully_qualified_class_name,
get_job_name,
get_operator_class,
+ get_user_provided_run_facets,
)
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.task_group import TaskGroup
@@ -541,14 +541,14 @@ def test_get_task_groups_details_no_task_groups():
@patch("airflow.providers.openlineage.conf.custom_run_facets",
return_value=set())
-def
test_get_custom_facets_with_no_function_definition(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_no_function_definition(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
assert result == {}
@@ -556,27 +556,17 @@ def
test_get_custom_facets_with_no_function_definition(mock_custom_facet_funcs):
"airflow.providers.openlineage.conf.custom_run_facets",
return_value={"tests.providers.openlineage.utils.custom_facet_fixture.get_additional_test_facet"},
)
-def test_get_custom_facets_with_function_definition(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_function_definition(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
- assert result == {
- "additional_run_facet": {
- "_producer": ANY,
- "_schemaURL": ANY,
- "name": "test-lineage-namespace",
- "jobState": "running",
- "uniqueName": "TEST.test-dag.test-task",
- "displayName": "test-dag.test-task",
- "dagId": "test-dag",
- "taskId": "test-task",
- "cluster": "TEST",
- }
- }
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
+ assert len(result) == 1
+ assert result["additional_run_facet"].name ==
f"test-lineage-namespace-{TaskInstanceState.RUNNING}"
+ assert result["additional_run_facet"].cluster == "TEST_test-dag.test-task"
@patch(
@@ -585,7 +575,7 @@ def
test_get_custom_facets_with_function_definition(mock_custom_facet_funcs):
"tests.providers.openlineage.utils.custom_facet_fixture.get_additional_test_facet",
},
)
-def test_get_custom_facets_with_return_value_as_none(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_return_value_as_none(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=BashOperator(
task_id="test-task",
@@ -594,7 +584,7 @@ def
test_get_custom_facets_with_return_value_as_none(mock_custom_facet_funcs):
),
state="running",
)
- result = get_custom_facets(sample_ti)
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
assert result == {}
@@ -607,28 +597,18 @@ def
test_get_custom_facets_with_return_value_as_none(mock_custom_facet_funcs):
"tests.providers.openlineage.utils.custom_facet_fixture.get_another_test_facet",
},
)
-def
test_get_custom_facets_with_multiple_function_definition(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_multiple_function_definition(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
- assert result == {
- "additional_run_facet": {
- "_producer": ANY,
- "_schemaURL": ANY,
- "name": "test-lineage-namespace",
- "jobState": "running",
- "uniqueName": "TEST.test-dag.test-task",
- "displayName": "test-dag.test-task",
- "dagId": "test-dag",
- "taskId": "test-task",
- "cluster": "TEST",
- },
- "another_run_facet": {"name": "another-lineage-namespace"},
- }
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
+ assert len(result) == 2
+ assert result["additional_run_facet"].name ==
f"test-lineage-namespace-{TaskInstanceState.RUNNING}"
+ assert result["additional_run_facet"].cluster == "TEST_test-dag.test-task"
+ assert result["another_run_facet"] == {"name": "another-lineage-namespace"}
@patch(
@@ -638,41 +618,31 @@ def
test_get_custom_facets_with_multiple_function_definition(mock_custom_facet_f
"tests.providers.openlineage.utils.custom_facet_fixture.get_duplicate_test_facet_key",
},
)
-def test_get_custom_facets_with_duplicate_facet_keys(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_duplicate_facet_keys(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
- assert result == {
- "additional_run_facet": {
- "_producer": ANY,
- "_schemaURL": ANY,
- "name": "test-lineage-namespace",
- "jobState": "running",
- "uniqueName": "TEST.test-dag.test-task",
- "displayName": "test-dag.test-task",
- "dagId": "test-dag",
- "taskId": "test-task",
- "cluster": "TEST",
- }
- }
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
+ assert len(result) == 1
+ assert result["additional_run_facet"].name ==
f"test-lineage-namespace-{TaskInstanceState.RUNNING}"
+ assert result["additional_run_facet"].cluster == "TEST_test-dag.test-task"
@patch(
"airflow.providers.openlineage.conf.custom_run_facets",
return_value={"invalid_function"},
)
-def
test_get_custom_facets_with_invalid_function_definition(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_invalid_function_definition(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
assert result == {}
@@ -680,14 +650,14 @@ def
test_get_custom_facets_with_invalid_function_definition(mock_custom_facet_fu
"airflow.providers.openlineage.conf.custom_run_facets",
return_value={"tests.providers.openlineage.utils.custom_facet_fixture.return_type_is_not_dict"},
)
-def
test_get_custom_facets_with_wrong_return_type_function(mock_custom_facet_funcs):
+def
test_get_user_provided_run_facets_with_wrong_return_type_function(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
assert result == {}
@@ -695,12 +665,12 @@ def
test_get_custom_facets_with_wrong_return_type_function(mock_custom_facet_fun
"airflow.providers.openlineage.conf.custom_run_facets",
return_value={"tests.providers.openlineage.utils.custom_facet_fixture.get_custom_facet_throws_exception"},
)
-def test_get_custom_facets_with_exception(mock_custom_facet_funcs):
+def test_get_user_provided_run_facets_with_exception(mock_custom_facet_funcs):
sample_ti = TaskInstance(
task=EmptyOperator(
task_id="test-task", dag=DAG("test-dag",
start_date=datetime.datetime(2024, 7, 1))
),
state="running",
)
- result = get_custom_facets(sample_ti)
+ result = get_user_provided_run_facets(sample_ti, TaskInstanceState.RUNNING)
assert result == {}