This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b703d53b774 Move Literal alias into TYPE_CHECKING block (#45345)
b703d53b774 is described below
commit b703d53b774960326b8d91963304bac3ca5d533c
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Jan 9 13:34:05 2025 +0800
Move Literal alias into TYPE_CHECKING block (#45345)
---
airflow/api_fastapi/app.py | 5 ++++-
airflow/api_fastapi/core_api/security.py | 6 ++++--
airflow/auth/managers/base_auth_manager.py | 12 +++++++-----
.../auth/managers/simple/simple_auth_manager.py | 3 ++-
.../cli/commands/remote_commands/task_command.py | 9 +++++----
airflow/models/mappedoperator.py | 4 ++--
airflow/providers_manager.py | 4 ++--
.../amazon/aws/auth_manager/aws_auth_manager.py | 7 ++++---
.../airflow/providers/amazon/aws/sensors/sqs.py | 3 ++-
.../airflow/providers/amazon/aws/triggers/sqs.py | 3 ++-
.../src/airflow/providers/amazon/aws/utils/sqs.py | 10 ++++++----
.../airflow/providers/docker/decorators/docker.py | 4 +++-
.../providers/fab/auth_manager/fab_auth_manager.py | 3 ++-
.../airflow/providers/standard/operators/python.py | 6 ++++--
.../airflow/providers/weaviate/hooks/weaviate.py | 2 +-
.../api_fastapi/core_api/routes/public/test_job.py | 22 +++++++++++++---------
tests/auth/managers/test_base_auth_manager.py | 3 ++-
tests_common/_internals/capture_warnings.py | 9 ++++++---
18 files changed, 71 insertions(+), 44 deletions(-)
diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py
index 3c9aa39ae70..9cbb190d411 100644
--- a/airflow/api_fastapi/app.py
+++ b/airflow/api_fastapi/app.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import logging
from contextlib import AsyncExitStack, asynccontextmanager
+from typing import TYPE_CHECKING
from fastapi import FastAPI
from starlette.routing import Mount
@@ -31,10 +32,12 @@ from airflow.api_fastapi.core_api.app import (
init_views,
)
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
-from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
+if TYPE_CHECKING:
+ from airflow.auth.managers.base_auth_manager import BaseAuthManager
+
log = logging.getLogger(__name__)
app: FastAPI | None = None
diff --git a/airflow/api_fastapi/core_api/security.py
b/airflow/api_fastapi/core_api/security.py
index 30470e9b5da..7aaee4d0fae 100644
--- a/airflow/api_fastapi/core_api/security.py
+++ b/airflow/api_fastapi/core_api/security.py
@@ -17,19 +17,21 @@
from __future__ import annotations
from functools import cache
-from typing import Annotated, Any, Callable
+from typing import TYPE_CHECKING, Annotated, Any, Callable
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jwt import InvalidTokenError
from airflow.api_fastapi.app import get_auth_manager
-from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import DagAccessEntity,
DagDetails
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner
+if TYPE_CHECKING:
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
diff --git a/airflow/auth/managers/base_auth_manager.py
b/airflow/auth/managers/base_auth_manager.py
index 345d22d5363..6a9ef11e3d7 100644
--- a/airflow/auth/managers/base_auth_manager.py
+++ b/airflow/auth/managers/base_auth_manager.py
@@ -18,21 +18,21 @@
from __future__ import annotations
from abc import abstractmethod
-from collections.abc import Container, Sequence
-from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
+from typing import TYPE_CHECKING, Any, Generic, TypeVar
from sqlalchemy import select
from airflow.auth.managers.models.base_user import BaseUser
-from airflow.auth.managers.models.resource_details import (
- DagDetails,
-)
+from airflow.auth.managers.models.resource_details import DagDetails
from airflow.exceptions import AirflowException
from airflow.models import DagModel
+from airflow.typing_compat import Literal
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
+ from collections.abc import Container, Sequence
+
from fastapi import FastAPI
from flask import Blueprint
from flask_appbuilder.menu import MenuItem
@@ -55,6 +55,8 @@ if TYPE_CHECKING:
)
from airflow.cli.cli_config import CLICommand
+# This cannot be in the TYPE_CHECKING block since some providers import it
globally.
+# TODO: Move this inside once all providers drop Airflow 2.x support.
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
T = TypeVar("T", bound=BaseUser)
diff --git a/airflow/auth/managers/simple/simple_auth_manager.py
b/airflow/auth/managers/simple/simple_auth_manager.py
index 5c411a5202d..6ac5f342587 100644
--- a/airflow/auth/managers/simple/simple_auth_manager.py
+++ b/airflow/auth/managers/simple/simple_auth_manager.py
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any
from flask import session, url_for
from termcolor import colored
-from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.auth.managers.simple.views.auth import
SimpleAuthManagerAuthenticationViews
from airflow.configuration import AIRFLOW_HOME, conf
@@ -35,6 +35,7 @@ from airflow.configuration import AIRFLOW_HOME, conf
if TYPE_CHECKING:
from flask_appbuilder.menu import MenuItem
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.resource_details import (
AccessView,
AssetDetails,
diff --git a/airflow/cli/commands/remote_commands/task_command.py
b/airflow/cli/commands/remote_commands/task_command.py
index ad3c26ff56e..6c591801515 100644
--- a/airflow/cli/commands/remote_commands/task_command.py
+++ b/airflow/cli/commands/remote_commands/task_command.py
@@ -28,7 +28,7 @@ import sys
import textwrap
from collections.abc import Generator
from contextlib import contextmanager, redirect_stderr, redirect_stdout,
suppress
-from typing import TYPE_CHECKING, Protocol, Union, cast
+from typing import TYPE_CHECKING, Protocol, cast
import pendulum
from pendulum.parsing.exceptions import ParserError
@@ -50,7 +50,6 @@ from airflow.models.taskinstance import TaskReturnCode
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
-from airflow.typing_compat import Literal
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.cli import (
get_dag,
@@ -70,13 +69,15 @@ from airflow.utils.task_instance_session import
set_current_task_instance_sessio
from airflow.utils.types import DagRunTriggeredByType
if TYPE_CHECKING:
+ from typing import Literal
+
from sqlalchemy.orm.session import Session
from airflow.models.operator import Operator
-log = logging.getLogger(__name__)
+ CreateIfNecessary = Literal[False, "db", "memory"]
-CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]]
+log = logging.getLogger(__name__)
def _generate_temporary_run_id() -> str:
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 4d362714794..f4728095037 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -53,7 +53,6 @@ from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.triggers.base import StartTriggerArgs
-from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.task_instance_session import
get_current_task_instance_session
@@ -62,6 +61,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY
if TYPE_CHECKING:
import datetime
+ from typing import Literal
import jinja2 # Slow import.
import pendulum
@@ -89,7 +89,7 @@ if TYPE_CHECKING:
TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback,
list[TaskStateChangeCallback]]
-ValidationSource = Union[Literal["expand"], Literal["partial"]]
+ ValidationSource = Literal["expand", "partial"]
def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource,
value: dict[str, Any]) -> None:
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 8d9f93734d7..575306a840b 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -86,12 +86,12 @@ def _ensure_prefix_for_placeholders(field_behaviors:
dict[str, Any], conn_type:
if TYPE_CHECKING:
+ from typing import Literal
from urllib.parse import SplitResult
from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook
from airflow.sdk.definitions.asset import Asset
- from airflow.typing_compat import Literal
class LazyDictWithCache(MutableMapping):
@@ -201,7 +201,7 @@ class ProviderInfo:
version: str
data: dict
- package_or_source: Literal["source"] | Literal["package"]
+ package_or_source: Literal["source", "package"]
def __post_init__(self):
if self.package_or_source not in ("source", "package"):
diff --git
a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index ab82c6042ce..88f8cef8b76 100644
---
a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++
b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, cast
from flask import session, url_for
-from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.resource_details import (
AccessView,
ConnectionDetails,
@@ -53,6 +53,7 @@ from airflow.providers.amazon.version_compat import
AIRFLOW_V_3_0_PLUS
if TYPE_CHECKING:
from flask_appbuilder.menu import MenuItem
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
@@ -326,11 +327,11 @@ class AwsAuthManager(BaseAuthManager):
for method in ["GET", "PUT"]:
if method in methods:
request: IsAuthorizedRequest = {
- "method": cast(ResourceMethod, method),
+ "method": cast("ResourceMethod", method),
"entity_type": AvpEntities.DAG,
"entity_id": dag_id,
}
- requests[dag_id][cast(ResourceMethod, method)] = request
+ requests[dag_id][cast("ResourceMethod", method)] = request
requests_list.append(request)
batch_is_authorized_results =
self.avp_facade.get_batch_is_authorized_results(
diff --git a/providers/src/airflow/providers/amazon/aws/sensors/sqs.py
b/providers/src/airflow/providers/amazon/aws/sensors/sqs.py
index 006c5bf2ad2..016e4fbd6c6 100644
--- a/providers/src/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/providers/src/airflow/providers/amazon/aws/sensors/sqs.py
@@ -30,10 +30,11 @@ from airflow.providers.amazon.aws.sensors.base_aws import
AwsBaseSensor
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
-from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType,
process_response
+from airflow.providers.amazon.aws.utils.sqs import process_response
if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
+ from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType
from airflow.utils.context import Context
diff --git a/providers/src/airflow/providers/amazon/aws/triggers/sqs.py
b/providers/src/airflow/providers/amazon/aws/triggers/sqs.py
index 31f344b9982..28c0b509d28 100644
--- a/providers/src/airflow/providers/amazon/aws/triggers/sqs.py
+++ b/providers/src/airflow/providers/amazon/aws/triggers/sqs.py
@@ -22,11 +22,12 @@ from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType,
process_response
+from airflow.providers.amazon.aws.utils.sqs import process_response
from airflow.triggers.base import BaseTrigger, TriggerEvent
if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
+ from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType
class SqsSensorTrigger(BaseTrigger):
diff --git a/providers/src/airflow/providers/amazon/aws/utils/sqs.py
b/providers/src/airflow/providers/amazon/aws/utils/sqs.py
index 293aa1b898d..3c509454655 100644
--- a/providers/src/airflow/providers/amazon/aws/utils/sqs.py
+++ b/providers/src/airflow/providers/amazon/aws/utils/sqs.py
@@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+#
from __future__ import annotations
import json
import logging
-from typing import Any
+from typing import TYPE_CHECKING, Any
import jsonpath_ng
import jsonpath_ng.ext
-from typing_extensions import Literal
-log = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from typing import Literal
+ MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"]
-MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"]
+log = logging.getLogger(__name__)
def process_response(
diff --git a/providers/src/airflow/providers/docker/decorators/docker.py
b/providers/src/airflow/providers/docker/decorators/docker.py
index 560028c16ed..77355ff03b2 100644
--- a/providers/src/airflow/providers/docker/decorators/docker.py
+++ b/providers/src/airflow/providers/docker/decorators/docker.py
@@ -20,7 +20,7 @@ import base64
import os
from collections.abc import Sequence
from tempfile import TemporaryDirectory
-from typing import TYPE_CHECKING, Any, Callable, Literal
+from typing import TYPE_CHECKING, Any, Callable
from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.exceptions import AirflowException
@@ -28,6 +28,8 @@ from airflow.providers.common.compat.standard.utils import
write_python_script
from airflow.providers.docker.operators.docker import DockerOperator
if TYPE_CHECKING:
+ from typing import Literal
+
from airflow.decorators.base import TaskDecorator
from airflow.utils.context import Context
diff --git
a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
index 6c8bf51b7ac..4c889a9c14e 100644
--- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -34,7 +34,7 @@ from sqlalchemy.orm import Session, joinedload
from starlette.middleware.wsgi import WSGIMiddleware
from airflow import __version__ as airflow_version
-from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
@@ -94,6 +94,7 @@ from airflow.utils.yaml import safe_load
from airflow.version import version
if TYPE_CHECKING:
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.cli.cli_config import (
CLICommand,
diff --git a/providers/src/airflow/providers/standard/operators/python.py
b/providers/src/airflow/providers/standard/operators/python.py
index 40a0cb7a922..25de405a80c 100644
--- a/providers/src/airflow/providers/standard/operators/python.py
+++ b/providers/src/airflow/providers/standard/operators/python.py
@@ -51,7 +51,6 @@ from airflow.providers.standard.version_compat import (
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
)
-from airflow.typing_compat import Literal
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
@@ -61,10 +60,14 @@ from airflow.utils.process_utils import
execute_in_subprocess, execute_in_subpro
log = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from typing import Literal
+
from pendulum.datetime import DateTime
from airflow.utils.context import Context
+ _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
+
@cache
def _parse_version_info(text: str) -> tuple[int, int, int, str, int]:
@@ -343,7 +346,6 @@ def _load_cloudpickle():
return cloudpickle
-_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
_SERIALIZERS: dict[_SerializerTypeDef, Any] = {
"pickle": lazy_object_proxy.Proxy(_load_pickle),
"dill": lazy_object_proxy.Proxy(_load_dill),
diff --git a/providers/src/airflow/providers/weaviate/hooks/weaviate.py
b/providers/src/airflow/providers/weaviate/hooks/weaviate.py
index 716bc3e10e6..e49cc58f024 100644
--- a/providers/src/airflow/providers/weaviate/hooks/weaviate.py
+++ b/providers/src/airflow/providers/weaviate/hooks/weaviate.py
@@ -749,7 +749,7 @@ class WeaviateHook(BaseHook):
verbose: bool = False,
) -> Sequence[dict[str, UUID | str] | None]:
"""
- create or replace objects belonging to documents.
+ Create or replace objects belonging to documents.
In real-world scenarios, information sources like Airflow docs, Stack
Overflow, or other issues
are considered 'documents' here. It's crucial to keep the database
objects in sync with these sources.
diff --git a/tests/api_fastapi/core_api/routes/public/test_job.py
b/tests/api_fastapi/core_api/routes/public/test_job.py
index f09c2b902c1..780d51f4467 100644
--- a/tests/api_fastapi/core_api/routes/public/test_job.py
+++ b/tests/api_fastapi/core_api/routes/public/test_job.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Literal
+from typing import TYPE_CHECKING
import pytest
@@ -28,15 +28,19 @@ from airflow.utils.state import JobState, State
from tests_common.test_utils.db import clear_db_jobs
from tests_common.test_utils.format_datetime import from_datetime_to_zulu
+if TYPE_CHECKING:
+ from typing import Literal
+
+ TestCase = Literal[
+ "should_report_success_for_one_working_scheduler",
+ "should_report_success_for_one_working_scheduler_with_hostname",
+ "should_report_success_for_ha_schedulers",
+ "should_ignore_not_running_jobs",
+ "should_raise_exception_for_multiple_scheduler_on_one_host",
+ ]
+
pytestmark = pytest.mark.db_test
-TESTCASE_TYPE = Literal[
- "should_report_success_for_one_working_scheduler",
- "should_report_success_for_one_working_scheduler_with_hostname",
- "should_report_success_for_ha_schedulers",
- "should_ignore_not_running_jobs",
- "should_raise_exception_for_multiple_scheduler_on_one_host",
-]
TESTCASE_ONE_SCHEDULER = "should_report_success_for_one_working_scheduler"
TESTCASE_ONE_SCHEDULER_WITH_HOSTNAME =
"should_report_success_for_one_working_scheduler_with_hostname"
TESTCASE_HA_SCHEDULERS = "should_report_success_for_ha_schedulers"
@@ -107,7 +111,7 @@ class TestJobEndpoint:
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)
@provide_session
- def setup(self, testcase: TESTCASE_TYPE, session=None) -> None:
+ def setup(self, testcase: TestCase, session=None) -> None:
"""
Setup testcase at runtime based on the `testcase` provided by
`pytest.mark.parametrize`.
"""
diff --git a/tests/auth/managers/test_base_auth_manager.py
b/tests/auth/managers/test_base_auth_manager.py
index a6480e809a8..4406ae9d436 100644
--- a/tests/auth/managers/test_base_auth_manager.py
+++ b/tests/auth/managers/test_base_auth_manager.py
@@ -21,7 +21,7 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
-from airflow.auth.managers.base_auth_manager import BaseAuthManager,
ResourceMethod
+from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
ConnectionDetails,
@@ -34,6 +34,7 @@ from airflow.exceptions import AirflowException
if TYPE_CHECKING:
from flask_appbuilder.menu import MenuItem
+ from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.resource_details import (
AccessView,
AssetDetails,
diff --git a/tests_common/_internals/capture_warnings.py
b/tests_common/_internals/capture_warnings.py
index d0cb2a03615..a8e719c24df 100644
--- a/tests_common/_internals/capture_warnings.py
+++ b/tests_common/_internals/capture_warnings.py
@@ -28,12 +28,15 @@ from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
-from typing import Callable
+from typing import TYPE_CHECKING, Callable
import pytest
-from typing_extensions import Literal
-WhenTypeDef = Literal["config", "collect", "runtest"]
+if TYPE_CHECKING:
+ from typing import Literal
+
+ WhenTypeDef = Literal["config", "collect", "runtest"]
+
TESTS_DIR = Path(__file__).parents[1].resolve()