This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 875387afa5 Refactor unneeded jumps in providers (#33833)
875387afa5 is described below
commit 875387afa53c207364fa20b515d154100b5d0a8d
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Fri Sep 1 16:01:15 2023 +0000
Refactor unneeded jumps in providers (#33833)
---
airflow/providers/amazon/aws/hooks/datasync.py | 25 +++++-------
airflow/providers/amazon/aws/hooks/sagemaker.py | 12 +++---
airflow/providers/amazon/aws/sensors/sqs.py | 8 ++--
airflow/providers/amazon/aws/utils/sqs.py | 12 ++----
.../providers/cncf/kubernetes/utils/delete_from.py | 4 +-
airflow/providers/google/cloud/hooks/bigquery.py | 36 ++++++++---------
airflow/providers/google/cloud/hooks/datafusion.py | 6 +--
airflow/providers/google/cloud/hooks/gcs.py | 14 ++-----
.../providers/google/cloud/log/gcs_task_handler.py | 4 +-
.../providers/google/cloud/operators/compute.py | 45 ++++++++++------------
airflow/providers/google/cloud/operators/gcs.py | 10 ++---
.../microsoft/azure/hooks/data_factory.py | 4 +-
.../microsoft/azure/triggers/data_factory.py | 8 ++--
airflow/providers/openlineage/utils/utils.py | 14 +++----
airflow/providers/smtp/hooks/smtp.py | 24 ++++++------
.../log/elasticmock/fake_elasticsearch.py | 17 ++++----
.../cloud/log/test_stackdriver_task_handler.py | 7 ++--
17 files changed, 108 insertions(+), 142 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/datasync.py
b/airflow/providers/amazon/aws/hooks/datasync.py
index 841255ed95..493f722dc7 100644
--- a/airflow/providers/amazon/aws/hooks/datasync.py
+++ b/airflow/providers/amazon/aws/hooks/datasync.py
@@ -301,25 +301,18 @@ class DataSyncHook(AwsBaseHook):
if not task_execution_arn:
raise AirflowBadRequest("task_execution_arn not specified")
- status = None
- iterations = max_iterations
- while status is None or status in
self.TASK_EXECUTION_INTERMEDIATE_STATES:
+ for _ in range(max_iterations):
task_execution =
self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
status = task_execution["Status"]
self.log.info("status=%s", status)
- iterations -= 1
- if status in self.TASK_EXECUTION_FAILURE_STATES:
- break
if status in self.TASK_EXECUTION_SUCCESS_STATES:
- break
- if iterations <= 0:
- break
+ return True
+ elif status in self.TASK_EXECUTION_FAILURE_STATES:
+ return False
+ elif status is None or status in
self.TASK_EXECUTION_INTERMEDIATE_STATES:
+ time.sleep(self.wait_interval_seconds)
+ else:
+ raise AirflowException(f"Unknown status: {status}") # Should
never happen
time.sleep(self.wait_interval_seconds)
-
- if status in self.TASK_EXECUTION_SUCCESS_STATES:
- return True
- if status in self.TASK_EXECUTION_FAILURE_STATES:
- return False
- if iterations <= 0:
+ else:
raise AirflowTaskTimeout("Max iterations exceeded!")
- raise AirflowException(f"Unknown status: {status}") # Should never
happen
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py
b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 40b9b55c54..0dfcbca0f3 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -252,12 +252,12 @@ class SageMakerHook(AwsBaseHook):
]
events: list[Any | None] = []
for event_stream in event_iters:
- if not event_stream:
- events.append(None)
- continue
- try:
- events.append(next(event_stream))
- except StopIteration:
+ if event_stream:
+ try:
+ events.append(next(event_stream))
+ except StopIteration:
+ events.append(None)
+ else:
events.append(None)
while any(events):
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py
b/airflow/providers/amazon/aws/sensors/sqs.py
index aca950375a..0cb3604bb4 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -204,12 +204,12 @@ class SqsSensor(BaseSensorOperator):
if "Successful" not in response:
raise AirflowException(f"Delete SQS Messages failed
{response} for messages {messages}")
- if not message_batch:
+ if message_batch:
+ context["ti"].xcom_push(key="messages", value=message_batch)
+ return True
+ else:
return False
- context["ti"].xcom_push(key="messages", value=message_batch)
- return True
-
@deprecated(reason="use `hook` property instead.")
def get_hook(self) -> SqsHook:
"""Create and return an SqsHook."""
diff --git a/airflow/providers/amazon/aws/utils/sqs.py
b/airflow/providers/amazon/aws/utils/sqs.py
index ea0c7afea1..0ae5e7ac98 100644
--- a/airflow/providers/amazon/aws/utils/sqs.py
+++ b/airflow/providers/amazon/aws/utils/sqs.py
@@ -79,13 +79,9 @@ def filter_messages_jsonpath(messages,
message_filtering_match_values, message_f
# Body is a string, deserialize to an object and then parse
body = json.loads(body)
results = jsonpath_expr.find(body)
- if not results:
- continue
- if message_filtering_match_values is None:
+ if results and (
+ message_filtering_match_values is None
+ or any(result.value in message_filtering_match_values for result
in results)
+ ):
filtered_messages.append(message)
- continue
- for result in results:
- if result.value in message_filtering_match_values:
- filtered_messages.append(message)
- break
return filtered_messages
diff --git a/airflow/providers/cncf/kubernetes/utils/delete_from.py
b/airflow/providers/cncf/kubernetes/utils/delete_from.py
index 2b28169ca6..663242fad1 100644
--- a/airflow/providers/cncf/kubernetes/utils/delete_from.py
+++ b/airflow/providers/cncf/kubernetes/utils/delete_from.py
@@ -81,9 +81,7 @@ def delete_from_yaml(
**kwargs,
):
for yml_document in yaml_objects:
- if yml_document is None:
- continue
- else:
+ if yml_document is not None:
delete_from_dict(
k8s_client=k8s_client,
data=yml_document,
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index f9540a6c99..18ad159a98 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2206,27 +2206,25 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
if param_name == "schemaUpdateOptions" and param:
self.log.info("Adding experimental 'schemaUpdateOptions': %s",
schema_update_options)
- if param_name != "destinationTable":
- continue
-
- for key in ["projectId", "datasetId", "tableId"]:
- if key not in configuration["query"]["destinationTable"]:
- raise ValueError(
- "Not correct 'destinationTable' in "
- "api_resource_configs. 'destinationTable' "
- "must be a dict with {'projectId':'', "
- "'datasetId':'', 'tableId':''}"
+ if param_name == "destinationTable":
+ for key in ["projectId", "datasetId", "tableId"]:
+ if key not in configuration["query"]["destinationTable"]:
+ raise ValueError(
+ "Not correct 'destinationTable' in "
+ "api_resource_configs. 'destinationTable' "
+ "must be a dict with {'projectId':'', "
+ "'datasetId':'', 'tableId':''}"
+ )
+ else:
+ configuration["query"].update(
+ {
+ "allowLargeResults": allow_large_results,
+ "flattenResults": flatten_results,
+ "writeDisposition": write_disposition,
+ "createDisposition": create_disposition,
+ }
)
- configuration["query"].update(
- {
- "allowLargeResults": allow_large_results,
- "flattenResults": flatten_results,
- "writeDisposition": write_disposition,
- "createDisposition": create_disposition,
- }
- )
-
if (
"useLegacySql" in configuration["query"]
and configuration["query"]["useLegacySql"]
diff --git a/airflow/providers/google/cloud/hooks/datafusion.py
b/airflow/providers/google/cloud/hooks/datafusion.py
index b0e44081c7..dcf51357c6 100644
--- a/airflow/providers/google/cloud/hooks/datafusion.py
+++ b/airflow/providers/google/cloud/hooks/datafusion.py
@@ -371,12 +371,12 @@ class DataFusionHook(GoogleBaseHook):
self._check_response_status_and_data(
response, f"Deleting a pipeline failed with code
{response.status}: {response.data}"
)
- if response.status == 200:
- break
except ConflictException as exc:
self.log.info(exc)
sleep(time_to_wait)
- continue
+ else:
+ if response.status == 200:
+ break
def list_pipelines(
self,
diff --git a/airflow/providers/google/cloud/hooks/gcs.py
b/airflow/providers/google/cloud/hooks/gcs.py
index b27bedafa6..05279583da 100644
--- a/airflow/providers/google/cloud/hooks/gcs.py
+++ b/airflow/providers/google/cloud/hooks/gcs.py
@@ -361,7 +361,6 @@ class GCSHook(GoogleBaseHook):
# Wait with exponential backoff scheme before retrying.
timeout_seconds = 2 ** (num_file_attempts - 1)
time.sleep(timeout_seconds)
- continue
def download_as_byte_array(
self,
@@ -508,28 +507,23 @@ class GCSHook(GoogleBaseHook):
:param f: Callable that should be retried.
"""
- num_file_attempts = 0
-
- while num_file_attempts < num_max_attempts:
+ for attempt in range(1, 1 + num_max_attempts):
try:
- num_file_attempts += 1
f()
-
except GoogleCloudError as e:
- if num_file_attempts == num_max_attempts:
+ if attempt == num_max_attempts:
self.log.error(
"Upload attempt of object: %s from %s has failed.
Attempt: %s, max %s.",
object_name,
object_name,
- num_file_attempts,
+ attempt,
num_max_attempts,
)
raise e
# Wait with exponential backoff scheme before retrying.
- timeout_seconds = 2 ** (num_file_attempts - 1)
+ timeout_seconds = 2 ** (attempt - 1)
time.sleep(timeout_seconds)
- continue
client = self.get_conn()
bucket = client.bucket(bucket_name, user_project=user_project)
diff --git a/airflow/providers/google/cloud/log/gcs_task_handler.py
b/airflow/providers/google/cloud/log/gcs_task_handler.py
index 79055f0847..5cede21863 100644
--- a/airflow/providers/google/cloud/log/gcs_task_handler.py
+++ b/airflow/providers/google/cloud/log/gcs_task_handler.py
@@ -243,9 +243,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
old_log = blob.download_as_bytes().decode()
log = "\n".join([old_log, log]) if old_log else log
except Exception as e:
- if self.no_log_found(e):
- pass
- else:
+ if not self.no_log_found(e):
log += self._add_message(
f"Error checking for previous log; if exists, may be
overwritten: {e}"
)
diff --git a/airflow/providers/google/cloud/operators/compute.py
b/airflow/providers/google/cloud/operators/compute.py
index 2abc46d35a..8379904f35 100644
--- a/airflow/providers/google/cloud/operators/compute.py
+++ b/airflow/providers/google/cloud/operators/compute.py
@@ -174,14 +174,13 @@ class
ComputeEngineInsertInstanceOperator(ComputeEngineBaseOperator):
def check_body_fields(self) -> None:
required_params = ["machine_type", "disks", "network_interfaces"]
for param in required_params:
- if param in self.body:
- continue
- readable_param = param.replace("_", " ")
- raise AirflowException(
- f"The body '{self.body}' should contain at least
{readable_param} for the new operator "
- f"in the '{param}' field. Check
(google.cloud.compute_v1.types.Instance) "
- f"for more details about body fields description."
- )
+ if param not in self.body:
+ readable_param = param.replace("_", " ")
+ raise AirflowException(
+ f"The body '{self.body}' should contain at least
{readable_param} for the new operator "
+ f"in the '{param}' field. Check
(google.cloud.compute_v1.types.Instance) "
+ f"for more details about body fields description."
+ )
def _validate_inputs(self) -> None:
super()._validate_inputs()
@@ -915,14 +914,13 @@ class
ComputeEngineInsertInstanceTemplateOperator(ComputeEngineBaseOperator):
def check_body_fields(self) -> None:
required_params = ["machine_type", "disks", "network_interfaces"]
for param in required_params:
- if param in self.body["properties"]:
- continue
- readable_param = param.replace("_", " ")
- raise AirflowException(
- f"The body '{self.body}' should contain at least
{readable_param} for the new operator "
- f"in the '{param}' field. Check
(google.cloud.compute_v1.types.Instance) "
- f"for more details about body fields description."
- )
+ if param not in self.body["properties"]:
+ readable_param = param.replace("_", " ")
+ raise AirflowException(
+ f"The body '{self.body}' should contain at least
{readable_param} for the new operator "
+ f"in the '{param}' field. Check
(google.cloud.compute_v1.types.Instance) "
+ f"for more details about body fields description."
+ )
def _validate_all_body_fields(self) -> None:
if self._field_validator:
@@ -1500,14 +1498,13 @@ class
ComputeEngineInsertInstanceGroupManagerOperator(ComputeEngineBaseOperator)
def check_body_fields(self) -> None:
required_params = ["base_instance_name", "target_size",
"instance_template"]
for param in required_params:
- if param in self.body:
- continue
- readable_param = param.replace("_", " ")
- raise AirflowException(
- f"The body '{self.body}' should contain at least
{readable_param} for the new operator "
- f"in the '{param}' field. Check
(google.cloud.compute_v1.types.Instance) "
- f"for more details about body fields description."
- )
+ if param not in self.body:
+ readable_param = param.replace("_", " ")
+ raise AirflowException(
+ f"The body '{self.body}' should contain at least
{readable_param} for the new operator "
+ f"in the '{param}' field. Check
(google.cloud.compute_v1.types.Instance) "
+ f"for more details about body fields description."
+ )
def _validate_all_body_fields(self) -> None:
if self._field_validator:
diff --git a/airflow/providers/google/cloud/operators/gcs.py
b/airflow/providers/google/cloud/operators/gcs.py
index 9b95032b42..bb257f94b3 100644
--- a/airflow/providers/google/cloud/operators/gcs.py
+++ b/airflow/providers/google/cloud/operators/gcs.py
@@ -797,9 +797,8 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
num_max_attempts=self.download_num_attempts,
)
except GoogleCloudError:
- if self.download_continue_on_fail:
- continue
- raise
+ if not self.download_continue_on_fail:
+ raise
self.log.info("Starting the transformation")
cmd = [self.transform_script] if isinstance(self.transform_script,
str) else self.transform_script
@@ -847,9 +846,8 @@ class
GCSTimeSpanFileTransformOperator(GoogleCloudBaseOperator):
)
files_uploaded.append(str(upload_file_name))
except GoogleCloudError:
- if self.upload_continue_on_fail:
- continue
- raise
+ if not self.upload_continue_on_fail:
+ raise
return files_uploaded
diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py
b/airflow/providers/microsoft/azure/hooks/data_factory.py
index 7301ace03e..b4516ccfd7 100644
--- a/airflow/providers/microsoft/azure/hooks/data_factory.py
+++ b/airflow/providers/microsoft/azure/hooks/data_factory.py
@@ -855,8 +855,8 @@ class AzureDataFactoryHook(BaseHook):
except ServiceRequestError:
if executed_after_token_refresh:
self.refresh_conn()
- continue
- raise
+ else:
+ raise
return pipeline_run_status in expected_statuses
diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py
b/airflow/providers/microsoft/azure/triggers/data_factory.py
index 1ce5484008..e087e3556d 100644
--- a/airflow/providers/microsoft/azure/triggers/data_factory.py
+++ b/airflow/providers/microsoft/azure/triggers/data_factory.py
@@ -103,8 +103,8 @@ class ADFPipelineRunStatusSensorTrigger(BaseTrigger):
if executed_after_token_refresh:
await hook.refresh_conn()
executed_after_token_refresh = False
- continue
- raise
+ else:
+ raise
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
@@ -207,8 +207,8 @@ class AzureDataFactoryTrigger(BaseTrigger):
if executed_after_token_refresh:
await hook.refresh_conn()
executed_after_token_refresh = False
- continue
- raise
+ else:
+ raise
yield TriggerEvent(
{
diff --git a/airflow/providers/openlineage/utils/utils.py
b/airflow/providers/openlineage/utils/utils.py
index 07484d994d..837c5eb615 100644
--- a/airflow/providers/openlineage/utils/utils.py
+++ b/airflow/providers/openlineage/utils/utils.py
@@ -208,16 +208,14 @@ class InfoJsonEncodable(dict):
raise Exception("Don't use both includes and excludes.")
if self.includes:
for field in self.includes:
- if field in self._fields or not hasattr(self.obj, field):
- continue
- setattr(self, field, getattr(self.obj, field))
- self._fields.append(field)
+ if field not in self._fields and hasattr(self.obj, field):
+ setattr(self, field, getattr(self.obj, field))
+ self._fields.append(field)
else:
for field, val in self.obj.__dict__.items():
- if field in self._fields or field in self.excludes or field in
self.renames:
- continue
- setattr(self, field, val)
- self._fields.append(field)
+ if field not in self._fields and field not in self.excludes
and field not in self.renames:
+ setattr(self, field, val)
+ self._fields.append(field)
class DagInfo(InfoJsonEncodable):
diff --git a/airflow/providers/smtp/hooks/smtp.py
b/airflow/providers/smtp/hooks/smtp.py
index 0f2c689c37..bacb3ca3e7 100644
--- a/airflow/providers/smtp/hooks/smtp.py
+++ b/airflow/providers/smtp/hooks/smtp.py
@@ -87,14 +87,14 @@ class SmtpHook(BaseHook):
try:
self.smtp_client = self._build_client()
except smtplib.SMTPServerDisconnected:
- if attempt < self.smtp_retry_limit:
- continue
- raise AirflowException("Unable to connect to smtp server")
- if self.smtp_starttls:
- self.smtp_client.starttls()
- if self.smtp_user and self.smtp_password:
- self.smtp_client.login(self.smtp_user, self.smtp_password)
- break
+ if attempt == self.smtp_retry_limit:
+ raise AirflowException("Unable to connect to smtp
server")
+ else:
+ if self.smtp_starttls:
+ self.smtp_client.starttls()
+ if self.smtp_user and self.smtp_password:
+ self.smtp_client.login(self.smtp_user,
self.smtp_password)
+ break
return self
@@ -234,10 +234,10 @@ class SmtpHook(BaseHook):
from_addr=from_email, to_addrs=recipients,
msg=mime_msg.as_string()
)
except smtplib.SMTPServerDisconnected as e:
- if attempt < self.smtp_retry_limit:
- continue
- raise e
- break
+ if attempt == self.smtp_retry_limit:
+ raise e
+ else:
+ break
def _build_mime_message(
self,
diff --git
a/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
b/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
index b37608232d..e65a403a63 100644
--- a/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
+++ b/tests/providers/elasticsearch/log/elasticmock/fake_elasticsearch.py
@@ -331,9 +331,8 @@ class FakeElasticsearch(Elasticsearch):
i = 0
for searchable_index in searchable_indexes:
for document in self.__documents_dict[searchable_index]:
- if searchable_doc_types and document.get("_type") not in
searchable_doc_types:
- continue
- i += 1
+ if not searchable_doc_types or document.get("_type") in
searchable_doc_types:
+ i += 1
result = {"count": i, "_shards": {"successful": 1, "failed": 0,
"total": 1}}
return result
@@ -457,13 +456,11 @@ class FakeElasticsearch(Elasticsearch):
def find_document_in_searchable_index(self, matches, must,
searchable_doc_types, searchable_index):
for document in self.__documents_dict[searchable_index]:
- if searchable_doc_types and document.get("_type") not in
searchable_doc_types:
- continue
-
- if "match_phrase" in must:
- self.match_must_phrase(document, matches, must)
- else:
- matches.append(document)
+ if not searchable_doc_types or document.get("_type") in
searchable_doc_types:
+ if "match_phrase" in must:
+ self.match_must_phrase(document, matches, must)
+ else:
+ matches.append(document)
@staticmethod
def match_must_phrase(document, matches, must):
diff --git a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
index 0abb714cc5..cb240b7009 100644
--- a/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
+++ b/tests/providers/google/cloud/log/test_stackdriver_task_handler.py
@@ -41,10 +41,9 @@ def clean_stackdriver_handlers():
yield
for handler_ref in reversed(logging._handlerList[:]):
handler = handler_ref()
- if not isinstance(handler, StackdriverTaskHandler):
- continue
- logging._removeHandlerRef(handler_ref)
- del handler
+ if isinstance(handler, StackdriverTaskHandler):
+ logging._removeHandlerRef(handler_ref)
+ del handler
@pytest.mark.usefixtures("clean_stackdriver_handlers")