This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 714a933479 openlineage: add `opt-in` option (#37725)
714a933479 is described below

commit 714a933479f9dc1c3ef5916e43292efc182a0857
Author: Jakub Dardzinski <[email protected]>
AuthorDate: Tue Mar 26 13:01:44 2024 +0100

    openlineage: add `opt-in` option (#37725)
    
    * Add `opt-in` option to disable OpenLineage for all DAGs/tasks
    by default and enable it selectively.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
    
    * Rename `opt_in` to `selective_enable`.
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
    
    ---------
    
    Signed-off-by: Jakub Dardzinski <[email protected]>
---
 airflow/providers/openlineage/conf.py              |   5 +
 airflow/providers/openlineage/plugins/listener.py  |  16 +++
 airflow/providers/openlineage/provider.yaml        |   8 ++
 .../openlineage/utils/selective_enable.py          |  87 ++++++++++++
 airflow/providers/openlineage/utils/utils.py       |  19 ++-
 .../guides/user.rst                                |  56 ++++++++
 .../providers/openlineage/plugins/test_listener.py | 158 ++++++++++++++++++++-
 .../openlineage/utils/test_selective_enable.py     |  72 ++++++++++
 8 files changed, 415 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/openlineage/conf.py 
b/airflow/providers/openlineage/conf.py
index ba8ce913c7..4ca42eedfd 100644
--- a/airflow/providers/openlineage/conf.py
+++ b/airflow/providers/openlineage/conf.py
@@ -51,6 +51,11 @@ def disabled_operators() -> set[str]:
     return set(operator.strip() for operator in option.split(";") if 
operator.strip())
 
 
+@cache
+def selective_enable() -> bool:
+    return conf.getboolean(_CONFIG_SECTION, "selective_enable", fallback=False)
+
+
 @cache
 def custom_extractors() -> set[str]:
     """[openlineage] extractors."""
diff --git a/airflow/providers/openlineage/plugins/listener.py 
b/airflow/providers/openlineage/plugins/listener.py
index 0d6b487f22..ba1e5a7906 100644
--- a/airflow/providers/openlineage/plugins/listener.py
+++ b/airflow/providers/openlineage/plugins/listener.py
@@ -31,6 +31,7 @@ from airflow.providers.openlineage.utils.utils import (
     get_custom_facets,
     get_job_name,
     is_operator_disabled,
+    is_selective_lineage_enabled,
     print_warning,
 )
 from airflow.stats import Stats
@@ -83,6 +84,9 @@ class OpenLineageListener:
             )
             return None
 
+        if not is_selective_lineage_enabled(task):
+            return
+
         @print_warning(self.log)
         def on_running():
             # that's a workaround to detect task running from deferred state
@@ -150,6 +154,9 @@ class OpenLineageListener:
             )
             return None
 
+        if not is_selective_lineage_enabled(task):
+            return
+
         @print_warning(self.log)
         def on_success():
             parent_run_id = OpenLineageAdapter.build_dag_run_id(dag.dag_id, 
dagrun.run_id)
@@ -202,6 +209,9 @@ class OpenLineageListener:
             )
             return None
 
+        if not is_selective_lineage_enabled(task):
+            return
+
         @print_warning(self.log)
         def on_failure():
             parent_run_id = OpenLineageAdapter.build_dag_run_id(dag.dag_id, 
dagrun.run_id)
@@ -255,6 +265,8 @@ class OpenLineageListener:
 
     @hookimpl
     def on_dag_run_running(self, dag_run: DagRun, msg: str):
+        if not is_selective_lineage_enabled(dag_run.dag):
+            return
         data_interval_start = dag_run.data_interval_start.isoformat() if 
dag_run.data_interval_start else None
         data_interval_end = dag_run.data_interval_end.isoformat() if 
