This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2d374f71bc8 Remove TaskInstance and TaskLogReader unused methods
(#59922)
2d374f71bc8 is described below
commit 2d374f71bc81202204ac0208df07b07c280668fa
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Tue Jan 6 01:18:57 2026 +0800
Remove TaskInstance and TaskLogReader unused methods (#59922)
---
airflow-core/newsfragments/59835.significant.rst | 7 +-
.../src/airflow/cli/commands/task_command.py | 56 +++++---
airflow-core/src/airflow/models/taskinstance.py | 160 +--------------------
airflow-core/src/airflow/serialization/enums.py | 1 -
.../airflow/serialization/serialized_objects.py | 24 +---
airflow-core/src/airflow/utils/context.py | 76 +++-------
airflow-core/src/airflow/utils/helpers.py | 36 +----
airflow-core/src/airflow/utils/log/log_reader.py | 26 ----
airflow-core/tests/unit/core/test_core.py | 10 +-
.../tests/unit/models/test_taskinstance.py | 37 +++--
.../tests/unit/utils/log/test_log_reader.py | 57 --------
airflow-core/tests/unit/utils/test_helpers.py | 19 ---
devel-common/src/tests_common/pytest_plugin.py | 42 +++++-
.../src/tests_common/test_utils/taskinstance.py | 64 +++++++--
.../aws/executors/aws_lambda/lambda_executor.py | 8 +-
.../amazon/aws/executors/batch/batch_executor.py | 7 +-
.../amazon/aws/executors/ecs/ecs_executor.py | 7 +-
.../providers/amazon/aws/hooks/sagemaker.py | 7 +-
.../executors/aws_lambda/test_lambda_executor.py | 6 +-
.../tests/unit/amazon/aws/hooks/test_appflow.py | 5 +-
.../unit/amazon/aws/operators/test_appflow.py | 5 +-
.../tests/unit/amazon/aws/operators/test_athena.py | 10 +-
.../amazon/aws/operators/test_cloud_formation.py | 6 +-
.../unit/amazon/aws/operators/test_datasync.py | 28 ++--
.../tests/unit/amazon/aws/operators/test_dms.py | 18 +--
.../amazon/aws/operators/test_emr_add_steps.py | 8 +-
.../aws/operators/test_emr_create_job_flow.py | 10 +-
.../aws/operators/test_emr_modify_cluster.py | 6 +-
.../tests/unit/amazon/aws/operators/test_rds.py | 6 +-
.../tests/unit/amazon/aws/operators/test_s3.py | 2 +-
.../amazon/aws/operators/test_sagemaker_base.py | 8 +-
.../tests/unit/amazon/aws/sensors/test_ecs.py | 5 +-
.../tests/unit/amazon/aws/sensors/test_rds.py | 6 +-
.../tests/unit/amazon/aws/sensors/test_s3.py | 10 +-
.../unit/amazon/aws/transfers/test_mongo_to_s3.py | 2 +-
.../unit/apache/druid/operators/test_druid.py | 13 +-
.../decorators/test_kubernetes_commons.py | 10 +-
.../unit/cncf/kubernetes/operators/test_pod.py | 6 +-
.../io/tests/unit/common/io/xcom/test_backend.py | 6 +-
.../cloud/tests/unit/dbt/cloud/hooks/test_dbt.py | 5 +-
.../tests/unit/docker/decorators/test_docker.py | 30 +---
.../tests/unit/github/operators/test_github.py | 5 +-
.../tests/unit/github/sensors/test_github.py | 5 +-
.../tests/unit/google/cloud/hooks/test_gcs.py | 6 +-
.../mysql/tests/unit/mysql/hooks/test_mysql.py | 13 +-
.../unit/slack/transfers/test_sql_to_slack.py | 5 +-
.../slack/transfers/test_sql_to_slack_webhook.py | 5 +-
.../tests/unit/standard/decorators/test_bash.py | 30 +++-
.../tests/unit/standard/decorators/test_python.py | 74 +++++-----
.../tests/unit/standard/operators/test_python.py | 43 ++----
.../tests/unit/standard/sensors/test_time.py | 5 +-
.../check_template_context_variable_in_sync.py | 34 +----
task-sdk/src/airflow/sdk/definitions/dag.py | 12 +-
.../src/airflow/sdk/definitions/mappedoperator.py | 10 +-
54 files changed, 335 insertions(+), 767 deletions(-)
diff --git a/airflow-core/newsfragments/59835.significant.rst
b/airflow-core/newsfragments/59835.significant.rst
index c55ff910611..02b9d709a49 100644
--- a/airflow-core/newsfragments/59835.significant.rst
+++ b/airflow-core/newsfragments/59835.significant.rst
@@ -1,5 +1,6 @@
Methods removed from TaskInstance
-On class ``TaskInstance``, functions ``run()``, ``render_templates()``, and
-private members related to them have been removed. The class has been
-considered internal since 3.0, and should not be relied on in user code.
+On class ``TaskInstance``, functions ``run()``, ``render_templates()``,
+``get_template_context()``, and private members related to them have been
+removed. The class has been considered internal since 3.0, and should not be
+relied on in user code.
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py
b/airflow-core/src/airflow/cli/commands/task_command.py
index 5e87912b5bc..3cdd5c9f619 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -62,6 +62,8 @@ if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
+ from airflow.sdk import Context
+ from airflow.sdk.types import Operator as SdkOperator
from airflow.serialization.definitions.mappedoperator import Operator
CreateIfNecessary = Literal[False, "db", "memory"]
@@ -224,6 +226,24 @@ def _get_ti(
return ti, dr_created
+def _get_template_context(ti: TaskInstance, task: SdkOperator) -> Context:
+ from airflow.api_fastapi.execution_api.datamodels.taskinstance import
DagRun, TaskInstance, TIRunContext
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ runtime_ti = RuntimeTaskInstance.model_construct(
+ **TaskInstance.model_validate(ti,
from_attributes=True).model_dump(exclude_unset=True),
+ task=task,
+ _ti_context_from_server=TIRunContext(
+ dag_run=DagRun.model_validate(ti.dag_run, from_attributes=True),
+ max_tries=ti.max_tries,
+ variables=[],
+ connections=[],
+ xcom_keys_to_clear=[],
+ ),
+ )
+ return runtime_ti.get_template_context()
+
+
class TaskCommandMarker:
"""Marker for listener hooks, to properly detect from which component they
are called."""
@@ -441,27 +461,21 @@ def task_render(args, dag: DAG | None = None) -> None:
create_if_necessary="memory",
)
- with create_session() as session:
- context = ti.get_template_context(session=session)
- task = sdk_dag.get_task(args.task_id)
- # TODO (GH-52141): After sdk separation, ti.get_template_context()
would
- # contain serialized operators, but we need the real operators for
- # rendering. This does not make sense and eventually we should rewrite
- # this entire function so "ti" is a RuntimeTaskInstance instead, but
for
- # now we'll just manually fix it to contain the right objects.
- context["task"] = context["ti"].task = task
- task.render_template_fields(context)
- for attr in context["task"].template_fields:
- print(
- textwrap.dedent(
- f"""\
- #
----------------------------------------------------------
- # property: {attr}
- #
----------------------------------------------------------
- """
- )
- + str(getattr(context["task"], attr)) # This shouldn't be
dedented.
- )
+ task = sdk_dag.get_task(args.task_id)
+ context = _get_template_context(ti, task)
+ task.render_template_fields(context)
+ for attr in task.template_fields:
+ print(
+ textwrap.dedent(
+ f"""\
+ # ----------------------------------------------------------
+ # property: {attr}
+ # ----------------------------------------------------------
+ """
+ ),
+ getattr(context["task"], attr), # This shouldn't be dedented.
+ sep="",
+ )
@cli_utils.action_cli(check_db=False)
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 61c6d6f3691..d41849f3e20 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -27,13 +27,11 @@ import uuid
from collections import defaultdict
from collections.abc import Collection, Iterable
from datetime import datetime, timedelta
-from functools import cache
from typing import TYPE_CHECKING, Any
from urllib.parse import quote
import attrs
import dill
-import lazy_object_proxy
import uuid6
from sqlalchemy import (
JSON,
@@ -72,7 +70,7 @@ from airflow._shared.timezones import timezone
from airflow.assets.manager import asset_manager
from airflow.configuration import conf
from airflow.listeners.listener import get_listener_manager
-from airflow.models.asset import AssetEvent, AssetModel
+from airflow.models.asset import AssetModel
from airflow.models.base import Base, StringID, TaskInstanceDependencies
from airflow.models.dag_version import DagVersion
@@ -106,7 +104,6 @@ if TYPE_CHECKING:
from datetime import datetime
from typing import Literal
- import pendulum
from sqlalchemy.engine import Connection as SAConnection, Engine
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Update
@@ -115,7 +112,6 @@ if TYPE_CHECKING:
from airflow.api_fastapi.execution_api.datamodels.asset import AssetProfile
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
- from airflow.sdk import Context
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.mappedoperator import Operator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
@@ -1567,160 +1563,6 @@ class TaskInstance(Base, LoggingMixin):
return bool(self.task.retries and self.try_number <= self.max_tries)
- # TODO (GH-52141): We should remove this entire function (only makes sense
at runtime).
- def get_template_context(
- self,
- session: Session | None = None,
- ignore_param_exceptions: bool = True,
- ) -> Context:
- """
- Return TI Context.
-
- :param session: SQLAlchemy ORM Session
- :param ignore_param_exceptions: flag to suppress value exceptions
while initializing the ParamsDict
- """
- # Do not use provide_session here -- it expunges everything on exit!
- if not session:
- session = settings.get_session()()
-
- from airflow.exceptions import NotMapped
- from airflow.sdk.api.datamodels._generated import (
- DagRun as DagRunSDK,
- PrevSuccessfulDagRunResponse,
- TIRunContext,
- )
- from airflow.sdk.definitions.param import process_params
- from airflow.sdk.execution_time.context import InletEventsAccessors
- from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
- from airflow.serialization.definitions.mappedoperator import
get_mapped_ti_count
- from airflow.utils.context import (
- ConnectionAccessor,
- OutletEventAccessors,
- VariableAccessor,
- )
-
- if TYPE_CHECKING:
- assert session
-
- def _get_dagrun(session: Session) -> DagRun:
- dag_run = self.get_dagrun(session)
- if dag_run in session:
- return dag_run
- # The dag_run may not be attached to the session anymore since the
- # code base is over-zealous with use of session.expunge_all().
- # Re-attach it if the relation is not loaded so we can load it
when needed.
- info: Any = inspect(dag_run)
- if info.attrs.consumed_asset_events.loaded_value is not NO_VALUE:
- return dag_run
- # If dag_run is not flushed to db at all (e.g. CLI commands using
- # in-memory objects for ad-hoc operations), just set the value
manually.
- if not info.has_identity:
- dag_run.consumed_asset_events = []
- return dag_run
- return session.merge(dag_run, load=False)
-
- task: Any = self.task
- dag = task.dag
- dag_run = _get_dagrun(session)
-
- validated_params = process_params(dag, task, dag_run.conf,
suppress_exception=ignore_param_exceptions)
- runtime_ti = RuntimeTaskInstance.model_construct(
- id=self.id,
- task_id=self.task_id,
- dag_id=self.dag_id,
- run_id=self.run_id,
- try_numer=self.try_number,
- map_index=self.map_index,
- task=self.task,
- max_tries=self.max_tries,
- hostname=self.hostname,
- _ti_context_from_server=TIRunContext(
- dag_run=DagRunSDK.model_validate(dag_run,
from_attributes=True),
- max_tries=self.max_tries,
- should_retry=self.is_eligible_to_retry(),
- ),
- start_date=self.start_date,
- dag_version_id=self.dag_version_id,
- )
-
- context: Context = runtime_ti.get_template_context()
-
- @cache # Prevent multiple database access.
- def _get_previous_dagrun_success() -> PrevSuccessfulDagRunResponse:
- dr_from_db = self.get_previous_dagrun(state=DagRunState.SUCCESS,
session=session)
- if dr_from_db:
- return PrevSuccessfulDagRunResponse.model_validate(dr_from_db,
from_attributes=True)
- return PrevSuccessfulDagRunResponse()
-
- def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
- return
timezone.coerce_datetime(_get_previous_dagrun_success().data_interval_start)
-
- def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
- return
timezone.coerce_datetime(_get_previous_dagrun_success().data_interval_end)
-
- def get_prev_start_date_success() -> pendulum.DateTime | None:
- return
timezone.coerce_datetime(_get_previous_dagrun_success().start_date)
-
- def get_prev_end_date_success() -> pendulum.DateTime | None:
- return
timezone.coerce_datetime(_get_previous_dagrun_success().end_date)
-
- def get_triggering_events() -> dict[str, list[AssetEvent]]:
- asset_events = dag_run.consumed_asset_events
- triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
- for event in asset_events:
- if event.asset:
- triggering_events[event.asset.uri].append(event)
-
- return triggering_events
-
- # NOTE: If you add to this dict, make sure to also update the
following:
- # * Context in task-sdk/src/airflow/sdk/definitions/context.py
- # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
- # * Table in docs/apache-airflow/templates-ref.rst
-
- context.update(
- {
- "outlet_events": OutletEventAccessors(),
- "inlet_events": InletEventsAccessors(task.inlets),
- "params": validated_params,
- "prev_data_interval_start_success":
get_prev_data_interval_start_success(),
- "prev_data_interval_end_success":
get_prev_data_interval_end_success(),
- "prev_start_date_success": get_prev_start_date_success(),
- "prev_end_date_success": get_prev_end_date_success(),
- "test_mode": self.test_mode,
- # ti/task_instance are added here for ti.xcom_{push,pull}
- "task_instance": self,
- "ti": self,
- "triggering_asset_events":
lazy_object_proxy.Proxy(get_triggering_events),
- "var": {
- "json": VariableAccessor(deserialize_json=True),
- "value": VariableAccessor(deserialize_json=False),
- },
- "conn": ConnectionAccessor(),
- }
- )
-
- try:
- expanded_ti_count: int | None = get_mapped_ti_count(task,
self.run_id, session=session)
- context["expanded_ti_count"] = expanded_ti_count
- if expanded_ti_count:
- setattr(
- self,
- "_upstream_map_indexes",
- {
- upstream.task_id:
self.get_relevant_upstream_map_indexes(
- upstream,
- expanded_ti_count,
- session=session,
- )
- for upstream in task.upstream_list
- },
- )
- except NotMapped:
- pass
-
- return context
-
def set_duration(self) -> None:
"""Set task instance duration."""
if self.end_date and self.start_date:
diff --git a/airflow-core/src/airflow/serialization/enums.py
b/airflow-core/src/airflow/serialization/enums.py
index 484b6ed5ad9..5fdd4d66987 100644
--- a/airflow-core/src/airflow/serialization/enums.py
+++ b/airflow-core/src/airflow/serialization/enums.py
@@ -63,7 +63,6 @@ class DagAttributeTypes(str, Enum):
ASSET_UNIQUE_KEY = "asset_unique_key"
ASSET_ALIAS_UNIQUE_KEY = "asset_alias_unique_key"
CONNECTION = "connection"
- TASK_CONTEXT = "task_context"
ARG_NOT_SET = "arg_not_set"
TASK_CALLBACK_REQUEST = "task_callback_request"
DAG_CALLBACK_REQUEST = "dag_callback_request"
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 53b6377b96e..f3e5c74cfe4 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -40,7 +40,6 @@ import pydantic
from dateutil import relativedelta
from pendulum.tz.timezone import FixedTimezone, Timezone
-from airflow import macros
from airflow._shared.module_loading import import_string, qualname
from airflow._shared.timezones.timezone import from_timestamp, parse_timezone,
utcnow
from airflow.callbacks.callback_requests import DagCallbackRequest,
TaskCallbackRequest
@@ -100,7 +99,6 @@ from airflow.task.priority_strategy import (
from airflow.timetables.base import DagRunInfo, Timetable
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
-from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
from airflow.utils.db import LazySelectSequence
if TYPE_CHECKING:
@@ -654,12 +652,6 @@ class BaseSerialization:
elif isinstance(var, MappedArgument):
data = {"input": encode_expand_input(var._input), "key": var._key}
return cls._encode(data, type_=DAT.MAPPED_ARGUMENT)
- elif var.__class__ == Context:
- d = {}
- for k, v in var.items():
- obj = cls.serialize(v, strict=strict)
- d[str(k)] = obj
- return cls._encode(d, type_=DAT.TASK_CONTEXT)
else:
return cls.default_serialization(strict, var)
@@ -686,21 +678,7 @@ class BaseSerialization:
raise ValueError(f"The encoded_var should be dict and is
{type(encoded_var)}")
var = encoded_var[Encoding.VAR]
type_ = encoded_var[Encoding.TYPE]
- if type_ == DAT.TASK_CONTEXT:
- d = {}
- for k, v in var.items():
- if k == "task": # todo: add `_encode` of Operator so we don't
need this
- continue
- d[k] = cls.deserialize(v)
- d["task"] = d["task_instance"].task # todo: add `_encode` of
Operator so we don't need this
- d["macros"] = macros
- d["var"] = {
- "json": VariableAccessor(deserialize_json=True),
- "value": VariableAccessor(deserialize_json=False),
- }
- d["conn"] = ConnectionAccessor()
- return Context(**d)
- elif type_ == DAT.DICT:
+ if type_ == DAT.DICT:
return {k: cls.deserialize(v) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
return decode_outlet_event_accessors(var)
diff --git a/airflow-core/src/airflow/utils/context.py
b/airflow-core/src/airflow/utils/context.py
index 02e324b083b..4783becb562 100644
--- a/airflow-core/src/airflow/utils/context.py
+++ b/airflow-core/src/airflow/utils/context.py
@@ -32,51 +32,16 @@ from airflow.sdk.execution_time.context import (
VariableAccessor as VariableAccessorSDK,
)
from airflow.serialization.definitions.notset import NOTSET, is_arg_set
-from airflow.utils.deprecation_tools import DeprecatedImportWarning
+from airflow.utils.deprecation_tools import DeprecatedImportWarning,
add_deprecated_classes
from airflow.utils.session import create_session
-# NOTE: Please keep this in sync with the following:
-# * Context in task-sdk/src/airflow/sdk/definitions/context.py
-# * Table in docs/apache-airflow/templates-ref.rst
-KNOWN_CONTEXT_KEYS: set[str] = {
- "conn",
- "dag",
- "dag_run",
- "data_interval_end",
- "data_interval_start",
- "ds",
- "ds_nodash",
- "expanded_ti_count",
- "exception",
- "inlets",
- "inlet_events",
- "logical_date",
- "macros",
- "map_index_template",
- "outlets",
- "outlet_events",
- "params",
- "prev_data_interval_start_success",
- "prev_data_interval_end_success",
- "prev_start_date_success",
- "prev_end_date_success",
- "reason",
- "run_id",
- "start_date",
- "task",
- "task_reschedule_count",
- "task_instance",
- "task_instance_key_str",
- "test_mode",
- "templates_dict",
- "ti",
- "triggering_asset_events",
- "ts",
- "ts_nodash",
- "ts_nodash_with_tz",
- "try_number",
- "var",
-}
+warnings.warn(
+ "Module airflow.utils.context is deprecated and will be removed in the "
+ "future. Use airflow.sdk.execution_time.context if you are using the "
+ "classes inside an Airflow task.",
+ DeprecatedImportWarning,
+ stacklevel=2,
+)
class VariableAccessor(VariableAccessorSDK):
@@ -140,17 +105,14 @@ class OutletEventAccessors(OutletEventAccessorsSDK):
return Asset(name=asset.name, uri=asset.uri, group=asset.group,
extra=asset.extra)
-def __getattr__(name: str):
- if name in ("Context", "context_copy_partial", "context_merge"):
- warnings.warn(
- "Importing Context from airflow.utils.context is deprecated and
will "
- "be removed in the future. Please import it from airflow.sdk
instead.",
- DeprecatedImportWarning,
- stacklevel=2,
- )
-
- import airflow.sdk.definitions.context as sdk
-
- return getattr(sdk, name)
-
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+add_deprecated_classes(
+ {
+ __name__: {
+ "KNOWN_CONTEXT_KEYS": "airflow.sdk.definitions.context",
+ "Context": "airflow.sdk.definitions.context",
+ "context_copy_partial": "airflow.sdk.definitions.context",
+ "context_merge": "airflow.sdk.definitions.context",
+ },
+ },
+ package=__name__,
+)
diff --git a/airflow-core/src/airflow/utils/helpers.py
b/airflow-core/src/airflow/utils/helpers.py
index 5b977ff75d1..a48780022a6 100644
--- a/airflow-core/src/airflow/utils/helpers.py
+++ b/airflow-core/src/airflow/utils/helpers.py
@@ -23,7 +23,7 @@ import re
import signal
from collections.abc import Callable, Generator, Iterable, MutableMapping
from functools import cache
-from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
+from typing import TYPE_CHECKING, Any, TypeVar, overload
from urllib.parse import urljoin
from lazy_object_proxy import Proxy
@@ -39,7 +39,6 @@ if TYPE_CHECKING:
import jinja2
from airflow.models.taskinstance import TaskInstance
- from airflow.sdk.definitions.context import Context
CT = TypeVar("CT", str, datetime)
@@ -160,39 +159,6 @@ def log_filename_template_renderer() -> Callable[..., str]:
return f_str_format
-def _render_template_to_string(template: jinja2.Template, context: Context) ->
str:
- """
- Render a Jinja template to string using the provided context.
-
- This is a private utility function specifically for log filename rendering.
- It ensures templates are rendered as strings rather than native Python
objects.
- """
- return render_template(template, cast("MutableMapping[str, Any]",
context), native=False)
-
-
-def render_log_filename(ti: TaskInstance, try_number, filename_template) ->
str:
- """
- Given task instance, try_number, filename_template, return the rendered
log filename.
-
- :param ti: task instance
- :param try_number: try_number of the task
- :param filename_template: filename template, which can be jinja template or
- python string template
- """
- filename_template, filename_jinja_template =
parse_template_string(filename_template)
- if filename_jinja_template:
- jinja_context = ti.get_template_context()
- jinja_context["try_number"] = try_number
- return _render_template_to_string(filename_jinja_template,
jinja_context)
-
- return filename_template.format(
- dag_id=ti.dag_id,
- task_id=ti.task_id,
- logical_date=ti.logical_date.isoformat(),
- try_number=try_number,
- )
-
-
def convert_camel_to_snake(camel_str: str) -> str:
"""Convert CamelCase to snake_case."""
return CAMELCASE_TO_SNAKE_CASE_REGEX.sub(r"_\1", camel_str).lower()
diff --git a/airflow-core/src/airflow/utils/log/log_reader.py
b/airflow-core/src/airflow/utils/log/log_reader.py
index 9ebc3a9050c..99576559e96 100644
--- a/airflow-core/src/airflow/utils/log/log_reader.py
+++ b/airflow-core/src/airflow/utils/log/log_reader.py
@@ -25,17 +25,13 @@ from functools import cached_property
from typing import TYPE_CHECKING
from airflow.configuration import conf
-from airflow.utils.helpers import render_log_filename
from airflow.utils.log.file_task_handler import FileTaskHandler,
StructuredLogMessage
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
-from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
from typing import TypeAlias
- from sqlalchemy.orm.session import Session
-
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.utils.log.file_task_handler import LogHandlerOutputStream,
LogMetadata
@@ -190,25 +186,3 @@ class TaskLogReader:
return False
return self.log_handler.supports_external_link
-
- @provide_session
- def render_log_filename(
- self,
- ti: TaskInstance | TaskInstanceHistory,
- try_number: int | None = None,
- *,
- session: Session = NEW_SESSION,
- ) -> str:
- """
- Render the log attachment filename.
-
- :param ti: The task instance
- :param try_number: The task try number
- """
- dagrun = ti.get_dagrun(session=session)
- attachment_filename = render_log_filename(
- ti=ti,
- try_number="all" if try_number is None else try_number,
-
filename_template=dagrun.get_log_template(session=session).filename,
- )
- return attachment_filename
diff --git a/airflow-core/tests/unit/core/test_core.py
b/airflow-core/tests/unit/core/test_core.py
index aa341560553..66b70ddcfa2 100644
--- a/airflow-core/tests/unit/core/test_core.py
+++ b/airflow-core/tests/unit/core/test_core.py
@@ -27,6 +27,7 @@ from airflow.sdk import BaseOperator
from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_dags, clear_db_runs
+from tests_common.test_utils.taskinstance import get_template_context
pytestmark = pytest.mark.db_test
@@ -87,14 +88,13 @@ class TestCore:
params={"key_2": "value_2_new", "key_3": "value_3"},
)
task2 = EmptyOperator(task_id="task2")
- dr = dag_maker.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- )
+ dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
ti1 = dag_maker.run_ti(task1.task_id, dr)
ti2 = dag_maker.run_ti(task2.task_id, dr)
+ ti1.dag_run = ti2.dag_run = dr
- context1 = ti1.get_template_context()
- context2 = ti2.get_template_context()
+ context1 = get_template_context(ti1, task1)
+ context2 = get_template_context(ti2, task2)
assert context1["params"] == {"key_1": "value_1", "key_2":
"value_2_new", "key_3": "value_3"}
assert context2["params"] == {"key_1": "value_1", "key_2":
"value_2_old"}
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 5de8ca8db3c..e607efd1a44 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -71,9 +71,8 @@ from airflow.providers.standard.operators.hitl import (
)
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.sensors.python import PythonSensor
-from airflow.sdk import DAG, BaseOperator, BaseSensorOperator, Metadata, task,
task_group
+from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator,
BaseSensorOperator, Metadata, task, task_group
from airflow.sdk.api.datamodels._generated import AssetEventResponse,
AssetResponse
-from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sdk.definitions.param import process_params
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.execution_time.comms import AssetEventsResult
@@ -444,6 +443,7 @@ class TestTaskInstance:
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
dag=dag_maker.dag,
)
+ dag_maker.sync_dag_to_db()
ti2 = _create_task_instance(task=task2, run_id=ti.run_id,
dag_version_id=ti.dag_version_id)
session.add(ti2)
session.flush()
@@ -552,18 +552,19 @@ class TestTaskInstance:
retry_delay=datetime.timedelta(seconds=0),
)
- def run_with_error(ti):
+ def run_with_error():
with contextlib.suppress(AirflowException):
- dag_maker.run_ti(ti.task_id, ti.dag_run)
- ti.refresh_from_db(session)
+ dag_maker.run_ti(ti.task_id, dag_run)
+ return session.get(TaskInstance, ti.id)
- ti =
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
+ dag_run = dag_maker.create_dagrun(logical_date=timezone.utcnow())
+ ti = dag_run.task_instances[0]
assert ti.try_number == 0
session.get(TaskInstance, ti.id).try_number += 1
session.commit()
# first run -- up for retry
- run_with_error(ti)
+ ti = run_with_error()
assert ti.state == State.UP_FOR_RETRY
assert ti.try_number == 1
@@ -571,7 +572,7 @@ class TestTaskInstance:
session.commit()
# second run -- fail
- run_with_error(ti)
+ ti = run_with_error()
assert ti.state == State.FAILED
assert ti.try_number == 2
@@ -579,14 +580,13 @@ class TestTaskInstance:
# clearing it first
dag.clear()
- ti.refresh_from_db(session)
+ ti.refresh_from_db()
ti.try_number += 1
session.add(ti)
session.commit()
# third run -- up for retry
- run_with_error(ti)
- ti.refresh_from_db()
+ ti = run_with_error()
assert ti.state == State.UP_FOR_RETRY
assert ti.try_number == 3
@@ -594,8 +594,7 @@ class TestTaskInstance:
session.commit()
# fourth run -- fail
- run_with_error(ti)
- ti.refresh_from_db()
+ ti = run_with_error()
assert ti.state == State.FAILED
assert ti.try_number == 4
assert RenderedTaskInstanceFields.get_templated_fields(ti) ==
expected_rendered_ti_fields
@@ -2094,9 +2093,7 @@ class TestTaskInstance:
Test that when a task that produces asset has ran, that changing the
consumer
dag asset will not cause primary key blank-out
"""
- from airflow.sdk.definitions.asset import Asset
-
- with dag_maker(schedule=None, serialized=True) as dag1:
+ with dag_maker(schedule=None, serialized=False) as dag1:
@task(outlets=Asset("test/1"))
def test_task1():
@@ -2107,7 +2104,7 @@ class TestTaskInstance:
dr1 = dag_maker.create_dagrun()
test_task1 = dag1.get_task("test_task1")
- with dag_maker(dag_id="testdag", schedule=[Asset("test/1")],
serialized=True):
+ with dag_maker(dag_id="testdag", schedule=[Asset("test/1")]):
@task
def test_task2():
@@ -2119,7 +2116,7 @@ class TestTaskInstance:
run_task_instance(ti, dag1.get_task(ti.task_id))
# Change the asset.
- with dag_maker(dag_id="testdag", schedule=[Asset("test2/1")],
serialized=True):
+ with dag_maker(dag_id="testdag", schedule=[Asset("test2/1")]):
@task
def test_task2():
@@ -2644,7 +2641,7 @@ class TestTaskInstanceRecordTaskMapXComPush:
@pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2},
"abc"])
def test_not_recorded_if_leaf(self, dag_maker, xcom_value):
"""Return value should not be recorded if there are no downstreams."""
- with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+ with dag_maker(dag_id="test_not_recorded_for_unused",
serialized=False) as dag:
@dag.task()
def push_something():
@@ -2660,7 +2657,7 @@ class TestTaskInstanceRecordTaskMapXComPush:
@pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2},
"abc"])
def test_not_recorded_if_not_used(self, dag_maker, xcom_value):
"""Return value should not be recorded if no downstreams are mapped."""
- with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+ with dag_maker(dag_id="test_not_recorded_for_unused",
serialized=False) as dag:
@dag.task()
def push_something():
diff --git a/airflow-core/tests/unit/utils/log/test_log_reader.py
b/airflow-core/tests/unit/utils/log/test_log_reader.py
index d1279f71323..addcb5ebdee 100644
--- a/airflow-core/tests/unit/utils/log/test_log_reader.py
+++ b/airflow-core/tests/unit/utils/log/test_log_reader.py
@@ -17,23 +17,18 @@
from __future__ import annotations
import copy
-import datetime
import os
import sys
import tempfile
import types
-from typing import TYPE_CHECKING
from unittest import mock
-import pendulum
import pytest
from airflow import settings
from airflow._shared.timezones import timezone
from airflow.config_templates.airflow_local_settings import
DEFAULT_LOGGING_CONFIG
from airflow.models.tasklog import LogTemplate
-from airflow.providers.standard.operators.python import PythonOperator
-from airflow.timetables.base import DataInterval
from airflow.utils.log.log_reader import TaskLogReader
from airflow.utils.log.logging_mixin import ExternalLoggingMixin
from airflow.utils.state import TaskInstanceState
@@ -46,10 +41,6 @@ from tests_common.test_utils.file_task_handler import
convert_list_to_stream
pytestmark = pytest.mark.db_test
-if TYPE_CHECKING:
- from airflow.models import DagRun
-
-
class TestLogView:
DAG_ID = "dag_log_reader"
TASK_ID = "task_log_reader"
@@ -292,54 +283,6 @@ class TestLogView:
mock_prop.return_value = True
assert task_log_reader.supports_external_link
- def test_task_log_filename_unique(self, dag_maker):
- """
- Ensure the default log_filename_template produces a unique filename.
-
- See discussion in apache/airflow#19058 [1]_ for how uniqueness may
- change in a future Airflow release. For now, the logical date is used
- to distinguish DAG runs. This test should be modified when the logical
- date is no longer used to ensure uniqueness.
-
- [1]: https://github.com/apache/airflow/issues/19058
- """
- dag_id = "test_task_log_filename_ts_corresponds_to_logical_date"
- task_id = "echo_run_type"
-
- def echo_run_type(dag_run: DagRun, **kwargs):
- print(dag_run.run_type)
-
- with dag_maker(dag_id, start_date=self.DEFAULT_DATE,
schedule="@daily"):
- PythonOperator(task_id=task_id, python_callable=echo_run_type)
-
- start = pendulum.datetime(2021, 1, 1)
- end = start + datetime.timedelta(days=1)
- trigger_time = end + datetime.timedelta(hours=4, minutes=29) #
Arbitrary.
-
- # Create two DAG runs that have the same data interval, but not the
same
- # logical date, to check if they correctly use different log files.
- scheduled_dagrun: DagRun = dag_maker.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- logical_date=start,
- data_interval=DataInterval(start, end),
- )
- manual_dagrun: DagRun = dag_maker.create_dagrun(
- run_type=DagRunType.MANUAL,
- logical_date=trigger_time,
- data_interval=DataInterval(start, end),
- )
-
- scheduled_ti = scheduled_dagrun.get_task_instance(task_id)
- manual_ti = manual_dagrun.get_task_instance(task_id)
- assert scheduled_ti is not None
- assert manual_ti is not None
-
-
scheduled_ti.refresh_from_task(dag_maker.serialized_dag.get_task(task_id))
- manual_ti.refresh_from_task(dag_maker.serialized_dag.get_task(task_id))
-
- reader = TaskLogReader()
- assert reader.render_log_filename(scheduled_ti, 1) !=
reader.render_log_filename(manual_ti, 1)
-
@pytest.mark.parametrize(
("state", "try_number", "expected_event", "use_self_ti"),
[
diff --git a/airflow-core/tests/unit/utils/test_helpers.py
b/airflow-core/tests/unit/utils/test_helpers.py
index 6179acadfb9..8e16d118698 100644
--- a/airflow-core/tests/unit/utils/test_helpers.py
+++ b/airflow-core/tests/unit/utils/test_helpers.py
@@ -23,7 +23,6 @@ from typing import TYPE_CHECKING
import pytest
-from airflow._shared.timezones import timezone
from airflow.exceptions import AirflowException
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.serialization.definitions.notset import NOTSET
@@ -55,24 +54,6 @@ def clear_db():
class TestHelpers:
- @pytest.mark.db_test
- @pytest.mark.usefixtures("clear_db")
- def test_render_log_filename(self, create_task_instance):
- try_number = 1
- dag_id = "test_render_log_filename_dag"
- task_id = "test_render_log_filename_task"
- logical_date = timezone.datetime(2016, 1, 1)
-
- ti = create_task_instance(dag_id=dag_id, task_id=task_id,
logical_date=logical_date)
- filename_template = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{
try_number }}.log"
-
- ts = ti.get_template_context()["ts"]
- expected_filename = f"{dag_id}/{task_id}/{ts}/{try_number}.log"
-
- rendered_filename = helpers.render_log_filename(ti, try_number,
filename_template)
-
- assert rendered_filename == expected_filename
-
def test_chunks(self):
with pytest.raises(ValueError, match=CHUNK_SIZE_POSITIVE_INT):
list(helpers.chunks([1, 2, 3], 0))
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 3dc2cec736e..b7e79451a85 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -807,6 +807,8 @@ class DagMaker(Generic[Dag], Protocol):
def get_serialized_data(self) -> dict[str, Any]: ...
+ def sync_dag_to_db(self) -> None: ...
+
def create_dagrun(
self,
run_id: str = ...,
@@ -818,6 +820,15 @@ class DagMaker(Generic[Dag], Protocol):
def create_dagrun_after(self, dagrun: DagRun, **kwargs) -> DagRun: ...
+ def create_ti(
+ self,
+ task_id: str,
+ dag_run: DagRun | None = ...,
+ dag_run_kwargs: dict | None = ...,
+ map_index: int = ...,
+ **kwargs,
+ ) -> TaskInstance: ...
+
def run_ti(
self,
task_id: str,
@@ -1047,6 +1058,9 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
self._bag_dag_compat(dag)
self.session.flush()
+ def sync_dag_to_db(self):
+ self._make_serdag(self.dag)
+
def create_dagrun(self, *, logical_date=NOTSET, **kwargs):
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
@@ -1155,15 +1169,15 @@ def dag_maker(request) -> Generator[DagMaker, None,
None]:
**kwargs,
)
- def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None,
map_index=-1, **kwargs):
+ def create_ti(self, task_id, dag_run=None, dag_run_kwargs=None,
map_index=-1):
"""
- Create a dagrun and run a specific task instance with proper task
refresh.
+ Create a specific task instance with proper task refresh.
+
+ This is a convenience method for creating a single task instance:
- This is a convenience method for running a single task instance:
1. Create a dagrun if it does not exist
2. Get the specific task instance by task_id
3. Refresh the task instance from the DAG task
- 4. Run the task instance
Returns the created TaskInstance.
"""
@@ -1178,7 +1192,23 @@ def dag_maker(request) -> Generator[DagMaker, None,
None]:
f"Task instance with task_id '{task_id}' not found in dag
run. "
f"Available task_ids: {available_task_ids}"
)
+ if AIRFLOW_V_3_2_PLUS:
+ ti.refresh_from_task(self.serialized_dag.get_task(task_id))
+ else:
+ ti.refresh_from_task(self.dag.get_task(task_id))
+ return ti
+
+ def run_ti(self, task_id, dag_run=None, dag_run_kwargs=None,
map_index=-1, **kwargs):
+ """
+ Run a specific task instance with proper task refresh.
+ This is a convenience method for running a single task instance:
+
+ 1. Call ``create_ti()`` to obtain the task instance
+ 2. Run the task instance
+
+ Returns the created and run TaskInstance.
+ """
task = self.dag.get_task(task_id)
if not AIRFLOW_V_3_1_PLUS:
# Airflow <3.1 has a bug for DecoratedOperator has an unused
signature for
@@ -1192,13 +1222,13 @@ def dag_maker(request) -> Generator[DagMaker, None,
None]:
#
^^^^^^^^^^^^^^
# E AttributeError: '_PythonDecoratedOperator' object has no
attribute 'xcom_push'
task.xcom_push = lambda *args, **kwargs: None
+
+ ti = self.create_ti(task_id, dag_run=dag_run,
dag_run_kwargs=dag_run_kwargs, map_index=map_index)
if AIRFLOW_V_3_2_PLUS:
from tests_common.test_utils.taskinstance import
run_task_instance
- ti.refresh_from_task(self.serialized_dag.get_task(task_id))
run_task_instance(ti, task, **kwargs)
else:
- ti.refresh_from_task(task)
ti.run(**kwargs)
return ti
diff --git a/devel-common/src/tests_common/test_utils/taskinstance.py
b/devel-common/src/tests_common/test_utils/taskinstance.py
index 7f79819aee3..b82ad5a7fc7 100644
--- a/devel-common/src/tests_common/test_utils/taskinstance.py
+++ b/devel-common/src/tests_common/test_utils/taskinstance.py
@@ -22,6 +22,7 @@ import copy
from typing import TYPE_CHECKING
from airflow.models.taskinstance import TaskInstance
+from airflow.utils.session import NEW_SESSION
from tests_common.test_utils.compat import SerializedBaseOperator,
SerializedMappedOperator
from tests_common.test_utils.dag import create_scheduler_dag
@@ -36,9 +37,10 @@ if TYPE_CHECKING:
from uuid import UUID
from jinja2 import Environment
+ from sqlalchemy.orm import Session
from airflow.sdk import Context
- from airflow.sdk.types import Operator as SdkOperator
+ from airflow.sdk.types import Operator as SdkOperator,
RuntimeTaskInstanceProtocol
from airflow.serialization.definitions.mappedoperator import Operator as
SerializedOperator
__all__ = ["TaskInstanceWrapper", "create_task_instance",
"render_template_fields", "run_task_instance"]
@@ -62,16 +64,15 @@ class TaskInstanceWrapper:
def __copy__(self):
return TaskInstanceWrapper(copy.copy(self.__dict__["__ti"]),
copy.copy(self.__dict__["__task"]))
- def run(self, **kwargs) -> None:
- from tests_common.test_utils.taskinstance import run_task_instance
-
- run_task_instance(self.__dict__["__ti"], self.__dict__["__task"],
**kwargs)
+ def run(self, **kwargs) -> RuntimeTaskInstanceProtocol:
+ return run_task_instance(self.__dict__["__ti"],
self.__dict__["__task"], **kwargs)
def render_templates(self, **kwargs) -> SdkOperator:
- from tests_common.test_utils.taskinstance import render_template_fields
-
return render_template_fields(self.__dict__["__ti"],
self.__dict__["__task"], **kwargs)
+ def get_template_context(self) -> Context:
+ return get_template_context(self.__dict__["__ti"],
self.__dict__["__task"])
+
def create_task_instance(
task: SdkOperator | SerializedOperator,
@@ -113,42 +114,75 @@ def run_task_instance(
ignore_ti_state: bool = False,
mark_success: bool = False,
session=None,
-):
+) -> RuntimeTaskInstanceProtocol:
+ session_kwargs = {"session": session} if session else {}
if not AIRFLOW_V_3_2_PLUS:
ti.refresh_from_task(task) # type: ignore[arg-type]
- ti.run()
+ ti.run(**session_kwargs)
return ti
- kwargs = {"session": session} if session else {}
if not ti.check_and_change_state_before_execution(
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
mark_success=mark_success,
- **kwargs,
+ **session_kwargs,
):
return ti
from airflow.sdk.definitions.dag import _run_task
- taskrun_result = _run_task(ti=ti, task=task)
+ # Session handling is a mess in tests; use a fresh ti to run the task.
+ new_ti = TaskInstance.get_task_instance(
+ dag_id=ti.dag_id,
+ run_id=ti.run_id,
+ task_id=ti.task_id,
+ map_index=ti.map_index,
+ **session_kwargs,
+ )
+ # Some tests don't even save the ti at all, in which case new_ti is None.
+ taskrun_result = _run_task(ti=new_ti or ti, task=task)
+ ti.refresh_from_db(**session_kwargs) # Some tests expect side effects.
if not taskrun_result:
- return None
+ raise RuntimeError("task failed to finish with a result")
if error := taskrun_result.error:
raise error
return taskrun_result.ti
+def get_template_context(ti: TaskInstance, task: SdkOperator, *, session:
Session = NEW_SESSION) -> Context:
+ if not AIRFLOW_V_3_2_PLUS:
+ ti.refresh_from_task(task) # type: ignore[arg-type]
+ return ti.get_template_context(session=session)
+
+ from airflow.cli.commands.task_command import _get_template_context
+ from airflow.utils.context import ConnectionAccessor, VariableAccessor
+
+ # TODO: Move these to test_utils too.
+ context = _get_template_context(ti, task)
+ context["ti"].__dict__.update(xcom_push=ti.xcom_push,
xcom_pull=ti.xcom_pull) # Avoid execution API.
+ context.update( # type: ignore[call-arg] #
https://github.com/python/mypy/issues/17750
+ conn=ConnectionAccessor(),
+ test_mode=ti.test_mode,
+ var={
+ "json": VariableAccessor(deserialize_json=True),
+ "value": VariableAccessor(deserialize_json=False),
+ },
+ )
+ return context
+
+
def render_template_fields(
ti: TaskInstance,
task: SdkOperator,
*,
context: Context | None = None,
jinja_env: Environment | None = None,
+ session: Session = NEW_SESSION,
) -> SdkOperator:
if AIRFLOW_V_3_2_PLUS:
- task.render_template_fields(context or ti.get_template_context(),
jinja_env)
+ task.render_template_fields(context or get_template_context(ti, task),
jinja_env)
return task
ti.refresh_from_task(task) # type: ignore[arg-type]
- ti.render_templates(context, jinja_env)
+ ti.render_templates(context or ti.get_template_context(session=session),
jinja_env)
return ti.task # type: ignore[return-value]
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
index f65113ad2a6..67b23e0f0aa 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py
@@ -40,14 +40,8 @@ from
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
)
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-from airflow.providers.common.compat.sdk import AirflowException, Stats, conf
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
+from airflow.providers.common.compat.sdk import AirflowException, Stats, conf,
timezone
if TYPE_CHECKING:
from sqlalchemy.orm import Session
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
index 8657bb07872..0b018a2ea0e 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py
@@ -35,12 +35,7 @@ from
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.providers.common.compat.sdk import AirflowException, Stats, conf
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from airflow.providers.common.compat.sdk import AirflowException, Stats, conf,
timezone
from airflow.utils.helpers import merge_dicts
if TYPE_CHECKING:
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index efd54def2ff..0887dcd47a8 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -48,12 +48,7 @@ from
airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.providers.common.compat.sdk import AirflowException, Stats
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from airflow.providers.common.compat.sdk import AirflowException, Stats,
timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
index 3929694ae37..f071ff844c9 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -35,12 +35,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.tags import format_tags
-from airflow.providers.common.compat.sdk import AirflowException
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from airflow.providers.common.compat.sdk import AirflowException, timezone
class LogState:
diff --git
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
index 4bce677a3c1..d1265b33a0b 100644
---
a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
+++
b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py
@@ -31,14 +31,10 @@ from airflow.providers.amazon.aws.executors.aws_lambda
import lambda_executor
from airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor import
AwsLambdaExecutor
from airflow.providers.amazon.aws.executors.aws_lambda.utils import
CONFIG_GROUP_NAME, AllLambdaConfigKeys
from airflow.providers.common.compat.sdk import AirflowException
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.utils.state import TaskInstanceState
from airflow.version import version as airflow_version_str
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
index b8ba2d927cd..8f31e9caf30 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_appflow.py
@@ -25,10 +25,7 @@ import pytest
from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
FLOW_NAME = "flow0"
EXECUTION_ID = "ex_id"
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
index 9e21004385d..86f765808d5 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_appflow.py
@@ -32,10 +32,7 @@ from airflow.providers.amazon.aws.operators.appflow import (
AppflowRunOperator,
)
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
CONN_ID = "aws_default"
DAG_ID = "dag_id"
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
index 67d945a19ea..4b0b91f0f3f 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_athena.py
@@ -41,16 +41,12 @@ from airflow.providers.openlineage.extractors import
OperatorLineage
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance
+from tests_common.test_utils.taskinstance import create_task_instance,
get_template_context
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
TEST_DAG_ID = "unit_tests"
DEFAULT_DATE = timezone.datetime(2018, 1, 1)
ATHENA_QUERY_ID = "eac29bf8-daa1-4ffc-b19a-0db31dc3b784"
@@ -272,7 +268,7 @@ class TestAthenaOperator:
ti.dag_run = dag_run
session.add(ti)
session.commit()
- assert self.athena.execute(ti.get_template_context()) ==
ATHENA_QUERY_ID
+ assert self.athena.execute(get_template_context(ti, self.athena)) ==
ATHENA_QUERY_ID
@mock.patch.object(AthenaHook, "check_query_status",
side_effect=("SUCCEEDED",))
@mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
index 5643eb9c48b..f718220df7e 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_cloud_formation.py
@@ -28,11 +28,7 @@ from airflow.providers.amazon.aws.operators.cloud_formation
import (
CloudFormationDeleteStackOperator,
)
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
DEFAULT_DATE = timezone.datetime(2019, 1, 1)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
index c47f948482e..2addb0ce539 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_datasync.py
@@ -30,16 +30,12 @@ from airflow.providers.common.compat.sdk import
AirflowException
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance
+from tests_common.test_utils.taskinstance import create_task_instance,
get_template_context
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
TEST_DAG_ID = "unit_tests"
DEFAULT_DATE = timezone.datetime(2018, 1, 1)
@@ -373,7 +369,7 @@ class TestDataSyncOperatorCreate(DataSyncTestCaseBase):
run_type=DagRunType.MANUAL,
state=DagRunState.RUNNING,
)
- ti = create_task_instance(task=self.datasync,
dag_version_id=dag_version.id)
+ ti = create_task_instance(task=self.datasync, run_id="test",
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=self.dag.dag_id,
@@ -386,7 +382,7 @@ class TestDataSyncOperatorCreate(DataSyncTestCaseBase):
ti.dag_run = dag_run
session.add(ti)
session.commit()
- assert self.datasync.execute(ti.get_template_context()) is not None
+ assert self.datasync.execute(get_template_context(ti, self.datasync))
is not None
# ### Check mocks:
mock_get_conn.assert_called()
@@ -587,7 +583,7 @@ class TestDataSyncOperatorGetTasks(DataSyncTestCaseBase):
sync_dag_to_db(self.dag)
dag_version = DagVersion.get_latest_version(self.dag.dag_id)
- ti = create_task_instance(task=self.datasync,
dag_version_id=dag_version.id)
+ ti = create_task_instance(task=self.datasync, run_id="test",
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=timezone.utcnow(),
@@ -607,7 +603,7 @@ class TestDataSyncOperatorGetTasks(DataSyncTestCaseBase):
ti.dag_run = dag_run
session.add(ti)
session.commit()
- result = self.datasync.execute(ti.get_template_context())
+ result = self.datasync.execute(get_template_context(ti, self.datasync))
assert result["TaskArn"] == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -710,7 +706,7 @@ class TestDataSyncOperatorUpdate(DataSyncTestCaseBase):
sync_dag_to_db(self.dag)
dag_version = DagVersion.get_latest_version(self.dag.dag_id)
- ti = create_task_instance(task=self.datasync,
dag_version_id=dag_version.id)
+ ti = create_task_instance(task=self.datasync, run_id="test",
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=timezone.utcnow(),
@@ -730,7 +726,7 @@ class TestDataSyncOperatorUpdate(DataSyncTestCaseBase):
ti.dag_run = dag_run
session.add(ti)
session.commit()
- result = self.datasync.execute(ti.get_template_context())
+ result = self.datasync.execute(get_template_context(ti, self.datasync))
assert result["TaskArn"] == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
@@ -926,7 +922,7 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
sync_dag_to_db(self.dag)
dag_version = DagVersion.get_latest_version(self.dag.dag_id)
- ti = create_task_instance(task=self.datasync,
dag_version_id=dag_version.id)
+ ti = create_task_instance(task=self.datasync, run_id="test",
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=timezone.utcnow(),
@@ -946,7 +942,7 @@ class TestDataSyncOperator(DataSyncTestCaseBase):
ti.dag_run = dag_run
session.add(ti)
session.commit()
- assert self.datasync.execute(ti.get_template_context()) is not None
+ assert self.datasync.execute(get_template_context(ti, self.datasync))
is not None
# ### Check mocks:
mock_get_conn.assert_called()
@@ -1045,7 +1041,7 @@ class TestDataSyncOperatorDelete(DataSyncTestCaseBase):
sync_dag_to_db(self.dag)
dag_version = DagVersion.get_latest_version(self.dag.dag_id)
- ti = create_task_instance(task=self.datasync,
dag_version_id=dag_version.id)
+ ti = create_task_instance(task=self.datasync, run_id="test",
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=timezone.utcnow(),
@@ -1065,7 +1061,7 @@ class TestDataSyncOperatorDelete(DataSyncTestCaseBase):
ti.dag_run = dag_run
session.add(ti)
session.commit()
- result = self.datasync.execute(ti.get_template_context())
+ result = self.datasync.execute(get_template_context(ti, self.datasync))
assert result["TaskArn"] == self.task_arn
# ### Check mocks:
mock_get_conn.assert_called()
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
index 0a591cce72d..178a2043f98 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
@@ -47,16 +47,16 @@ from airflow.providers.common.compat.sdk import
AirflowException, TaskDeferred
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
+from tests_common.test_utils.taskinstance import (
+ create_task_instance,
+ get_template_context,
+ render_template_fields,
+)
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
if AIRFLOW_V_3_0_PLUS:
from airflow.models.dag_version import DagVersion
@@ -333,7 +333,7 @@ class TestDmsDescribeTasksOperator:
if AIRFLOW_V_3_0_PLUS:
sync_dag_to_db(self.dag)
dag_version = DagVersion.get_latest_version(self.dag.dag_id)
- ti = create_task_instance(task=describe_task,
dag_version_id=dag_version.id)
+ ti = create_task_instance(task=describe_task, run_id="test",
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.dag.dag_id,
logical_date=timezone.utcnow(),
@@ -353,7 +353,7 @@ class TestDmsDescribeTasksOperator:
ti.dag_run = dag_run
session.add(ti)
session.commit()
- marker, response = describe_task.execute(ti.get_template_context())
+ marker, response = describe_task.execute(get_template_context(ti,
describe_task))
assert marker is None
assert response == self.MOCK_RESPONSE
@@ -535,7 +535,7 @@ class TestDmsDescribeReplicationConfigsOperator:
if AIRFLOW_V_3_0_PLUS:
sync_dag_to_db(dag)
dag_version = DagVersion.get_latest_version(dag.dag_id)
- ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+ ti = create_task_instance(task=op, run_id="test",
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=dag.dag_id,
run_id="test",
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
index d1a32ee8d6f..e52eb36e7bb 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_add_steps.py
@@ -107,7 +107,6 @@ class TestEmrAddStepsOperator:
sync_dag_to_db(self.operator.dag)
dag_version =
DagVersion.get_latest_version(self.operator.dag.dag_id)
- ti = create_task_instance(task=self.operator,
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.operator.dag.dag_id,
logical_date=DEFAULT_DATE,
@@ -116,6 +115,11 @@ class TestEmrAddStepsOperator:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
+ ti = create_task_instance(
+ task=self.operator,
+ run_id=dag_run.run_id,
+ dag_version_id=dag_version.id,
+ )
else:
dag_run = DagRun(
dag_id=self.operator.dag.dag_id,
@@ -180,7 +184,6 @@ class TestEmrAddStepsOperator:
sync_dag_to_db(dag)
dag_version = DagVersion.get_latest_version(dag.dag_id)
- ti = create_task_instance(task=test_task,
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=dag.dag_id,
logical_date=timezone.utcnow(),
@@ -189,6 +192,7 @@ class TestEmrAddStepsOperator:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
+ ti = create_task_instance(task=test_task, run_id=dag_run.run_id,
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=dag.dag_id,
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
index 28b5952471f..e49f7ba7758 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_create_job_flow.py
@@ -35,17 +35,13 @@ from airflow.providers.common.compat.sdk import TaskDeferred
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.dag import sync_dag_to_db
from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
TASK_ID = "test_task"
TEST_DAG_ID = "test_dag_id"
@@ -111,7 +107,6 @@ class TestEmrCreateJobFlowOperator:
sync_dag_to_db(self.operator.dag)
dag_version =
DagVersion.get_latest_version(self.operator.dag.dag_id)
- ti = create_task_instance(task=self.operator,
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.operator.dag_id,
logical_date=DEFAULT_DATE,
@@ -120,6 +115,7 @@ class TestEmrCreateJobFlowOperator:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
+ ti = create_task_instance(task=self.operator, run_id="test",
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=self.operator.dag_id,
@@ -165,7 +161,6 @@ class TestEmrCreateJobFlowOperator:
sync_dag_to_db(self.operator.dag)
dag_version =
DagVersion.get_latest_version(self.operator.dag.dag_id)
- ti = create_task_instance(task=self.operator,
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.operator.dag_id,
logical_date=DEFAULT_DATE,
@@ -174,6 +169,7 @@ class TestEmrCreateJobFlowOperator:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
+ ti = create_task_instance(task=self.operator, run_id="test",
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=self.operator.dag_id,
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
index a08ca1c5c99..fc012cad306 100644
---
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_modify_cluster.py
@@ -24,11 +24,7 @@ from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.emr import EmrModifyClusterOperator
from airflow.providers.common.compat.sdk import AirflowException
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
index c98b057c54e..940137f1058 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_rds.py
@@ -44,11 +44,7 @@ from airflow.providers.amazon.aws.operators.rds import (
from airflow.providers.amazon.aws.triggers.rds import RdsDbAvailableTrigger,
RdsDbStoppedTrigger
from airflow.providers.common.compat.sdk import TaskDeferred
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
if TYPE_CHECKING:
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
index 73b7cfdc679..ffd43a59c0c 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
@@ -680,7 +680,7 @@ class TestS3DeleteObjectsOperator:
sync_dag_to_db(dag)
dag_version = DagVersion.get_latest_version(dag.dag_id)
- ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+ ti = create_task_instance(task=op, run_id="test",
dag_version_id=dag_version.id)
else:
ti = TaskInstance(task=op)
ti.dag_run = dag_run
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
index 6ae2ad25b9b..e0269ae603f 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_base.py
@@ -33,16 +33,12 @@ from airflow.providers.common.compat.sdk import
AirflowException
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.dag import sync_dag_to_db
from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
CONFIG: dict = {
"key1": "1",
"key2": {"key3": "3", "key4": "4"},
@@ -221,7 +217,6 @@ class TestSageMakerExperimentOperator:
sync_dag_to_db(dag)
dag_version = DagVersion.get_latest_version(dag.dag_id)
- ti = create_task_instance(task=op, dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=dag.dag_id,
logical_date=logical_date,
@@ -230,6 +225,7 @@ class TestSageMakerExperimentOperator:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
+ ti = create_task_instance(task=op, run_id="test",
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=dag.dag_id,
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
index 903171121c5..72891b91bc4 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
@@ -37,10 +37,7 @@ from airflow.providers.amazon.aws.sensors.ecs import (
from airflow.providers.amazon.version_compat import NOTSET
from airflow.providers.common.compat.sdk import AirflowException
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
_Operator = TypeVar("_Operator")
TEST_CLUSTER_NAME = "fake-cluster"
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
index a097c5ed3d4..74f724ac398 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_rds.py
@@ -32,11 +32,7 @@ from airflow.providers.amazon.aws.sensors.rds import (
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.providers.common.compat.sdk import AirflowException
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
+from tests_common.test_utils.compat import timezone
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
DEFAULT_DATE = timezone.datetime(2019, 1, 1)
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
index f789f0cdb57..d5a697cbdb8 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
@@ -32,15 +32,11 @@ from airflow.providers.common.compat.sdk import
AirflowException
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.dag import sync_dag_to_db
from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
DEFAULT_DATE = datetime(2015, 1, 1)
@@ -149,7 +145,7 @@ class TestS3KeySensor:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
- ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+ ti = create_task_instance(task=op, run_id="test",
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=dag.dag_id,
@@ -206,7 +202,7 @@ class TestS3KeySensor:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
- ti = create_task_instance(task=op, dag_version_id=dag_version.id)
+ ti = create_task_instance(task=op, run_id="test",
dag_version_id=dag_version.id)
ti.dag_run = dag_run
rendered = render_template_fields(ti, op)
rendered.poke(None)
diff --git
a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
index f589c66fdbf..b2138d88769 100644
--- a/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/transfers/test_mongo_to_s3.py
@@ -90,7 +90,6 @@ class TestMongoToS3Operator:
sync_dag_to_db(self.dag)
dag_version =
DagVersion.get_latest_version(self.mock_operator.dag_id)
- ti = create_task_instance(self.mock_operator,
dag_version_id=dag_version.id)
dag_run = DagRun(
dag_id=self.mock_operator.dag_id,
logical_date=DEFAULT_DATE,
@@ -99,6 +98,7 @@ class TestMongoToS3Operator:
state=DagRunState.RUNNING,
run_after=timezone.utcnow(),
)
+ ti = create_task_instance(self.mock_operator, run_id="test",
dag_version_id=dag_version.id)
else:
dag_run = DagRun(
dag_id=self.mock_operator.dag_id,
diff --git
a/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
b/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
index 41eb0cf06b2..bc2b2c52a07 100644
--- a/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
+++ b/providers/apache/druid/tests/unit/apache/druid/operators/test_druid.py
@@ -24,8 +24,9 @@ import pytest
from airflow.providers.apache.druid.hooks.druid import IngestionType
from airflow.providers.apache.druid.operators.druid import DruidOperator
-from airflow.utils import timezone
-from airflow.utils.types import DagRunType
+
+from tests_common.test_utils.compat import timezone
+from tests_common.test_utils.taskinstance import get_template_context
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
@@ -65,8 +66,8 @@ def test_render_template(dag_maker):
params={"index_type": "index_hadoop", "datasource":
"datasource_prd"},
)
- ti =
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0]
- operator.render_template_fields(ti.get_template_context())
+ ti = dag_maker.create_ti(operator.task_id)
+ operator.render_template_fields(get_template_context(ti, operator))
assert json.loads(operator.json_index_file) == RENDERED_INDEX
@@ -89,8 +90,8 @@ def test_render_template_from_file(tmp_path, dag_maker):
params={"index_type": "index_hadoop", "datasource":
"datasource_prd"},
)
- ti =
dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED).task_instances[0]
- operator.render_template_fields(ti.get_template_context())
+ ti = dag_maker.create_ti(operator.task_id)
+ operator.render_template_fields(get_template_context(ti, operator))
assert json.loads(operator.json_index_file) == RENDERED_INDEX
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
index 82c022918d3..b062ab3a310 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
@@ -21,6 +21,9 @@ from unittest import mock
import pytest
+from tests_common.test_utils.compat import timezone
+from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_rendered_ti_fields
+from tests_common.test_utils.taskinstance import get_template_context
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
if AIRFLOW_V_3_0_PLUS:
@@ -28,10 +31,6 @@ if AIRFLOW_V_3_0_PLUS:
else:
from airflow.decorators import setup, task, teardown # type:
ignore[attr-defined,no-redef]
-from airflow.utils import timezone
-
-from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_rendered_ti_fields
-
TASK_FUNCTION_NAME_ID = "task_function_name"
DEFAULT_DATE = timezone.datetime(2023, 1, 1)
DAG_ID = "k8s_deco_test_dag"
@@ -124,8 +123,7 @@ class TestKubernetesDecoratorsBase:
session = self.dag_maker.session
dag_run =
self.dag_maker.create_dagrun(run_id=f"k8s_decorator_test_{DEFAULT_DATE.date()}")
ti = dag_run.get_task_instance(task.operator.task_id, session=session)
- return_val =
task.operator.execute(context=ti.get_template_context(session=session))
-
+ return_val = task.operator.execute(context=get_template_context(ti,
task.operator, session=session))
return ti, return_val
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
index 5209fabec05..5b0b247c938 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py
@@ -51,7 +51,7 @@ from airflow.utils.types import DagRunType
from tests_common.test_utils import db
from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.taskinstance import create_task_instance
+from tests_common.test_utils.taskinstance import create_task_instance,
get_template_context
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS
if AIRFLOW_V_3_0_PLUS or AIRFLOW_V_3_1_PLUS:
@@ -259,7 +259,7 @@ class TestKubernetesPodOperator:
(ti,) = dr.task_instances
ti.map_index = map_index
self.dag_run = dr
- context = ti.get_template_context(session=self.dag_maker.session)
+ context = get_template_context(ti, operator,
session=self.dag_maker.session)
self.dag_maker.session.commit() # So 'execute' can read dr and ti.
remote_pod_mock = MagicMock()
@@ -2371,7 +2371,7 @@ class TestKubernetesPodOperatorAsync:
(ti,) = dr.task_instances
ti.map_index = map_index
self.dag_run = dr
- context = ti.get_template_context(session=self.dag_maker.session)
+ context = get_template_context(ti, operator,
session=self.dag_maker.session)
self.dag_maker.session.commit() # So 'execute' can read dr and ti.
remote_pod_mock = MagicMock()
diff --git a/providers/common/io/tests/unit/common/io/xcom/test_backend.py
b/providers/common/io/tests/unit/common/io/xcom/test_backend.py
index 8cd2cef1b38..3ad796b7e63 100644
--- a/providers/common/io/tests/unit/common/io/xcom/test_backend.py
+++ b/providers/common/io/tests/unit/common/io/xcom/test_backend.py
@@ -25,12 +25,8 @@ import airflow.models.xcom
from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
from airflow.providers.standard.operators.empty import EmptyOperator
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
from tests_common.test_utils import db
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS, XCOM_RETURN_KEY
diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
index e201a73b6c6..45f940236b7 100644
--- a/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
+++ b/providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py
@@ -39,10 +39,7 @@ from airflow.providers.dbt.cloud.hooks.dbt import (
fallback_to_default_account,
)
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
ACCOUNT_ID_CONN = "account_id_conn"
NO_ACCOUNT_ID_CONN = "no_account_id_conn"
diff --git a/providers/docker/tests/unit/docker/decorators/test_docker.py
b/providers/docker/tests/unit/docker/decorators/test_docker.py
index 8d69f5de7bc..51ddc6fe876 100644
--- a/providers/docker/tests/unit/docker/decorators/test_docker.py
+++ b/providers/docker/tests/unit/docker/decorators/test_docker.py
@@ -20,24 +20,12 @@ from importlib.util import find_spec
import pytest
-from airflow.models import TaskInstance
-from airflow.providers.common.compat.sdk import AirflowException
+from airflow.providers.common.compat.sdk import DAG, AirflowException, setup,
task, teardown
from airflow.utils.state import TaskInstanceState
+from tests_common.test_utils.compat import timezone
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
-from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS
-
-if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk import DAG, setup, task, teardown
-else:
- from airflow.decorators import setup, task, teardown # type:
ignore[attr-defined,no-redef]
- from airflow.models import DAG # type: ignore[attr-defined,no-redef]
-
-if AIRFLOW_V_3_1_PLUS:
- from airflow.sdk import timezone
-else:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.taskinstance import render_template_fields
DEFAULT_DATE = timezone.datetime(2021, 9, 1)
DILL_INSTALLED = find_spec("dill") is not None
@@ -99,17 +87,9 @@ class TestDockerDecorator:
with dag_maker():
ret = f()
- dr = dag_maker.create_dagrun()
- if AIRFLOW_V_3_0_PLUS:
- ti = create_task_instance(
- task=ret.operator,
- run_id=dr.run_id,
- dag_version_id=dr.created_dag_version_id,
- )
- else:
- ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
+ ti = dag_maker.create_ti("f")
rendered = render_template_fields(ti, ret.operator)
- assert rendered.container_name == f"python_{dr.dag_id}"
+ assert rendered.container_name == f"python_{ti.dag_id}"
assert rendered.mounts[0]["Target"] == f"/{ti.run_id}"
@pytest.mark.db_test
diff --git a/providers/github/tests/unit/github/operators/test_github.py
b/providers/github/tests/unit/github/operators/test_github.py
index 25176e222a7..9b9844faeae 100644
--- a/providers/github/tests/unit/github/operators/test_github.py
+++ b/providers/github/tests/unit/github/operators/test_github.py
@@ -25,10 +25,7 @@ from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.providers.github.operators.github import GithubOperator
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
github_client_mock = Mock(name="github_client_for_test")
diff --git a/providers/github/tests/unit/github/sensors/test_github.py
b/providers/github/tests/unit/github/sensors/test_github.py
index a9a414bae37..8001df4edce 100644
--- a/providers/github/tests/unit/github/sensors/test_github.py
+++ b/providers/github/tests/unit/github/sensors/test_github.py
@@ -25,10 +25,7 @@ from airflow.models import Connection
from airflow.models.dag import DAG
from airflow.providers.github.sensors.github import GithubTagSensor
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
github_client_mock = Mock(name="github_client_for_test")
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
b/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
index bdaee90cea2..6040115f220 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_gcs.py
@@ -39,13 +39,9 @@ from airflow.providers.common.compat.sdk import
AirflowException
from airflow.providers.google.cloud.hooks import gcs
from airflow.providers.google.cloud.hooks.gcs import
_fallback_object_url_to_object_name_and_bucket_name
from airflow.providers.google.common.consts import CLIENT_INFO
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.version import version
+from tests_common.test_utils.compat import timezone
from unit.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
diff --git a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
index 401df681ad6..3e5a731d79d 100644
--- a/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
+++ b/providers/mysql/tests/unit/mysql/hooks/test_mysql.py
@@ -27,6 +27,10 @@ import sqlalchemy
from airflow.models.dag import DAG
from airflow.providers.common.compat.sdk import Connection
+from airflow.providers.mysql.hooks.mysql import MySqlHook
+
+from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces
+from tests_common.test_utils.compat import timezone
try:
import MySQLdb.cursors
@@ -35,15 +39,6 @@ try:
except ImportError:
MYSQL_AVAILABLE = False
-from airflow.providers.mysql.hooks.mysql import MySqlHook
-
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-
-from tests_common.test_utils.asserts import assert_equal_ignore_multiple_spaces
-
SSL_DICT = {"cert": "/tmp/client-cert.pem", "ca": "/tmp/server-ca.pem", "key":
"/tmp/client-key.pem"}
INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type,
description, host, `schema`, login, password, port, is_encrypted,
is_extra_encrypted, extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
diff --git a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
index 4ccec5b79cd..b33f9577784 100644
--- a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
+++ b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack.py
@@ -23,10 +23,7 @@ import pytest
from airflow.providers.common.compat.sdk import AirflowSkipException
from airflow.providers.slack.transfers.sql_to_slack import
SqlToSlackApiFileOperator
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
TEST_DAG_ID = "sql_to_slack_unit_test"
TEST_TASK_ID = "sql_to_slack_unit_test_task"
diff --git
a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
index 54cffd37e9a..beb91bf255d 100644
--- a/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
+++ b/providers/slack/tests/unit/slack/transfers/test_sql_to_slack_webhook.py
@@ -24,10 +24,7 @@ import pytest
from airflow.models import Connection
from airflow.providers.slack.transfers.sql_to_slack_webhook import
SqlToSlackWebhookOperator
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
TEST_DAG_ID = "sql_to_slack_unit_test"
TEST_TASK_ID = "sql_to_slack_unit_test_task"
diff --git a/providers/standard/tests/unit/standard/decorators/test_bash.py
b/providers/standard/tests/unit/standard/decorators/test_bash.py
index ad899c6bcfc..b4ae3847f26 100644
--- a/providers/standard/tests/unit/standard/decorators/test_bash.py
+++ b/providers/standard/tests/unit/standard/decorators/test_bash.py
@@ -29,8 +29,12 @@ from airflow.models.renderedtifields import
RenderedTaskInstanceFields
from airflow.providers.common.compat.sdk import AirflowException,
AirflowSkipException
from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_rendered_ti_fields
-from tests_common.test_utils.taskinstance import render_template_fields,
run_task_instance
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.taskinstance import (
+ get_template_context,
+ render_template_fields,
+ run_task_instance,
+)
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS
if TYPE_CHECKING:
from airflow.models import TaskInstance
@@ -359,7 +363,10 @@ class TestBashDecorator:
ti = dr.task_instances[0]
with pytest.raises(AirflowException, match=f"Can not find the cwd:
{cwd_path}"):
run_task_instance(ti, bash_task.operator)
- assert ti.task.bash_command == "echo"
+ if AIRFLOW_V_3_2_PLUS:
+ assert ti.task.bash_command == "DYNAMIC (set during execution)"
+ else:
+ assert ti.task.bash_command == "echo"
def test_cwd_is_file(self, tmp_path):
"""Verify task failure for user-defined working directory that is
actually a file."""
@@ -380,7 +387,10 @@ class TestBashDecorator:
ti = dr.task_instances[0]
with pytest.raises(AirflowException, match=f"The cwd {cwd_file} must
be a directory"):
run_task_instance(ti, bash_task.operator)
- assert ti.task.bash_command == "echo"
+ if AIRFLOW_V_3_2_PLUS:
+ assert ti.task.bash_command == "DYNAMIC (set during execution)"
+ else:
+ assert ti.task.bash_command == "echo"
def test_command_not_found(self):
"""Fail task if executed command is not found on path."""
@@ -400,7 +410,10 @@ class TestBashDecorator:
AirflowException, match="Bash command failed\\. The command
returned a non-zero exit code 127\\."
):
run_task_instance(ti, bash_task.operator)
- assert ti.task.bash_command == "set -e; something-that-isnt-on-path"
+ if AIRFLOW_V_3_2_PLUS:
+ assert ti.task.bash_command == "DYNAMIC (set during execution)"
+ else:
+ assert ti.task.bash_command == "set -e;
something-that-isnt-on-path"
def test_multiple_outputs_true(self):
"""Verify setting `multiple_outputs` for a @task.bash-decorated
function is ignored."""
@@ -496,7 +509,10 @@ class TestBashDecorator:
ti = dr.task_instances[0]
with pytest.raises(AirflowException):
run_task_instance(ti, bash_task.operator)
- assert ti.task.bash_command == f"{DEFAULT_DATE.date()}; exit 1;"
+ if AIRFLOW_V_3_2_PLUS:
+ assert ti.task.bash_command == "DYNAMIC (set during execution)"
+ else:
+ assert ti.task.bash_command == f"{DEFAULT_DATE.date()}; exit 1;"
@pytest.mark.db_test
def test_templated_bash_script(self, dag_maker, tmp_path, session):
@@ -519,7 +535,7 @@ class TestBashDecorator:
task_arg = test_templated_fields_task()
ti: TaskInstance = dag_maker.create_dagrun().task_instances[0]
- context = ti.get_template_context(session=session)
+ context = get_template_context(ti, task_arg.operator, session=session)
op = render_template_fields(ti, task_arg.operator, context=context)
result = op.execute(context=context)
assert result == "test_templated_fields_task"
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 8b82a737739..32b2fc2615d 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import sys
import typing
from collections import namedtuple
@@ -22,11 +23,10 @@ from datetime import date
import pytest
-from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.providers.common.compat.sdk import AirflowException, XComNotFound
-from tests_common.test_utils.taskinstance import create_task_instance,
render_template_fields
+from tests_common.test_utils.taskinstance import get_template_context,
render_template_fields
from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_1,
AIRFLOW_V_3_0_PLUS,
@@ -68,6 +68,7 @@ else:
if typing.TYPE_CHECKING:
from airflow.models.dagrun import DagRun
+ from airflow.models.taskinstance import TaskInstance
pytestmark = pytest.mark.db_test
@@ -76,6 +77,15 @@ PY38 = sys.version_info >= (3, 8)
PY311 = sys.version_info >= (3, 11)
[email protected](autouse=True)
+def clear_current_task_session():
+ try:
+ import airflow.utils.task_instance_session
+ except ModuleNotFoundError:
+ return
+ airflow.utils.task_instance_session.__current_task_instance_session = None
+
+
class TestAirflowTaskDecorator(BasePythonTest):
default_date = DEFAULT_DATE
@@ -399,20 +409,13 @@ class TestAirflowTaskDecorator(BasePythonTest):
ret = arg_task(4, date(2019, 1, 1), "dag {{dag.dag_id}} ran on
{{ds}}.", named_tuple)
dr = self.create_dag_run()
- if AIRFLOW_V_3_0_PLUS:
- ti = create_task_instance(
- task=ret.operator,
- run_id=dr.run_id,
- dag_version_id=dr.created_dag_version_id,
- )
- else:
- ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
- rendered_op_args = render_template_fields(ti, ret.operator).op_args
- assert len(rendered_op_args) == 4
- assert rendered_op_args[0] == 4
- assert rendered_op_args[1] == date(2019, 1, 1)
- assert rendered_op_args[2] == f"dag {self.dag_id} ran on
{self.ds_templated}."
- assert rendered_op_args[3] == Named(self.ds_templated, "unchanged")
+ ti = dr.get_task_instance("arg_task", session=self.dag_maker.session)
+ assert render_template_fields(ti, ret.operator).op_args == (
+ 4,
+ date(2019, 1, 1),
+ f"dag {self.dag_id} ran on {self.ds_templated}.",
+ Named(self.ds_templated, "unchanged"),
+ )
def test_python_callable_keyword_arguments_are_templatized(self):
"""Test PythonOperator op_kwargs are templatized"""
@@ -427,18 +430,12 @@ class TestAirflowTaskDecorator(BasePythonTest):
)
dr = self.create_dag_run()
- if AIRFLOW_V_3_0_PLUS:
- ti = create_task_instance(
- task=ret.operator,
- run_id=dr.run_id,
- dag_version_id=dr.created_dag_version_id,
- )
- else:
- ti = TaskInstance(task=ret.operator, run_id=dr.run_id)
- rendered_op_kwargs = render_template_fields(ti, ret.operator).op_kwargs
- assert rendered_op_kwargs["an_int"] == 4
- assert rendered_op_kwargs["a_date"] == date(2019, 1, 1)
- assert rendered_op_kwargs["a_templated_string"] == f"dag {self.dag_id}
ran on {self.ds_templated}."
+ ti = dr.get_task_instance("kwargs_task",
session=self.dag_maker.session)
+ assert render_template_fields(ti, ret.operator).op_kwargs == {
+ "an_int": 4,
+ "a_date": date(2019, 1, 1),
+ "a_templated_string": f"dag {self.dag_id} ran on
{self.ds_templated}.",
+ }
def test_manual_task_id(self):
"""Test manually setting task_id"""
@@ -821,14 +818,10 @@ def
test_mapped_decorator_unmap_merge_op_kwargs(dag_maker, session):
assert [ti.task_id for ti in dec.schedulable_tis] == ["task2"]
ti = dec.schedulable_tis[0]
- # Use the real task for unmapping to mimic actual execution path
- ti.task = dag_maker.dag.task_dict[ti.task_id]
-
- if AIRFLOW_V_3_0_PLUS:
- unmapped = ti.task.unmap((ti.get_template_context(session),))
- else:
- unmapped = ti.task.unmap((ti.get_template_context(session), session))
- assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
+ task = dag_maker.dag.task_dict[ti.task_id]
+ context = get_template_context(ti, task, session=session)
+ render_template_fields(ti, task, context=context, session=session)
+ assert set(context["task"].op_kwargs) == {"arg1", "arg2"}
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
@@ -863,11 +856,12 @@ def test_mapped_render_template_fields(dag_maker,
session):
mapped_ti.map_index = 0
mapped_ti.task = mapped.operator
assert isinstance(mapped_ti.task, MappedOperator)
-
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
- assert isinstance(mapped_ti.task, BaseOperator)
+ context = get_template_context(mapped_ti, mapped.operator, session=session)
+ mapped.operator.render_template_fields(context)
+ assert isinstance(context["task"], BaseOperator)
- assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
- assert mapped_ti.task.op_kwargs["arg2"] == "fn"
+ assert context["task"].op_kwargs["arg1"] == "{{ ds }}"
+ assert context["task"].op_kwargs["arg2"] == "fn"
@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py
b/providers/standard/tests/unit/standard/operators/test_python.py
index f7bca58e4b9..a59c33b29dc 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -62,8 +62,9 @@ from airflow.utils.session import create_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import TriggerRule, timezone
from tests_common.test_utils.db import clear_db_runs
-from tests_common.test_utils.taskinstance import run_task_instance
+from tests_common.test_utils.taskinstance import get_template_context,
run_task_instance
from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_1,
AIRFLOW_V_3_0_PLUS,
@@ -79,16 +80,6 @@ else:
from airflow.models.baseoperator import BaseOperator # type:
ignore[no-redef]
from airflow.models.taskinstance import set_current_context # type:
ignore[attr-defined,no-redef]
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-try:
- from airflow.sdk import TriggerRule
-except ImportError:
- # Compatibility for Airflow < 3.1
- from airflow.utils.trigger_rule import TriggerRule # type:
ignore[no-redef,attr-defined]
-
if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
@@ -205,7 +196,7 @@ class BasePythonTest:
"""Create TaskInstance and run it."""
ti = self.create_ti(fn, **kwargs)
assert ti.task is not None
- ti.run()
+ ti = ti.run()
if return_ti:
return ti
return ti.task
@@ -996,11 +987,8 @@ class TestDagBundleImportInSubprocess(BasePythonTest):
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(self.task_id)
- mock_bundle_instance = mock.Mock()
- mock_bundle_instance.path = str(bundle_root)
- ti.bundle_instance = mock_bundle_instance
-
- context = ti.get_template_context()
+ context = get_template_context(ti, op)
+ context["ti"].bundle_instance = mock.Mock(path=str(bundle_root))
# Mock subprocess execution to avoid testing-environment related issues
# on the ExternalPythonOperator (Socket operation on non-socket)
@@ -1011,13 +999,10 @@ class TestDagBundleImportInSubprocess(BasePythonTest):
with mock.patch.object(op, "_read_result", return_value=None):
op.execute(context)
- assert mock_execute_subprocess.called, "_execute_in_subprocess should
have been called"
- call_kwargs = mock_execute_subprocess.call_args.kwargs
- env = call_kwargs.get("env")
- assert "PYTHONPATH" in env, "PYTHONPATH should be in env"
-
- pythonpath = env["PYTHONPATH"]
- assert str(bundle_root) in pythonpath, f"Bundle path {bundle_root}
should be in PYTHONPATH"
+ pythonpath =
mock_execute_subprocess.call_args.kwargs["env"]["PYTHONPATH"]
+ assert str(bundle_root) in pythonpath, (
+ f"Bundle path {str(bundle_root)!r} not in PYTHONPATH
{pythonpath!r}"
+ )
@pytest.mark.execution_timeout(120)
@@ -1060,7 +1045,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
return None
ti = self.run_as_task(f, return_ti=True)
- assert ti.xcom_pull() is None
+ assert TaskInstance.xcom_pull(ti) is None
def test_return_false(self):
def f():
@@ -1068,7 +1053,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
ti = self.run_as_task(f, return_ti=True)
- assert ti.xcom_pull() is False
+ assert TaskInstance.xcom_pull(ti) is False
def test_lambda(self):
with pytest.raises(
@@ -1234,7 +1219,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
return os.environ["MY_ENV_VAR"]
ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "ABCDE"},
return_ti=True)
- assert ti.xcom_pull() == "ABCDE"
+ assert TaskInstance.xcom_pull(ti) == "ABCDE"
def test_environment_variables_with_inherit_env_true(self, monkeypatch):
monkeypatch.setenv("MY_ENV_VAR", "QWERT")
@@ -1245,7 +1230,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
return os.environ["MY_ENV_VAR"]
ti = self.run_as_task(f, inherit_env=True, return_ti=True)
- assert ti.xcom_pull() == "QWERT"
+ assert TaskInstance.xcom_pull(ti) == "QWERT"
def test_environment_variables_with_inherit_env_false(self, monkeypatch):
monkeypatch.setenv("MY_ENV_VAR", "TYUIO")
@@ -1267,7 +1252,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
return os.environ["MY_ENV_VAR"]
ti = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"},
inherit_env=True, return_ti=True)
- assert ti.xcom_pull() == "EFGHI"
+ assert TaskInstance.xcom_pull(ti) == "EFGHI"
venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")
diff --git a/providers/standard/tests/unit/standard/sensors/test_time.py
b/providers/standard/tests/unit/standard/sensors/test_time.py
index 8e9097d6a64..1989c35d91b 100644
--- a/providers/standard/tests/unit/standard/sensors/test_time.py
+++ b/providers/standard/tests/unit/standard/sensors/test_time.py
@@ -28,10 +28,7 @@ from airflow.providers.common.compat.sdk import TaskDeferred
from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.triggers.temporal import DateTimeTrigger
-try:
- from airflow.sdk import timezone
-except ImportError:
- from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
+from tests_common.test_utils.compat import timezone
DEFAULT_TIMEZONE = pendulum.timezone("Asia/Singapore") # UTC+08:00
DEFAULT_DATE_WO_TZ = datetime(2015, 1, 1)
diff --git a/scripts/ci/prek/check_template_context_variable_in_sync.py
b/scripts/ci/prek/check_template_context_variable_in_sync.py
index 01f4c461c68..1d78b9d7242 100755
--- a/scripts/ci/prek/check_template_context_variable_in_sync.py
+++ b/scripts/ci/prek/check_template_context_variable_in_sync.py
@@ -32,14 +32,10 @@ import sys
import typing
sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve())) # make sure
common_prek_utils is imported
-from common_prek_utils import (
- AIRFLOW_CORE_ROOT_PATH,
- AIRFLOW_CORE_SOURCES_PATH,
- AIRFLOW_TASK_SDK_SOURCES_PATH,
-)
+
+from common_prek_utils import AIRFLOW_CORE_ROOT_PATH,
AIRFLOW_TASK_SDK_SOURCES_PATH
TASKRUNNER_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" /
"execution_time" / "task_runner.py"
-CONTEXT_PY = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "utils" / "context.py"
CONTEXT_HINT = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" /
"definitions" / "context.py"
TEMPLATES_REF_RST = AIRFLOW_CORE_ROOT_PATH / "docs" / "templates-ref.rst"
@@ -117,23 +113,6 @@ def _iter_template_context_keys_from_original_return() ->
typing.Iterator[str]:
yield from extract_keys_from_dict(sub_stmt.value)
-def _iter_template_context_keys_from_declaration() -> typing.Iterator[str]:
- context_mod = ast.parse(CONTEXT_PY.read_text("utf-8"), str(CONTEXT_PY))
- st_known_context_keys = next(
- stmt.value
- for stmt in context_mod.body
- if isinstance(stmt, ast.AnnAssign)
- and isinstance(stmt.target, ast.Name)
- and stmt.target.id == "KNOWN_CONTEXT_KEYS"
- )
- if not isinstance(st_known_context_keys, ast.Set):
- raise ValueError("'KNOWN_CONTEXT_KEYS' is not assigned a set literal")
- for expr in st_known_context_keys.elts:
- if not isinstance(expr, ast.Constant) or not isinstance(expr.value,
str):
- raise ValueError("item in 'KNOWN_CONTEXT_KEYS' set is not a str
literal")
- yield expr.value
-
-
def _iter_template_context_keys_from_type_hints() -> typing.Iterator[str]:
context_mod = ast.parse(CONTEXT_HINT.read_text("utf-8"), str(CONTEXT_HINT))
cls_context = next(
@@ -158,7 +137,7 @@ def _iter_template_context_keys_from_documentation() ->
typing.Iterator[str]:
yield match.group("name")
-def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys:
set[str], docs_keys: set[str]) -> int:
+def _compare_keys(retn_keys: set[str], hint_keys: set[str], docs_keys:
set[str]) -> int:
# Added by PythonOperator and commonly used.
# Not listed in templates-ref (but in operator docs).
retn_keys.add("templates_dict")
@@ -168,10 +147,7 @@ def _compare_keys(retn_keys: set[str], decl_keys:
set[str], hint_keys: set[str],
retn_keys.add("expanded_ti_count")
# TODO: These are the keys that are yet to be ported over to the Task SDK.
- retn_keys.add("inlet_events")
- retn_keys.add("params")
retn_keys.add("test_mode")
- retn_keys.add("triggering_asset_events")
# Only present in callbacks. Not listed in templates-ref (that doc is for
task execution).
retn_keys.update(("exception", "reason", "try_number"))
@@ -182,7 +158,6 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str],
hint_keys: set[str],
check_candidates = [
("get_template_context()", retn_keys),
- ("KNOWN_CONTEXT_KEYS", decl_keys),
("Context type hint", hint_keys),
("templates-ref", docs_keys),
]
@@ -198,10 +173,9 @@ def _compare_keys(retn_keys: set[str], decl_keys:
set[str], hint_keys: set[str],
def main() -> str | int | None:
retn_keys = set(_iter_template_context_keys_from_original_return())
- decl_keys = set(_iter_template_context_keys_from_declaration())
hint_keys = set(_iter_template_context_keys_from_type_hints())
docs_keys = set(_iter_template_context_keys_from_documentation())
- return _compare_keys(retn_keys, decl_keys, hint_keys, docs_keys)
+ return _compare_keys(retn_keys, hint_keys, docs_keys)
if __name__ == "__main__":
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index b0b8357b5e0..3351f591a95 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -1398,6 +1398,8 @@ def _run_task(
possible. This function is only meant for the `dag.test` function as a
helper function.
"""
from airflow.sdk._shared.module_loading import import_string
+ from airflow.sdk.serde import deserialize, serialize
+ from airflow.utils.session import create_session
taskrun_result: TaskRunResult | None
log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id,
ti.map_index)
@@ -1429,16 +1431,16 @@ def _run_task(
ti.task = create_scheduler_operator(taskrun_result.ti.task)
if ti.state == TaskInstanceState.DEFERRED and isinstance(msg,
DeferTask) and run_triggerer:
- from airflow.sdk.serde import deserialize, serialize
- from airflow.utils.session import create_session
-
# API Server expects the task instance to be in QUEUED state
before
# resuming from deferral.
ti.set_state(TaskInstanceState.QUEUED)
log.info("[DAG TEST] running trigger in line")
- # trigger_kwargs need to be deserialized before passing to the
trigger class since they are in serde encoded format
- kwargs = deserialize(msg.trigger_kwargs) # type:
ignore[type-var] # needed to convince mypy that trigger_kwargs is a dict or a
str because its unable to infer JsonValue
+ # trigger_kwargs need to be deserialized before passing to the
+ # trigger class since they are in serde encoded format.
+ # Ignore needed to convince mypy that trigger_kwargs is a dict
+ # or a str because its unable to infer JsonValue.
+ kwargs = deserialize(msg.trigger_kwargs) # type:
ignore[type-var]
if TYPE_CHECKING:
assert isinstance(kwargs, dict)
trigger = import_string(msg.classpath)(**kwargs)
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index 6aff9ca3c68..ef4c4aeb59a 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -752,19 +752,13 @@ class MappedOperator(AbstractOperator):
"params": params,
}
- def unmap(self, resolve: None | Mapping[str, Any]) -> BaseOperator:
+ def unmap(self, resolve: Mapping[str, Any]) -> BaseOperator:
"""
Get the "normal" Operator after applying the current mapping.
:meta private:
"""
- if isinstance(resolve, Mapping):
- kwargs = resolve
- elif resolve is not None:
- kwargs, _ = self._expand_mapped_kwargs(*resolve)
- else:
- raise RuntimeError("cannot unmap a non-serialized operator without
context")
- kwargs = self._get_unmap_kwargs(kwargs,
strict=self._disallow_kwargs_override)
+ kwargs = self._get_unmap_kwargs(resolve,
strict=self._disallow_kwargs_override)
is_setup = kwargs.pop("is_setup", False)
is_teardown = kwargs.pop("is_teardown", False)
on_failure_fail_dagrun = kwargs.pop("on_failure_fail_dagrun", False)