This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e29f2d7e58 Improved mocking of Context with TaskInstance for use in
unit tests msgraph operator (#39157)
e29f2d7e58 is described below
commit e29f2d7e588234a5db3bb140b7846a1e028533f8
Author: David Blain <[email protected]>
AuthorDate: Tue Apr 23 11:58:02 2024 +0200
Improved mocking of Context with TaskInstance for use in unit tests msgraph
operator (#39157)
---------
Co-authored-by: David Blain <[email protected]>
---
tests/providers/microsoft/azure/base.py | 38 ++-------------------------------
tests/providers/microsoft/conftest.py | 28 +++++++++++++++++-------
2 files changed, 22 insertions(+), 44 deletions(-)
diff --git a/tests/providers/microsoft/azure/base.py
b/tests/providers/microsoft/azure/base.py
index 4cda62858e..cad6c1449f 100644
--- a/tests/providers/microsoft/azure/base.py
+++ b/tests/providers/microsoft/azure/base.py
@@ -19,57 +19,23 @@ from __future__ import annotations
import asyncio
from contextlib import contextmanager
from copy import deepcopy
-from datetime import datetime
-from typing import TYPE_CHECKING, Any, Iterable
+from typing import TYPE_CHECKING, Any
from unittest.mock import patch
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from airflow.exceptions import TaskDeferred
-from airflow.models import Operator, TaskInstance
from airflow.providers.microsoft.azure.hooks.msgraph import
KiotaRequestAdapterHook
-from airflow.utils.session import NEW_SESSION
-from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.providers.microsoft.conftest import get_airflow_connection,
mock_context
if TYPE_CHECKING:
- from sqlalchemy.orm import Session
-
+ from airflow.models import Operator
from airflow.triggers.base import BaseTrigger, TriggerEvent
-class MockedTaskInstance(TaskInstance):
- values = {}
-
- def xcom_pull(
- self,
- task_ids: Iterable[str] | str | None = None,
- dag_id: str | None = None,
- key: str = XCOM_RETURN_KEY,
- include_prior_dates: bool = False,
- session: Session = NEW_SESSION,
- *,
- map_indexes: Iterable[int] | int | None = None,
- default: Any | None = None,
- ) -> Any:
- self.task_id = task_ids
- self.dag_id = dag_id
- return self.values.get(f"{task_ids}_{dag_id}_{key}")
-
- def xcom_push(
- self,
- key: str,
- value: Any,
- execution_date: datetime | None = None,
- session: Session = NEW_SESSION,
- ) -> None:
- self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
-
-
class Base:
def teardown_method(self, method):
KiotaRequestAdapterHook.cached_request_adapters.clear()
- MockedTaskInstance.values.clear()
@contextmanager
def patch_hook_and_request_adapter(self, response):
diff --git a/tests/providers/microsoft/conftest.py
b/tests/providers/microsoft/conftest.py
index 78d8748a89..dfba931023 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -29,6 +29,7 @@ from httpx import Response
from msgraph_core import APIVersion
from airflow.models import Connection
+from airflow.utils.context import Context
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -103,7 +104,7 @@ def mock_response(status_code, content: Any = None) ->
Response:
return response
-def mock_context(task):
+def mock_context(task) -> Context:
from datetime import datetime
from airflow.models import TaskInstance
@@ -111,9 +112,20 @@ def mock_context(task):
from airflow.utils.state import TaskInstanceState
from airflow.utils.xcom import XCOM_RETURN_KEY
+ values = {}
+
class MockedTaskInstance(TaskInstance):
- def __init__(self):
- super().__init__(task=task, run_id="run_id",
state=TaskInstanceState.RUNNING)
+ def __init__(
+ self,
+ task,
+ execution_date: datetime | None = None,
+ run_id: str | None = "run_id",
+ state: str | None = TaskInstanceState.RUNNING,
+ map_index: int = -1,
+ ):
+ super().__init__(
+ task=task, execution_date=execution_date, run_id=run_id,
state=state, map_index=map_index
+ )
self.values = {}
def xcom_pull(
@@ -127,9 +139,7 @@ def mock_context(task):
map_indexes: Iterable[int] | int | None = None,
default: Any | None = None,
) -> Any:
- self.task_id = task_ids
- self.dag_id = dag_id
- return self.values.get(f"{task_ids}_{dag_id}_{key}")
+ return values.get(f"{task_ids or self.task_id}_{dag_id or
self.dag_id}_{key}")
def xcom_push(
self,
@@ -138,9 +148,11 @@ def mock_context(task):
execution_date: datetime | None = None,
session: Session = NEW_SESSION,
) -> None:
- self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+ values[f"{self.task_id}_{self.dag_id}_{key}"] = value
+
+ values["ti"] = MockedTaskInstance(task=task)
- return {"ti": MockedTaskInstance()}
+ return Context(values)
def load_json(*locations: Iterable[str]):