This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a0f7e61497 Allow DagParam to hold falsy values (#22964)
a0f7e61497 is described below

commit a0f7e61497d547b82edc1154d39535d79aaedff3
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Apr 13 15:48:46 2022 +0800

    Allow DagParam to hold falsy values (#22964)
---
 airflow/models/dag.py    |  2 +-
 airflow/models/param.py  | 24 ++++++++++++++----------
 tests/models/test_dag.py | 23 ++++++++++++-----------
 3 files changed, 27 insertions(+), 22 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 8d5e8eacd6..b00b4f666a 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -985,7 +985,7 @@ class DAG(LoggingMixin):
     def pickle_id(self, value: int) -> None:
         self._pickle_id = value
 
-    def param(self, name: str, default=None) -> DagParam:
+    def param(self, name: str, default: Any = NOTSET) -> DagParam:
         """
         Return a DagParam object for current dag.
 
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 22d6a02639..fcbe7a0f93 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -14,15 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import contextlib
 import copy
 import json
 import warnings
-from typing import Any, Dict, ItemsView, MutableMapping, Optional, ValuesView
+from typing import TYPE_CHECKING, Any, Dict, ItemsView, MutableMapping, 
Optional, ValuesView
 
 from airflow.exceptions import AirflowException, ParamValidationError
 from airflow.utils.context import Context
 from airflow.utils.types import NOTSET, ArgNotSet
 
+if TYPE_CHECKING:
+    from airflow.models.dag import DAG
+
 
 class Param:
     """
@@ -228,18 +232,18 @@ class DagParam:
     :param default: Default value used if no parameter was set.
     """
 
-    def __init__(self, current_dag, name: str, default: Optional[Any] = None):
-        if default:
+    def __init__(self, current_dag: "DAG", name: str, default: Any = NOTSET):
+        if default is not NOTSET:
             current_dag.params[name] = default
         self._name = name
         self._default = default
 
     def resolve(self, context: Context) -> Any:
         """Pull DagParam value from DagRun context. This method is run during 
``op.execute()``."""
-        default = self._default
-        if not self._default:
-            default = context['params'][self._name] if self._name in 
context['params'] else None
-        resolved = context['dag_run'].conf.get(self._name, default)
-        if not resolved:
-            raise AirflowException(f'No value could be resolved for parameter 
{self._name}')
-        return resolved
+        with contextlib.suppress(KeyError):
+            return context['dag_run'].conf[self._name]
+        if self._default is not NOTSET:
+            return self._default
+        with contextlib.suppress(KeyError):
+            return context['params'][self._name]
+        raise AirflowException(f'No value could be resolved for parameter 
{self._name}')
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 49378b3d26..72515e4b5e 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -27,7 +27,7 @@ from contextlib import redirect_stdout
 from datetime import timedelta
 from pathlib import Path
 from tempfile import NamedTemporaryFile
-from typing import List, Optional, Sequence
+from typing import List, Optional
 from unittest import mock
 from unittest.mock import patch
 
@@ -499,7 +499,7 @@ class TestDag(unittest.TestCase):
                 task = DummyOperator(task_id='op1')
 
             task.test_field = template_file
-            task.template_fields: Sequence[str] = ('test_field',)
+            task.template_fields = ('test_field',)
             task.template_ext = ('.template',)
             task.resolve_template_files()
 
@@ -517,7 +517,7 @@ class TestDag(unittest.TestCase):
                 task = DummyOperator(task_id='op1')
 
             task.test_field = [template_file, 'some_string']
-            task.template_fields: Sequence[str] = ('test_field',)
+            task.template_fields = ('test_field',)
             task.template_ext = ('.template',)
             task.resolve_template_files()
 
@@ -2006,7 +2006,7 @@ class TestQueries(unittest.TestCase):
             )
 
 
-class TestDagDecorator(unittest.TestCase):
+class TestDagDecorator:
     DEFAULT_ARGS = {
         "owner": "test",
         "depends_on_past": True,
@@ -2017,12 +2017,10 @@ class TestDagDecorator(unittest.TestCase):
     DEFAULT_DATE = timezone.datetime(2016, 1, 1)
     VALUE = 42
 
-    def setUp(self):
-        super().setUp()
+    def setup_method(self):
         self.operator = None
 
-    def tearDown(self):
-        super().tearDown()
+    def teardown_method(self):
         clear_db_runs()
 
     def test_fileloc(self):
@@ -2118,6 +2116,7 @@ class TestDagDecorator(unittest.TestCase):
             run_id=DagRunType.MANUAL.value,
             start_date=timezone.utcnow(),
             execution_date=self.DEFAULT_DATE,
+            data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
             state=State.RUNNING,
         )
 
@@ -2145,6 +2144,7 @@ class TestDagDecorator(unittest.TestCase):
             run_id=DagRunType.MANUAL.value,
             start_date=timezone.utcnow(),
             execution_date=self.DEFAULT_DATE,
+            data_interval=(self.DEFAULT_DATE, self.DEFAULT_DATE),
             state=State.RUNNING,
             conf={'value': new_value},
         )
@@ -2153,11 +2153,12 @@ class TestDagDecorator(unittest.TestCase):
         ti = dr.get_task_instances()[0]
         assert ti.xcom_pull(), new_value
 
-    def test_set_params_for_dag(self):
+    @pytest.mark.parametrize("value", [VALUE, 0])
+    def test_set_params_for_dag(self, value):
         """Test that dag param is correctly set when using dag decorator"""
 
         @dag_decorator(default_args=self.DEFAULT_ARGS)
-        def xcom_pass_to_op(value=self.VALUE):
+        def xcom_pass_to_op(value=value):
             @task_decorator
             def return_num(num):
                 return num
@@ -2166,7 +2167,7 @@ class TestDagDecorator(unittest.TestCase):
             self.operator = xcom_arg.operator
 
         dag = xcom_pass_to_op()
-        assert dag.params['value'] == self.VALUE
+        assert dag.params['value'] == value
 
 
 @pytest.mark.parametrize("run_id, execution_date", [(None, datetime_tz(2020, 
1, 1)), ('test-run-id', None)])

Reply via email to