This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-7-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 28e88f412aa1553133c155420ad064e2a92596bc
Author: Jarek Potiuk <[email protected]>
AuthorDate: Wed Aug 9 01:59:03 2023 +0200

    Fix venv detection for Python virtualenv operator (#33223)
    
    This is a follow-up after #32939. It seems that findspec does not
    cover all the cases and the previous check is also faster.
    
    Adding check for the binary first and then falling back to spec
    finding will make it faster and work in the cases where the
    findspec does not work (for local development cases).
    
    (cherry picked from commit c4fe5b8b8a3b12750bb2984aca198f7bb16b6785)
---
 airflow/example_dags/example_python_operator.py          |  5 ++---
 airflow/example_dags/tutorial_taskflow_api_virtualenv.py |  4 ++--
 airflow/operators/python.py                              | 14 +++++++++++++-
 tests/operators/test_python.py                           |  6 ++++--
 4 files changed, 21 insertions(+), 8 deletions(-)

diff --git a/airflow/example_dags/example_python_operator.py 
b/airflow/example_dags/example_python_operator.py
index 30e447840a..7c8b27de76 100644
--- a/airflow/example_dags/example_python_operator.py
+++ b/airflow/example_dags/example_python_operator.py
@@ -22,7 +22,6 @@ virtual environment.
 from __future__ import annotations
 
 import logging
-import shutil
 import sys
 import tempfile
 import time
@@ -32,7 +31,7 @@ import pendulum
 
 from airflow import DAG
 from airflow.decorators import task
-from airflow.operators.python import ExternalPythonOperator, 
PythonVirtualenvOperator
+from airflow.operators.python import ExternalPythonOperator, 
PythonVirtualenvOperator, is_venv_installed
 
 log = logging.getLogger(__name__)
 
@@ -86,7 +85,7 @@ with DAG(
         run_this >> log_the_sql >> sleeping_task
     # [END howto_operator_python_kwargs]
 
-    if not shutil.which("virtualenv"):
+    if not is_venv_installed():
         log.warning("The virtalenv_python example task requires virtualenv, 
please install it.")
     else:
         # [START howto_operator_python_venv]
diff --git a/airflow/example_dags/tutorial_taskflow_api_virtualenv.py 
b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py
index a78116d2f7..44134e4458 100644
--- a/airflow/example_dags/tutorial_taskflow_api_virtualenv.py
+++ b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py
@@ -18,14 +18,14 @@
 from __future__ import annotations
 
 import logging
-import shutil
 from datetime import datetime
 
 from airflow.decorators import dag, task
+from airflow.operators.python import is_venv_installed
 
 log = logging.getLogger(__name__)
 
-if not shutil.which("virtualenv"):
+if not is_venv_installed():
     log.warning("The tutorial_taskflow_api_virtualenv example DAG requires 
virtualenv, please install it.")
 else:
 
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 3a10cf4f73..479fb600c4 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -22,6 +22,7 @@ import inspect
 import logging
 import os
 import pickle
+import shutil
 import subprocess
 import sys
 import types
@@ -54,6 +55,17 @@ if TYPE_CHECKING:
     from pendulum.datetime import DateTime
 
 
+def is_venv_installed() -> bool:
+    """
+    Checks if the virtualenv package is installed via checking if it is on the 
path or installed as package.
+
+    :return: True if it is. Whichever way of checking it works, is fine.
+    """
+    if shutil.which("virtualenv") or importlib.util.find_spec("virtualenv"):
+        return True
+    return False
+
+
 def task(python_callable: Callable | None = None, multiple_outputs: bool | 
None = None, **kwargs):
     """Deprecated. Use :func:`airflow.decorators.task` instead.
 
@@ -540,7 +552,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
                 "major versions for PythonVirtualenvOperator. Please use 
string_args."
                 f"Sys version: {sys.version_info}. Venv version: 
{python_version}"
             )
-        if importlib.util.find_spec("virtualenv") is None:
+        if not is_venv_installed():
             raise AirflowException("PythonVirtualenvOperator requires 
virtualenv, please install it.")
         if not requirements:
             self.requirements: list[str] | str = []
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index ba296f77a4..986f3ae510 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -847,9 +847,11 @@ class 
TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
         kwargs["python_version"] = python_version
         return kwargs
 
+    @mock.patch("shutil.which")
     @mock.patch("airflow.operators.python.importlib")
-    def test_virtuenv_not_installed(self, importlib):
-        importlib.util.find_spec.return_value = None
+    def test_virtuenv_not_installed(self, importlib_mock, which_mock):
+        which_mock.return_value = None
+        importlib_mock.util.find_spec.return_value = None
         with pytest.raises(AirflowException, match="requires virtualenv"):
 
             def f():

Reply via email to