This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d182beec6 Use a single  statement with multiple contexts instead of 
nested  statements in providers (#33768)
6d182beec6 is described below

commit 6d182beec6e86b372c37fb164a31c2f8811d8c03
Author: Hussein Awala <[email protected]>
AuthorDate: Sat Aug 26 12:55:22 2023 +0200

    Use a single  statement with multiple contexts instead of nested  
statements in providers (#33768)
---
 airflow/providers/apache/hive/hooks/hive.py        | 124 ++++++++++-----------
 .../apache/hive/transfers/mysql_to_hive.py         |  29 +++--
 airflow/providers/apache/pig/hooks/pig.py          |  65 ++++++-----
 airflow/providers/dbt/cloud/hooks/dbt.py           |  15 +--
 airflow/providers/exasol/hooks/exasol.py           |  10 +-
 airflow/providers/google/cloud/hooks/gcs.py        |   7 +-
 .../google/cloud/hooks/kubernetes_engine.py        |  58 +++++-----
 airflow/providers/microsoft/azure/hooks/asb.py     |  53 +++++----
 .../providers/mysql/transfers/vertica_to_mysql.py  |  59 +++++-----
 airflow/providers/postgres/hooks/postgres.py       |  10 +-
 .../providers/snowflake/hooks/snowflake_sql_api.py |  11 +-
 11 files changed, 213 insertions(+), 228 deletions(-)

diff --git a/airflow/providers/apache/hive/hooks/hive.py 
b/airflow/providers/apache/hive/hooks/hive.py
index 5b0c91083a..773ea4af7d 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -236,58 +236,55 @@ class HiveCliHook(BaseHook):
         if schema:
             hql = f"USE {schema};\n{hql}"
 
-        with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
-            with NamedTemporaryFile(dir=tmp_dir) as f:
-                hql += "\n"
-                f.write(hql.encode("UTF-8"))
-                f.flush()
-                hive_cmd = self._prepare_cli_cmd()
-                env_context = get_context_from_env_var()
-                # Only extend the hive_conf if it is defined.
-                if hive_conf:
-                    env_context.update(hive_conf)
-                hive_conf_params = self._prepare_hiveconf(env_context)
-                if self.mapred_queue:
-                    hive_conf_params.extend(
-                        [
-                            "-hiveconf",
-                            f"mapreduce.job.queuename={self.mapred_queue}",
-                            "-hiveconf",
-                            f"mapred.job.queue.name={self.mapred_queue}",
-                            "-hiveconf",
-                            f"tez.queue.name={self.mapred_queue}",
-                        ]
-                    )
+        with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, 
NamedTemporaryFile(dir=tmp_dir) as f:
+            hql += "\n"
+            f.write(hql.encode("UTF-8"))
+            f.flush()
+            hive_cmd = self._prepare_cli_cmd()
+            env_context = get_context_from_env_var()
+            # Only extend the hive_conf if it is defined.
+            if hive_conf:
+                env_context.update(hive_conf)
+            hive_conf_params = self._prepare_hiveconf(env_context)
+            if self.mapred_queue:
+                hive_conf_params.extend(
+                    [
+                        "-hiveconf",
+                        f"mapreduce.job.queuename={self.mapred_queue}",
+                        "-hiveconf",
+                        f"mapred.job.queue.name={self.mapred_queue}",
+                        "-hiveconf",
+                        f"tez.queue.name={self.mapred_queue}",
+                    ]
+                )
 
-                if self.mapred_queue_priority:
-                    hive_conf_params.extend(
-                        ["-hiveconf", 
f"mapreduce.job.priority={self.mapred_queue_priority}"]
-                    )
+            if self.mapred_queue_priority:
+                hive_conf_params.extend(["-hiveconf", 
f"mapreduce.job.priority={self.mapred_queue_priority}"])
 
-                if self.mapred_job_name:
-                    hive_conf_params.extend(["-hiveconf", 
f"mapred.job.name={self.mapred_job_name}"])
+            if self.mapred_job_name:
+                hive_conf_params.extend(["-hiveconf", 
f"mapred.job.name={self.mapred_job_name}"])
 
-                hive_cmd.extend(hive_conf_params)
-                hive_cmd.extend(["-f", f.name])
+            hive_cmd.extend(hive_conf_params)
+            hive_cmd.extend(["-f", f.name])
 
+            if verbose:
+                self.log.info("%s", " ".join(hive_cmd))
+            sub_process: Any = subprocess.Popen(
+                hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 
cwd=tmp_dir, close_fds=True
+            )
+            self.sub_process = sub_process
+            stdout = ""
+            for line in iter(sub_process.stdout.readline, b""):
+                line = line.decode()
+                stdout += line
                 if verbose:
-                    self.log.info("%s", " ".join(hive_cmd))
-                sub_process: Any = subprocess.Popen(
-                    hive_cmd, stdout=subprocess.PIPE, 
stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
-                )
-                self.sub_process = sub_process
-                stdout = ""
-                for line in iter(sub_process.stdout.readline, b""):
-                    line = line.decode()
-                    stdout += line
-                    if verbose:
-                        self.log.info(line.strip())
-                sub_process.wait()
+                    self.log.info(line.strip())
+            sub_process.wait()
 
-                if sub_process.returncode:
-                    raise AirflowException(stdout)
+            if sub_process.returncode:
+                raise AirflowException(stdout)
 
-                return stdout
+            return stdout
 
     def test_hql(self, hql: str) -> None:
         """Test an hql statement using the hive cli and EXPLAIN."""
@@ -376,25 +373,26 @@ class HiveCliHook(BaseHook):
         if pandas_kwargs is None:
             pandas_kwargs = {}
 
-        with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
-            with NamedTemporaryFile(dir=tmp_dir, mode="w") as f:
-                if field_dict is None:
-                    field_dict = _infer_field_types_from_df(df)
-
-                df.to_csv(
-                    path_or_buf=f,
-                    sep=delimiter,
-                    header=False,
-                    index=False,
-                    encoding=encoding,
-                    date_format="%Y-%m-%d %H:%M:%S",
-                    **pandas_kwargs,
-                )
-                f.flush()
+        with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, 
NamedTemporaryFile(
+            dir=tmp_dir, mode="w"
+        ) as f:
+            if field_dict is None:
+                field_dict = _infer_field_types_from_df(df)
+
+            df.to_csv(
+                path_or_buf=f,
+                sep=delimiter,
+                header=False,
+                index=False,
+                encoding=encoding,
+                date_format="%Y-%m-%d %H:%M:%S",
+                **pandas_kwargs,
+            )
+            f.flush()
 
-                return self.load_file(
-                    filepath=f.name, table=table, delimiter=delimiter, 
field_dict=field_dict, **kwargs
-                )
+            return self.load_file(
+                filepath=f.name, table=table, delimiter=delimiter, 
field_dict=field_dict, **kwargs
+            )
 
     def load_file(
         self,
diff --git a/airflow/providers/apache/hive/transfers/mysql_to_hive.py 
b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
index ee1cb082bc..bd7876efd4 100644
--- a/airflow/providers/apache/hive/transfers/mysql_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
@@ -136,21 +136,20 @@ class MySqlToHiveOperator(BaseOperator):
         mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
         self.log.info("Dumping MySQL query results to local file")
         with NamedTemporaryFile(mode="w", encoding="utf-8") as f:
-            with closing(mysql.get_conn()) as conn:
-                with closing(conn.cursor()) as cursor:
-                    cursor.execute(self.sql)
-                    csv_writer = csv.writer(
-                        f,
-                        delimiter=self.delimiter,
-                        quoting=self.quoting,
-                        quotechar=self.quotechar if self.quoting != 
csv.QUOTE_NONE else None,
-                        escapechar=self.escapechar,
-                    )
-                    field_dict = {}
-                    if cursor.description is not None:
-                        for field in cursor.description:
-                            field_dict[field[0]] = self.type_map(field[1])
-                    csv_writer.writerows(cursor)  # type: ignore[arg-type]
+            with closing(mysql.get_conn()) as conn, closing(conn.cursor()) as 
cursor:
+                cursor.execute(self.sql)
+                csv_writer = csv.writer(
+                    f,
+                    delimiter=self.delimiter,
+                    quoting=self.quoting,
+                    quotechar=self.quotechar if self.quoting != csv.QUOTE_NONE 
else None,
+                    escapechar=self.escapechar,
+                )
+                field_dict = {}
+                if cursor.description is not None:
+                    for field in cursor.description:
+                        field_dict[field[0]] = self.type_map(field[1])
+                csv_writer.writerows(cursor)  # type: ignore[arg-type]
             f.flush()
             self.log.info("Loading file into Hive")
             hive.load_file(
diff --git a/airflow/providers/apache/pig/hooks/pig.py 
b/airflow/providers/apache/pig/hooks/pig.py
index 71c39536d3..31e6006de3 100644
--- a/airflow/providers/apache/pig/hooks/pig.py
+++ b/airflow/providers/apache/pig/hooks/pig.py
@@ -64,41 +64,40 @@ class PigCliHook(BaseHook):
         >>> ("hdfs://" in result)
         True
         """
-        with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir:
-            with NamedTemporaryFile(dir=tmp_dir) as f:
-                f.write(pig.encode("utf-8"))
-                f.flush()
-                fname = f.name
-                pig_bin = "pig"
-                cmd_extra: list[str] = []
-
-                pig_cmd = [pig_bin]
-
-                if self.pig_properties:
-                    pig_cmd.extend(self.pig_properties)
-                if pig_opts:
-                    pig_opts_list = pig_opts.split()
-                    pig_cmd.extend(pig_opts_list)
+        with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir, 
NamedTemporaryFile(dir=tmp_dir) as f:
+            f.write(pig.encode("utf-8"))
+            f.flush()
+            fname = f.name
+            pig_bin = "pig"
+            cmd_extra: list[str] = []
+
+            pig_cmd = [pig_bin]
+
+            if self.pig_properties:
+                pig_cmd.extend(self.pig_properties)
+            if pig_opts:
+                pig_opts_list = pig_opts.split()
+                pig_cmd.extend(pig_opts_list)
+
+            pig_cmd.extend(["-f", fname] + cmd_extra)
+
+            if verbose:
+                self.log.info("%s", " ".join(pig_cmd))
+            sub_process: Any = subprocess.Popen(
+                pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 
cwd=tmp_dir, close_fds=True
+            )
+            self.sub_process = sub_process
+            stdout = ""
+            for line in iter(sub_process.stdout.readline, b""):
+                stdout += line.decode("utf-8")
+                if verbose:
+                    self.log.info(line.strip())
+            sub_process.wait()
 
-                pig_cmd.extend(["-f", fname] + cmd_extra)
+            if sub_process.returncode:
+                raise AirflowException(stdout)
 
-                if verbose:
-                    self.log.info("%s", " ".join(pig_cmd))
-                sub_process: Any = subprocess.Popen(
-                    pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, 
cwd=tmp_dir, close_fds=True
-                )
-                self.sub_process = sub_process
-                stdout = ""
-                for line in iter(sub_process.stdout.readline, b""):
-                    stdout += line.decode("utf-8")
-                    if verbose:
-                        self.log.info(line.strip())
-                sub_process.wait()
-
-                if sub_process.returncode:
-                    raise AirflowException(stdout)
-
-                return stdout
+            return stdout
 
     def kill(self) -> None:
         """Kill Pig job."""
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py 
b/airflow/providers/dbt/cloud/hooks/dbt.py
index 4a9785da3e..72446efecc 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -234,13 +234,14 @@ class DbtCloudHook(HttpHook):
         endpoint = f"{account_id}/runs/{run_id}/"
         headers, tenant = await self.get_headers_tenants_from_connection()
         url, params = self.get_request_url_params(tenant, endpoint, 
include_related)
-        async with aiohttp.ClientSession(headers=headers) as session:
-            async with session.get(url, params=params) as response:
-                try:
-                    response.raise_for_status()
-                    return await response.json()
-                except ClientResponseError as e:
-                    raise AirflowException(str(e.status) + ":" + e.message)
+        async with aiohttp.ClientSession(headers=headers) as session, 
session.get(
+            url, params=params
+        ) as response:
+            try:
+                response.raise_for_status()
+                return await response.json()
+            except ClientResponseError as e:
+                raise AirflowException(f"{e.status}:{e.message}")
 
     async def get_job_status(
         self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
diff --git a/airflow/providers/exasol/hooks/exasol.py 
b/airflow/providers/exasol/hooks/exasol.py
index ed71205ebc..ffadf46072 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -97,9 +97,8 @@ class ExasolHook(DbApiHook):
             sql statements to execute
         :param parameters: The parameters to render the SQL query with.
         """
-        with closing(self.get_conn()) as conn:
-            with closing(conn.execute(sql, parameters)) as cur:
-                return cur.fetchall()
+        with closing(self.get_conn()) as conn, closing(conn.execute(sql, 
parameters)) as cur:
+            return cur.fetchall()
 
     def get_first(self, sql: str | list[str], parameters: Iterable | 
Mapping[str, Any] | None = None) -> Any:
         """Execute the SQL and return the first resulting row.
@@ -108,9 +107,8 @@ class ExasolHook(DbApiHook):
             sql statements to execute
         :param parameters: The parameters to render the SQL query with.
         """
-        with closing(self.get_conn()) as conn:
-            with closing(conn.execute(sql, parameters)) as cur:
-                return cur.fetchone()
+        with closing(self.get_conn()) as conn, closing(conn.execute(sql, 
parameters)) as cur:
+            return cur.fetchone()
 
     def export_to_file(
         self,
diff --git a/airflow/providers/google/cloud/hooks/gcs.py 
b/airflow/providers/google/cloud/hooks/gcs.py
index f489c9200b..72c555bbda 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -550,10 +550,9 @@ class GCSHook(GoogleBaseHook):
             if gzip:
                 filename_gz = filename + ".gz"
 
-                with open(filename, "rb") as f_in:
-                    with gz.open(filename_gz, "wb") as f_out:
-                        shutil.copyfileobj(f_in, f_out)
-                        filename = filename_gz
+                with open(filename, "rb") as f_in, gz.open(filename_gz, "wb") 
as f_out:
+                    shutil.copyfileobj(f_in, f_out)
+                    filename = filename_gz
 
             _call_with_retry(
                 partial(blob.upload_from_filename, filename=filename, 
content_type=mime_type, timeout=timeout)
diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py 
b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
index 00df2b9b28..1837613ea0 100644
--- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py
+++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py
@@ -493,19 +493,18 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
         :param name: Name of the pod.
         :param namespace: Name of the pod's namespace.
         """
-        async with Token(scopes=self.scopes) as token:
-            async with self.get_conn(token) as connection:
-                try:
-                    v1_api = async_client.CoreV1Api(connection)
-                    await v1_api.delete_namespaced_pod(
-                        name=name,
-                        namespace=namespace,
-                        body=client.V1DeleteOptions(),
-                    )
-                except async_client.ApiException as e:
-                    # If the pod is already deleted
-                    if e.status != 404:
-                        raise
+        async with Token(scopes=self.scopes) as token, self.get_conn(token) as 
connection:
+            try:
+                v1_api = async_client.CoreV1Api(connection)
+                await v1_api.delete_namespaced_pod(
+                    name=name,
+                    namespace=namespace,
+                    body=client.V1DeleteOptions(),
+                )
+            except async_client.ApiException as e:
+                # If the pod is already deleted
+                if e.status != 404:
+                    raise
 
     async def read_logs(self, name: str, namespace: str):
         """Read logs inside the pod while starting containers inside.
@@ -518,20 +517,19 @@ class GKEPodAsyncHook(GoogleBaseAsyncHook):
         :param name: Name of the pod.
         :param namespace: Name of the pod's namespace.
         """
-        async with Token(scopes=self.scopes) as token:
-            async with self.get_conn(token) as connection:
-                try:
-                    v1_api = async_client.CoreV1Api(connection)
-                    logs = await v1_api.read_namespaced_pod_log(
-                        name=name,
-                        namespace=namespace,
-                        follow=False,
-                        timestamps=True,
-                    )
-                    logs = logs.splitlines()
-                    for line in logs:
-                        self.log.info("Container logs from %s", line)
-                    return logs
-                except HTTPError:
-                    self.log.exception("There was an error reading the 
kubernetes API.")
-                    raise
+        async with Token(scopes=self.scopes) as token, self.get_conn(token) as 
connection:
+            try:
+                v1_api = async_client.CoreV1Api(connection)
+                logs = await v1_api.read_namespaced_pod_log(
+                    name=name,
+                    namespace=namespace,
+                    follow=False,
+                    timestamps=True,
+                )
+                logs = logs.splitlines()
+                for line in logs:
+                    self.log.info("Container logs from %s", line)
+                return logs
+            except HTTPError:
+                self.log.exception("There was an error reading the kubernetes 
API.")
+                raise
diff --git a/airflow/providers/microsoft/azure/hooks/asb.py 
b/airflow/providers/microsoft/azure/hooks/asb.py
index 7001db46f9..80273d6f96 100644
--- a/airflow/providers/microsoft/azure/hooks/asb.py
+++ b/airflow/providers/microsoft/azure/hooks/asb.py
@@ -215,19 +215,18 @@ class MessageHook(BaseAzureServiceBusHook):
             raise ValueError("Messages list cannot be empty.")
         with self.get_conn() as service_bus_client, 
service_bus_client.get_queue_sender(
             queue_name=queue_name
-        ) as sender:
-            with sender:
-                if isinstance(messages, str):
-                    if not batch_message_flag:
-                        msg = ServiceBusMessage(messages)
-                        sender.send_messages(msg)
-                    else:
-                        self.send_batch_message(sender, [messages])
+        ) as sender, sender:
+            if isinstance(messages, str):
+                if not batch_message_flag:
+                    msg = ServiceBusMessage(messages)
+                    sender.send_messages(msg)
                 else:
-                    if not batch_message_flag:
-                        self.send_list_messages(sender, messages)
-                    else:
-                        self.send_batch_message(sender, messages)
+                    self.send_batch_message(sender, [messages])
+            else:
+                if not batch_message_flag:
+                    self.send_list_messages(sender, messages)
+                else:
+                    self.send_batch_message(sender, messages)
 
     @staticmethod
     def send_list_messages(sender: ServiceBusSender, messages: list[str]):
@@ -256,14 +255,13 @@ class MessageHook(BaseAzureServiceBusHook):
 
         with self.get_conn() as service_bus_client, 
service_bus_client.get_queue_receiver(
             queue_name=queue_name
-        ) as receiver:
-            with receiver:
-                received_msgs = receiver.receive_messages(
-                    max_message_count=max_message_count, 
max_wait_time=max_wait_time
-                )
-                for msg in received_msgs:
-                    self.log.info(msg)
-                    receiver.complete_message(msg)
+        ) as receiver, receiver:
+            received_msgs = receiver.receive_messages(
+                max_message_count=max_message_count, 
max_wait_time=max_wait_time
+            )
+            for msg in received_msgs:
+                self.log.info(msg)
+                receiver.complete_message(msg)
 
     def receive_subscription_message(
         self,
@@ -293,11 +291,10 @@ class MessageHook(BaseAzureServiceBusHook):
             raise TypeError("Topic name cannot be None.")
         with self.get_conn() as service_bus_client, 
service_bus_client.get_subscription_receiver(
             topic_name, subscription_name
-        ) as subscription_receiver:
-            with subscription_receiver:
-                received_msgs = subscription_receiver.receive_messages(
-                    max_message_count=max_message_count, 
max_wait_time=max_wait_time
-                )
-                for msg in received_msgs:
-                    self.log.info(msg)
-                    subscription_receiver.complete_message(msg)
+        ) as subscription_receiver, subscription_receiver:
+            received_msgs = subscription_receiver.receive_messages(
+                max_message_count=max_message_count, 
max_wait_time=max_wait_time
+            )
+            for msg in received_msgs:
+                self.log.info(msg)
+                subscription_receiver.complete_message(msg)
diff --git a/airflow/providers/mysql/transfers/vertica_to_mysql.py 
b/airflow/providers/mysql/transfers/vertica_to_mysql.py
index 16be186fc7..fd196315d7 100644
--- a/airflow/providers/mysql/transfers/vertica_to_mysql.py
+++ b/airflow/providers/mysql/transfers/vertica_to_mysql.py
@@ -99,17 +99,16 @@ class VerticaToMySqlOperator(BaseOperator):
         self.log.info("Done")
 
     def _non_bulk_load_transfer(self, mysql, vertica):
-        with closing(vertica.get_conn()) as conn:
-            with closing(conn.cursor()) as cursor:
-                cursor.execute(self.sql)
-                selected_columns = [d.name for d in cursor.description]
-                self.log.info("Selecting rows from Vertica...")
-                self.log.info(self.sql)
+        with closing(vertica.get_conn()) as conn, closing(conn.cursor()) as 
cursor:
+            cursor.execute(self.sql)
+            selected_columns = [d.name for d in cursor.description]
+            self.log.info("Selecting rows from Vertica...")
+            self.log.info(self.sql)
 
-                result = cursor.fetchall()
-                count = len(result)
+            result = cursor.fetchall()
+            count = len(result)
 
-                self.log.info("Selected rows from Vertica %s", count)
+            self.log.info("Selected rows from Vertica %s", count)
         self._run_preoperator(mysql)
         try:
             self.log.info("Inserting rows into MySQL...")
@@ -121,31 +120,29 @@ class VerticaToMySqlOperator(BaseOperator):
 
     def _bulk_load_transfer(self, mysql, vertica):
         count = 0
-        with closing(vertica.get_conn()) as conn:
-            with closing(conn.cursor()) as cursor:
-                cursor.execute(self.sql)
-                selected_columns = [d.name for d in cursor.description]
-                with NamedTemporaryFile("w", encoding="utf-8") as tmpfile:
-                    self.log.info("Selecting rows from Vertica to local file 
%s...", tmpfile.name)
-                    self.log.info(self.sql)
-
-                    csv_writer = csv.writer(tmpfile, delimiter="\t")
-                    for row in cursor.iterate():
-                        csv_writer.writerow(row)
-                        count += 1
-
-                    tmpfile.flush()
+        with closing(vertica.get_conn()) as conn, closing(conn.cursor()) as 
cursor:
+            cursor.execute(self.sql)
+            selected_columns = [d.name for d in cursor.description]
+            with NamedTemporaryFile("w", encoding="utf-8") as tmpfile:
+                self.log.info("Selecting rows from Vertica to local file 
%s...", tmpfile.name)
+                self.log.info(self.sql)
+
+                csv_writer = csv.writer(tmpfile, delimiter="\t")
+                for row in cursor.iterate():
+                    csv_writer.writerow(row)
+                    count += 1
+
+                tmpfile.flush()
         self._run_preoperator(mysql)
         try:
             self.log.info("Bulk inserting rows into MySQL...")
-            with closing(mysql.get_conn()) as conn:
-                with closing(conn.cursor()) as cursor:
-                    cursor.execute(
-                        f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
-                        f"INTO TABLE {self.mysql_table} "
-                        f"LINES TERMINATED BY '\r\n' ({', 
'.join(selected_columns)})"
-                    )
-                    conn.commit()
+            with closing(mysql.get_conn()) as conn, closing(conn.cursor()) as 
cursor:
+                cursor.execute(
+                    f"LOAD DATA LOCAL INFILE '{tmpfile.name}' "
+                    f"INTO TABLE {self.mysql_table} "
+                    f"LINES TERMINATED BY '\r\n' ({', 
'.join(selected_columns)})"
+                )
+                conn.commit()
             tmpfile.close()
             self.log.info("Inserted rows into MySQL %s", count)
         except (MySQLdb.Error, MySQLdb.Warning):
diff --git a/airflow/providers/postgres/hooks/postgres.py 
b/airflow/providers/postgres/hooks/postgres.py
index b6b214e990..95e7be94cb 100644
--- a/airflow/providers/postgres/hooks/postgres.py
+++ b/airflow/providers/postgres/hooks/postgres.py
@@ -170,12 +170,10 @@ class PostgresHook(DbApiHook):
             with open(filename, "w"):
                 pass
 
-        with open(filename, "r+") as file:
-            with closing(self.get_conn()) as conn:
-                with closing(conn.cursor()) as cur:
-                    cur.copy_expert(sql, file)
-                    file.truncate(file.tell())
-                    conn.commit()
+        with open(filename, "r+") as file, closing(self.get_conn()) as conn, 
closing(conn.cursor()) as cur:
+            cur.copy_expert(sql, file)
+            file.truncate(file.tell())
+            conn.commit()
 
     def get_uri(self) -> str:
         """Extract the URI from the connection.
diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py 
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index b10a3c670f..49cedf115a 100644
--- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -271,8 +271,9 @@ class SnowflakeSqlApiHook(SnowflakeHook):
         """
         self.log.info("Retrieving status for query id %s", query_id)
         header, params, url = self.get_request_url_header_params(query_id)
-        async with aiohttp.ClientSession(headers=header) as session:
-            async with session.get(url, params=params) as response:
-                status_code = response.status
-                resp = await response.json()
-                return self._process_response(status_code, resp)
+        async with aiohttp.ClientSession(headers=header) as session, 
session.get(
+            url, params=params
+        ) as response:
+            status_code = response.status
+            resp = await response.json()
+            return self._process_response(status_code, resp)

Reply via email to