This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e29f2d7e58 Improved mocking of Context with TaskInstance for use in 
unit tests msgraph operator (#39157)
e29f2d7e58 is described below

commit e29f2d7e588234a5db3bb140b7846a1e028533f8
Author: David Blain <i...@dabla.be>
AuthorDate: Tue Apr 23 11:58:02 2024 +0200

    Improved mocking of Context with TaskInstance for use in unit tests msgraph 
operator (#39157)
    
    
    ---------
    
    Co-authored-by: David Blain <david.bl...@infrabel.be>
---
 tests/providers/microsoft/azure/base.py | 38 ++-------------------------------
 tests/providers/microsoft/conftest.py   | 28 +++++++++++++++++-------
 2 files changed, 22 insertions(+), 44 deletions(-)

diff --git a/tests/providers/microsoft/azure/base.py 
b/tests/providers/microsoft/azure/base.py
index 4cda62858e..cad6c1449f 100644
--- a/tests/providers/microsoft/azure/base.py
+++ b/tests/providers/microsoft/azure/base.py
@@ -19,57 +19,23 @@ from __future__ import annotations
 import asyncio
 from contextlib import contextmanager
 from copy import deepcopy
-from datetime import datetime
-from typing import TYPE_CHECKING, Any, Iterable
+from typing import TYPE_CHECKING, Any
 from unittest.mock import patch
 
 from kiota_http.httpx_request_adapter import HttpxRequestAdapter
 
 from airflow.exceptions import TaskDeferred
-from airflow.models import Operator, TaskInstance
 from airflow.providers.microsoft.azure.hooks.msgraph import 
KiotaRequestAdapterHook
-from airflow.utils.session import NEW_SESSION
-from airflow.utils.xcom import XCOM_RETURN_KEY
 from tests.providers.microsoft.conftest import get_airflow_connection, 
mock_context
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm import Session
-
+    from airflow.models import Operator
     from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 
-class MockedTaskInstance(TaskInstance):
-    values = {}
-
-    def xcom_pull(
-        self,
-        task_ids: Iterable[str] | str | None = None,
-        dag_id: str | None = None,
-        key: str = XCOM_RETURN_KEY,
-        include_prior_dates: bool = False,
-        session: Session = NEW_SESSION,
-        *,
-        map_indexes: Iterable[int] | int | None = None,
-        default: Any | None = None,
-    ) -> Any:
-        self.task_id = task_ids
-        self.dag_id = dag_id
-        return self.values.get(f"{task_ids}_{dag_id}_{key}")
-
-    def xcom_push(
-        self,
-        key: str,
-        value: Any,
-        execution_date: datetime | None = None,
-        session: Session = NEW_SESSION,
-    ) -> None:
-        self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
-
-
 class Base:
     def teardown_method(self, method):
         KiotaRequestAdapterHook.cached_request_adapters.clear()
-        MockedTaskInstance.values.clear()
 
     @contextmanager
     def patch_hook_and_request_adapter(self, response):
diff --git a/tests/providers/microsoft/conftest.py 
b/tests/providers/microsoft/conftest.py
index 78d8748a89..dfba931023 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -29,6 +29,7 @@ from httpx import Response
 from msgraph_core import APIVersion
 
 from airflow.models import Connection
+from airflow.utils.context import Context
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
@@ -103,7 +104,7 @@ def mock_response(status_code, content: Any = None) -> 
Response:
     return response
 
 
-def mock_context(task):
+def mock_context(task) -> Context:
     from datetime import datetime
 
     from airflow.models import TaskInstance
@@ -111,9 +112,20 @@ def mock_context(task):
     from airflow.utils.state import TaskInstanceState
     from airflow.utils.xcom import XCOM_RETURN_KEY
 
+    values = {}
+
     class MockedTaskInstance(TaskInstance):
-        def __init__(self):
-            super().__init__(task=task, run_id="run_id", 
state=TaskInstanceState.RUNNING)
+        def __init__(
+            self,
+            task,
+            execution_date: datetime | None = None,
+            run_id: str | None = "run_id",
+            state: str | None = TaskInstanceState.RUNNING,
+            map_index: int = -1,
+        ):
+            super().__init__(
+                task=task, execution_date=execution_date, run_id=run_id, 
state=state, map_index=map_index
+            )
             self.values = {}
 
         def xcom_pull(
@@ -127,9 +139,7 @@ def mock_context(task):
             map_indexes: Iterable[int] | int | None = None,
             default: Any | None = None,
         ) -> Any:
-            self.task_id = task_ids
-            self.dag_id = dag_id
-            return self.values.get(f"{task_ids}_{dag_id}_{key}")
+            return values.get(f"{task_ids or self.task_id}_{dag_id or 
self.dag_id}_{key}")
 
         def xcom_push(
             self,
@@ -138,9 +148,11 @@ def mock_context(task):
             execution_date: datetime | None = None,
             session: Session = NEW_SESSION,
         ) -> None:
-            self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+            values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+
+    values["ti"] = MockedTaskInstance(task=task)
 
-    return {"ti": MockedTaskInstance()}
+    return Context(values)
 
 
 def load_json(*locations: Iterable[str]):

Reply via email to