This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a0f7e61497 Allow DagParam to hold falsy values (#22964)
a0f7e61497 is described below
commit a0f7e61497d547b82edc1154d39535d79aaedff3
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Apr 13 15:48:46 2022 +0800
Allow DagParam to hold falsy values (#22964)
---
airflow/models/dag.py | 2 +-
airflow/models/param.py | 24 ++++++++++++++----------
tests/models/test_dag.py | 23 ++++++++++++-----------
3 files changed, 27 insertions(+), 22 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8d5e8eacd6..b00b4f666a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -985,7 +985,7 @@ class DAG(LoggingMixin):
def pickle_id(self, value: int) -> None:
self._pickle_id = value
- def param(self, name: str, default=None) -> DagParam:
+ def param(self, name: str, default: Any = NOTSET) -> DagParam:
"""
Return a DagParam object for current dag.
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 22d6a02639..fcbe7a0f93 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -14,15 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import contextlib
import copy
import json
import warnings
-from typing import Any, Dict, ItemsView, MutableMapping, Optional, ValuesView
+from typing import TYPE_CHECKING, Any, Dict, ItemsView, MutableMapping,
Optional, ValuesView
from airflow.exceptions import AirflowException, ParamValidationError
from airflow.utils.context import Context
from airflow.utils.types import NOTSET, ArgNotSet
+if TYPE_CHECKING:
+ from airflow.models.dag import DAG
+
class Param:
"""
@@ -228,18 +232,18 @@ class DagParam:
:param default: Default value used if no parameter was set.
"""
- def __init__(self, current_dag, name: str, default: Optional[Any] = None):
- if default:
+ def __init__(self, current_dag: "DAG", name: str, default: Any = NOTSET):
+ if default is not NOTSET:
current_dag.params[name] = default
self._name = name
self._default = default
def resolve(self, context: Context) -> Any:
"""Pull DagParam value from DagRun context. This method is run during
``op.execute()``."""
- default = self._default
- if not self._default:
- default = context['params'][self._name] if self._name in
context['params'] else None
- resolved = context['dag_run'].conf.get(self._name, default)
- if not resolved:
- raise AirflowException(f'No value could be resolved for parameter
{self._name}')
- return resolved
+ with contextlib.suppress(KeyError):
+ return context['dag_run'].conf[self._name]
+ if self._default is not NOTSET:
+ return self._default
+ with contextlib.suppress(KeyError):
+ return context['params'][self._name]
+ raise AirflowException(f'No value could be resolved for parameter
{self._name}')
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 49378b3d26..72515e4b5e 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -27,7 +27,7 @@ from contextlib import redirect_stdout
from datetime import timedelta
from pathlib import Path
from tempfile import NamedTemporaryFile
-from typing import List, Optional, Sequence
+from typing import List, Optional
from unittest import mock
from unittest.mock import patch
@@ -499,7 +499,7 @@ class TestDag(unittest.TestCase):
task = DummyOperator(task_id='op1')
task.test_field = template_file
- task.template_fields: Sequence[str] = ('test_field',)
+ task.template_fields = ('test_field',)
task.template_ext = ('.template',)
task.resolve_template_files()
@@ -517,7 +517,7 @@ class TestDag(unittest.TestCase):
task = DummyOperator(task_id='op1')
task.test_field = [template_file, 'some_string']
- task.template_fields: Sequence[str] = ('test_field',)
+ task.template_fields = ('test_field',)
task.template_ext = ('.template',)
task.resolve_template_files()
@@ -2006,7 +2006,7 @@ class TestQueries(unittest.TestCase):
)
-class TestDagDecorator(unittest.TestCase):
+class TestDagDecorator:
DEFAULT_ARGS = {
"owner": "test",
"depends_on_past": True,
@@ -2017,12 +2017,10 @@ class TestDagDecorator(unittest.TestCase):
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
VALUE = 42
- def setUp(self):
- super().setUp()
+ def setup_method(self):
self.operator = None
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
clear_db_runs()
def test_fileloc(self):
@@ -2118,6 +2116,7 @@ class TestDagDecorator(unittest.TestCase):
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
+ data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
state=State.RUNNING,
)
@@ -2145,6 +2144,7 @@ class TestDagDecorator(unittest.TestCase):
run_id=DagRunType.MANUAL.value,
start_date=timezone.utcnow(),
execution_date=self.DEFAULT_DATE,
+ data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
state=State.RUNNING,
conf={'value': new_value},
)
@@ -2153,11 +2153,12 @@ class TestDagDecorator(unittest.TestCase):
ti = dr.get_task_instances()[0]
assert ti.xcom_pull(), new_value
- def test_set_params_for_dag(self):
+ @pytest.mark.parametrize("value", [VALUE, 0])
+ def test_set_params_for_dag(self, value):
"""Test that dag param is correctly set when using dag decorator"""
@dag_decorator(default_args=self.DEFAULT_ARGS)
- def xcom_pass_to_op(value=self.VALUE):
+ def xcom_pass_to_op(value=value):
@task_decorator
def return_num(num):
return num
@@ -2166,7 +2167,7 @@ class TestDagDecorator(unittest.TestCase):
self.operator = xcom_arg.operator
dag = xcom_pass_to_op()
- assert dag.params['value'] == self.VALUE
+ assert dag.params['value'] == value
@pytest.mark.parametrize("run_id, execution_date", [(None, datetime_tz(2020,
1, 1)), ('test-run-id', None)])