This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4174bc7d39f feat: Add helper for any provider version check (#47909)
4174bc7d39f is described below

commit 4174bc7d39fb336c329d887878ac327d4e283f6d
Author: Kacper Muda <[email protected]>
AuthorDate: Thu Mar 20 13:10:33 2025 +0100

    feat: Add helper for any provider version check (#47909)
---
 .../src/airflow/providers/common/compat/check.py   |  99 ++++++++++++++++++++
 .../providers/common/compat/openlineage/check.py   |  34 ++++---
 .../unit/common/compat/openlineage/test_check.py   |  99 ++++++++------------
 .../compat/tests/unit/common/compat/test_check.py  | 102 +++++++++++++++++++++
 4 files changed, 253 insertions(+), 81 deletions(-)

diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/check.py 
b/providers/common/compat/src/airflow/providers/common/compat/check.py
new file mode 100644
index 00000000000..e11ce29be7c
--- /dev/null
+++ b/providers/common/compat/src/airflow/providers/common/compat/check.py
@@ -0,0 +1,99 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import functools
+import importlib
+from importlib import metadata
+
+from packaging.version import Version
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+
+
+def require_provider_version(provider_name: str, provider_min_version: str):
+    """
+    Enforce minimum version requirement for a specific provider.
+
+    Some providers, do not explicitly require other provider packages but may 
offer optional features
+    that depend on it. These features are generally available starting from a 
specific version of such
+    provider. This decorator helps ensure compatibility, preventing import 
errors and providing clear
+    logs about version requirements.
+
+    Args:
+        provider_name: Name of the provider e.g., 
apache-airflow-providers-openlineage
+        provider_min_version: Optional minimum version requirement e.g., 1.0.1
+
+    Raises:
+        ValueError: If neither `provider_name` nor `provider_min_version` is 
provided.
+        ValueError: If full provider name (e.g., 
apache-airflow-providers-openlineage) is not provided.
+        TypeError: If the decorator is used without parentheses (e.g., 
`@require_provider_version`).
+    """
+    err_msg = (
+        "`require_provider_version` decorator must be used with two arguments: 
"
+        "'provider_name' and 'provider_min_version', "
+        'e.g., 
@require_provider_version(provider_name="apache-airflow-providers-openlineage", 
'
+        'provider_min_version="1.0.0")'
+    )
+    # Detect if decorator is mistakenly used without arguments
+    if callable(provider_name) and not provider_min_version:
+        raise TypeError(err_msg)
+
+    # Ensure both arguments are provided and not empty
+    if not provider_name or not provider_min_version:
+        raise ValueError(err_msg)
+
+    # Ensure full provider name is passed
+    if not provider_name.startswith("apache-airflow-providers-"):
+        raise ValueError(
+            f"Full `provider_name` must be provided starting with 
'apache-airflow-providers-', "
+            f"got `{provider_name}`."
+        )
+
+    def decorator(func):
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            try:
+                provider_version: str = metadata.version(provider_name)
+            except metadata.PackageNotFoundError:
+                try:
+                    # Try dynamically importing the provider module based on 
the provider name
+                    import_provider_name = 
provider_name.replace("apache-airflow-providers-", "").replace(
+                        "-", "."
+                    )
+                    provider_module = 
importlib.import_module(f"airflow.providers.{import_provider_name}")
+
+                    provider_version = getattr(provider_module, "__version__")
+
+                except (ImportError, AttributeError, ModuleNotFoundError):
+                    raise AirflowOptionalProviderFeatureException(
+                        f"Provider `{provider_name}` not found or has no 
version, "
+                        f"skipping function `{func.__name__}` execution"
+                    )
+
+            if provider_version and Version(provider_version) < 
Version(provider_min_version):
+                raise AirflowOptionalProviderFeatureException(
+                    f"Provider's `{provider_name}` version 
`{provider_version}` is lower than required "
+                    f"`{provider_min_version}`, skipping function 
`{func.__name__}` execution"
+                )
+
+            return func(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
diff --git 
a/providers/common/compat/src/airflow/providers/common/compat/openlineage/check.py
 
b/providers/common/compat/src/airflow/providers/common/compat/openlineage/check.py
index 16626cb9ea5..a2cc503703b 100644
--- 
a/providers/common/compat/src/airflow/providers/common/compat/openlineage/check.py
+++ 
b/providers/common/compat/src/airflow/providers/common/compat/openlineage/check.py
@@ -23,6 +23,8 @@ from importlib import metadata
 
 from packaging.version import Version
 
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+
 log = logging.getLogger(__name__)
 
 
@@ -68,36 +70,32 @@ def require_openlineage_version(
                     try:
                         from airflow.providers.openlineage import __version__ 
as provider_version
                     except (ImportError, AttributeError, ModuleNotFoundError):
-                        log.info(
-                            "OpenLineage provider not found or has no version, 
skipping function `%s` execution",
-                            func.__name__,
+                        raise AirflowOptionalProviderFeatureException(
+                            "OpenLineage provider not found or has no version, 
"
+                            f"skipping function `{func.__name__}` execution"
                         )
-                        return None
 
                 if provider_version and Version(provider_version) < 
Version(provider_min_version):
-                    log.info(
-                        "OpenLineage provider version `%s` is lower than 
required `%s`, skipping function `%s` execution",
-                        provider_version,
-                        provider_min_version,
-                        func.__name__,
+                    raise AirflowOptionalProviderFeatureException(
+                        f"OpenLineage provider version `{provider_version}` "
+                        f"is lower than required `{provider_min_version}`, "
+                        f"skipping function `{func.__name__}` execution"
                     )
-                    return None
 
             if client_min_version:
                 try:
                     client_version: str = 
metadata.version("openlineage-python")
                 except metadata.PackageNotFoundError:
-                    log.info("OpenLineage client not found, skipping function 
`%s` execution", func.__name__)
-                    return None
+                    raise AirflowOptionalProviderFeatureException(
+                        f"OpenLineage client not found, skipping function 
`{func.__name__}` execution"
+                    )
 
                 if client_version and Version(client_version) < 
Version(client_min_version):
-                    log.info(
-                        "OpenLineage client version `%s` is lower than 
required `%s`, skipping function `%s` execution",
-                        client_version,
-                        client_min_version,
-                        func.__name__,
+                    raise AirflowOptionalProviderFeatureException(
+                        f"OpenLineage client version `{client_version}` "
+                        f"is lower than required `{client_min_version}`, "
+                        f"skipping function `{func.__name__}` execution"
                     )
-                    return None
 
             return func(*args, **kwargs)
 
diff --git 
a/providers/common/compat/tests/unit/common/compat/openlineage/test_check.py 
b/providers/common/compat/tests/unit/common/compat/openlineage/test_check.py
index 75937b3232b..43317509e5a 100644
--- a/providers/common/compat/tests/unit/common/compat/openlineage/test_check.py
+++ b/providers/common/compat/tests/unit/common/compat/openlineage/test_check.py
@@ -17,7 +17,6 @@
 
 from __future__ import annotations
 
-import logging
 import sys
 import types
 from importlib import metadata
@@ -25,6 +24,7 @@ from unittest.mock import patch
 
 import pytest
 
+from airflow.exceptions import AirflowOptionalProviderFeatureException
 from airflow.providers.common.compat.openlineage.check import 
require_openlineage_version
 
 
@@ -79,37 +79,32 @@ def test_no_arguments_provided():
 
 @pytest.mark.parametrize("provider_min_version", ("1.0.0", "0.9", "0", 
"0.9.9", "1.0.0.dev0", "1.0.0rc1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_provider_version_sufficient(mock_version, caplog, 
provider_min_version):
+def test_provider_version_sufficient(mock_version, provider_min_version):
     @require_openlineage_version(provider_min_version=provider_min_version)
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
     result = dummy()
     assert result == "result"
-    # No log messages about skipping should be emitted.
-    assert "skipping function" not in caplog.text
 
 
 @pytest.mark.parametrize("provider_min_version", ("1.1.0", "1.0.1.dev0", 
"1.0.1rc1", "2", "1.1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_provider_version_insufficient(mock_version, caplog, 
provider_min_version):
+def test_provider_version_insufficient(mock_version, provider_min_version):
     @require_openlineage_version(provider_min_version=provider_min_version)
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
-    result = dummy()
-    assert result is None
-
-    expected_log = (
+    expected_err = (
         f"OpenLineage provider version `1.0.0` is lower than required 
`{provider_min_version}`, "
         "skipping function `dummy` execution"
     )
-    assert expected_log in caplog.text
 
+    with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+        dummy()
 
-def test_provider_not_found(caplog):
+
+def test_provider_not_found():
     def fake_version(package):
         if package == "apache-airflow-providers-openlineage":
             raise metadata.PackageNotFoundError
@@ -124,17 +119,15 @@ def test_provider_not_found(caplog):
             def dummy():
                 return "result"
 
-            caplog.set_level(logging.INFO)
-            result = dummy()
-            assert result is None
-
-            expected_log = (
+            expected_err = (
                 "OpenLineage provider not found or has no version, skipping 
function `dummy` execution"
             )
-            assert expected_log in caplog.text
+
+            with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+                dummy()
 
 
-def test_provider_fallback_import(caplog):
+def test_provider_fallback_import():
     def fake_version(package):
         if package == "apache-airflow-providers-openlineage":
             raise metadata.PackageNotFoundError
@@ -150,45 +143,38 @@ def test_provider_fallback_import(caplog):
             def dummy():
                 return "result"
 
-            caplog.set_level(logging.INFO)
             result = dummy()
             assert result == "result"
-            assert "skipping function" not in caplog.text
 
 
 @pytest.mark.parametrize("client_min_version", ("1.0.0", "0.9", "0", "0.9.9", 
"1.0.0.dev0", "1.0.0rc1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_client_version_sufficient(mock_version, caplog, client_min_version):
+def test_client_version_sufficient(mock_version, client_min_version):
     @require_openlineage_version(client_min_version=client_min_version)
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
     result = dummy()
     assert result == "result"
-    # No log messages about skipping should be emitted.
-    assert "skipping function" not in caplog.text
 
 
 @pytest.mark.parametrize("client_min_version", ("1.1.0", "1.0.1.dev0", 
"1.0.1rc1", "2", "1.1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_client_version_insufficient(mock_version, caplog, client_min_version):
+def test_client_version_insufficient(mock_version, client_min_version):
     @require_openlineage_version(client_min_version=client_min_version)
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
-    result = dummy()
-    assert result is None
-
-    expected_log = (
+    expected_err = (
         f"OpenLineage client version `1.0.0` is lower than required 
`{client_min_version}`, "
         "skipping function `dummy` execution"
     )
-    assert expected_log in caplog.text
 
+    with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+        dummy()
 
-def test_client_version_not_found(caplog):
+
+def test_client_version_not_found():
     def fake_version(package):
         if package == "openlineage-python":
             raise metadata.PackageNotFoundError
@@ -200,81 +186,68 @@ def test_client_version_not_found(caplog):
         def dummy():
             return "result"
 
-        caplog.set_level(logging.INFO)
-        result = dummy()
-        assert result is None
-        expected_log = "OpenLineage client not found, skipping function 
`dummy` execution"
-        assert expected_log in caplog.text
+        expected_err = "OpenLineage client not found, skipping function 
`dummy` execution"
+        with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+            dummy()
 
 
 @pytest.mark.parametrize("client_min_version", ("1.1.0", "1.0.1.dev0", 
"1.0.1rc1", "2", "1.1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_client_version_insufficient_when_both_passed(mock_version, caplog, 
client_min_version):
+def test_client_version_insufficient_when_both_passed(mock_version, 
client_min_version):
     @require_openlineage_version(provider_min_version="1.0.0", 
client_min_version=client_min_version)
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
-    result = dummy()
-    assert result is None
-
-    expected_log = (
+    expected_err = (
         f"OpenLineage client version `1.0.0` is lower than required 
`{client_min_version}`, "
         "skipping function `dummy` execution"
     )
-    assert expected_log in caplog.text
+    with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+        dummy()
 
 
 @pytest.mark.parametrize("provider_min_version", ("1.1.0", "1.0.1.dev0", 
"1.0.1rc1", "2", "1.1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_provider_version_insufficient_when_both_passed(mock_version, caplog, 
provider_min_version):
+def test_provider_version_insufficient_when_both_passed(mock_version, 
provider_min_version):
     @require_openlineage_version(provider_min_version=provider_min_version, 
client_min_version="1.0.0")
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
-    result = dummy()
-    assert result is None
-
-    expected_log = (
+    expected_err = (
         f"OpenLineage provider version `1.0.0` is lower than required 
`{provider_min_version}`, "
         "skipping function `dummy` execution"
     )
-    assert expected_log in caplog.text
+    with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+        dummy()
 
 
 @pytest.mark.parametrize("client_min_version", ("1.0.0", "0.9", "0", "0.9.9", 
"1.0.0.dev0", "1.0.0rc1"))
 @pytest.mark.parametrize("provider_min_version", ("1.0.0", "0.9", "0", 
"0.9.9", "1.0.0.dev0", "1.0.0rc1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_both_versions_sufficient(mock_version, caplog, provider_min_version, 
client_min_version):
+def test_both_versions_sufficient(mock_version, provider_min_version, 
client_min_version):
     @require_openlineage_version(
         provider_min_version=provider_min_version, 
client_min_version=client_min_version
     )
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
     result = dummy()
     assert result == "result"
-    assert "skipping function" not in caplog.text
 
 
 @pytest.mark.parametrize("client_min_version", ("1.1.0", "1.0.1.dev0", 
"1.0.1rc1", "2", "1.1"))
 @pytest.mark.parametrize("provider_min_version", ("1.1.0", "1.0.1.dev0", 
"1.0.1rc1", "2", "1.1"))
 @patch("importlib.metadata.version", side_effect=_mock_version)
-def test_both_versions_insufficient(mock_version, caplog, 
provider_min_version, client_min_version):
+def test_both_versions_insufficient(mock_version, provider_min_version, 
client_min_version):
     @require_openlineage_version(
         provider_min_version=provider_min_version, 
client_min_version=client_min_version
     )
     def dummy():
         return "result"
 
-    caplog.set_level(logging.INFO)
-    result = dummy()
-    assert result is None
-
-    expected_log = (
+    expected_err = (
         f"OpenLineage provider version `1.0.0` is lower than required 
`{provider_min_version}`, "
         "skipping function `dummy` execution"
     )
-    assert expected_log in caplog.text
+    with pytest.raises(AirflowOptionalProviderFeatureException, 
match=expected_err):
+        dummy()
diff --git a/providers/common/compat/tests/unit/common/compat/test_check.py 
b/providers/common/compat/tests/unit/common/compat/test_check.py
new file mode 100644
index 00000000000..7fb4e3eb760
--- /dev/null
+++ b/providers/common/compat/tests/unit/common/compat/test_check.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from importlib import metadata
+from unittest.mock import patch
+
+import pytest
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.providers.common.compat.check import require_provider_version
+
+
+def test_decorator_usage_without_parentheses():
+    with pytest.raises(TypeError):
+
+        @require_provider_version
+        def dummy_function():
+            pass
+
+
+def test_empty_provider_name_and_version():
+    with pytest.raises(ValueError, match="decorator must be used with two 
arguments"):
+
+        @require_provider_version("", "")
+        def dummy_function():
+            pass
+
+
+def test_invalid_provider_name():
+    expected_error = (
+        "Full `provider_name` must be provided starting with 
'apache-airflow-providers-', "
+        "got `invalid-provider-name`."
+    )
+    with pytest.raises(ValueError, match=expected_error):
+
+        @require_provider_version("invalid-provider-name", "1.0.0")
+        def dummy_function():
+            pass
+
+
+@patch("importlib.metadata.version", return_value="0.9.9")
+def test_provider_version_lower_than_required(mock_version):
+    @require_provider_version("apache-airflow-providers-mockprovider", "1.0.0")
+    def dummy_function():
+        return "Function Executed"
+
+    with pytest.raises(
+        AirflowOptionalProviderFeatureException,
+        match=r"Provider's `apache-airflow-providers-mockprovider` version 
`0.9.9` is lower than required `1.0.0`",
+    ):
+        dummy_function()
+
+
+@patch("importlib.metadata.version", side_effect=metadata.PackageNotFoundError)
+@patch("importlib.import_module", side_effect=ModuleNotFoundError)
+def test_provider_not_installed(mock_import, mock_version):
+    @require_provider_version("apache-airflow-providers-mockprovider", "1.0.0")
+    def dummy_function():
+        return "Function Executed"
+
+    with pytest.raises(
+        AirflowOptionalProviderFeatureException,
+        match=r"Provider `apache-airflow-providers-mockprovider` not found or 
has no version",
+    ):
+        dummy_function()
+
+
+@patch("importlib.metadata.version", return_value="2.0.0")
+def test_provider_version_ok(mock_version):
+    @require_provider_version("apache-airflow-providers-mockprovider", "1.0.0")
+    def dummy_function():
+        return "Function Executed"
+
+    result = dummy_function()
+    assert result == "Function Executed"
+
+
+@patch("importlib.import_module", return_value=type("module", (), 
{"__version__": "1.5.0"}))
+@patch("importlib.metadata.version", side_effect=metadata.PackageNotFoundError)
+def test_provider_dynamic_import(mock_version, mock_import):
+    @require_provider_version("apache-airflow-providers-mockprovider", "1.0.0")
+    def dummy_function():
+        return "Function Executed"
+
+    result = dummy_function()
+    assert result == "Function Executed"

Reply via email to