dag_run.data_interval_end else None
         self.executor.submit(
@@ -267,6 +279,8 @@ class OpenLineageListener:
 
     @hookimpl
     def on_dag_run_success(self, dag_run: DagRun, msg: str):
+        if not is_selective_lineage_enabled(dag_run.dag):
+            return
         if not self.executor:
             self.log.debug("Executor have not started before 
`on_dag_run_success`")
             return
@@ -274,6 +288,8 @@ class OpenLineageListener:
 
     @hookimpl
     def on_dag_run_failed(self, dag_run: DagRun, msg: str):
+        if not is_selective_lineage_enabled(dag_run.dag):
+            return
         if not self.executor:
             self.log.debug("Executor have not started before 
`on_dag_run_failed`")
             return
diff --git a/airflow/providers/openlineage/provider.yaml 
b/airflow/providers/openlineage/provider.yaml
index aac9b43111..075f711779 100644
--- a/airflow/providers/openlineage/provider.yaml
+++ b/airflow/providers/openlineage/provider.yaml
@@ -77,6 +77,14 @@ config:
         example: 
"airflow.operators.bash.BashOperator;airflow.operators.python.PythonOperator"
         default: ""
         version_added: 1.1.0
+      selective_enable:
+        description: |
+          If this setting is enabled, OpenLineage integration won't collect 
and emit metadata,
+          unless you explicitly enable it per `DAG` or `Task` using  
`enable_lineage` method.
+        type: boolean
+        default: "False"
+        example: ~
+        version_added: 1.7.0
       namespace:
         description: |
           Set namespace that the lineage data belongs to, so that if you use 
multiple OpenLineage producers,
diff --git a/airflow/providers/openlineage/utils/selective_enable.py 
b/airflow/providers/openlineage/utils/selective_enable.py
new file mode 100644
index 0000000000..71cb915e6d
--- /dev/null
+++ b/airflow/providers/openlineage/utils/selective_enable.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import logging
+from typing import TypeVar
+
+from airflow.models import DAG, Operator, Param
+from airflow.models.xcom_arg import XComArg
+
+ENABLE_OL_PARAM_NAME = "_selective_enable_ol"
+ENABLE_OL_PARAM = Param(True, const=True)
+DISABLE_OL_PARAM = Param(False, const=False)
+T = TypeVar("T", bound="DAG | Operator")
+
+log = logging.getLogger(__name__)
+
+
+def enable_lineage(obj: T) -> T:
+    """Set selective enable OpenLineage parameter to True.
+
+    The method also propagates param to tasks if the object is DAG.
+    """
+    if isinstance(obj, XComArg):
+        enable_lineage(obj.operator)
+        return obj
+    # propagate param to tasks
+    if isinstance(obj, DAG):
+        for task in obj.task_dict.values():
+            enable_lineage(task)
+    obj.params[ENABLE_OL_PARAM_NAME] = ENABLE_OL_PARAM
+    return obj
+
+
+def disable_lineage(obj: T) -> T:
+    """Set selective enable OpenLineage parameter to False.
+
+    The method also propagates param to tasks if the object is DAG.
+    """
+    if isinstance(obj, XComArg):
+        disable_lineage(obj.operator)
+        return obj
+    # propagate param to tasks
+    if isinstance(obj, DAG):
+        for task in obj.task_dict.values():
+            disable_lineage(task)
+    obj.params[ENABLE_OL_PARAM_NAME] = DISABLE_OL_PARAM
+    return obj
+
+
+def is_task_lineage_enabled(task: Operator) -> bool:
+    """Check if selective enable OpenLineage parameter is set to True on task 
level."""
+    if task.params.get(ENABLE_OL_PARAM_NAME) is False:
+        log.debug(
+            "OpenLineage event emission suppressed. Task for this 
functionality is selectively disabled."
+        )
+    return task.params.get(ENABLE_OL_PARAM_NAME) is True
+
+
+def is_dag_lineage_enabled(dag: DAG) -> bool:
+    """Check if DAG is selectively enabled to emit OpenLineage events.
+
+    The method also checks if selective enable parameter is set to True
+    or if any of the tasks in DAG is selectively enabled.
+    """
+    if dag.params.get(ENABLE_OL_PARAM_NAME) is False:
+        log.debug(
+            "OpenLineage event emission suppressed. DAG for this functionality 
is selectively disabled."
+        )
+    return dag.params.get(ENABLE_OL_PARAM_NAME) is True or any(
+        is_task_lineage_enabled(task) for task in dag.tasks
+    )
diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index ef18933a75..62691bc3b2 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -30,16 +30,21 @@ from attrs import asdict
 # TODO: move this maybe to Airflow's logic?
 from openlineage.client.utils import RedactMixin
 
+from airflow.models import DAG, BaseOperator, MappedOperator
 from airflow.providers.openlineage import conf
 from airflow.providers.openlineage.plugins.facets import (
     AirflowMappedTaskRunFacet,
     AirflowRunFacet,
 )
+from airflow.providers.openlineage.utils.selective_enable import (
+    is_dag_lineage_enabled,
+    is_task_lineage_enabled,
+)
 from airflow.utils.context import AirflowContextDeprecationWarning
 from airflow.utils.log.secrets_masker import Redactable, Redacted, 
SecretsMasker, should_hide_value_for_key
 
 if TYPE_CHECKING:
-    from airflow.models import DAG, BaseOperator, DagRun, MappedOperator, 
TaskInstance
+    from airflow.models import DagRun, TaskInstance
 
 
 log = logging.getLogger(__name__)
@@ -73,6 +78,18 @@ def is_operator_disabled(operator: BaseOperator | 
MappedOperator) -> bool:
     return get_fully_qualified_class_name(operator) in 
conf.disabled_operators()
 
 
+def is_selective_lineage_enabled(obj: DAG | BaseOperator | MappedOperator) -> 
bool:
+    """If selective enable is active check if DAG or Task is enabled to emit 
events."""
+    if not conf.selective_enable():
+        return True
+    if isinstance(obj, DAG):
+        return is_dag_lineage_enabled(obj)
+    elif isinstance(obj, (BaseOperator, MappedOperator)):
+        return is_task_lineage_enabled(obj)
+    else:
+        raise TypeError("is_selective_lineage_enabled can only be used on DAG 
or Operator objects")
+
+
 class InfoJsonEncodable(dict):
     """
     Airflow objects might not be json-encodable overall.
diff --git a/docs/apache-airflow-providers-openlineage/guides/user.rst 
b/docs/apache-airflow-providers-openlineage/guides/user.rst
index acceafd619..e7decec1f1 100644
--- a/docs/apache-airflow-providers-openlineage/guides/user.rst
+++ b/docs/apache-airflow-providers-openlineage/guides/user.rst
@@ -189,6 +189,7 @@ If not set, it's using ``default`` namespace. Provide the 
name of the namespace
 
   AIRFLOW__OPENLINEAGE__NAMESPACE='my-team-airflow-instance'
 
+.. _options:disable:
 
 Disable
 ^^^^^^^
@@ -263,6 +264,61 @@ a string of semicolon separated Airflow Operators full 
import paths to ``extract
 
   
AIRFLOW__OPENLINEAGE__EXTRACTORS='full.path.to.ExtractorClass;full.path.to.AnotherExtractorClass'
 
+Enabling OpenLineage on DAG/task level
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+One can selectively enable OpenLineage for specific DAGs and tasks by using 
the ``selective_enable`` policy.
+To enable this policy, set the ``selective_enable`` option to True in the 
[openlineage] section of your Airflow configuration file:
+
+.. code-block:: ini
+
+    [openlineage]
+    selective_enable = True
+
+
+While ``selective_enable`` enables selective control, the ``disabled`` 
:ref:`option <options:disable>` still has precedence.
+If you set ``disabled`` to True in the configuration, OpenLineage will be 
disabled for all DAGs and tasks regardless of the ``selective_enable`` setting.
+
+Once the ``selective_enable`` policy is enabled, you can choose to enable 
OpenLineage
+for individual DAGs and tasks using the ``enable_lineage`` and 
``disable_lineage`` functions.
+
+1. Enabling Lineage on a DAG:
+
+.. code-block:: python
+
+    from airflow.providers.openlineage.utils.selective_enable import 
disable_lineage, enable_lineage
+
+    with enable_lineage(DAG(...)):
+        # Tasks within this DAG will have lineage tracking enabled
+        MyOperator(...)
+
+        AnotherOperator(...)
+
+2. Enabling Lineage on a Task:
+
+While enabling lineage on a DAG implicitly enables it for all tasks within 
that DAG, you can still selectively disable it for specific tasks:
+
+.. code-block:: python
+
+    from airflow.providers.openlineage.utils.selective_enable import 
disable_lineage, enable_lineage
+
+    with DAG(...) as dag:
+        t1 = MyOperator(...)
+        t2 = AnotherOperator(...)
+
+    # Enable lineage for the entire DAG
+    enable_lineage(dag)
+
+    # Disable lineage for task t1
+    disable_lineage(t1)
+
+Enabling lineage on the DAG level automatically enables it for all tasks 
within that DAG unless explicitly disabled per task.
+
+Enabling lineage on the task level implicitly enables lineage on its DAG.
+This is because each emitting task sends a `ParentRunFacet 
<https://openlineage.io/docs/spec/facets/run-facets/parent_run>`_,
+which requires the DAG-level lineage to be enabled in some OpenLineage backend 
systems.
+Disabling DAG-level lineage while enabling task-level lineage might cause 
errors or inconsistencies.
+
 
 Troubleshooting
 ===============
diff --git a/tests/providers/openlineage/plugins/test_listener.py 
b/tests/providers/openlineage/plugins/test_listener.py
index 69bfdabe91..c37892c1f3 100644
--- a/tests/providers/openlineage/plugins/test_listener.py
+++ b/tests/providers/openlineage/plugins/test_listener.py
@@ -29,8 +29,11 @@ import pytest
 from airflow.models import DAG, DagRun, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.operators.python import PythonOperator
+from airflow.providers.openlineage import conf
 from airflow.providers.openlineage.plugins.listener import OpenLineageListener
+from airflow.providers.openlineage.utils.selective_enable import 
disable_lineage, enable_lineage
 from airflow.utils.state import State
+from tests.test_utils.config import conf_vars
 
 pytestmark = pytest.mark.db_test
 
@@ -107,7 +110,7 @@ def _setup_mock_listener(mock_listener: mock.Mock, 
captured_try_numbers: dict[st
         ).side_effect = capture_try_number(event)
 
 
-def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) 
-> TaskInstance:
+def _create_test_dag_and_task(python_callable: Callable, scenario_name: str) 
-> tuple[DagRun, TaskInstance]:
     """Creates a test DAG and a task for a custom test scenario.
 
     :param python_callable: The Python callable to be executed by the 
PythonOperator.
@@ -132,9 +135,9 @@ def _create_test_dag_and_task(python_callable: Callable, 
scenario_name: str) ->
     )
     t = PythonOperator(task_id=f"test_task_{scenario_name}", dag=dag, 
python_callable=python_callable)
     run_id = str(uuid.uuid1())
-    dag.create_dagrun(state=State.NONE, run_id=run_id)  # type: ignore
+    dagrun = dag.create_dagrun(state=State.NONE, run_id=run_id)  # type: ignore
     task_instance = TaskInstance(t, run_id=run_id)
-    return task_instance
+    return dagrun, task_instance
 
 
 def _create_listener_and_task_instance() -> tuple[OpenLineageListener, 
TaskInstance]:
@@ -423,7 +426,7 @@ def 
test_listener_on_task_instance_failed_is_called_before_try_number_increment(
     def fail_callable(**kwargs):
         raise CustomError("Simulated task failure")
 
-    task_instance = _create_test_dag_and_task(fail_callable, "failure")
+    _, task_instance = _create_test_dag_and_task(fail_callable, "failure")
     # try_number before execution
     assert task_instance.try_number == 1
     with suppress(CustomError):
@@ -452,7 +455,7 @@ def 
test_listener_on_task_instance_success_is_called_after_try_number_increment(
     def success_callable(**kwargs):
         return None
 
-    task_instance = _create_test_dag_and_task(success_callable, "success")
+    _, task_instance = _create_test_dag_and_task(success_callable, "success")
     # try_number before execution
     assert task_instance.try_number == 1
     task_instance.run()
@@ -518,3 +521,148 @@ def 
test_listener_on_task_instance_success_do_not_call_adapter_when_disabled_ope
     mocked_adapter.build_task_instance_run_id.assert_not_called()
     listener.extractor_manager.extract_metadata.assert_not_called()
     listener.adapter.complete_task.assert_not_called()
+
+
+class TestOpenLineageSelectiveEnable:
+    def setup_method(self):
+        self.dag = DAG(
+            "test_selective_enable",
+            start_date=dt.datetime(2022, 1, 1),
+        )
+
+        def simple_callable(**kwargs):
+            return None
+
+        self.task_1 = PythonOperator(
+            task_id="test_task_selective_enable_1", dag=self.dag, 
python_callable=simple_callable
+        )
+        self.task_2 = PythonOperator(
+            task_id="test_task_selective_enable_2", dag=self.dag, 
python_callable=simple_callable
+        )
+        run_id = str(uuid.uuid1())
+        self.dagrun = self.dag.create_dagrun(state=State.NONE, run_id=run_id)  
# type: ignore
+        self.task_instance_1 = TaskInstance(self.task_1, run_id=run_id)
+        self.task_instance_2 = TaskInstance(self.task_2, run_id=run_id)
+        self.task_instance_1.dag_run = self.task_instance_2.dag_run = 
self.dagrun
+
+    @pytest.mark.parametrize(
+        "selective_enable, enable_dag, expected_call_count",
+        [
+            ("True", True, 3),
+            ("False", True, 3),
+            ("True", False, 0),
+            ("False", False, 3),
+        ],
+    )
+    def test_listener_with_dag_enabled(self, selective_enable, enable_dag, 
expected_call_count):
+        """Tests listener's behaviour with selective-enable on DAG level."""
+
+        if enable_dag:
+            enable_lineage(self.dag)
+
+        conf.selective_enable.cache_clear()
+        with conf_vars({("openlineage", "selective_enable"): 
selective_enable}):
+            listener = OpenLineageListener()
+            listener._executor = mock.Mock()
+
+            # run all three DagRun-related hooks
+            listener.on_dag_run_running(self.dagrun, msg="test running")
+            listener.on_dag_run_failed(self.dagrun, msg="test failure")
+            listener.on_dag_run_success(self.dagrun, msg="test success")
+
+        try:
+            assert expected_call_count == listener._executor.submit.call_count
+        finally:
+            conf.selective_enable.cache_clear()
+
+    @pytest.mark.parametrize(
+        "selective_enable, enable_task, expected_dag_call_count, 
expected_task_call_count",
+        [
+            ("True", True, 3, 3),
+            ("False", True, 3, 3),
+            ("True", False, 0, 0),
+            ("False", False, 3, 3),
+        ],
+    )
+    def test_listener_with_task_enabled(
+        self, selective_enable, enable_task, expected_dag_call_count, 
expected_task_call_count
+    ):
+        """Tests listener's behaviour with selective-enable on task level."""
+
+        if enable_task:
+            enable_lineage(self.task_1)
+
+        conf.selective_enable.cache_clear()
+        with conf_vars({("openlineage", "selective_enable"): 
selective_enable}):
+            listener = OpenLineageListener()
+            listener._executor = mock.Mock()
+            listener.extractor_manager = mock.Mock()
+            listener.adapter = mock.Mock()
+            try:
+                # run all three DagRun-related hooks
+                listener.on_dag_run_running(self.dagrun, msg="test running")
+                listener.on_dag_run_failed(self.dagrun, msg="test failure")
+                listener.on_dag_run_success(self.dagrun, msg="test success")
+
+                assert expected_dag_call_count == 
listener._executor.submit.call_count
+
+                # run TaskInstance-related hooks for lineage enabled task
+                listener.on_task_instance_running(None, self.task_instance_1, 
None)
+                listener.on_task_instance_success(None, self.task_instance_1, 
None)
+                listener.on_task_instance_failed(None, self.task_instance_1, 
None)
+
+                assert expected_task_call_count == 
listener.extractor_manager.extract_metadata.call_count
+
+                # run TaskInstance-related hooks for lineage disabled task
+                listener.on_task_instance_running(None, self.task_instance_2, 
None)
+                listener.on_task_instance_success(None, self.task_instance_2, 
None)
+                listener.on_task_instance_failed(None, self.task_instance_2, 
None)
+
+                # with selective-enable disabled both task_1 and task_2 should 
trigger metadata extraction
+                if selective_enable == "False":
+                    expected_task_call_count *= 2
+
+                assert expected_task_call_count == 
listener.extractor_manager.extract_metadata.call_count
+            finally:
+                conf.selective_enable.cache_clear()
+
+    @pytest.mark.parametrize(
+        "selective_enable, enable_task, expected_call_count, 
expected_task_call_count",
+        [
+            ("True", True, 3, 3),
+            ("False", True, 3, 3),
+            ("True", False, 0, 0),
+            ("False", False, 3, 3),
+        ],
+    )
+    def test_listener_with_dag_disabled_task_enabled(
+        self, selective_enable, enable_task, expected_call_count, 
expected_task_call_count
+    ):
+        """Tests listener's behaviour with selective-enable on task level with 
DAG disabled."""
+        disable_lineage(self.dag)
+
+        if enable_task:
+            enable_lineage(self.task_1)
+
+        conf.selective_enable.cache_clear()
+        with conf_vars({("openlineage", "selective_enable"): 
selective_enable}):
+            listener = OpenLineageListener()
+            listener._executor = mock.Mock()
+            listener.extractor_manager = mock.Mock()
+            listener.adapter = mock.Mock()
+
+            # run all three DagRun-related hooks
+            listener.on_dag_run_running(self.dagrun, msg="test running")
+            listener.on_dag_run_failed(self.dagrun, msg="test failure")
+            listener.on_dag_run_success(self.dagrun, msg="test success")
+
+            # run TaskInstance-related hooks for lineage enabled task
+            listener.on_task_instance_running(None, self.task_instance_1, None)
+            listener.on_task_instance_success(None, self.task_instance_1, None)
+            listener.on_task_instance_failed(None, self.task_instance_1, None)
+
+        try:
+            assert expected_call_count == listener._executor.submit.call_count
+            assert expected_task_call_count == 
listener.extractor_manager.extract_metadata.call_count
+        finally:
+            conf.selective_enable.cache_clear()
diff --git a/tests/providers/openlineage/utils/test_selective_enable.py 
b/tests/providers/openlineage/utils/test_selective_enable.py
new file mode 100644
index 0000000000..a177181489
--- /dev/null
+++ b/tests/providers/openlineage/utils/test_selective_enable.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.decorators import dag, task
+from airflow.models import DAG
+from airflow.operators.empty import EmptyOperator
+from airflow.providers.openlineage.utils.selective_enable import (
+    DISABLE_OL_PARAM,
+    ENABLE_OL_PARAM,
+    ENABLE_OL_PARAM_NAME,
+    disable_lineage,
+    enable_lineage,
+)
+
+
+class TestOpenLineageSelectiveEnable:
+    def setup_method(self):
+        @dag(dag_id="test_selective_enable_decorated_dag")
+        def decorated_dag():
+            @task
+            def decorated_task():
+                return "test"
+
+            self.decorated_task = decorated_task()
+
+        self.decorated_dag = decorated_dag()
+
+        with DAG(dag_id="test_selective_enable_dag") as self.dag:
+            self.task = EmptyOperator(task_id="test_selective_enable")
+
+    def test_enable_lineage_task_level(self):
+        assert ENABLE_OL_PARAM_NAME not in self.task.params
+        enable_lineage(self.task)
+        assert ENABLE_OL_PARAM.value == self.task.params[ENABLE_OL_PARAM_NAME]
+
+    def test_disable_lineage_task_level(self):
+        assert ENABLE_OL_PARAM_NAME not in self.task.params
+        disable_lineage(self.task)
+        assert DISABLE_OL_PARAM.value == self.task.params[ENABLE_OL_PARAM_NAME]
+
+    def test_enable_lineage_dag_level(self):
+        assert ENABLE_OL_PARAM_NAME not in self.dag.params
+        enable_lineage(self.dag)
+        assert ENABLE_OL_PARAM.value == self.dag.params[ENABLE_OL_PARAM_NAME]
+        # check if param propagates to the task
+        assert ENABLE_OL_PARAM.value == self.task.params[ENABLE_OL_PARAM_NAME]
+
+    def test_enable_lineage_decorated_dag(self):
+        enable_lineage(self.decorated_dag)
+        assert ENABLE_OL_PARAM.value == 
self.decorated_dag.params[ENABLE_OL_PARAM_NAME]
+        # check if param propagates to the task
+        assert ENABLE_OL_PARAM.value == 
self.decorated_task.operator.params[ENABLE_OL_PARAM_NAME]
+
+    def test_enable_lineage_decorated_task(self):
+        enable_lineage(self.decorated_task)
+        assert ENABLE_OL_PARAM.value == 
self.decorated_task.operator.params[ENABLE_OL_PARAM_NAME]

Reply via email to