This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f1301daef0 Improve stability of remove_task_decorator function (#38649)
f1301daef0 is described below
commit f1301daef027a750f4060e0f26d53151af99d5f7
Author: rom sharon <[email protected]>
AuthorDate: Mon Apr 1 20:10:31 2024 +0300
Improve stability of remove_task_decorator function (#38649)
* Improve stability of remove_task_decorator function
* fix statics
* test
* remove test
---------
Co-authored-by: Sam Wheating <[email protected]>
---
airflow/utils/decorators.py | 10 ++++++++--
tests/utils/test_python_virtualenv.py | 5 +++++
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py
index 21b6ff3412..77a5eddaf0 100644
--- a/airflow/utils/decorators.py
+++ b/airflow/utils/decorators.py
@@ -62,12 +62,18 @@ def remove_task_decorator(python_source: str,
task_decorator_name: str) -> str:
:param python_source: The python source code
:param task_decorator_name: the decorator name
+
+ TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
"""
def _remove_task_decorator(py_source, decorator_name):
- if decorator_name not in py_source:
+ # if no line starts with @decorator_name, we can early exit
+ for line in py_source.split("\n"):
+ if line.startswith(decorator_name):
+ break
+ else:
return python_source
- split = python_source.split(decorator_name)
+ split = python_source.split(decorator_name, 1)
before_decorator, after_decorator = split[0], split[1]
if after_decorator[0] == "(":
after_decorator = _balance_parens(after_decorator)
diff --git a/tests/utils/test_python_virtualenv.py
b/tests/utils/test_python_virtualenv.py
index 8f067fd3ae..38cda4854b 100644
--- a/tests/utils/test_python_virtualenv.py
+++ b/tests/utils/test_python_virtualenv.py
@@ -125,6 +125,11 @@ class TestPrepareVirtualenv:
res = remove_task_decorator(python_source=py_source,
task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"
+ def test_remove_decorator_including_comment(self):
+ py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport
funcsigs"
+ res = remove_task_decorator(python_source=py_source,
task_decorator_name="@task.virtualenv")
+ assert res == "def f():\n# @task.virtualenv\nimport funcsigs"
+
def test_remove_decorator_nested(self):
py_source = "@foo\[email protected]\n@bar\ndef f():\nimport funcsigs"
res = remove_task_decorator(python_source=py_source,
task_decorator_name="@task.virtualenv")