This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v1-10-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a22dba100401f934f907eab150f12efaae428819 Author: Denis Evseev <[email protected]> AuthorDate: Wed Sep 16 01:40:41 2020 +0300 Fix operator field update for SerializedBaseOperator (#10924) Co-authored-by: Denis Evseev <[email protected]> Co-authored-by: Kaxil Naik <[email protected]> (cherry picked from commit f7da7d94b4ac6dc59fb50a4f4abba69776aac798) (cherry picked from commit cfc9732d71ae5b4b65077f2ba9cd51180a6c4548) --- airflow/models/taskinstance.py | 2 +- airflow/sensors/external_task_sensor.py | 14 ++++++++++++++ tests/models/test_taskinstance.py | 18 ++++++++++++++++++ tests/sensors/test_external_task_sensor.py | 21 +++++++++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ae296ba..7c1caef 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -512,7 +512,7 @@ class TaskInstance(Base, LoggingMixin): self.run_as_user = task.run_as_user self.max_tries = task.retries self.executor_config = task.executor_config - self.operator = task.__class__.__name__ + self.operator = task.task_type @provide_session def clear_xcom_data(self, session=None): diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py index 2dc0875..b759a71 100644 --- a/airflow/sensors/external_task_sensor.py +++ b/airflow/sensors/external_task_sensor.py @@ -201,6 +201,9 @@ class ExternalTaskMarker(DummyOperator): template_fields = ['external_dag_id', 'external_task_id', 'execution_date'] ui_color = '#19647e' + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + __serialized_fields = None + @apply_defaults def __init__(self, external_dag_id, @@ -222,3 +225,14 @@ class ExternalTaskMarker(DummyOperator): if recursion_depth <= 0: raise ValueError("recursion_depth should be a positive integer") self.recursion_depth = recursion_depth + + @classmethod + def get_serialized_fields(cls): + """Serialized ExternalTaskMarker contain exactly these fields + templated_fields .""" + if not cls.__serialized_fields: + cls.__serialized_fields = frozenset( + super(ExternalTaskMarker, cls).get_serialized_fields() | { + "recursion_depth" + } + ) + return cls.__serialized_fields diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 0c416a4..3b05bbb 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -40,6 +40,7 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator from airflow.ti_deps.dep_context import REQUEUEABLE_DEPS, RUNNABLE_STATES, RUNNING_DEPS +from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.ti_deps.deps.base_ti_dep import TIDepStatus from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone @@ -1560,6 +1561,23 @@ class TaskInstanceTest(unittest.TestCase): with create_session() as session: session.query(RenderedTaskInstanceFields).delete() + def test_operator_field_with_serialization(self): + + dag = DAG('test_queries', start_date=DEFAULT_DATE) + task = DummyOperator(task_id='op', dag=dag) + self.assertEqual(task.task_type, 'DummyOperator') + + # Verify that ti.operator field renders correctly "without" Serialization + ti = TI(task=task, execution_date=datetime.datetime.now()) + self.assertEqual(ti.operator, "DummyOperator") + + serialized_op = SerializedBaseOperator.serialize_operator(task) + deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) + self.assertEqual(deserialized_op.task_type, 'DummyOperator') + # Verify that ti.operator field renders correctly "with" Serialization + ser_ti = TI(task=deserialized_op, execution_date=datetime.datetime.now()) + self.assertEqual(ser_ti.operator, "DummyOperator") + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) def test_refresh_from_task(pool_override): diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index e2a58ec..00d5835 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -27,6 +27,7 @@ from airflow.operators.bash_operator import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.sensors.external_task_sensor import ExternalTaskMarker, ExternalTaskSensor from airflow.sensors.time_sensor import TimeSensor +from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.state import State from airflow.utils.timezone import datetime @@ -339,6 +340,26 @@ exit 0 ) +class TestExternalTaskMarker(unittest.TestCase): + def test_serialized_fields(self): + self.assertTrue({"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields())) + + def test_serialized_external_task_marker(self): + dag = DAG('test_serialized_external_task_marker', start_date=DEFAULT_DATE) + task = ExternalTaskMarker( + task_id="parent_task", + external_dag_id="external_task_marker_child", + external_task_id="child_task1", + dag=dag + ) + + serialized_op = SerializedBaseOperator.serialize_operator(task) + deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op) + self.assertEqual(deserialized_op.task_type, 'ExternalTaskMarker') + self.assertEqual(getattr(deserialized_op, 'external_dag_id'), 'external_task_marker_child') + self.assertEqual(getattr(deserialized_op, 'external_task_id'), 'child_task1') + + @pytest.fixture def dag_bag_ext(): """
