This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new e29f2d7e58 Improved mocking of Context with TaskInstance for use in unit tests msgraph operator (#39157) e29f2d7e58 is described below commit e29f2d7e588234a5db3bb140b7846a1e028533f8 Author: David Blain <i...@dabla.be> AuthorDate: Tue Apr 23 11:58:02 2024 +0200 Improved mocking of Context with TaskInstance for use in unit tests msgraph operator (#39157) --------- Co-authored-by: David Blain <david.bl...@infrabel.be> --- tests/providers/microsoft/azure/base.py | 38 ++------------------------------- tests/providers/microsoft/conftest.py | 28 +++++++++++++++++------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/tests/providers/microsoft/azure/base.py b/tests/providers/microsoft/azure/base.py index 4cda62858e..cad6c1449f 100644 --- a/tests/providers/microsoft/azure/base.py +++ b/tests/providers/microsoft/azure/base.py @@ -19,57 +19,23 @@ from __future__ import annotations import asyncio from contextlib import contextmanager from copy import deepcopy -from datetime import datetime -from typing import TYPE_CHECKING, Any, Iterable +from typing import TYPE_CHECKING, Any from unittest.mock import patch from kiota_http.httpx_request_adapter import HttpxRequestAdapter from airflow.exceptions import TaskDeferred -from airflow.models import Operator, TaskInstance from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook -from airflow.utils.session import NEW_SESSION -from airflow.utils.xcom import XCOM_RETURN_KEY from tests.providers.microsoft.conftest import get_airflow_connection, mock_context if TYPE_CHECKING: - from sqlalchemy.orm import Session - + from airflow.models import Operator from airflow.triggers.base import BaseTrigger, TriggerEvent -class MockedTaskInstance(TaskInstance): - values = {} - - def xcom_pull( - self, - task_ids: Iterable[str] | str | None = None, - dag_id: str | None = None, - key: str = XCOM_RETURN_KEY, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - *, - map_indexes: Iterable[int] | int | None = None, - default: Any | None = None, - ) -> Any: - self.task_id = task_ids - self.dag_id = dag_id - return self.values.get(f"{task_ids}_{dag_id}_{key}") - - def xcom_push( - self, - key: str, - value: Any, - execution_date: datetime | None = None, - session: Session = NEW_SESSION, - ) -> None: - self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value - - class Base: def teardown_method(self, method): KiotaRequestAdapterHook.cached_request_adapters.clear() - MockedTaskInstance.values.clear() @contextmanager def patch_hook_and_request_adapter(self, response): diff --git a/tests/providers/microsoft/conftest.py b/tests/providers/microsoft/conftest.py index 78d8748a89..dfba931023 100644 --- a/tests/providers/microsoft/conftest.py +++ b/tests/providers/microsoft/conftest.py @@ -29,6 +29,7 @@ from httpx import Response from msgraph_core import APIVersion from airflow.models import Connection +from airflow.utils.context import Context if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -103,7 +104,7 @@ def mock_response(status_code, content: Any = None) -> Response: return response -def mock_context(task): +def mock_context(task) -> Context: from datetime import datetime from airflow.models import TaskInstance @@ -111,9 +112,20 @@ def mock_context(task): from airflow.utils.state import TaskInstanceState from airflow.utils.xcom import XCOM_RETURN_KEY + values = {} + class MockedTaskInstance(TaskInstance): - def __init__(self): - super().__init__(task=task, run_id="run_id", state=TaskInstanceState.RUNNING) + def __init__( + self, + task, + execution_date: datetime | None = None, + run_id: str | None = "run_id", + state: str | None = TaskInstanceState.RUNNING, + map_index: int = -1, + ): + super().__init__( + task=task, execution_date=execution_date, run_id=run_id, state=state, map_index=map_index + ) self.values = {} def xcom_pull( @@ -127,9 +139,7 @@ def mock_context(task): map_indexes: Iterable[int] | int | None = None, default: Any | None = None, ) -> Any: - self.task_id = task_ids - self.dag_id = dag_id - return self.values.get(f"{task_ids}_{dag_id}_{key}") + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") def xcom_push( self, @@ -138,9 +148,11 @@ def mock_context(task): execution_date: datetime | None = None, session: Session = NEW_SESSION, ) -> None: - self.values[f"{self.task_id}_{self.dag_id}_{key}"] = value + values[f"{self.task_id}_{self.dag_id}_{key}"] = value + + values["ti"] = MockedTaskInstance(task=task) - return {"ti": MockedTaskInstance()} + return Context(values) def load_json(*locations: Iterable[str]):