This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 875387afa5 Refactor unneeded  jumps in providers (#33833)
875387afa5 is described below

commit 875387afa53c207364fa20b515d154100b5d0a8d
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Fri Sep 1 16:01:15 2023 +0000

    Refactor unneeded  jumps in providers (#33833)
---
 airflow/providers/amazon/aws/hooks/datasync.py     | 25 +++++-------
 airflow/providers/amazon/aws/hooks/sagemaker.py    | 12 +++---
 airflow/providers/amazon/aws/sensors/sqs.py        |  8 ++--
 airflow/providers/amazon/aws/utils/sqs.py          | 12 ++----
 .../providers/cncf/kubernetes/utils/delete_from.py |  4 +-
 airflow/providers/google/cloud/hooks/bigquery.py   | 36 ++++++++---------
 airflow/providers/google/cloud/hooks/datafusion.py |  6 +--
 airflow/providers/google/cloud/hooks/gcs.py        | 14 ++-----
 .../providers/google/cloud/log/gcs_task_handler.py |  4 +-
 .../providers/google/cloud/operators/compute.py    | 45 ++++++++++------------
 airflow/providers/google/cloud/operators/gcs.py    | 10 ++---
 .../microsoft/azure/hooks/data_factory.py          |  4 +-
 .../microsoft/azure/triggers/data_factory.py       |  8 ++--
 airflow/providers/openlineage/utils/utils.py       | 14 +++----
 airflow/providers/smtp/hooks/smtp.py               | 24 ++++++------
 .../log/elasticmock/fake_elasticsearch.py          | 17 ++++----
 .../cloud/log/test_stackdriver_task_handler.py     |  7 ++--
 17 files changed, 108 insertions(+), 142 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/datasync.py 
b/airflow/providers/amazon/aws/hooks/datasync.py
index 841255ed95..493f722dc7 100644
--- a/airflow/providers/amazon/aws/hooks/datasync.py
+++ b/airflow/providers/amazon/aws/hooks/datasync.py
@@ -301,25 +301,18 @@ class DataSyncHook(AwsBaseHook):
         if not task_execution_arn:
             raise AirflowBadRequest("task_execution_arn not specified")
 
-        status = None
-        iterations = max_iterations
-        while status is None or status in 
self.TASK_EXECUTION_INTERMEDIATE_STATES:
+        for _ in range(max_iterations):
             task_execution = 
self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
             status = task_execution["Status"]
             self.log.info("status=%s", status)
-            iterations -= 1
-            if status in self.TASK_EXECUTION_FAILURE_STATES:
-                break
             if status in self.TASK_EXECUTION_SUCCESS_STATES:
-                break
-            if iterations <= 0:
-                break
+                return True
+            elif status in self.TASK_EXECUTION_FAILURE_STATES:
+                return False
+            elif status is None or status in 
self.TASK_EXECUTION_INTERMEDIATE_STATES:
+                time.sleep(self.wait_interval_seconds)
+            else:
+                raise AirflowException(f"Unknown status: {status}")  # Should 
never happen
             time.sleep(self.wait_interval_seconds)
-
-        if status in self.TASK_EXECUTION_SUCCESS_STATES:
-            return True
-        if status in self.TASK_EXECUTION_FAILURE_STATES:
-            return False
-        if iterations <= 0:
+        else:
             raise AirflowTaskTimeout("Max iterations exceeded!")
-        raise AirflowException(f"Unknown status: {status}")  # Should never 
happen
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py 
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 40b9b55c54..0dfcbca0f3 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -252,12 +252,12 @@ class SageMakerHook(AwsBaseHook):
         ]
         events: list[Any | None] = []
         for event_stream in event_iters:
-            if not event_stream:
-                events.append(None)
-                continue
-            try:
-                events.append(next(event_stream))
-            except StopIteration:
+            if event_stream:
+                try:
+                    events.append(next(event_stream))
+                except StopIteration:
+                    events.append(None)
+            else:
                 events.append(None)
 
         while any(events):
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py 
b/airflow/providers/amazon/aws/sensors/sqs.py
index aca950375a..0cb3604bb4 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -204,12 +204,12 @@ class SqsSensor(BaseSensorOperator):
 
                 if "Successful" not in response:
                     raise AirflowException(f"Delete SQS Messages failed 
{response} for messages {messages}")
-        if not message_batch:
+        if message_batch:
+            context["ti"].xcom_push(key="messages", value=message_batch)
+            return True
+        else:
             return False
 
-        context["ti"].xcom_push(key="messages", value=message_batch)
-        return True
-
     @deprecated(reason="use `hook` property instead.")
     def get_hook(self) -> SqsHook:
         """Create and return an SqsHook."""
diff --git a/airflow/providers/amazon/aws/utils/sqs.py 
b/airflow/providers/amazon/aws/utils/sqs.py
index ea0c7afea1..0ae5e7ac98 100644
--- a/airflow/providers/amazon/aws/utils/sqs.py
+++ b/airflow/providers/amazon/aws/utils/sqs.py
@@ -79,13 +79,9 @@ def filter_messages_jsonpath(messages, 
message_filtering_match_values, message_f
         # Body is a string, deserialize to an object and then parse
         body = json.loads(body)
         results = jsonpath_expr.find(body)
-        if not results:
-            continue
-        if message_filtering_match_values is None:
+        if results and (
+            message_filtering_match_values is None
+            or any(result.value in message_filtering_match_values for result 
in results)
+        ):
             filtered_messages.append(message)
-            continue
-        for result in results:
-            if result.value in message_filtering_match_values:
-                filtered_messages.append(message)
-                break
     return filtered_messages
diff --git a/airflow/providers/cncf/kubernetes/utils/delete_from.py 
b/airflow/providers/cncf/kubernetes/utils/delete_from.py
index 2b28169ca6..663242fad1 100644
--- a/airflow/providers/cncf/kubernetes/utils/delete_from.py
+++ b/airflow/providers/cncf/kubernetes/utils/delete_from.py
@@ -81,9 +81,7 @@ def delete_from_yaml(
     **kwargs,
 ):
     for yml_document in yaml_objects:
-        if yml_document is None:
-            continue
-        else:
+        if yml_document is not None:
             delete_from_dict(
                 k8s_client=k8s_client,
                 data=yml_document,
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index f9540a6c99..18ad159a98 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2206,27 +2206,25 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
             if param_name == "schemaUpdateOptions" and param:
                 self.log.info("Adding experimental 'schemaUpdateOptions': %s", 
schema_update_options)
 
-            if param_name != "destinationTable":
-                continue
-
-            for key in ["projectId", "datasetId", "tableId"]:
-                if key not in configuration["query"]["destinationTable"]:
-                    raise ValueError(
-                        "Not correct 'destinationTable' in "
-                        "api_resource_configs. 'destinationTable' "
-                        "must be a dict with {'projectId':'', "
-                        "'datasetId':'', 'tableId':''}"
+            if param_name == "destinationTable":
+                for key in ["projectId", "datasetId", "tableId"]:
+                    if key not in configuration["query"]["destinationTable"]:
+                        raise ValueError(
+                            "Not correct 'destinationTable' in "
+                            "api_resource_configs. 'destinationTable' "
+                            "must be a dict with {'projectId':'', "
+                            "'datasetId':'', 'tableId':''}"
+                        )
+                else:
+                    configuration["query"].update(
+                        {
+                            "allowLargeResults": allow_large_results,
+                            "flattenResults": flatten_results,
+                            "writeDisposition": write_disposition,
+                            "createDisposition": create_disposition,
+                        }
                     )
 
-            configuration["query"].update(
-                {
-                    "allowLargeResults": allow_large_results,
-                    "flattenResults": flatten_results,
-                    "writeDisposition": write_disposition,
-                    "createDisposition": create_disposition,
-                }
-            )
-
         if (
             "useLegacySql" in configuration["query"]
             and configuration["query"]["useLegacySql"]
diff --git a/airflow/providers/google/cloud/hooks/datafusion.py 
b/airflow/providers/google/cloud/hooks/datafusion.py
index b0e44081c7..dcf51357c6 100644
--- a/airflow/providers/google/cloud/hooks/datafusion.py
+++ b/airflow/providers/google/cloud/hooks/datafusion.py
@@ -371,12 +371,12 @@ class DataFusionHook(GoogleBaseHook):
                 self._check_response_status_and_data(
                     response, f"Deleting a pipeline failed with code 
{response.status}: {response.data}"
                 )
-                if response.status == 200:
-                    break
             except ConflictException as exc:
                 self.log.info(exc)
                 sleep(time_to_wait)
-                continue
+            else:
+                if response.status == 200:
+                    break
 
     def list_pipelines(
         self,
diff --git a/airflow/providers/google/cloud/hooks/gcs.py 
b/airflow/providers/google/cloud/hooks/gcs.py
index b27bedafa6..05279583da 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -361,7 +361,6 @@ class GCSHook(GoogleBaseHook):
                 # Wait with exponential backoff scheme before retrying.
                 timeout_seconds = 2 ** (num_file_attempts - 1)
                 time.sleep(timeout_seconds)
-                continue
 
     def download_as_byte_array(
         self,
@@ -508,28 +507,23 @@ class GCSHook(GoogleBaseHook):
 
             :param f: Callable that should be retried.
             """
-            num_file_attempts = 0
-
-            while num_file_attempts < num_max_attempts:
+            for attempt in range(1, 1 + num_max_attempts):
                 try:
-                    num_file_attempts += 1
                     f()
-
                 except GoogleCloudError as e:
-                    if num_file_attempts == num_max_attempts:
+                    if attempt == num_max_attempts:
                         self.log.error(
                             "Upload attempt of object: %s from %s has failed. 
Attempt: %s, max %s.",
                             object_name,
                             object_name,
-                            num_file_attempts,
+                            attempt,
                             num_max_attempts,
                         )
                         raise e
 
                     # Wait with exponential backoff scheme before retrying.
-                    timeout_seconds = 2 ** (num_file_attempts - 1)
+                    timeout_seconds = 2 ** (attempt - 1)
                     time.sleep(timeout_seconds)
-                    continue
 
         client = self.get_conn()
         bucket = client.bucket(bucket_name, user_project=user_project)
diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py 
b/airflow/providers/google/cloud/log/gcs_task_handler.py
index 79055f0847..5cede21863 100644
--- a/airflow/providers/google/cloud/log/gcs_task_handler.py
+++ b/airflow/providers/google/cloud/log/gcs_task_handler.py
@@ -243,9 +243,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
             old_log = blob.download_as_bytes().decode()
             log = "\n".join([old_log, log]) if old_log else log
         except Exception as e:
-            if self.no_log_found(e):
-                pass
-            else:
+            if not self.no_log_found(e):
                 log += self._add_message(
                     f"Error checking for previous log; if exists, may be 
overwritten: {e}"
                 )
diff --git a/airflow/providers/google/cloud/operators/compute.py 
b/airflow/providers/google/cloud/operators/compute.py
index 2abc46d35a..8379904f35 100644
--- a/airflow/providers/google/cloud/operators/compute.py
+++ b/airflow/providers/google/cloud/operators/compute.py
@@ -174,14 +174,13 @@ class 
ComputeEngineInsertInstanceOperator(ComputeEngineBaseOperator):
     def check_body_fields(self) -> None:
         required_params = ["machine_type", "disks", "network_interfaces"]
         for param in required_params:
-            if param in self.body:
-                continue
-            readable_param = param.replace("_", " ")
-            raise AirflowException(
-                f"The body '{self.body}' should contain at least 
{readable_param} for the new operator "
-                f"in the '{param}' field. Check 
(google.cloud.compute_v1.types.Instance) "
-                f"for more details about body fields description."
-            )
+            if param not in self.body:
+                readable_param = param.replace("_", " ")
+                raise AirflowException(
+                    f"The body '{self.body}' should contain at least 
{readable_param} for the new operator "
+                    f"in the '{param}' field. Check 
(google.cloud.compute_v1.types.Instance) "
+                    f"for more details about body fields description."
+                )
 
     def _validate_inputs(self) -> None:
         super()._validate_inputs()
@@ -915,14 +914,13 @@ class 
ComputeEngineInsertInstanceTemplateOperator(ComputeEngineBaseOperator):
     def check_body_fields(self) -> None:
         required_params = ["machine_type", "disks", "network_interfaces"]
         for param in required_params:
-            if param in self.body["properties"]:
-                continue
-            readable_param = param.replace("_", " ")
-            raise AirflowException(
-                f"The body '{self.body}' should contain at least 
{readable_param} for the new operator "
-                f"in the '{param}' field. Check 
(google.cloud.compute_v1.types.Instance) "
-                f"for more details about body fields description."
-            )
+            if param not in self.body["properties"]:
+                readable_param = param.replace("_", " ")
+                raise AirflowException(
+                    f"The body '{self.body}' should contain at least 
{readable_param} for the new operator "
+                    f"in the '{param}' field. Check 
(google.cloud.compute_v1.types.Instance) "
+                    f"for more details about body fields description."
+                )
 
     def _validate_all_body_fields(self) -> None:
         if self._field_validator:
@@ -1500,14 +1498,13 @@ class 
ComputeEngineInsertInstanceGroupManagerOperator(ComputeEngineBaseOperator)
     def check_body_fields(self) -> None:
         required_params = ["base_instance_name", "target_size", 
"instance_template"]
         for param in required_params:
-            if param in self.body:
-                continue
-            readable_param = param.replace("_", " ")
-            raise AirflowException(
-                f"The body '{self.body}' should contain at least 
{readable_param} for the new operator "
-                f"in the '{param}' field. Check 
(google.cloud.compute_v1.types.Instance) "
-                f"for more details about body fields description."
-            )
+            if param not in self.body:
+                readable_param = param.replace("_", " ")
+                raise AirflowException(
+                    f"The body '{self.body}' should contain at least 
{readable_param} for the new operator "
+                    f"in the '{param}' field. Check 
(google.cloud.compute_v1.types.Instance) "
+                    f"for more details about body fields description."
+                )
 
     def _validate_all_body_fields(self) -> None:
         if self._field_validator:
diff --git a/airflow/providers/google/cloud/operators/gcs.py 
b/airflow/providers/google/cloud/operators/gcs.py
index 9b95032b42..bb257f94b3 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -797,9 +797,8 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
                         num_max_attempts=self.download_num_attempts,
                     )
                 except GoogleCloudError:
-                    if self.download_continue_on_fail:
-                        continue
-                    raise
+                    if not self.download_continue_on_fail:
+                        raise
 
             self.log.info("Starting the transformation")
             cmd = [self.transform_script] if isinstance(self.transform_script, 
str) else self.transform_script
@@ -847,9 +846,8 @@ class 
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
                     )
                     files_uploaded.append(str(upload_file_name))
                 except GoogleCloudError:
-                    if self.upload_continue_on_fail:
-                        continue
-                    raise
+                    if not self.upload_continue_on_fail:
+                        raise
 
             return files_uploaded
 
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py 
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index 7301ace03e..b4516ccfd7 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -855,8 +855,8 @@ class AzureDataFactoryHook(BaseHook):
             except ServiceRequestError:
                 if executed_after_token_refresh:
                     self.refresh_conn()
-                    continue
-                raise
+                else:
+                    raise
 
         return pipeline_run_status in expected_statuses
 
diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py 
b/airflow/providers/microsoft/azure/triggers/data_factory.py
index 1ce5484008..e087e3556d 100644
--- a/airflow/providers/microsoft/azure/triggers/data_factory.py
+++ b/airflow/providers/microsoft/azure/triggers/data_factory.py
@@ -103,8 +103,8 @@ class ADFPipelineRunStatusSensorTrigger(BaseTrigger):
                     if executed_after_token_refresh:
                         await hook.refresh_conn()
                         executed_after_token_refresh = False
-                        continue
-                    raise
+                    else:
+                        raise
         except Exception as e:
             yield TriggerEvent({"status": "error", "message": str(e)})
 
@@ -207,8 +207,8 @@ class AzureDataFactoryTrigger(BaseTrigger):
                         if executed_after_token_refresh:
                             await hook.refresh_conn()
                             executed_after_token_refresh = False
-                            continue
-                        raise
+                        else:
+                            raise
 
                 yield TriggerEvent(
                     {
diff --git a/airflow/providers/openlineage/utils/utils.py 
b/airflow/providers/openlineage/utils/utils.py
index 07484d994d..837c5eb615 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -208,16 +208,14 @@ class InfoJsonEncodable(dict):
             raise Exception("Don't use both includes and excludes.")
         if self.includes:
             for field in self.includes:
-                if field in self._fields or not hasattr(self.obj, field):
-                    continue
-                setattr(self, field, getattr(self.obj, field))
-                self._fields.append(field)
+                if field not in self._fields and hasattr(self.obj, field):
+                    setattr(self, field, getattr(self.obj, field))
+                    self._fields.append(field)
         else:
             for field, val in self.obj.__dict__.items():
-                if field in self._fields or field in self.excludes or field in 
self.renames:
-                    continue
-                setattr(self, field, val)
-                self._fields.append(field)
+                if field not in self._fields and field not in self.excludes 
and field not in self.renames:
+                    setattr(self, field, val)
+                    self._fields.append(field)
 
 
 class DagInfo(InfoJsonEncodable):
diff --git a/airflow/providers/smtp/hooks/smtp.py 
b/airflow/providers/smtp/hooks/smtp.py
index 0f2c689c37..bacb3ca3e7 100644
--- a/airflow/providers/smtp/hooks/smtp.py
+++ b/airflow/providers/smtp/hooks/smtp.py
@@ -87,14 +87,14 @@ class SmtpHook(BaseHook):
                 try:
                     self.smtp_client = self._build_client()
                 except smtplib.SMTPServerDisconnected:
-                    if attempt < self.smtp_retry_limit:
-                        continue
-                    raise AirflowException("Unable to connect to smtp server")
-                if self.smtp_starttls:
-                    self.smtp_client.starttls()
-                if self.smtp_user and self.smtp_password:
-                    self.smtp_client.login(self.smtp_user, self.smtp_password)
-                break
+                    if attempt == self.smtp_retry_limit:
+                        raise AirflowException("Unable to connect to smtp 
server")
+                else:
+                    if self.smtp_starttls:
+                        self.smtp_client.starttls()
+                    if self.smtp_user and self.smtp_password:
+                        self.smtp_client.login(self.smtp_user, 
self.smtp_password)
+                    break
 
         return self
 
@@ -234,10 +234,10 @@ class SmtpHook(BaseHook):
                         from_addr=from_email, to_addrs=recipients, 
msg=mime_msg.as_string()
                     )
                 except smtplib.SMTPServerDisconnected as e:
-                    if attempt < self.smtp_retry_limit:
-                        continue
-                    raise e
-                break
+                    if attempt == self.smtp_retry_limit:
+                        raise e
+                else:
+                    break
 
     def _build_mime_message(
         self,
diff --git 
a/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py 
b/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
index b37608232d..e65a403a63 100644
--- a/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
+++ b/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
@@ -331,9 +331,8 @@ class FakeElasticsearch(Elasticsearch):
         i = 0
         for searchable_index in searchable_indexes:
             for document in self.__documents_dict[searchable_index]:
-                if searchable_doc_types and document.get("_type") not in 
searchable_doc_types:
-                    continue
-                i += 1
+                if not searchable_doc_types or document.get("_type") in 
searchable_doc_types:
+                    i += 1
         result = {"count": i, "_shards": {"successful": 1, "failed": 0, 
"total": 1}}
 
         return result
@@ -457,13 +456,11 @@ class FakeElasticsearch(Elasticsearch):
 
     def find_document_in_searchable_index(self, matches, must, 
searchable_doc_types, searchable_index):
         for document in self.__documents_dict[searchable_index]:
-            if searchable_doc_types and document.get("_type") not in 
searchable_doc_types:
-                continue
-
-            if "match_phrase" in must:
-                self.match_must_phrase(document, matches, must)
-            else:
-                matches.append(document)
+            if not searchable_doc_types or document.get("_type") in 
searchable_doc_types:
+                if "match_phrase" in must:
+                    self.match_must_phrase(document, matches, must)
+                else:
+                    matches.append(document)
 
     @staticmethod
     def match_must_phrase(document, matches, must):
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py 
b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
index 0abb714cc5..cb240b7009 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
@@ -41,10 +41,9 @@ def clean_stackdriver_handlers():
     yield
     for handler_ref in reversed(logging._handlerList[:]):
         handler = handler_ref()
-        if not isinstance(handler, StackdriverTaskHandler):
-            continue
-        logging._removeHandlerRef(handler_ref)
-        del handler
+        if isinstance(handler, StackdriverTaskHandler):
+            logging._removeHandlerRef(handler_ref)
+            del handler
 
 
 @pytest.mark.usefixtures("clean_stackdriver_handlers")

Reply via email to