This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a7ab95cee2da2f9d89ec1aa6dc014122b81f6456 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Sun May 10 08:57:21 2020 +0100 Correctly store non-default Nones in serialized tasks/dags (#8772) The default schedule_interval for a DAG is `@daily`, so `schedule_interval=None` is actually not the default, but we were not storing _any_ null attributes previously. This meant that upon re-inflating the DAG the schedule_interval would become @daily. This fixes that problem, and extends the test to look at _all_ the serialized attributes in our round-trip tests, rather than just the few that the webserver cared about. It doesn't change the serialization format, it just changes what/when values were stored. This solution was more complex than I hoped for, but the test case in test_operator_subclass_changing_base_defaults is a real one that the round trip tests discovered from the DatabricksSubmitRunOperator -- I have just captured it in this test in case that specific operator changes in future. (cherry picked from commit a715aa692e88160cb8e9df4effda2440e4778c17) --- airflow/serialization/serialized_objects.py | 24 +++- tests/serialization/test_dag_serialization.py | 176 ++++++++++++++++++-------- 2 files changed, 139 insertions(+), 61 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 917d80f..3e564ec 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -122,10 +122,16 @@ class BaseSerialization: @classmethod def _is_excluded(cls, var, attrname, instance): """Types excluded from serialization.""" + + if var is None: + if not cls._is_constructor_param(attrname, instance): + # Any instance attribute, that is not a constructor argument, we exclude None as the default + return True + + return cls._value_is_hardcoded_default(attrname, var, instance) return ( - var is None or isinstance(var, cls._excluded_types) or - cls._value_is_hardcoded_default(attrname, var) + cls._value_is_hardcoded_default(attrname, var, instance) ) @classmethod @@ -259,7 +265,12 @@ class BaseSerialization: return datetime.timedelta(seconds=seconds) @classmethod - def _value_is_hardcoded_default(cls, attrname, value): + def _is_constructor_param(cls, attrname, instance): + # pylint: disable=unused-argument + return attrname in cls._CONSTRUCTOR_PARAMS + + @classmethod + def _value_is_hardcoded_default(cls, attrname, value, instance): """ Return true if ``value`` is the hard-coded default for the given attribute. This takes in to account cases where the ``concurrency`` parameter is @@ -273,8 +284,9 @@ class BaseSerialization: to account for the case where the default value of the field is None but has the ``field = field or {}`` set. """ + # pylint: disable=unused-argument if attrname in cls._CONSTRUCTOR_PARAMS and \ - (cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])): + (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])): return True return False @@ -288,7 +300,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization): _decorated_fields = {'executor_config', } _CONSTRUCTOR_PARAMS = { - k: v for k, v in signature(BaseOperator).parameters.items() + k: v.default for k, v in signature(BaseOperator).parameters.items() if v.default is not v.empty } @@ -511,7 +523,7 @@ class SerializedDAG(DAG, BaseSerialization): 'access_control': '_access_control', } return { - param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items() + param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items() if v.default is not v.empty } diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 52f6e1a..e28e2b2 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -18,7 +18,9 @@ # under the License. """Unit tests for stringified DAGs.""" +from glob import glob import multiprocessing +import os import unittest import six @@ -29,13 +31,10 @@ from datetime import datetime, timedelta from parameterized import parameterized from dateutil.relativedelta import relativedelta, FR -from airflow import example_dags -from airflow.contrib import example_dags as contrib_example_dags from airflow.hooks.base_hook import BaseHook from airflow.models import DAG, Connection, DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.operators.bash_operator import BashOperator -from airflow.operators.subdag_operator import SubDagOperator from airflow.serialization.json_schema import load_dag_schema_dict from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.utils.tests import CustomOperator, CustomOpLink, GoogleLink @@ -110,10 +109,14 @@ serialized_simple_dag_ground_truth = { }, } +ROOT_FOLDER = os.path.realpath( + os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir) +) -def make_example_dags(module): + +def make_example_dags(module_path): """Loads DAGs from a module for test.""" - dagbag = DagBag(module.__path__[0]) + dagbag = DagBag(module_path) return dagbag.dags @@ -170,22 +173,34 @@ def make_user_defined_macro_filter_dag(): return {dag.dag_id: dag} -def collect_dags(): +def collect_dags(dag_folder=None): """Collects DAGs to test.""" dags = {} dags.update(make_simple_dag()) dags.update(make_user_defined_macro_filter_dag()) - dags.update(make_example_dags(example_dags)) - dags.update(make_example_dags(contrib_example_dags)) + + if dag_folder: + if isinstance(dag_folder, (list, tuple)): + patterns = dag_folder + else: + patterns = [dag_folder] + else: + patterns = [ + "airflow/example_dags", + "airflow/contrib/example_dags", + ] + for pattern in patterns: + for directory in glob(ROOT_FOLDER + "/" + pattern): + dags.update(make_example_dags(directory)) # Filter subdags as they are stored in same row in Serialized Dag table dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag} return dags -def serialize_subprocess(queue): +def serialize_subprocess(queue, dag_folder): """Validate pickle in a subprocess.""" - dags = collect_dags() + dags = collect_dags(dag_folder) for dag in dags.values(): queue.put(SerializedDAG.to_json(dag)) queue.put(None) @@ -242,14 +257,17 @@ class TestStringifiedDAGs(unittest.TestCase): ) return dag_dict - self.assertEqual(sorted_serialized_dag(ground_truth_dag), - sorted_serialized_dag(json_dag)) + assert sorted_serialized_dag(ground_truth_dag) == sorted_serialized_dag(json_dag) - def test_deserialization(self): + def test_deserialization_across_process(self): """A serialized DAG can be deserialized in another process.""" + + # Since we need to parse the dags twice here (once in the subprocess, + # and once here to get a DAG to compare to) we don't want to load all + # dags. queue = multiprocessing.Queue() proc = multiprocessing.Process( - target=serialize_subprocess, args=(queue,)) + target=serialize_subprocess, args=(queue, "airflow/example_dags")) proc.daemon = True proc.start() @@ -262,69 +280,100 @@ class TestStringifiedDAGs(unittest.TestCase): self.assertTrue(isinstance(dag, DAG)) stringified_dags[dag.dag_id] = dag - dags = collect_dags() - self.assertTrue(set(stringified_dags.keys()) == set(dags.keys())) + dags = collect_dags("airflow/example_dags") + assert set(stringified_dags.keys()) == set(dags.keys()) # Verify deserialized DAGs. for dag_id in stringified_dags: self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id]) - example_skip_dag = stringified_dags['example_skip_dag'] - skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1'] - self.validate_deserialized_task( - skip_operator_1_task, 'DummySkipOperator', '#e8b7e4', '#000') + def test_roundtrip_provider_example_dags(self): + dags = collect_dags([ + "airflow/providers/*/example_dags", + "airflow/providers/*/*/example_dags", + ]) - # Verify that the DAG object has 'full_filepath' attribute - # and is equal to fileloc - self.assertTrue(hasattr(example_skip_dag, 'full_filepath')) - self.assertEqual(example_skip_dag.full_filepath, example_skip_dag.fileloc) - - example_subdag_operator = stringified_dags['example_subdag_operator'] - section_1_task = example_subdag_operator.task_dict['section-1'] - self.validate_deserialized_task( - section_1_task, - SubDagOperator.__name__, - SubDagOperator.ui_color, - SubDagOperator.ui_fgcolor - ) + # Verify deserialized DAGs. + for dag in dags.values(): + serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag)) + self.validate_deserialized_dag(serialized_dag, dag) def validate_deserialized_dag(self, serialized_dag, dag): """ Verify that all example DAGs work with DAG Serialization by checking fields between Serialized Dags & non-Serialized Dags """ - fields_to_check = [ - "params", "fileloc", "max_active_runs", "concurrency", - "is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag", - "catchup", "description", "start_date", "end_date", "parent_dag", - "template_searchpath", "_access_control", "dagrun_timeout" - ] + fields_to_check = dag.get_serialized_fields() - { + # Doesn't implement __eq__ properly. Check manually + 'timezone', - # fields_to_check = dag.get_serialized_fields() + # Need to check fields in it, to exclude functions + 'default_args', + } for field in fields_to_check: - self.assertEqual(getattr(serialized_dag, field), getattr(dag, field)) + assert getattr(serialized_dag, field) == getattr(dag, field), \ + '{}.{} does not match'.format(dag.dag_id, field) - self.assertEqual( - sorted(serialized_dag.task_ids), - sorted([str(task) for task in dag.task_ids])) + if dag.default_args: + for k, v in dag.default_args.items(): + if callable(v): + # Check we stored _someting_. + assert k in serialized_dag.default_args + else: + assert v == serialized_dag.default_args[k], \ + '{}.default_args[{}] does not match'.format(dag.dag_id, k) + + assert serialized_dag.timezone.name == dag.timezone.name + + for task_id in dag.task_ids: + self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id)) + + # Verify that the DAG object has 'full_filepath' attribute + # and is equal to fileloc + assert serialized_dag.full_filepath == dag.fileloc - def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor): + def validate_deserialized_task(self, serialized_task, task,): """Verify non-airflow operators are casted to BaseOperator.""" - self.assertTrue(isinstance(task, SerializedBaseOperator)) - # Verify the original operator class is recorded for UI. - self.assertTrue(task.task_type == task_type) - self.assertTrue(task.ui_color == ui_color) - self.assertTrue(task.ui_fgcolor == ui_fgcolor) + assert isinstance(serialized_task, SerializedBaseOperator) + assert not isinstance(task, SerializedBaseOperator) + assert isinstance(task, BaseOperator) + + fields_to_check = task.get_serialized_fields() - { + # Checked separately + '_task_type', 'subdag', + + # Type is exluded, so don't check it + '_log', + + # List vs tuple. Check separately + 'template_fields', + + # We store the string, real dag has the actual code + 'on_failure_callback', 'on_success_callback', 'on_retry_callback', + + # Checked separately + 'resources', + } + + assert serialized_task.task_type == task.task_type + assert set(serialized_task.template_fields) == set(task.template_fields) + + for field in fields_to_check: + assert getattr(serialized_task, field) == getattr(task, field), \ + '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field) + + if serialized_task.resources is None: + assert task.resources is None or task.resources == [] + else: + assert serialized_task.resources == task.resources # Check that for Deserialised task, task.subdag is None for all other Operators # except for the SubDagOperator where task.subdag is an instance of DAG object if task.task_type == "SubDagOperator": - self.assertIsNotNone(task.subdag) - self.assertTrue(isinstance(task.subdag, DAG)) + assert serialized_task.subdag is not None + assert isinstance(serialized_task.subdag, DAG) else: - self.assertIsNone(task.subdag) - self.assertEqual({}, task.params) - self.assertEqual({}, task.executor_config) + assert serialized_task.subdag is None @parameterized.expand([ (datetime(2019, 8, 1), None, datetime(2019, 8, 1)), @@ -650,6 +699,23 @@ class TestStringifiedDAGs(unittest.TestCase): dag_params = set(dag_schema.keys()) - ignored_keys self.assertEqual(set(DAG.get_serialized_fields()), dag_params) + def test_operator_subclass_changing_base_defaults(self): + assert BaseOperator(task_id='dummy').do_xcom_push is True, \ + "Precondition check! If this fails the test won't make sense" + + class MyOperator(BaseOperator): + def __init__(self, do_xcom_push=False, **kwargs): + super(MyOperator, self).__init__(**kwargs) + self.do_xcom_push = do_xcom_push + + op = MyOperator(task_id='dummy') + assert op.do_xcom_push is False + + blob = SerializedBaseOperator.serialize_operator(op) + serialized_op = SerializedBaseOperator.deserialize_operator(blob) + + assert serialized_op.do_xcom_push is False + def test_no_new_fields_added_to_base_operator(self): """ This test verifies that there are no new fields added to BaseOperator. And reminds that
