This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 95a930bc0a Consolidate import and usage of itertools (#33479)
95a930bc0a is described below
commit 95a930bc0a720c5548e4fa2e1f74e25f12e9ae1d
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Mon Aug 21 05:36:29 2023 +0000
Consolidate import and usage of itertools (#33479)
---
airflow/configuration.py | 6 +++---
airflow/decorators/base.py | 4 ++--
airflow/lineage/__init__.py | 3 +--
airflow/providers/amazon/aws/hooks/batch_client.py | 4 ++--
airflow/providers/amazon/aws/triggers/batch.py | 4 ++--
airflow/providers/cncf/kubernetes/utils/pod_manager.py | 4 ++--
airflow/utils/helpers.py | 6 +++---
airflow/www/decorators.py | 6 +++---
dev/check_files.py | 4 ++--
docs/build_docs.py | 6 +++---
docs/exts/docs_build/fetch_inventories.py | 4 ++--
docs/exts/docs_build/lint_checks.py | 6 +++---
.../ci/pre_commit/pre_commit_check_deferrable_default.py | 2 +-
.../ci/pre_commit/pre_commit_sort_installed_providers.py | 3 +--
scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py | 3 +--
scripts/in_container/run_provider_yaml_files_check.py | 14 +++++++-------
tests/always/test_project_structure.py | 10 ++++------
tests/models/test_taskmixin.py | 10 +++++++---
tests/providers/amazon/aws/hooks/test_batch_waiters.py | 8 +++-----
tests/providers/apache/hive/transfers/test_s3_to_hive.py | 6 +++---
tests/providers/apache/spark/hooks/test_spark_sql.py | 4 ++--
.../cloud/operators/test_cloud_storage_transfer_service.py | 9 ++++-----
tests/system/conftest.py | 4 ++--
tests/utils/test_helpers.py | 6 +++---
24 files changed, 66 insertions(+), 70 deletions(-)
diff --git a/airflow/configuration.py b/airflow/configuration.py
index 8f6f703ed6..3c9c6975ea 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import datetime
import functools
import io
-import itertools as it
+import itertools
import json
import logging
import multiprocessing
@@ -473,7 +473,7 @@ class AirflowConfigParser(ConfigParser):
:return: list of section names
"""
- return list(dict.fromkeys(it.chain(self.configuration_description,
self.sections())))
+ return
list(dict.fromkeys(itertools.chain(self.configuration_description,
self.sections())))
def get_options_including_defaults(self, section: str) -> list[str]:
"""
@@ -485,7 +485,7 @@ class AirflowConfigParser(ConfigParser):
"""
my_own_options = self.options(section) if self.has_section(section)
else []
all_options_from_defaults =
self.configuration_description.get(section, {}).get("options", {})
- return list(dict.fromkeys(it.chain(all_options_from_defaults,
my_own_options)))
+ return list(dict.fromkeys(itertools.chain(all_options_from_defaults,
my_own_options)))
def optionxform(self, optionstr: str) -> str:
"""
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index c4a4b0ed61..750e1fa1e7 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -17,9 +17,9 @@
from __future__ import annotations
import inspect
+import itertools
import warnings
from functools import cached_property
-from itertools import chain
from textwrap import dedent
from typing import (
Any,
@@ -226,7 +226,7 @@ class DecoratedOperator(BaseOperator):
def execute(self, context: Context):
# todo make this more generic (move to prepare_lineage) so it deals
with non taskflow operators
# as well
- for arg in chain(self.op_args, self.op_kwargs.values()):
+ for arg in itertools.chain(self.op_args, self.op_kwargs.values()):
if isinstance(arg, Dataset):
self.inlets.append(arg)
return_value = super().execute(context)
diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py
index a2fcdf4ed5..4843da12fc 100644
--- a/airflow/lineage/__init__.py
+++ b/airflow/lineage/__init__.py
@@ -18,7 +18,6 @@
"""Provides lineage support functions."""
from __future__ import annotations
-import itertools
import logging
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
@@ -142,7 +141,7 @@ def prepare_lineage(func: T) -> T:
_inlets = self.xcom_pull(
context, task_ids=task_ids, dag_id=self.dag_id,
key=PIPELINE_OUTLETS, session=session
)
- self.inlets.extend(itertools.chain.from_iterable(_inlets))
+ self.inlets.extend(i for it in _inlets for i in it)
elif self.inlets:
raise AttributeError("inlets is not a list, operator, string or
attr annotated object")
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py
b/airflow/providers/amazon/aws/hooks/batch_client.py
index 74dbef1eac..26304ed367 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -26,7 +26,7 @@ A client for AWS Batch services.
"""
from __future__ import annotations
-import itertools as it
+import itertools
from random import uniform
from time import sleep
from typing import Callable
@@ -488,7 +488,7 @@ class BatchClientHook(AwsBaseHook):
# cross stream names with options (i.e. attempts X nodes) to generate
all log infos
result = []
- for stream, option in it.product(stream_names, log_options):
+ for stream, option in itertools.product(stream_names, log_options):
result.append(
{
"awslogs_stream_name": stream,
diff --git a/airflow/providers/amazon/aws/triggers/batch.py
b/airflow/providers/amazon/aws/triggers/batch.py
index 774ce3c4bf..900040afa2 100644
--- a/airflow/providers/amazon/aws/triggers/batch.py
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import asyncio
-import itertools as it
+import itertools
from functools import cached_property
from typing import Any
@@ -162,7 +162,7 @@ class BatchSensorTrigger(BaseTrigger):
"""
async with self.hook.async_conn as client:
waiter = self.hook.get_waiter("batch_job_complete",
deferrable=True, client=client)
- for attempt in it.count(1):
+ for attempt in itertools.count(1):
try:
await waiter.wait(
jobs=[self.job_id],
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 139befdbff..1c2e0ab597 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -18,7 +18,7 @@
from __future__ import annotations
import enum
-import itertools as it
+import itertools
import json
import logging
import math
@@ -628,7 +628,7 @@ class PodManager(LoggingMixin):
def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None:
self.log.info("Checking if xcom sidecar container is started.")
- for attempt in it.count():
+ for attempt in itertools.count():
if self.container_is_running(pod,
PodDefaults.SIDECAR_CONTAINER_NAME):
self.log.info("The xcom sidecar container is started.")
break
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index e07608030d..e55a8e0044 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -18,12 +18,12 @@
from __future__ import annotations
import copy
+import itertools
import re
import signal
import warnings
from datetime import datetime
from functools import reduce
-from itertools import filterfalse, tee
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping,
MutableMapping, TypeVar, cast
from lazy_object_proxy import Proxy
@@ -216,8 +216,8 @@ def merge_dicts(dict1: dict, dict2: dict) -> dict:
def partition(pred: Callable[[T], bool], iterable: Iterable[T]) ->
tuple[Iterable[T], Iterable[T]]:
"""Use a predicate to partition entries into false entries and true
entries."""
- iter_1, iter_2 = tee(iterable)
- return filterfalse(pred, iter_1), filter(pred, iter_2)
+ iter_1, iter_2 = itertools.tee(iterable)
+ return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
def chain(*args, **kwargs):
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index 975910fe50..94c1c34921 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -19,10 +19,10 @@ from __future__ import annotations
import functools
import gzip
+import itertools
import json
import logging
from io import BytesIO as IO
-from itertools import chain
from typing import Callable, TypeVar, cast
import pendulum
@@ -94,7 +94,7 @@ def action_logging(func: Callable | None = None, event: str |
None = None) -> Ca
fields_skip_logging = {"csrf_token", "_csrf_token"}
extra_fields = [
(k, secrets_masker.redact(v, k))
- for k, v in chain(request.values.items(multi=True),
request.view_args.items())
+ for k, v in
itertools.chain(request.values.items(multi=True), request.view_args.items())
if k not in fields_skip_logging
]
if event and event.startswith("variable."):
@@ -102,7 +102,7 @@ def action_logging(func: Callable | None = None, event: str
| None = None) -> Ca
if event and event.startswith("connection."):
extra_fields = _mask_connection_fields(extra_fields)
- params = {k: v for k, v in chain(request.values.items(),
request.view_args.items())}
+ params = {k: v for k, v in
itertools.chain(request.values.items(), request.view_args.items())}
log = Log(
event=event or f.__name__,
diff --git a/dev/check_files.py b/dev/check_files.py
index 52260c1da2..f50875cc01 100644
--- a/dev/check_files.py
+++ b/dev/check_files.py
@@ -16,9 +16,9 @@
# under the License.
from __future__ import annotations
+import itertools
import os
import re
-from itertools import product
import rich_click as click
from rich import print
@@ -141,7 +141,7 @@ def check_release(files: list[str], version: str):
def expand_name_variations(files):
- return sorted(base + suffix for base, suffix in product(files, ["",
".asc", ".sha512"]))
+ return sorted(base + suffix for base, suffix in itertools.product(files,
["", ".asc", ".sha512"]))
def check_upgrade_check(files: list[str], version: str):
diff --git a/docs/build_docs.py b/docs/build_docs.py
index 84ddb4d22c..417854c8d1 100755
--- a/docs/build_docs.py
+++ b/docs/build_docs.py
@@ -23,11 +23,11 @@ Builds documentation and runs spell checking
from __future__ import annotations
import argparse
+import itertools
import multiprocessing
import os
import sys
from collections import defaultdict
-from itertools import filterfalse, tee
from typing import Callable, Iterable, NamedTuple, TypeVar
from rich.console import Console
@@ -74,8 +74,8 @@ T = TypeVar("T")
def partition(pred: Callable[[T], bool], iterable: Iterable[T]) ->
tuple[Iterable[T], Iterable[T]]:
"""Use a predicate to partition entries into false entries and true
entries"""
- iter_1, iter_2 = tee(iterable)
- return filterfalse(pred, iter_1), filter(pred, iter_2)
+ iter_1, iter_2 = itertools.tee(iterable)
+ return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
def _promote_new_flags():
diff --git a/docs/exts/docs_build/fetch_inventories.py
b/docs/exts/docs_build/fetch_inventories.py
index 6db368f0d0..9576a82b32 100644
--- a/docs/exts/docs_build/fetch_inventories.py
+++ b/docs/exts/docs_build/fetch_inventories.py
@@ -19,11 +19,11 @@ from __future__ import annotations
import concurrent
import concurrent.futures
import datetime
+import itertools
import os
import shutil
import sys
import traceback
-from itertools import repeat
from tempfile import NamedTemporaryFile
from typing import Iterator
@@ -142,7 +142,7 @@ def fetch_inventories():
with requests.Session() as session,
concurrent.futures.ThreadPoolExecutor(DEFAULT_POOLSIZE) as pool:
download_results: Iterator[tuple[str, bool]] = pool.map(
_fetch_file,
- repeat(session, len(to_download)),
+ itertools.repeat(session, len(to_download)),
(pkg_name for pkg_name, _, _ in to_download),
(url for _, url, _ in to_download),
(path for _, _, path in to_download),
diff --git a/docs/exts/docs_build/lint_checks.py
b/docs/exts/docs_build/lint_checks.py
index ef254e0976..e536feb680 100644
--- a/docs/exts/docs_build/lint_checks.py
+++ b/docs/exts/docs_build/lint_checks.py
@@ -17,10 +17,10 @@
from __future__ import annotations
import ast
+import itertools
import os
import re
from glob import glob
-from itertools import chain
from typing import Iterable
from docs.exts.docs_build.docs_builder import ALL_PROVIDER_YAMLS
@@ -87,7 +87,7 @@ def check_guide_links_in_operator_descriptions() ->
list[DocBuildError]:
operator_names=find_existing_guide_operator_names(
f"{DOCS_DIR}/apache-airflow/howto/operator/**/*.rst"
),
- python_module_paths=chain(
+ python_module_paths=itertools.chain(
glob(f"{ROOT_PACKAGE_DIR}/operators/*.py"),
glob(f"{ROOT_PACKAGE_DIR}/sensors/*.py"),
),
@@ -101,7 +101,7 @@ def check_guide_links_in_operator_descriptions() ->
list[DocBuildError]:
}
# Extract all potential python modules that can contain operators
- python_module_paths = chain(
+ python_module_paths = itertools.chain(
glob(f"{provider['package-dir']}/**/operators/*.py",
recursive=True),
glob(f"{provider['package-dir']}/**/sensors/*.py", recursive=True),
glob(f"{provider['package-dir']}/**/transfers/*.py",
recursive=True),
diff --git a/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
b/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
index 784d25d522..8373385f0d 100755
--- a/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
+++ b/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
@@ -76,7 +76,7 @@ def iter_check_deferrable_default_errors(module_filename:
str) -> Iterator[str]:
args = node.args
arguments = reversed([*args.args, *args.kwonlyargs])
defaults = reversed([*args.defaults, *args.kw_defaults])
- for argument, default in itertools.zip_longest(arguments, defaults,
fillvalue=None):
+ for argument, default in zip(arguments, defaults):
if argument is None or default is None:
continue
if argument.arg != "deferrable" or
_is_valid_deferrable_default(default):
diff --git a/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
b/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
index e2bd0e2921..7ab17c5dd8 100755
--- a/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
+++ b/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import itertools
from pathlib import Path
if __name__ not in ("__main__", "__mp_main__"):
@@ -35,7 +34,7 @@ def stable_sort(x):
def sort_uniq(sequence):
- return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
+ return sorted(set(sequence), key=stable_sort)
if __name__ == "__main__":
diff --git a/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
b/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
index f9eb8b4a06..41d7a3ce42 100755
--- a/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
+++ b/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import itertools
from pathlib import Path
if __name__ not in ("__main__", "__mp_main__"):
@@ -35,7 +34,7 @@ def stable_sort(x):
def sort_uniq(sequence):
- return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
+ return sorted(set(sequence), key=stable_sort)
if __name__ == "__main__":
diff --git a/scripts/in_container/run_provider_yaml_files_check.py
b/scripts/in_container/run_provider_yaml_files_check.py
index ae523eb4ba..3b5b321135 100755
--- a/scripts/in_container/run_provider_yaml_files_check.py
+++ b/scripts/in_container/run_provider_yaml_files_check.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import importlib
import inspect
+import itertools
import json
import os
import pathlib
@@ -27,7 +28,6 @@ import sys
import textwrap
from collections import Counter
from enum import Enum
-from itertools import chain, product
from typing import Any, Iterable
import jsonschema
@@ -219,7 +219,7 @@ def check_if_objects_exist_and_belong_to_package(
def parse_module_data(provider_data, resource_type, yaml_file_path):
package_dir = ROOT_DIR.joinpath(yaml_file_path).parent
provider_package =
pathlib.Path(yaml_file_path).parent.as_posix().replace("/", ".")
- py_files = chain(
+ py_files = itertools.chain(
package_dir.glob(f"**/{resource_type}/*.py"),
package_dir.glob(f"{resource_type}/*.py"),
package_dir.glob(f"**/{resource_type}/**/*.py"),
@@ -233,7 +233,7 @@ def parse_module_data(provider_data, resource_type,
yaml_file_path):
def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files:
dict[str, dict]):
print("Checking completeness of list of {sensors, hooks, operators,
triggers}")
print(" -- {sensors, hooks, operators, triggers} - Expected modules (left)
: Current modules (right)")
- for (yaml_file_path, provider_data), resource_type in product(
+ for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
expected_modules, provider_package, resource_data = parse_module_data(
@@ -257,7 +257,7 @@ def
check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict
def
check_duplicates_in_integrations_names_of_hooks_sensors_operators(yaml_files:
dict[str, dict]):
print("Checking for duplicates in list of {sensors, hooks, operators,
triggers}")
- for (yaml_file_path, provider_data), resource_type in product(
+ for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
resource_data = provider_data.get(resource_type, [])
@@ -362,7 +362,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
print("Detect unregistered integrations")
all_integration_names = set(get_all_integration_names(yaml_files))
- for (yaml_file_path, provider_data), resource_type in product(
+ for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
resource_data = provider_data.get(resource_type, [])
@@ -374,7 +374,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
f"Invalid values: {invalid_names}"
)
- for (yaml_file_path, provider_data), key in product(
+ for (yaml_file_path, provider_data), key in itertools.product(
yaml_files.items(), ["source-integration-name",
"target-integration-name"]
):
resource_data = provider_data.get("transfers", [])
@@ -409,7 +409,7 @@ def check_doc_files(yaml_files: dict[str, dict]):
console.print("[yellow]Suspended providers:[/]")
console.print(suspended_providers)
- expected_doc_files = chain(
+ expected_doc_files = itertools.chain(
DOCS_DIR.glob("apache-airflow-providers-*/operators/**/*.rst"),
DOCS_DIR.glob("apache-airflow-providers-*/transfer/**/*.rst"),
)
diff --git a/tests/always/test_project_structure.py
b/tests/always/test_project_structure.py
index 116f8f99d2..a518c9f3d2 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -455,12 +455,10 @@ class
TestDockerProviderProjectStructure(ExampleCoverageTest):
class TestOperatorsHooks:
def test_no_illegal_suffixes(self):
illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"]
- files = itertools.chain(
- *(
-
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py",
recursive=True)
- for resource_type in ["operators", "hooks", "sensors",
"example_dags"]
- for part in ["airflow", "tests"]
- )
+ files = itertools.chain.from_iterable(
+
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py",
recursive=True)
+ for resource_type in ["operators", "hooks", "sensors",
"example_dags"]
+ for part in ["airflow", "tests"]
)
invalid_files = [f for f in files if
f.endswith(tuple(illegal_suffixes))]
diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py
index 95aefd0faa..2435d6711a 100644
--- a/tests/models/test_taskmixin.py
+++ b/tests/models/test_taskmixin.py
@@ -17,7 +17,7 @@
from __future__ import annotations
-from itertools import product
+import itertools
import pytest
@@ -67,7 +67,9 @@ def make_task(name, type_, setup_=False, teardown_=False):
return my_task.override(task_id=name)()
[email protected]("setup_type, work_type, teardown_type", product(*3 *
[["classic", "taskflow"]]))
[email protected](
+ "setup_type, work_type, teardown_type", itertools.product(["classic",
"taskflow"], repeat=3)
+)
def test_as_teardown(dag_maker, setup_type, work_type, teardown_type):
"""
Check that as_teardown works properly as implemented in PlainXComArg
@@ -98,7 +100,9 @@ def test_as_teardown(dag_maker, setup_type, work_type,
teardown_type):
assert get_task_attr(t1, "upstream_task_ids") == {"w1", "s1"}
[email protected]("setup_type, work_type, teardown_type", product(*3 *
[["classic", "taskflow"]]))
[email protected](
+ "setup_type, work_type, teardown_type", itertools.product(["classic",
"taskflow"], repeat=3)
+)
def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type):
"""
Check that as_teardown implementations work properly. Tests all
combinations of taskflow and classic.
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index cdf581417c..206ce68857 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -274,11 +274,9 @@ class TestBatchJobWaiters:
self.mock_describe_jobs.side_effect = [
# Emulate change job status before one of expected states.
# SUBMITTED -> PENDING -> RUNNABLE -> STARTING
- *itertools.chain(
- *[
-
itertools.repeat(self.describe_jobs_response(job_id=job_id,
status=inter_status), 3)
- for inter_status in INTERMEDIATE_STATES
- ]
+ *itertools.chain.from_iterable(
+ itertools.repeat(self.describe_jobs_response(job_id=job_id,
status=inter_status), 3)
+ for inter_status in INTERMEDIATE_STATES
),
# Expected status
self.describe_jobs_response(job_id=job_id, status=status),
diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py
b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
index c84a78828e..3f674ec3fa 100644
--- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
@@ -20,10 +20,10 @@ from __future__ import annotations
import bz2
import errno
import filecmp
+import itertools
import logging
import shutil
from gzip import GzipFile
-from itertools import product
from tempfile import NamedTemporaryFile, mkdtemp
from unittest import mock
@@ -204,7 +204,7 @@ class TestS3ToHiveTransfer:
)
# Testing txt, zip, bz2 files with and without header row
- for (ext, has_header) in product([".txt", ".gz", ".bz2", ".GZ"],
[True, False]):
+ for ext, has_header in itertools.product([".txt", ".gz", ".bz2",
".GZ"], [True, False]):
self.kwargs["headers"] = has_header
self.kwargs["check_headers"] = has_header
logging.info("Testing %s format %s header", ext, "with" if
has_header else "without")
@@ -242,7 +242,7 @@ class TestS3ToHiveTransfer:
# Only testing S3ToHiveTransfer calls S3Hook.select_key with
# the right parameters and its execute method succeeds here,
# since Moto doesn't support select_object_content as of 1.3.2.
- for (ext, has_header) in product([".txt", ".gz", ".GZ"], [True,
False]):
+ for ext, has_header in itertools.product([".txt", ".gz", ".GZ"],
[True, False]):
input_compressed = ext.lower() != ".txt"
key = self.s3_key + ext
diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py
b/tests/providers/apache/spark/hooks/test_spark_sql.py
index 1666c51946..9bd46e6ce3 100644
--- a/tests/providers/apache/spark/hooks/test_spark_sql.py
+++ b/tests/providers/apache/spark/hooks/test_spark_sql.py
@@ -18,7 +18,7 @@
from __future__ import annotations
import io
-from itertools import dropwhile
+import itertools
from unittest.mock import call, patch
import pytest
@@ -32,7 +32,7 @@ from tests.test_utils.db import clear_db_connections
def get_after(sentinel, iterable):
"""Get the value after `sentinel` in an `iterable`"""
- truncated = dropwhile(lambda el: el != sentinel, iterable)
+ truncated = itertools.dropwhile(lambda el: el != sentinel, iterable)
next(truncated)
return next(truncated)
diff --git
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 7684d8476c..3b63251510 100644
---
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import itertools
from copy import deepcopy
from datetime import date, time
from unittest import mock
@@ -220,10 +219,10 @@ class TestTransferJobValidator:
@pytest.mark.parametrize(
"transfer_spec",
[
- dict(itertools.chain(SOURCE_AWS.items(), SOURCE_GCS.items(),
SOURCE_HTTP.items())),
- dict(itertools.chain(SOURCE_AWS.items(), SOURCE_GCS.items())),
- dict(itertools.chain(SOURCE_AWS.items(), SOURCE_HTTP.items())),
- dict(itertools.chain(SOURCE_GCS.items(), SOURCE_HTTP.items())),
+ {**SOURCE_AWS, **SOURCE_GCS, **SOURCE_HTTP},
+ {**SOURCE_AWS, **SOURCE_GCS},
+ {**SOURCE_AWS, **SOURCE_HTTP},
+ {**SOURCE_GCS, **SOURCE_HTTP},
],
)
def test_verify_data_source(self, transfer_spec):
diff --git a/tests/system/conftest.py b/tests/system/conftest.py
index 154e7c208f..58eca1287c 100644
--- a/tests/system/conftest.py
+++ b/tests/system/conftest.py
@@ -16,9 +16,9 @@
# under the License.
from __future__ import annotations
+import itertools
import os
import re
-from itertools import chain
from pathlib import Path
from unittest import mock
@@ -41,7 +41,7 @@ def provider_env_vars():
@pytest.fixture(autouse=True)
def skip_if_env_var_not_set(provider_env_vars):
- for env in chain(REQUIRED_ENV_VARS, provider_env_vars):
+ for env in itertools.chain(REQUIRED_ENV_VARS, provider_env_vars):
if env not in os.environ:
pytest.skip(f"Missing required environment variable {env}")
return
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index 9d6020874c..c3c370060a 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -17,8 +17,8 @@
# under the License.
from __future__ import annotations
+import itertools
import re
-from itertools import product
import pytest
@@ -264,7 +264,7 @@ class TestHelpers:
expected = True if true + truthy == 1 else False
assert exactly_one(*sample) is expected
- for row in product(range(4), range(4), range(4), range(4)):
+ for row in itertools.product(range(4), repeat=4):
assert_exactly_one(*row)
def test_exactly_one_should_fail(self):
@@ -295,7 +295,7 @@ class TestHelpers:
expected = True if true + truthy in (0, 1) else False
assert at_most_one(*sample) is expected
- for row in product(range(4), range(4), range(4), range(4), range(4)):
+ for row in itertools.product(range(4), repeat=4):
print(row)
assert_at_most_one(*row)