This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new f7da7d9  Fix ExternalTaskMarker serialized fields (#10924)
f7da7d9 is described below

commit f7da7d94b4ac6dc59fb50a4f4abba69776aac798
Author: Denis Evseev <[email protected]>
AuthorDate: Wed Sep 16 01:40:41 2020 +0300

    Fix ExternalTaskMarker serialized fields (#10924)
    
    Co-authored-by: Denis Evseev <[email protected]>
    Co-authored-by: Kaxil Naik <[email protected]>
---
 airflow/models/taskinstance.py             |  2 +-
 airflow/sensors/external_task_sensor.py    | 16 +++++++++++++++-
 tests/models/test_taskinstance.py          | 18 ++++++++++++++++++
 tests/sensors/test_external_task_sensor.py | 21 +++++++++++++++++++++
 4 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 76f3134..5280cb4 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -572,7 +572,7 @@ class TaskInstance(Base, LoggingMixin):     # pylint: 
disable=R0902,R0904
         self.run_as_user = task.run_as_user
         self.max_tries = task.retries
         self.executor_config = task.executor_config
-        self.operator = task.__class__.__name__
+        self.operator = task.task_type
 
     @provide_session
     def clear_xcom_data(self, session=None):
diff --git a/airflow/sensors/external_task_sensor.py 
b/airflow/sensors/external_task_sensor.py
index ab2fb20..365761e 100644
--- a/airflow/sensors/external_task_sensor.py
+++ b/airflow/sensors/external_task_sensor.py
@@ -18,7 +18,7 @@
 
 import datetime
 import os
-from typing import Optional, Union
+from typing import FrozenSet, Optional, Union
 
 from sqlalchemy import func
 
@@ -242,6 +242,9 @@ class ExternalTaskMarker(DummyOperator):
     template_fields = ['external_dag_id', 'external_task_id', 'execution_date']
     ui_color = '#19647e'
 
+    # The _serialized_fields are lazily loaded when get_serialized_fields() 
method is called
+    __serialized_fields: Optional[FrozenSet[str]] = None
+
     @apply_defaults
     def __init__(self, *,
                  external_dag_id,
@@ -262,3 +265,14 @@ class ExternalTaskMarker(DummyOperator):
         if recursion_depth <= 0:
             raise ValueError("recursion_depth should be a positive integer")
         self.recursion_depth = recursion_depth
+
+    @classmethod
+    def get_serialized_fields(cls):
+        """Serialized ExternalTaskMarker contain exactly these fields + 
templated_fields ."""
+        if not cls.__serialized_fields:
+            cls.__serialized_fields = frozenset(
+                super().get_serialized_fields() | {
+                    "recursion_depth"
+                }
+            )
+        return cls.__serialized_fields
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 4b30e09..21838e7 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -40,6 +40,7 @@ from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.python import PythonOperator
 from airflow.sensors.base_sensor_operator import BaseSensorOperator
 from airflow.sensors.python import PythonSensor
+from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.stats import Stats
 from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
 from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
@@ -1751,3 +1752,20 @@ class TestRunRawTaskQueriesCount(unittest.TestCase):
             "airflow.models.taskinstance.STORE_SERIALIZED_DAGS", True
         ):
             ti._run_raw_task()
+
+    def test_operator_field_with_serialization(self):
+
+        dag = DAG('test_queries', start_date=DEFAULT_DATE)
+        task = DummyOperator(task_id='op', dag=dag)
+        self.assertEqual(task.task_type, 'DummyOperator')
+
+        # Verify that ti.operator field renders correctly "without" 
Serialization
+        ti = TI(task=task, execution_date=datetime.datetime.now())
+        self.assertEqual(ti.operator, "DummyOperator")
+
+        serialized_op = SerializedBaseOperator.serialize_operator(task)
+        deserialized_op = 
SerializedBaseOperator.deserialize_operator(serialized_op)
+        self.assertEqual(deserialized_op.task_type, 'DummyOperator')
+        # Verify that ti.operator field renders correctly "with" Serialization
+        ser_ti = TI(task=deserialized_op, 
execution_date=datetime.datetime.now())
+        self.assertEqual(ser_ti.operator, "DummyOperator")
diff --git a/tests/sensors/test_external_task_sensor.py 
b/tests/sensors/test_external_task_sensor.py
index 8335e7b..a25023f 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -28,6 +28,7 @@ from airflow.operators.bash import BashOperator
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.sensors.external_task_sensor import ExternalTaskMarker, 
ExternalTaskSensor
 from airflow.sensors.time_sensor import TimeSensor
+from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 
@@ -396,6 +397,26 @@ exit 0
             )
 
 
+class TestExternalTaskMarker(unittest.TestCase):
+    def test_serialized_fields(self):
+        
self.assertTrue({"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields()))
+
+    def test_serialized_external_task_marker(self):
+        dag = DAG('test_serialized_external_task_marker', 
start_date=DEFAULT_DATE)
+        task = ExternalTaskMarker(
+            task_id="parent_task",
+            external_dag_id="external_task_marker_child",
+            external_task_id="child_task1",
+            dag=dag
+        )
+
+        serialized_op = SerializedBaseOperator.serialize_operator(task)
+        deserialized_op = 
SerializedBaseOperator.deserialize_operator(serialized_op)
+        self.assertEqual(deserialized_op.task_type, 'ExternalTaskMarker')
+        self.assertEqual(getattr(deserialized_op, 'external_dag_id'), 
'external_task_marker_child')
+        self.assertEqual(getattr(deserialized_op, 'external_task_id'), 
'child_task1')
+
+
 @pytest.fixture
 def dag_bag_ext():
     """

Reply via email to