This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d7e14ba0d6 Fix Serialization error in TaskCallbackRequest (#25471)
d7e14ba0d6 is described below
commit d7e14ba0d612d8315238f9d0cba4ef8c44b6867c
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Tue Aug 2 22:50:40 2022 +0100
Fix Serialization error in TaskCallbackRequest (#25471)
How we serialize `SimpleTaskInstance `in `TaskCallbackRequest` class leads
to JSON serialization error when there's start_date or end_date in the task
instance. Since there's always a start_date on tis, this would always fail.
This PR aims to fix this through a new method on the SimpleTaskInstance
that looks for start_date/end_date and converts them to isoformat for
serialization.
---
airflow/callbacks/callback_requests.py | 2 +-
airflow/models/taskinstance.py | 10 ++++++++++
tests/callbacks/test_callback_requests.py | 21 +++++++++++++++++----
3 files changed, 28 insertions(+), 5 deletions(-)
diff --git a/airflow/callbacks/callback_requests.py
b/airflow/callbacks/callback_requests.py
index 8112589cd0..b04a201c08 100644
--- a/airflow/callbacks/callback_requests.py
+++ b/airflow/callbacks/callback_requests.py
@@ -75,7 +75,7 @@ class TaskCallbackRequest(CallbackRequest):
def to_json(self) -> str:
dict_obj = self.__dict__.copy()
- dict_obj["simple_task_instance"] =
dict_obj["simple_task_instance"].__dict__
+ dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
return json.dumps(dict_obj)
@classmethod
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index d83ad11b04..e52976e359 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2631,6 +2631,16 @@ class SimpleTaskInstance:
return self.__dict__ == other.__dict__
return NotImplemented
+ def as_dict(self):
+ new_dict = dict(self.__dict__)
+ for key in new_dict:
+ if key in ['start_date', 'end_date']:
+ val = new_dict[key]
+ if not val or isinstance(val, str):
+ continue
+ new_dict.update({key: val.isoformat()})
+ return new_dict
+
@classmethod
def from_ti(cls, ti: TaskInstance) -> "SimpleTaskInstance":
return cls(
diff --git a/tests/callbacks/test_callback_requests.py
b/tests/callbacks/test_callback_requests.py
index 286d64eaa1..3764f19c4c 100644
--- a/tests/callbacks/test_callback_requests.py
+++ b/tests/callbacks/test_callback_requests.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import unittest
from datetime import datetime
from parameterized import parameterized
@@ -29,6 +28,7 @@ from airflow.callbacks.callback_requests import (
from airflow.models.dag import DAG
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.operators.bash import BashOperator
+from airflow.utils import timezone
from airflow.utils.state import State
TI = TaskInstance(
@@ -38,7 +38,7 @@ TI = TaskInstance(
)
-class TestCallbackRequest(unittest.TestCase):
+class TestCallbackRequest:
@parameterized.expand(
[
(CallbackRequest(full_filepath="filepath", msg="task_failure"),
CallbackRequest),
@@ -64,7 +64,20 @@ class TestCallbackRequest(unittest.TestCase):
)
def test_from_json(self, input, request_class):
json_str = input.to_json()
-
result = request_class.from_json(json_str=json_str)
+ assert result == input
- self.assertEqual(result, input)
+ def test_taskcallback_to_json_with_start_date_and_end_date(self, session,
create_task_instance):
+ ti = create_task_instance()
+ ti.start_date = timezone.utcnow()
+ ti.end_date = timezone.utcnow()
+ session.merge(ti)
+ session.flush()
+ input = TaskCallbackRequest(
+ full_filepath="filepath",
+ simple_task_instance=SimpleTaskInstance.from_ti(ti),
+ is_failure_callback=True,
+ )
+ json_str = input.to_json()
+ result = TaskCallbackRequest.from_json(json_str)
+ assert input == result