This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 95a930bc0a Consolidate import and usage of itertools (#33479)
95a930bc0a is described below

commit 95a930bc0a720c5548e4fa2e1f74e25f12e9ae1d
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Mon Aug 21 05:36:29 2023 +0000

    Consolidate import and usage of itertools (#33479)
---
 airflow/configuration.py                                   |  6 +++---
 airflow/decorators/base.py                                 |  4 ++--
 airflow/lineage/__init__.py                                |  3 +--
 airflow/providers/amazon/aws/hooks/batch_client.py         |  4 ++--
 airflow/providers/amazon/aws/triggers/batch.py             |  4 ++--
 airflow/providers/cncf/kubernetes/utils/pod_manager.py     |  4 ++--
 airflow/utils/helpers.py                                   |  6 +++---
 airflow/www/decorators.py                                  |  6 +++---
 dev/check_files.py                                         |  4 ++--
 docs/build_docs.py                                         |  6 +++---
 docs/exts/docs_build/fetch_inventories.py                  |  4 ++--
 docs/exts/docs_build/lint_checks.py                        |  6 +++---
 .../ci/pre_commit/pre_commit_check_deferrable_default.py   |  2 +-
 .../ci/pre_commit/pre_commit_sort_installed_providers.py   |  3 +--
 scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py |  3 +--
 scripts/in_container/run_provider_yaml_files_check.py      | 14 +++++++-------
 tests/always/test_project_structure.py                     | 10 ++++------
 tests/models/test_taskmixin.py                             | 10 +++++++---
 tests/providers/amazon/aws/hooks/test_batch_waiters.py     |  8 +++-----
 tests/providers/apache/hive/transfers/test_s3_to_hive.py   |  6 +++---
 tests/providers/apache/spark/hooks/test_spark_sql.py       |  4 ++--
 .../cloud/operators/test_cloud_storage_transfer_service.py |  9 ++++-----
 tests/system/conftest.py                                   |  4 ++--
 tests/utils/test_helpers.py                                |  6 +++---
 24 files changed, 66 insertions(+), 70 deletions(-)

diff --git a/airflow/configuration.py b/airflow/configuration.py
index 8f6f703ed6..3c9c6975ea 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import datetime
 import functools
 import io
-import itertools as it
+import itertools
 import json
 import logging
 import multiprocessing
@@ -473,7 +473,7 @@ class AirflowConfigParser(ConfigParser):
 
         :return: list of section names
         """
-        return list(dict.fromkeys(it.chain(self.configuration_description, 
self.sections())))
+        return 
list(dict.fromkeys(itertools.chain(self.configuration_description, 
self.sections())))
 
     def get_options_including_defaults(self, section: str) -> list[str]:
         """
@@ -485,7 +485,7 @@ class AirflowConfigParser(ConfigParser):
         """
         my_own_options = self.options(section) if self.has_section(section) 
else []
         all_options_from_defaults = 
self.configuration_description.get(section, {}).get("options", {})
-        return list(dict.fromkeys(it.chain(all_options_from_defaults, 
my_own_options)))
+        return list(dict.fromkeys(itertools.chain(all_options_from_defaults, 
my_own_options)))
 
     def optionxform(self, optionstr: str) -> str:
         """
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index c4a4b0ed61..750e1fa1e7 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -17,9 +17,9 @@
 from __future__ import annotations
 
 import inspect
+import itertools
 import warnings
 from functools import cached_property
-from itertools import chain
 from textwrap import dedent
 from typing import (
     Any,
@@ -226,7 +226,7 @@ class DecoratedOperator(BaseOperator):
     def execute(self, context: Context):
         # todo make this more generic (move to prepare_lineage) so it deals 
with non taskflow operators
         #  as well
-        for arg in chain(self.op_args, self.op_kwargs.values()):
+        for arg in itertools.chain(self.op_args, self.op_kwargs.values()):
             if isinstance(arg, Dataset):
                 self.inlets.append(arg)
         return_value = super().execute(context)
diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py
index a2fcdf4ed5..4843da12fc 100644
--- a/airflow/lineage/__init__.py
+++ b/airflow/lineage/__init__.py
@@ -18,7 +18,6 @@
 """Provides lineage support functions."""
 from __future__ import annotations
 
-import itertools
 import logging
 from functools import wraps
 from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
@@ -142,7 +141,7 @@ def prepare_lineage(func: T) -> T:
                 _inlets = self.xcom_pull(
                     context, task_ids=task_ids, dag_id=self.dag_id, 
key=PIPELINE_OUTLETS, session=session
                 )
-                self.inlets.extend(itertools.chain.from_iterable(_inlets))
+                self.inlets.extend(i for it in _inlets for i in it)
 
         elif self.inlets:
             raise AttributeError("inlets is not a list, operator, string or 
attr annotated object")
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py 
b/airflow/providers/amazon/aws/hooks/batch_client.py
index 74dbef1eac..26304ed367 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -26,7 +26,7 @@ A client for AWS Batch services.
 """
 from __future__ import annotations
 
-import itertools as it
+import itertools
 from random import uniform
 from time import sleep
 from typing import Callable
@@ -488,7 +488,7 @@ class BatchClientHook(AwsBaseHook):
 
         # cross stream names with options (i.e. attempts X nodes) to generate 
all log infos
         result = []
-        for stream, option in it.product(stream_names, log_options):
+        for stream, option in itertools.product(stream_names, log_options):
             result.append(
                 {
                     "awslogs_stream_name": stream,
diff --git a/airflow/providers/amazon/aws/triggers/batch.py 
b/airflow/providers/amazon/aws/triggers/batch.py
index 774ce3c4bf..900040afa2 100644
--- a/airflow/providers/amazon/aws/triggers/batch.py
+++ b/airflow/providers/amazon/aws/triggers/batch.py
@@ -17,7 +17,7 @@
 from __future__ import annotations
 
 import asyncio
-import itertools as it
+import itertools
 from functools import cached_property
 from typing import Any
 
@@ -162,7 +162,7 @@ class BatchSensorTrigger(BaseTrigger):
         """
         async with self.hook.async_conn as client:
             waiter = self.hook.get_waiter("batch_job_complete", 
deferrable=True, client=client)
-            for attempt in it.count(1):
+            for attempt in itertools.count(1):
                 try:
                     await waiter.wait(
                         jobs=[self.job_id],
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py 
b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 139befdbff..1c2e0ab597 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import enum
-import itertools as it
+import itertools
 import json
 import logging
 import math
@@ -628,7 +628,7 @@ class PodManager(LoggingMixin):
 
     def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None:
         self.log.info("Checking if xcom sidecar container is started.")
-        for attempt in it.count():
+        for attempt in itertools.count():
             if self.container_is_running(pod, 
PodDefaults.SIDECAR_CONTAINER_NAME):
                 self.log.info("The xcom sidecar container is started.")
                 break
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index e07608030d..e55a8e0044 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -18,12 +18,12 @@
 from __future__ import annotations
 
 import copy
+import itertools
 import re
 import signal
 import warnings
 from datetime import datetime
 from functools import reduce
-from itertools import filterfalse, tee
 from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, 
MutableMapping, TypeVar, cast
 
 from lazy_object_proxy import Proxy
@@ -216,8 +216,8 @@ def merge_dicts(dict1: dict, dict2: dict) -> dict:
 
 def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> 
tuple[Iterable[T], Iterable[T]]:
     """Use a predicate to partition entries into false entries and true 
entries."""
-    iter_1, iter_2 = tee(iterable)
-    return filterfalse(pred, iter_1), filter(pred, iter_2)
+    iter_1, iter_2 = itertools.tee(iterable)
+    return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
 
 
 def chain(*args, **kwargs):
diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py
index 975910fe50..94c1c34921 100644
--- a/airflow/www/decorators.py
+++ b/airflow/www/decorators.py
@@ -19,10 +19,10 @@ from __future__ import annotations
 
 import functools
 import gzip
+import itertools
 import json
 import logging
 from io import BytesIO as IO
-from itertools import chain
 from typing import Callable, TypeVar, cast
 
 import pendulum
@@ -94,7 +94,7 @@ def action_logging(func: Callable | None = None, event: str | 
None = None) -> Ca
                 fields_skip_logging = {"csrf_token", "_csrf_token"}
                 extra_fields = [
                     (k, secrets_masker.redact(v, k))
-                    for k, v in chain(request.values.items(multi=True), 
request.view_args.items())
+                    for k, v in 
itertools.chain(request.values.items(multi=True), request.view_args.items())
                     if k not in fields_skip_logging
                 ]
                 if event and event.startswith("variable."):
@@ -102,7 +102,7 @@ def action_logging(func: Callable | None = None, event: str 
| None = None) -> Ca
                 if event and event.startswith("connection."):
                     extra_fields = _mask_connection_fields(extra_fields)
 
-                params = {k: v for k, v in chain(request.values.items(), 
request.view_args.items())}
+                params = {k: v for k, v in 
itertools.chain(request.values.items(), request.view_args.items())}
 
                 log = Log(
                     event=event or f.__name__,
diff --git a/dev/check_files.py b/dev/check_files.py
index 52260c1da2..f50875cc01 100644
--- a/dev/check_files.py
+++ b/dev/check_files.py
@@ -16,9 +16,9 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import os
 import re
-from itertools import product
 
 import rich_click as click
 from rich import print
@@ -141,7 +141,7 @@ def check_release(files: list[str], version: str):
 
 
 def expand_name_variations(files):
-    return sorted(base + suffix for base, suffix in product(files, ["", 
".asc", ".sha512"]))
+    return sorted(base + suffix for base, suffix in itertools.product(files, 
["", ".asc", ".sha512"]))
 
 
 def check_upgrade_check(files: list[str], version: str):
diff --git a/docs/build_docs.py b/docs/build_docs.py
index 84ddb4d22c..417854c8d1 100755
--- a/docs/build_docs.py
+++ b/docs/build_docs.py
@@ -23,11 +23,11 @@ Builds documentation and runs spell checking
 from __future__ import annotations
 
 import argparse
+import itertools
 import multiprocessing
 import os
 import sys
 from collections import defaultdict
-from itertools import filterfalse, tee
 from typing import Callable, Iterable, NamedTuple, TypeVar
 
 from rich.console import Console
@@ -74,8 +74,8 @@ T = TypeVar("T")
 
 def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> 
tuple[Iterable[T], Iterable[T]]:
     """Use a predicate to partition entries into false entries and true 
entries"""
-    iter_1, iter_2 = tee(iterable)
-    return filterfalse(pred, iter_1), filter(pred, iter_2)
+    iter_1, iter_2 = itertools.tee(iterable)
+    return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
 
 
 def _promote_new_flags():
diff --git a/docs/exts/docs_build/fetch_inventories.py 
b/docs/exts/docs_build/fetch_inventories.py
index 6db368f0d0..9576a82b32 100644
--- a/docs/exts/docs_build/fetch_inventories.py
+++ b/docs/exts/docs_build/fetch_inventories.py
@@ -19,11 +19,11 @@ from __future__ import annotations
 import concurrent
 import concurrent.futures
 import datetime
+import itertools
 import os
 import shutil
 import sys
 import traceback
-from itertools import repeat
 from tempfile import NamedTemporaryFile
 from typing import Iterator
 
@@ -142,7 +142,7 @@ def fetch_inventories():
     with requests.Session() as session, 
concurrent.futures.ThreadPoolExecutor(DEFAULT_POOLSIZE) as pool:
         download_results: Iterator[tuple[str, bool]] = pool.map(
             _fetch_file,
-            repeat(session, len(to_download)),
+            itertools.repeat(session, len(to_download)),
             (pkg_name for pkg_name, _, _ in to_download),
             (url for _, url, _ in to_download),
             (path for _, _, path in to_download),
diff --git a/docs/exts/docs_build/lint_checks.py 
b/docs/exts/docs_build/lint_checks.py
index ef254e0976..e536feb680 100644
--- a/docs/exts/docs_build/lint_checks.py
+++ b/docs/exts/docs_build/lint_checks.py
@@ -17,10 +17,10 @@
 from __future__ import annotations
 
 import ast
+import itertools
 import os
 import re
 from glob import glob
-from itertools import chain
 from typing import Iterable
 
 from docs.exts.docs_build.docs_builder import ALL_PROVIDER_YAMLS
@@ -87,7 +87,7 @@ def check_guide_links_in_operator_descriptions() -> 
list[DocBuildError]:
             operator_names=find_existing_guide_operator_names(
                 f"{DOCS_DIR}/apache-airflow/howto/operator/**/*.rst"
             ),
-            python_module_paths=chain(
+            python_module_paths=itertools.chain(
                 glob(f"{ROOT_PACKAGE_DIR}/operators/*.py"),
                 glob(f"{ROOT_PACKAGE_DIR}/sensors/*.py"),
             ),
@@ -101,7 +101,7 @@ def check_guide_links_in_operator_descriptions() -> 
list[DocBuildError]:
         }
 
         # Extract all potential python modules that can contain operators
-        python_module_paths = chain(
+        python_module_paths = itertools.chain(
             glob(f"{provider['package-dir']}/**/operators/*.py", 
recursive=True),
             glob(f"{provider['package-dir']}/**/sensors/*.py", recursive=True),
             glob(f"{provider['package-dir']}/**/transfers/*.py", 
recursive=True),
diff --git a/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py 
b/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
index 784d25d522..8373385f0d 100755
--- a/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
+++ b/scripts/ci/pre_commit/pre_commit_check_deferrable_default.py
@@ -76,7 +76,7 @@ def iter_check_deferrable_default_errors(module_filename: 
str) -> Iterator[str]:
         args = node.args
         arguments = reversed([*args.args, *args.kwonlyargs])
         defaults = reversed([*args.defaults, *args.kw_defaults])
-        for argument, default in itertools.zip_longest(arguments, defaults, 
fillvalue=None):
+        for argument, default in zip(arguments, defaults):
             if argument is None or default is None:
                 continue
             if argument.arg != "deferrable" or 
_is_valid_deferrable_default(default):
diff --git a/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py 
b/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
index e2bd0e2921..7ab17c5dd8 100755
--- a/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
+++ b/scripts/ci/pre_commit/pre_commit_sort_installed_providers.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import itertools
 from pathlib import Path
 
 if __name__ not in ("__main__", "__mp_main__"):
@@ -35,7 +34,7 @@ def stable_sort(x):
 
 
 def sort_uniq(sequence):
-    return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
+    return sorted(set(sequence), key=stable_sort)
 
 
 if __name__ == "__main__":
diff --git a/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py 
b/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
index f9eb8b4a06..41d7a3ce42 100755
--- a/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
+++ b/scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import itertools
 from pathlib import Path
 
 if __name__ not in ("__main__", "__mp_main__"):
@@ -35,7 +34,7 @@ def stable_sort(x):
 
 
 def sort_uniq(sequence):
-    return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
+    return sorted(set(sequence), key=stable_sort)
 
 
 if __name__ == "__main__":
diff --git a/scripts/in_container/run_provider_yaml_files_check.py 
b/scripts/in_container/run_provider_yaml_files_check.py
index ae523eb4ba..3b5b321135 100755
--- a/scripts/in_container/run_provider_yaml_files_check.py
+++ b/scripts/in_container/run_provider_yaml_files_check.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 
 import importlib
 import inspect
+import itertools
 import json
 import os
 import pathlib
@@ -27,7 +28,6 @@ import sys
 import textwrap
 from collections import Counter
 from enum import Enum
-from itertools import chain, product
 from typing import Any, Iterable
 
 import jsonschema
@@ -219,7 +219,7 @@ def check_if_objects_exist_and_belong_to_package(
 def parse_module_data(provider_data, resource_type, yaml_file_path):
     package_dir = ROOT_DIR.joinpath(yaml_file_path).parent
     provider_package = 
pathlib.Path(yaml_file_path).parent.as_posix().replace("/", ".")
-    py_files = chain(
+    py_files = itertools.chain(
         package_dir.glob(f"**/{resource_type}/*.py"),
         package_dir.glob(f"{resource_type}/*.py"),
         package_dir.glob(f"**/{resource_type}/**/*.py"),
@@ -233,7 +233,7 @@ def parse_module_data(provider_data, resource_type, 
yaml_file_path):
 def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: 
dict[str, dict]):
     print("Checking completeness of list of {sensors, hooks, operators, 
triggers}")
     print(" -- {sensors, hooks, operators, triggers} - Expected modules (left) 
: Current modules (right)")
-    for (yaml_file_path, provider_data), resource_type in product(
+    for (yaml_file_path, provider_data), resource_type in itertools.product(
         yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
     ):
         expected_modules, provider_package, resource_data = parse_module_data(
@@ -257,7 +257,7 @@ def 
check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict
 
 def 
check_duplicates_in_integrations_names_of_hooks_sensors_operators(yaml_files: 
dict[str, dict]):
     print("Checking for duplicates in list of {sensors, hooks, operators, 
triggers}")
-    for (yaml_file_path, provider_data), resource_type in product(
+    for (yaml_file_path, provider_data), resource_type in itertools.product(
         yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
     ):
         resource_data = provider_data.get(resource_type, [])
@@ -362,7 +362,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
     print("Detect unregistered integrations")
     all_integration_names = set(get_all_integration_names(yaml_files))
 
-    for (yaml_file_path, provider_data), resource_type in product(
+    for (yaml_file_path, provider_data), resource_type in itertools.product(
         yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
     ):
         resource_data = provider_data.get(resource_type, [])
@@ -374,7 +374,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
                 f"Invalid values: {invalid_names}"
             )
 
-    for (yaml_file_path, provider_data), key in product(
+    for (yaml_file_path, provider_data), key in itertools.product(
         yaml_files.items(), ["source-integration-name", 
"target-integration-name"]
     ):
         resource_data = provider_data.get("transfers", [])
@@ -409,7 +409,7 @@ def check_doc_files(yaml_files: dict[str, dict]):
     console.print("[yellow]Suspended providers:[/]")
     console.print(suspended_providers)
 
-    expected_doc_files = chain(
+    expected_doc_files = itertools.chain(
         DOCS_DIR.glob("apache-airflow-providers-*/operators/**/*.rst"),
         DOCS_DIR.glob("apache-airflow-providers-*/transfer/**/*.rst"),
     )
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 116f8f99d2..a518c9f3d2 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -455,12 +455,10 @@ class 
TestDockerProviderProjectStructure(ExampleCoverageTest):
 class TestOperatorsHooks:
     def test_no_illegal_suffixes(self):
         illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"]
-        files = itertools.chain(
-            *(
-                
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", 
recursive=True)
-                for resource_type in ["operators", "hooks", "sensors", 
"example_dags"]
-                for part in ["airflow", "tests"]
-            )
+        files = itertools.chain.from_iterable(
+            
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", 
recursive=True)
+            for resource_type in ["operators", "hooks", "sensors", 
"example_dags"]
+            for part in ["airflow", "tests"]
         )
 
         invalid_files = [f for f in files if 
f.endswith(tuple(illegal_suffixes))]
diff --git a/tests/models/test_taskmixin.py b/tests/models/test_taskmixin.py
index 95aefd0faa..2435d6711a 100644
--- a/tests/models/test_taskmixin.py
+++ b/tests/models/test_taskmixin.py
@@ -17,7 +17,7 @@
 
 from __future__ import annotations
 
-from itertools import product
+import itertools
 
 import pytest
 
@@ -67,7 +67,9 @@ def make_task(name, type_, setup_=False, teardown_=False):
         return my_task.override(task_id=name)()
 
 
[email protected]("setup_type, work_type, teardown_type", product(*3 * 
[["classic", "taskflow"]]))
[email protected](
+    "setup_type, work_type, teardown_type", itertools.product(["classic", 
"taskflow"], repeat=3)
+)
 def test_as_teardown(dag_maker, setup_type, work_type, teardown_type):
     """
     Check that as_teardown works properly as implemented in PlainXComArg
@@ -98,7 +100,9 @@ def test_as_teardown(dag_maker, setup_type, work_type, 
teardown_type):
     assert get_task_attr(t1, "upstream_task_ids") == {"w1", "s1"}
 
 
[email protected]("setup_type, work_type, teardown_type", product(*3 * 
[["classic", "taskflow"]]))
[email protected](
+    "setup_type, work_type, teardown_type", itertools.product(["classic", 
"taskflow"], repeat=3)
+)
 def test_as_teardown_oneline(dag_maker, setup_type, work_type, teardown_type):
     """
     Check that as_teardown implementations work properly. Tests all 
combinations of taskflow and classic.
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py 
b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index cdf581417c..206ce68857 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -274,11 +274,9 @@ class TestBatchJobWaiters:
         self.mock_describe_jobs.side_effect = [
             # Emulate change job status before one of expected states.
             # SUBMITTED -> PENDING -> RUNNABLE -> STARTING
-            *itertools.chain(
-                *[
-                    
itertools.repeat(self.describe_jobs_response(job_id=job_id, 
status=inter_status), 3)
-                    for inter_status in INTERMEDIATE_STATES
-                ]
+            *itertools.chain.from_iterable(
+                itertools.repeat(self.describe_jobs_response(job_id=job_id, 
status=inter_status), 3)
+                for inter_status in INTERMEDIATE_STATES
             ),
             # Expected status
             self.describe_jobs_response(job_id=job_id, status=status),
diff --git a/tests/providers/apache/hive/transfers/test_s3_to_hive.py 
b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
index c84a78828e..3f674ec3fa 100644
--- a/tests/providers/apache/hive/transfers/test_s3_to_hive.py
+++ b/tests/providers/apache/hive/transfers/test_s3_to_hive.py
@@ -20,10 +20,10 @@ from __future__ import annotations
 import bz2
 import errno
 import filecmp
+import itertools
 import logging
 import shutil
 from gzip import GzipFile
-from itertools import product
 from tempfile import NamedTemporaryFile, mkdtemp
 from unittest import mock
 
@@ -204,7 +204,7 @@ class TestS3ToHiveTransfer:
             )
 
         # Testing txt, zip, bz2 files with and without header row
-        for (ext, has_header) in product([".txt", ".gz", ".bz2", ".GZ"], 
[True, False]):
+        for ext, has_header in itertools.product([".txt", ".gz", ".bz2", 
".GZ"], [True, False]):
             self.kwargs["headers"] = has_header
             self.kwargs["check_headers"] = has_header
             logging.info("Testing %s format %s header", ext, "with" if 
has_header else "without")
@@ -242,7 +242,7 @@ class TestS3ToHiveTransfer:
         # Only testing S3ToHiveTransfer calls S3Hook.select_key with
         # the right parameters and its execute method succeeds here,
         # since Moto doesn't support select_object_content as of 1.3.2.
-        for (ext, has_header) in product([".txt", ".gz", ".GZ"], [True, 
False]):
+        for ext, has_header in itertools.product([".txt", ".gz", ".GZ"], 
[True, False]):
             input_compressed = ext.lower() != ".txt"
             key = self.s3_key + ext
 
diff --git a/tests/providers/apache/spark/hooks/test_spark_sql.py 
b/tests/providers/apache/spark/hooks/test_spark_sql.py
index 1666c51946..9bd46e6ce3 100644
--- a/tests/providers/apache/spark/hooks/test_spark_sql.py
+++ b/tests/providers/apache/spark/hooks/test_spark_sql.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import io
-from itertools import dropwhile
+import itertools
 from unittest.mock import call, patch
 
 import pytest
@@ -32,7 +32,7 @@ from tests.test_utils.db import clear_db_connections
 
 def get_after(sentinel, iterable):
     """Get the value after `sentinel` in an `iterable`"""
-    truncated = dropwhile(lambda el: el != sentinel, iterable)
+    truncated = itertools.dropwhile(lambda el: el != sentinel, iterable)
     next(truncated)
     return next(truncated)
 
diff --git 
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py 
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
index 7684d8476c..3b63251510 100644
--- 
a/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
+++ 
b/tests/providers/google/cloud/operators/test_cloud_storage_transfer_service.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import itertools
 from copy import deepcopy
 from datetime import date, time
 from unittest import mock
@@ -220,10 +219,10 @@ class TestTransferJobValidator:
     @pytest.mark.parametrize(
         "transfer_spec",
         [
-            dict(itertools.chain(SOURCE_AWS.items(), SOURCE_GCS.items(), 
SOURCE_HTTP.items())),
-            dict(itertools.chain(SOURCE_AWS.items(), SOURCE_GCS.items())),
-            dict(itertools.chain(SOURCE_AWS.items(), SOURCE_HTTP.items())),
-            dict(itertools.chain(SOURCE_GCS.items(), SOURCE_HTTP.items())),
+            {**SOURCE_AWS, **SOURCE_GCS, **SOURCE_HTTP},
+            {**SOURCE_AWS, **SOURCE_GCS},
+            {**SOURCE_AWS, **SOURCE_HTTP},
+            {**SOURCE_GCS, **SOURCE_HTTP},
         ],
     )
     def test_verify_data_source(self, transfer_spec):
diff --git a/tests/system/conftest.py b/tests/system/conftest.py
index 154e7c208f..58eca1287c 100644
--- a/tests/system/conftest.py
+++ b/tests/system/conftest.py
@@ -16,9 +16,9 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import os
 import re
-from itertools import chain
 from pathlib import Path
 from unittest import mock
 
@@ -41,7 +41,7 @@ def provider_env_vars():
 
 @pytest.fixture(autouse=True)
 def skip_if_env_var_not_set(provider_env_vars):
-    for env in chain(REQUIRED_ENV_VARS, provider_env_vars):
+    for env in itertools.chain(REQUIRED_ENV_VARS, provider_env_vars):
         if env not in os.environ:
             pytest.skip(f"Missing required environment variable {env}")
             return
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index 9d6020874c..c3c370060a 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -17,8 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
 import re
-from itertools import product
 
 import pytest
 
@@ -264,7 +264,7 @@ class TestHelpers:
                 expected = True if true + truthy == 1 else False
                 assert exactly_one(*sample) is expected
 
-        for row in product(range(4), range(4), range(4), range(4)):
+        for row in itertools.product(range(4), repeat=4):
             assert_exactly_one(*row)
 
     def test_exactly_one_should_fail(self):
@@ -295,7 +295,7 @@ class TestHelpers:
                 expected = True if true + truthy in (0, 1) else False
                 assert at_most_one(*sample) is expected
 
-        for row in product(range(4), range(4), range(4), range(4), range(4)):
+        for row in itertools.product(range(4), repeat=4):
             print(row)
             assert_at_most_one(*row)
 

Reply via email to