This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 20b7cfc395 respect soft_fail argument when exception is raised for
google sensors (#34501)
20b7cfc395 is described below
commit 20b7cfc3956e404fe1a6d4ed9e363fca7161ede2
Author: Wei Lee <[email protected]>
AuthorDate: Tue Sep 26 11:17:14 2023 +0800
respect soft_fail argument when exception is raised for google sensors
(#34501)
---
airflow/providers/google/cloud/sensors/bigquery.py | 23 +++++-
.../providers/google/cloud/sensors/bigquery_dts.py | 8 ++-
.../google/cloud/sensors/cloud_composer.py | 13 +++-
airflow/providers/google/cloud/sensors/dataflow.py | 32 ++++++---
airflow/providers/google/cloud/sensors/dataform.py | 8 ++-
.../providers/google/cloud/sensors/datafusion.py | 14 +++-
airflow/providers/google/cloud/sensors/dataplex.py | 41 ++++++++---
airflow/providers/google/cloud/sensors/dataproc.py | 32 +++++++--
.../google/cloud/sensors/dataproc_metastore.py | 14 +++-
airflow/providers/google/cloud/sensors/gcs.py | 34 +++++++--
airflow/providers/google/cloud/sensors/looker.py | 33 +++++----
airflow/providers/google/cloud/sensors/pubsub.py | 5 +-
.../providers/google/cloud/sensors/workflows.py | 10 ++-
.../google/cloud/sensors/test_bigquery.py | 81 ++++++++++++++++------
.../google/cloud/sensors/test_bigtable.py | 12 +++-
.../google/cloud/sensors/test_cloud_composer.py | 12 ++--
.../google/cloud/sensors/test_dataflow.py | 26 +++++--
.../google/cloud/sensors/test_datafusion.py | 18 +++--
.../google/cloud/sensors/test_dataplex.py | 10 ++-
.../google/cloud/sensors/test_dataproc.py | 42 ++++++++---
.../cloud/sensors/test_dataproc_metastore.py | 25 +++----
tests/providers/google/cloud/sensors/test_gcs.py | 73 ++++++++++++++-----
.../providers/google/cloud/sensors/test_looker.py | 30 +++++---
.../providers/google/cloud/sensors/test_pubsub.py | 18 +++--
.../google/cloud/sensors/test_workflows.py | 10 ++-
25 files changed, 461 insertions(+), 163 deletions(-)
diff --git a/airflow/providers/google/cloud/sensors/bigquery.py
b/airflow/providers/google/cloud/sensors/bigquery.py
index d4f15fac1b..d1492fe248 100644
--- a/airflow/providers/google/cloud/sensors/bigquery.py
+++ b/airflow/providers/google/cloud/sensors/bigquery.py
@@ -23,7 +23,7 @@ from datetime import timedelta
from typing import TYPE_CHECKING, Any, Sequence
from airflow.configuration import conf
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.triggers.bigquery import (
BigQueryTableExistenceTrigger,
@@ -141,8 +141,16 @@ class BigQueryTableExistenceSensor(BaseSensorOperator):
if event:
if event["status"] == "success":
return event["message"]
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ message = "No event received in trigger callback"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
class BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
@@ -248,8 +256,17 @@ class
BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
if event:
if event["status"] == "success":
return event["message"]
+
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ message = "No event received in trigger callback"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
class BigQueryTableExistenceAsyncSensor(BigQueryTableExistenceSensor):
diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py
b/airflow/providers/google/cloud/sensors/bigquery_dts.py
index 34198d2819..b4926b3b95 100644
--- a/airflow/providers/google/cloud/sensors/bigquery_dts.py
+++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py
@@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Sequence
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.bigquery_datatransfer_v1 import TransferState
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.bigquery_dts import
BiqQueryDataTransferServiceHook
from airflow.sensors.base import BaseSensorOperator
@@ -140,5 +140,9 @@ class
BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
self.log.info("Status of %s run: %s", self.run_id, str(run.state))
if run.state in (TransferState.FAILED, TransferState.CANCELLED):
- raise AirflowException(f"Transfer {self.run_id} did not succeed")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Transfer {self.run_id} did not succeed"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
return run.state in self.expected_statuses
diff --git a/airflow/providers/google/cloud/sensors/cloud_composer.py
b/airflow/providers/google/cloud/sensors/cloud_composer.py
index ecd717aa5a..1873b51d68 100644
--- a/airflow/providers/google/cloud/sensors/cloud_composer.py
+++ b/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -21,7 +21,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.triggers.cloud_composer import
CloudComposerExecutionTrigger
from airflow.sensors.base import BaseSensorOperator
@@ -90,5 +90,14 @@ class CloudComposerEnvironmentSensor(BaseSensorOperator):
if event:
if event.get("operation_done"):
return event["operation_done"]
+
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ message = "No event received in trigger callback"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
diff --git a/airflow/providers/google/cloud/sensors/dataflow.py
b/airflow/providers/google/cloud/sensors/dataflow.py
index 187b4c0007..c9f32588d5 100644
--- a/airflow/providers/google/cloud/sensors/dataflow.py
+++ b/airflow/providers/google/cloud/sensors/dataflow.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Sequence
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
@@ -106,7 +106,11 @@ class DataflowJobStatusSensor(BaseSensorOperator):
if job_status in self.expected_statuses:
return True
elif job_status in DataflowJobStatus.TERMINAL_STATES:
- raise AirflowException(f"Job with id '{self.job_id}' is already in
terminal state: {job_status}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
return False
@@ -178,9 +182,11 @@ class DataflowJobMetricsSensor(BaseSensorOperator):
)
job_status = job["currentState"]
if job_status in DataflowJobStatus.TERMINAL_STATES:
- raise AirflowException(
- f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
- )
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
result = self.hook.fetch_job_metrics_by_id(
job_id=self.job_id,
@@ -257,9 +263,11 @@ class DataflowJobMessagesSensor(BaseSensorOperator):
)
job_status = job["currentState"]
if job_status in DataflowJobStatus.TERMINAL_STATES:
- raise AirflowException(
- f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
- )
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
result = self.hook.fetch_job_messages_by_id(
job_id=self.job_id,
@@ -336,9 +344,11 @@ class
DataflowJobAutoScalingEventsSensor(BaseSensorOperator):
)
job_status = job["currentState"]
if job_status in DataflowJobStatus.TERMINAL_STATES:
- raise AirflowException(
- f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
- )
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = f"Job with id '{self.job_id}' is already in terminal
state: {job_status}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
result = self.hook.fetch_job_autoscaling_events_by_id(
job_id=self.job_id,
diff --git a/airflow/providers/google/cloud/sensors/dataform.py
b/airflow/providers/google/cloud/sensors/dataform.py
index 965e9c5fe2..45c74627a7 100644
--- a/airflow/providers/google/cloud/sensors/dataform.py
+++ b/airflow/providers/google/cloud/sensors/dataform.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Sequence
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataform import DataformHook
from airflow.sensors.base import BaseSensorOperator
@@ -95,9 +95,13 @@ class
DataformWorkflowInvocationStateSensor(BaseSensorOperator):
workflow_status = workflow_invocation.state
if workflow_status is not None:
if self.failure_statuses and workflow_status in
self.failure_statuses:
- raise AirflowException(
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = (
f"Workflow Invocation with id
'{self.workflow_invocation_id}' "
f"state is: {workflow_status}. Terminating sensor..."
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
return workflow_status in self.expected_statuses
diff --git a/airflow/providers/google/cloud/sensors/datafusion.py
b/airflow/providers/google/cloud/sensors/datafusion.py
index 8297d60f44..b151a6fae7 100644
--- a/airflow/providers/google/cloud/sensors/datafusion.py
+++ b/airflow/providers/google/cloud/sensors/datafusion.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Iterable, Sequence
-from airflow.exceptions import AirflowException, AirflowNotFoundException
+from airflow.exceptions import AirflowException, AirflowNotFoundException,
AirflowSkipException
from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
from airflow.sensors.base import BaseSensorOperator
@@ -109,15 +109,23 @@ class
CloudDataFusionPipelineStateSensor(BaseSensorOperator):
)
pipeline_status = pipeline_workflow["status"]
except AirflowNotFoundException:
- raise AirflowException("Specified Pipeline ID was not found.")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = "Specified Pipeline ID was not found."
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
except AirflowException:
pass # Because the pipeline may not be visible in system yet
if pipeline_status is not None:
if self.failure_statuses and pipeline_status in
self.failure_statuses:
- raise AirflowException(
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = (
f"Pipeline with id '{self.pipeline_id}' state is:
{pipeline_status}. "
f"Terminating sensor..."
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
self.log.debug(
"Current status of the pipeline workflow for %s: %s.",
self.pipeline_id, pipeline_status
diff --git a/airflow/providers/google/cloud/sensors/dataplex.py
b/airflow/providers/google/cloud/sensors/dataplex.py
index c00373f947..ee0ffc7410 100644
--- a/airflow/providers/google/cloud/sensors/dataplex.py
+++ b/airflow/providers/google/cloud/sensors/dataplex.py
@@ -24,11 +24,12 @@ if TYPE_CHECKING:
from google.api_core.retry import Retry
from airflow.utils.context import Context
+
from google.api_core.exceptions import GoogleAPICallError
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.dataplex_v1.types import DataScanJob
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataplex import (
AirflowDataQualityScanException,
AirflowDataQualityScanResultTimeoutException,
@@ -116,7 +117,11 @@ class DataplexTaskStateSensor(BaseSensorOperator):
task_status = task.state
if task_status == TaskState.DELETING:
- raise AirflowException(f"Task is going to be deleted
{self.dataplex_task_id}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Task is going to be deleted {self.dataplex_task_id}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
self.log.info("Current status of the Dataplex task %s => %s",
self.dataplex_task_id, task_status)
@@ -196,9 +201,13 @@ class
DataplexDataQualityJobStatusSensor(BaseSensorOperator):
if self.result_timeout:
duration = self._duration()
if duration > self.result_timeout:
- raise AirflowDataQualityScanResultTimeoutException(
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = (
f"Timeout: Data Quality scan {self.job_id} is not ready
after {self.result_timeout}s"
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowDataQualityScanResultTimeoutException(message)
hook = DataplexHook(
gcp_conn_id=self.gcp_conn_id,
@@ -217,22 +226,36 @@ class
DataplexDataQualityJobStatusSensor(BaseSensorOperator):
metadata=self.metadata,
)
except GoogleAPICallError as e:
- raise AirflowException(
- f"Error occurred when trying to retrieve Data Quality scan
job: {self.data_scan_id}", e
- )
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Error occurred when trying to retrieve Data Quality
scan job: {self.data_scan_id}"
+ if self.soft_fail:
+ raise AirflowSkipException(message, e)
+ raise AirflowException(message, e)
job_status = job.state
self.log.info(
"Current status of the Dataplex Data Quality scan job %s => %s",
self.job_id, job_status
)
if job_status == DataScanJob.State.FAILED:
- raise AirflowException(f"Data Quality scan job failed:
{self.job_id}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Data Quality scan job failed: {self.job_id}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
if job_status == DataScanJob.State.CANCELLED:
- raise AirflowException(f"Data Quality scan job cancelled:
{self.job_id}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Data Quality scan job cancelled: {self.job_id}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
if self.fail_on_dq_failure:
if job_status == DataScanJob.State.SUCCEEDED and not
job.data_quality_result.passed:
- raise AirflowDataQualityScanException(
+ # TODO: remove this if check when min_airflow_version is set
to higher than 2.7.1
+ message = (
f"Data Quality job {self.job_id} execution failed due to
failure of its scanning "
f"rules: {self.data_scan_id}"
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowDataQualityScanException(message)
return job_status == DataScanJob.State.SUCCEEDED
diff --git a/airflow/providers/google/cloud/sensors/dataproc.py
b/airflow/providers/google/cloud/sensors/dataproc.py
index b3f87a83b5..2acd695dba 100644
--- a/airflow/providers/google/cloud/sensors/dataproc.py
+++ b/airflow/providers/google/cloud/sensors/dataproc.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Sequence
from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import Batch, JobStatus
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.sensors.base import BaseSensorOperator
@@ -83,10 +83,14 @@ class DataprocJobSensor(BaseSensorOperator):
duration = self._duration()
self.log.info("DURATION RUN: %f", duration)
if duration > self.wait_timeout:
- raise AirflowException(
+ # TODO: remove this if check when min_airflow_version is
set to higher than 2.7.1
+ message = (
f"Timeout: dataproc job {self.dataproc_job_id} "
f"is not ready after {self.wait_timeout}s"
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
self.log.info("Retrying. Dataproc API returned server error
when waiting for job: %s", err)
return False
else:
@@ -94,13 +98,21 @@ class DataprocJobSensor(BaseSensorOperator):
state = job.status.state
if state == JobStatus.State.ERROR:
- raise AirflowException(f"Job failed:\n{job}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Job failed:\n{job}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif state in {
JobStatus.State.CANCELLED,
JobStatus.State.CANCEL_PENDING,
JobStatus.State.CANCEL_STARTED,
}:
- raise AirflowException(f"Job was cancelled:\n{job}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Job was cancelled:\n{job}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif JobStatus.State.DONE == state:
self.log.debug("Job %s completed successfully.",
self.dataproc_job_id)
return True
@@ -171,12 +183,20 @@ class DataprocBatchSensor(BaseSensorOperator):
state = batch.state
if state == Batch.State.FAILED:
- raise AirflowException("Batch failed")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = "Batch failed"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif state in {
Batch.State.CANCELLED,
Batch.State.CANCELLING,
}:
- raise AirflowException("Batch was cancelled.")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = "Batch was cancelled."
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif state == Batch.State.SUCCEEDED:
self.log.debug("Batch %s completed successfully.", self.batch_id)
return True
diff --git a/airflow/providers/google/cloud/sensors/dataproc_metastore.py
b/airflow/providers/google/cloud/sensors/dataproc_metastore.py
index c50c8f1a8b..ccb2226452 100644
--- a/airflow/providers/google/cloud/sensors/dataproc_metastore.py
+++ b/airflow/providers/google/cloud/sensors/dataproc_metastore.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataproc_metastore import
DataprocMetastoreHook
from airflow.providers.google.cloud.hooks.gcs import parse_json_from_gcs
from airflow.sensors.base import BaseSensorOperator
@@ -95,13 +95,21 @@ class MetastoreHivePartitionSensor(BaseSensorOperator):
self.log.info("Extracting result manifest")
manifest: dict = parse_json_from_gcs(gcp_conn_id=self.gcp_conn_id,
file_uri=result_manifest_uri)
if not (manifest and isinstance(manifest, dict)):
- raise AirflowException(
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = (
f"Failed to extract result manifest. "
f"Expected not empty dict, but this was received: {manifest}"
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
if manifest.get("status", {}).get("code") != 0:
- raise AirflowException(f"Request failed:
{manifest.get('message')}")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"Request failed: {manifest.get('message')}"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
# Extract actual query results
result_base_uri = result_manifest_uri.rsplit("/", 1)[0]
diff --git a/airflow/providers/google/cloud/sensors/gcs.py
b/airflow/providers/google/cloud/sensors/gcs.py
index 2fb220ab27..453bb3bf44 100644
--- a/airflow/providers/google/cloud/sensors/gcs.py
+++ b/airflow/providers/google/cloud/sensors/gcs.py
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence
from google.cloud.storage.retry import DEFAULT_RETRY
from airflow.configuration import conf
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.triggers.gcs import (
GCSBlobTrigger,
@@ -125,6 +125,9 @@ class GCSObjectExistenceSensor(BaseSensorOperator):
Relies on trigger to throw an exception, otherwise it assumes
execution was successful.
"""
if event["status"] == "error":
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
self.log.info("File %s was found in bucket %s.", self.object,
self.bucket)
return event["message"]
@@ -259,8 +262,16 @@ class GCSObjectUpdateSensor(BaseSensorOperator):
"Checking last updated time for object %s in bucket : %s",
self.object, self.bucket
)
return event["message"]
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ message = "No event received in trigger callback"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
class GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator):
@@ -347,6 +358,9 @@ class
GCSObjectsWithPrefixExistenceSensor(BaseSensorOperator):
self.log.info("Resuming from trigger and checking status")
if event["status"] == "success":
return event["matches"]
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
@@ -476,10 +490,14 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator):
)
return False
- raise AirflowException(
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = (
"Illegal behavior: objects were deleted in "
f"{os.path.join(self.bucket, self.prefix)} between pokes."
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
if self.last_activity_time:
self.inactivity_seconds = (get_time() -
self.last_activity_time).total_seconds()
@@ -549,5 +567,13 @@ class GCSUploadSessionCompleteSensor(BaseSensorOperator):
if event:
if event["status"] == "success":
return event["message"]
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
- raise AirflowException("No event received in trigger callback")
+
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ message = "No event received in trigger callback"
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
diff --git a/airflow/providers/google/cloud/sensors/looker.py
b/airflow/providers/google/cloud/sensors/looker.py
index e75d0fb665..5525734627 100644
--- a/airflow/providers/google/cloud/sensors/looker.py
+++ b/airflow/providers/google/cloud/sensors/looker.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.looker import JobStatus, LookerHook
from airflow.sensors.base import BaseSensorOperator
@@ -50,11 +50,14 @@ class LookerCheckPdtBuildSensor(BaseSensorOperator):
self.hook: LookerHook | None = None
def poke(self, context: Context) -> bool:
-
self.hook = LookerHook(looker_conn_id=self.looker_conn_id)
if not self.materialization_id:
- raise AirflowException("Invalid `materialization_id`.")
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = "Invalid `materialization_id`."
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
# materialization_id is templated var pulling output from start task
status_dict =
self.hook.pdt_build_status(materialization_id=self.materialization_id)
@@ -62,17 +65,23 @@ class LookerCheckPdtBuildSensor(BaseSensorOperator):
if status == JobStatus.ERROR.value:
msg = status_dict["message"]
- raise AirflowException(
- f'PDT materialization job failed. Job id:
{self.materialization_id}. Message:\n"{msg}"'
- )
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f'PDT materialization job failed. Job id:
{self.materialization_id}. Message:\n"{msg}"'
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif status == JobStatus.CANCELLED.value:
- raise AirflowException(
- f"PDT materialization job was cancelled. Job id:
{self.materialization_id}."
- )
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"PDT materialization job was cancelled. Job id:
{self.materialization_id}."
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif status == JobStatus.UNKNOWN.value:
- raise AirflowException(
- f"PDT materialization job has unknown status. Job id:
{self.materialization_id}."
- )
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = f"PDT materialization job has unknown status. Job id:
{self.materialization_id}."
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
elif status == JobStatus.DONE.value:
self.log.debug(
"PDT materialization job completed successfully. Job id: %s.",
self.materialization_id
diff --git a/airflow/providers/google/cloud/sensors/pubsub.py
b/airflow/providers/google/cloud/sensors/pubsub.py
index 7bd07a08e5..b4b9288934 100644
--- a/airflow/providers/google/cloud/sensors/pubsub.py
+++ b/airflow/providers/google/cloud/sensors/pubsub.py
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence
from google.cloud.pubsub_v1.types import ReceivedMessage
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.pubsub import PubSubHook
from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
from airflow.sensors.base import BaseSensorOperator
@@ -174,6 +174,9 @@ class PubSubPullSensor(BaseSensorOperator):
self.log.info("Sensor pulls messages: %s", event["message"])
return event["message"]
self.log.info("Sensor failed: %s", event["message"])
+ # TODO: remove this if check when min_airflow_version is set to higher
than 2.7.1
+ if self.soft_fail:
+ raise AirflowSkipException(event["message"])
raise AirflowException(event["message"])
def _default_message_callback(
diff --git a/airflow/providers/google/cloud/sensors/workflows.py
b/airflow/providers/google/cloud/sensors/workflows.py
index 712e328bdd..7f97fafdbb 100644
--- a/airflow/providers/google/cloud/sensors/workflows.py
+++ b/airflow/providers/google/cloud/sensors/workflows.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Sequence
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.workflows.executions_v1beta import Execution
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.workflows import WorkflowsHook
from airflow.sensors.base import BaseSensorOperator
@@ -100,10 +100,14 @@ class WorkflowExecutionSensor(BaseSensorOperator):
state = execution.state
if state in self.failure_states:
- raise AirflowException(
+ # TODO: remove this if check when min_airflow_version is set to
higher than 2.7.1
+ message = (
f"Execution {self.execution_id} for workflow
{self.execution_id} "
- f"failed and is in `{state}` state",
+ f"failed and is in `{state}` state"
)
+ if self.soft_fail:
+ raise AirflowSkipException(message)
+ raise AirflowException(message)
if state in self.success_states:
self.log.info(
diff --git a/tests/providers/google/cloud/sensors/test_bigquery.py
b/tests/providers/google/cloud/sensors/test_bigquery.py
index 5fe40227c5..ec489329fb 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery.py
@@ -20,7 +20,12 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, TaskDeferred
+from airflow.exceptions import (
+ AirflowException,
+ AirflowProviderDeprecationWarning,
+ AirflowSkipException,
+ TaskDeferred,
+)
from airflow.providers.google.cloud.sensors.bigquery import (
BigQueryTableExistenceAsyncSensor,
BigQueryTableExistencePartitionAsyncSensor,
@@ -100,16 +105,20 @@ class TestBigqueryTableExistenceSensor:
exc.value.trigger, BigQueryTableExistenceTrigger
), "Trigger is not a BigQueryTableExistenceTrigger"
- def test_execute_deferred_failure(self):
- """Tests that an AirflowException is raised in case of error event"""
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_execute_deferred_failure(self, soft_fail, expected_exception):
+ """Tests that an expected exception is raised in case of error event"""
task = BigQueryTableExistenceSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
deferrable=True,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context={}, event={"status": "error",
"message": "test failure message"})
def test_execute_complete(self):
@@ -126,15 +135,19 @@ class TestBigqueryTableExistenceSensor:
task.execute_complete(context={}, event={"status": "success",
"message": "Job completed"})
mock_log_info.assert_called_with("Sensor checks existence of table:
%s", table_uri)
- def test_execute_defered_complete_event_none(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_execute_defered_complete_event_none(self, soft_fail,
expected_exception):
"""Asserts that logging occurs as expected"""
task = BigQueryTableExistenceSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context={}, event=None)
@@ -206,7 +219,10 @@ class TestBigqueryTablePartitionExistenceSensor:
exc.value.trigger, BigQueryTablePartitionExistenceTrigger
), "Trigger is not a BigQueryTablePartitionExistenceTrigger"
- def test_execute_with_deferrable_mode_execute_failure(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_execute_with_deferrable_mode_execute_failure(self, soft_fail,
expected_exception):
"""Tests that an AirflowException is raised in case of error event"""
task = BigQueryTablePartitionExistenceSensor(
task_id="test_task_id",
@@ -215,11 +231,15 @@ class TestBigqueryTablePartitionExistenceSensor:
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
deferrable=True,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context={}, event={"status": "error",
"message": "test failure message"})
- def test_execute_complete_event_none(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_execute_complete_event_none(self, soft_fail, expected_exception):
"""Asserts that logging occurs as expected"""
task = BigQueryTablePartitionExistenceSensor(
task_id="task-id",
@@ -228,8 +248,9 @@ class TestBigqueryTablePartitionExistenceSensor:
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
deferrable=True,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="No event received in
trigger callback"):
+ with pytest.raises(expected_exception, match="No event received in
trigger callback"):
task.execute_complete(context={}, event=None)
def test_execute_complete(self):
@@ -287,16 +308,20 @@ class TestBigQueryTableExistenceAsyncSensor:
exc.value.trigger, BigQueryTableExistenceTrigger
), "Trigger is not a BigQueryTableExistenceTrigger"
- def test_big_query_table_existence_sensor_async_execute_failure(self):
- """Tests that an AirflowException is raised in case of error event"""
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_big_query_table_existence_sensor_async_execute_failure(self,
soft_fail, expected_exception):
+ """Tests that an expected_exception is raised in case of error event"""
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.depcrecation_message):
task = BigQueryTableExistenceAsyncSensor(
task_id="task-id",
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context={}, event={"status": "error",
"message": "test failure message"})
def test_big_query_table_existence_sensor_async_execute_complete(self):
@@ -313,7 +338,10 @@ class TestBigQueryTableExistenceAsyncSensor:
task.execute_complete(context={}, event={"status": "success",
"message": "Job completed"})
mock_log_info.assert_called_with("Sensor checks existence of table:
%s", table_uri)
- def test_big_query_sensor_async_execute_complete_event_none(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_big_query_sensor_async_execute_complete_event_none(self,
soft_fail, expected_exception):
"""Asserts that logging occurs as expected"""
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.depcrecation_message):
task = BigQueryTableExistenceAsyncSensor(
@@ -321,8 +349,9 @@ class TestBigQueryTableExistenceAsyncSensor:
project_id=TEST_PROJECT_ID,
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context={}, event=None)
@@ -355,8 +384,13 @@ class TestBigQueryTableExistencePartitionAsyncSensor:
exc.value.trigger, BigQueryTablePartitionExistenceTrigger
), "Trigger is not a BigQueryTablePartitionExistenceTrigger"
- def
test_big_query_table_existence_partition_sensor_async_execute_failure(self):
- """Tests that an AirflowException is raised in case of error event"""
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_big_query_table_existence_partition_sensor_async_execute_failure(
+ self, soft_fail, expected_exception
+ ):
+ """Tests that an expected exception is raised in case of error event"""
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.depcrecation_message):
task = BigQueryTableExistencePartitionAsyncSensor(
task_id="test_task_id",
@@ -364,11 +398,17 @@ class TestBigQueryTableExistencePartitionAsyncSensor:
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context={}, event={"status": "error",
"message": "test failure message"})
- def
test_big_query_table_existence_partition_sensor_async_execute_complete_event_none(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def
test_big_query_table_existence_partition_sensor_async_execute_complete_event_none(
+ self, soft_fail, expected_exception
+ ):
"""Asserts that logging occurs as expected"""
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.depcrecation_message):
task = BigQueryTableExistencePartitionAsyncSensor(
@@ -377,8 +417,9 @@ class TestBigQueryTableExistencePartitionAsyncSensor:
dataset_id=TEST_DATASET_ID,
table_id=TEST_TABLE_ID,
partition_id=TEST_PARTITION_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="No event received in
trigger callback"):
+ with pytest.raises(expected_exception, match="No event received in
trigger callback"):
task.execute_complete(context={}, event=None)
def
test_big_query_table_existence_partition_sensor_async_execute_complete(self):
diff --git a/tests/providers/google/cloud/sensors/test_bigtable.py
b/tests/providers/google/cloud/sensors/test_bigtable.py
index 37bd5eaf8f..dea84fb9f9 100644
--- a/tests/providers/google/cloud/sensors/test_bigtable.py
+++ b/tests/providers/google/cloud/sensors/test_bigtable.py
@@ -24,7 +24,7 @@ import pytest
from google.cloud.bigtable.instance import Instance
from google.cloud.bigtable.table import ClusterState
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.sensors.bigtable import
BigtableTableReplicationCompletedSensor
PROJECT_ID = "test_project_id"
@@ -35,6 +35,9 @@ IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
class BigtableWaitForTableReplicationTest:
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@pytest.mark.parametrize(
"missing_attribute, project_id, instance_id, table_id",
[
@@ -43,8 +46,10 @@ class BigtableWaitForTableReplicationTest:
],
)
@mock.patch("airflow.providers.google.cloud.sensors.bigtable.BigtableHook")
- def test_empty_attribute(self, missing_attribute, project_id, instance_id,
table_id, mock_hook):
- with pytest.raises(AirflowException) as ctx:
+ def test_empty_attribute(
+ self, missing_attribute, project_id, instance_id, table_id, mock_hook,
soft_fail, expected_exception
+ ):
+ with pytest.raises(expected_exception) as ctx:
BigtableTableReplicationCompletedSensor(
project_id=project_id,
instance_id=instance_id,
@@ -52,6 +57,7 @@ class BigtableWaitForTableReplicationTest:
task_id="id",
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
err = ctx.value
assert str(err) == f"Empty parameter: {missing_attribute}"
diff --git a/tests/providers/google/cloud/sensors/test_cloud_composer.py
b/tests/providers/google/cloud/sensors/test_cloud_composer.py
index 8062da44b9..f6f3e81a40 100644
--- a/tests/providers/google/cloud/sensors/test_cloud_composer.py
+++ b/tests/providers/google/cloud/sensors/test_cloud_composer.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.exceptions import AirflowException, AirflowSkipException,
TaskDeferred
from airflow.providers.google.cloud.sensors.cloud_composer import
CloudComposerEnvironmentSensor
from airflow.providers.google.cloud.triggers.cloud_composer import
CloudComposerExecutionTrigger
@@ -48,15 +48,19 @@ class TestCloudComposerEnvironmentSensor:
exc.value.trigger, CloudComposerExecutionTrigger
), "Trigger is not a CloudComposerExecutionTrigger"
- def test_cloud_composer_existence_sensor_async_execute_failure(self):
- """Tests that an AirflowException is raised in case of error event."""
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_cloud_composer_existence_sensor_async_execute_failure(self,
soft_fail, expected_exception):
+ """Tests that an expected exception is raised in case of error
event."""
task = CloudComposerEnvironmentSensor(
task_id="task_id",
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
operation_name=TEST_OPERATION_NAME,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="No event received in
trigger callback"):
+ with pytest.raises(expected_exception, match="No event received in
trigger callback"):
task.execute_complete(context={}, event=None)
def test_cloud_composer_existence_sensor_async_execute_complete(self):
diff --git a/tests/providers/google/cloud/sensors/test_dataflow.py
b/tests/providers/google/cloud/sensors/test_dataflow.py
index 36d8840c81..d669b2b111 100644
--- a/tests/providers/google/cloud/sensors/test_dataflow.py
+++ b/tests/providers/google/cloud/sensors/test_dataflow.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
from airflow.providers.google.cloud.sensors.dataflow import (
DataflowJobAutoScalingEventsSensor,
@@ -71,8 +71,11 @@ class TestDataflowJobStatusSensor:
job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID,
location=TEST_LOCATION
)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook")
- def test_poke_raise_exception(self, mock_hook):
+ def test_poke_raise_exception(self, mock_hook, soft_fail,
expected_exception):
mock_get_job = mock_hook.return_value.get_job
task = DataflowJobStatusSensor(
task_id=TEST_TASK_ID,
@@ -82,11 +85,12 @@ class TestDataflowJobStatusSensor:
project_id=TEST_PROJECT_ID,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState":
DataflowJobStatus.JOB_STATE_CANCELLED}
with pytest.raises(
- AirflowException,
+ expected_exception,
match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
f"{DataflowJobStatus.JOB_STATE_CANCELLED}",
):
@@ -182,8 +186,11 @@ class DataflowJobMessagesSensorTest:
)
callback.assert_called_once_with(mock_fetch_job_messages_by_id.return_value)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook")
- def test_poke_raise_exception(self, mock_hook):
+ def test_poke_raise_exception(self, mock_hook, soft_fail,
expected_exception):
mock_get_job = mock_hook.return_value.get_job
mock_fetch_job_messages_by_id =
mock_hook.return_value.fetch_job_messages_by_id
callback = mock.MagicMock()
@@ -197,11 +204,12 @@ class DataflowJobMessagesSensorTest:
project_id=TEST_PROJECT_ID,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState":
DataflowJobStatus.JOB_STATE_DONE}
with pytest.raises(
- AirflowException,
+ expected_exception,
match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
f"{DataflowJobStatus.JOB_STATE_DONE}",
):
@@ -255,8 +263,11 @@ class DataflowJobAutoScalingEventsSensorTest:
)
callback.assert_called_once_with(mock_fetch_job_autoscaling_events_by_id.return_value)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.dataflow.DataflowHook")
- def test_poke_raise_exception_on_terminal_state(self, mock_hook):
+ def test_poke_raise_exception_on_terminal_state(self, mock_hook,
soft_fail, expected_exception):
mock_get_job = mock_hook.return_value.get_job
mock_fetch_job_autoscaling_events_by_id =
mock_hook.return_value.fetch_job_autoscaling_events_by_id
callback = mock.MagicMock()
@@ -270,11 +281,12 @@ class DataflowJobAutoScalingEventsSensorTest:
project_id=TEST_PROJECT_ID,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
mock_get_job.return_value = {"id": TEST_JOB_ID, "currentState":
DataflowJobStatus.JOB_STATE_DONE}
with pytest.raises(
- AirflowException,
+ expected_exception,
match=f"Job with id '{TEST_JOB_ID}' is already in terminal state: "
f"{DataflowJobStatus.JOB_STATE_DONE}",
):
diff --git a/tests/providers/google/cloud/sensors/test_datafusion.py
b/tests/providers/google/cloud/sensors/test_datafusion.py
index 32dcfbb050..6de24b7943 100644
--- a/tests/providers/google/cloud/sensors/test_datafusion.py
+++ b/tests/providers/google/cloud/sensors/test_datafusion.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException, AirflowNotFoundException
+from airflow.exceptions import AirflowException, AirflowNotFoundException,
AirflowSkipException
from airflow.providers.google.cloud.hooks.datafusion import PipelineStates
from airflow.providers.google.cloud.sensors.datafusion import
CloudDataFusionPipelineStateSensor
@@ -74,8 +74,11 @@ class TestCloudDataFusionPipelineStateSensor:
instance_name=INSTANCE_NAME, location=LOCATION,
project_id=PROJECT_ID
)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook")
- def test_assertion(self, mock_hook):
+ def test_assertion(self, mock_hook, soft_fail, expected_exception):
mock_hook.return_value.get_instance.return_value = {"apiEndpoint":
INSTANCE_URL}
task = CloudDataFusionPipelineStateSensor(
@@ -89,17 +92,21 @@ class TestCloudDataFusionPipelineStateSensor:
location=LOCATION,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
with pytest.raises(
- AirflowException,
+ expected_exception,
match=f"Pipeline with id '{PIPELINE_ID}' state is: FAILED.
Terminating sensor...",
):
mock_hook.return_value.get_pipeline_workflow.return_value =
{"status": "FAILED"}
task.poke(mock.MagicMock())
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.datafusion.DataFusionHook")
- def test_not_found_exception(self, mock_hook):
+ def test_not_found_exception(self, mock_hook, soft_fail,
expected_exception):
mock_hook.return_value.get_instance.return_value = {"apiEndpoint":
INSTANCE_URL}
mock_hook.return_value.get_pipeline_workflow.side_effect =
AirflowNotFoundException()
@@ -114,10 +121,11 @@ class TestCloudDataFusionPipelineStateSensor:
location=LOCATION,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
with pytest.raises(
- AirflowException,
+ expected_exception,
match="Specified Pipeline ID was not found.",
):
task.poke(mock.MagicMock())
diff --git a/tests/providers/google/cloud/sensors/test_dataplex.py
b/tests/providers/google/cloud/sensors/test_dataplex.py
index 18f5b68b9d..20a4de4ff0 100644
--- a/tests/providers/google/cloud/sensors/test_dataplex.py
+++ b/tests/providers/google/cloud/sensors/test_dataplex.py
@@ -22,7 +22,7 @@ import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.dataplex_v1.types import DataScanJob
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataplex import
AirflowDataQualityScanResultTimeoutException
from airflow.providers.google.cloud.sensors.dataplex import (
DataplexDataQualityJobStatusSensor,
@@ -81,8 +81,11 @@ class TestDataplexTaskStateSensor:
assert result
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPLEX_HOOK)
- def test_deleting(self, mock_hook):
+ def test_deleting(self, mock_hook, soft_fail, expected_exception):
task = self.create_task(TaskState.DELETING)
mock_hook.return_value.get_task.return_value = task
@@ -95,9 +98,10 @@ class TestDataplexTaskStateSensor:
api_version=API_VERSION,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="Task is going to be
deleted"):
+ with pytest.raises(expected_exception, match="Task is going to be
deleted"):
sensor.poke(context={})
mock_hook.return_value.get_task.assert_called_once_with(
diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py
b/tests/providers/google/cloud/sensors/test_dataproc.py
index 9080705ebd..f123976be9 100644
--- a/tests/providers/google/cloud/sensors/test_dataproc.py
+++ b/tests/providers/google/cloud/sensors/test_dataproc.py
@@ -23,7 +23,7 @@ import pytest
from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import Batch, JobStatus
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.sensors.dataproc import
DataprocBatchSensor, DataprocJobSensor
from airflow.version import version as airflow_version
@@ -66,8 +66,11 @@ class TestDataprocJobSensor:
)
assert ret
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_error(self, mock_hook):
+ def test_error(self, mock_hook, soft_fail, expected_exception):
job = self.create_job(JobStatus.State.ERROR)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job
@@ -79,9 +82,10 @@ class TestDataprocJobSensor:
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="Job failed"):
+ with pytest.raises(expected_exception, match="Job failed"):
sensor.poke(context={})
mock_hook.return_value.get_job.assert_called_once_with(
@@ -109,8 +113,11 @@ class TestDataprocJobSensor:
)
assert not ret
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_cancelled(self, mock_hook):
+ def test_cancelled(self, mock_hook, soft_fail, expected_exception):
job = self.create_job(JobStatus.State.CANCELLED)
job_id = "job_id"
mock_hook.return_value.get_job.return_value = job
@@ -122,8 +129,9 @@ class TestDataprocJobSensor:
dataproc_job_id=job_id,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="Job was cancelled"):
+ with pytest.raises(expected_exception, match="Job was cancelled"):
sensor.poke(context={})
mock_hook.return_value.get_job.assert_called_once_with(
@@ -163,8 +171,11 @@ class TestDataprocJobSensor:
result = sensor.poke(context={})
assert not result
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_wait_timeout_raise_exception(self, mock_hook):
+ def test_wait_timeout_raise_exception(self, mock_hook, soft_fail,
expected_exception):
job_id = "job_id"
mock_hook.return_value.get_job.side_effect = ServerError("Job are not
ready")
@@ -176,12 +187,13 @@ class TestDataprocJobSensor:
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
wait_timeout=300,
+ soft_fail=soft_fail,
)
sensor._duration = Mock()
sensor._duration.return_value = 301
- with pytest.raises(AirflowException, match="Timeout: dataproc job
job_id is not ready after 300s"):
+ with pytest.raises(expected_exception, match="Timeout: dataproc job
job_id is not ready after 300s"):
sensor.poke(context={})
@@ -212,8 +224,11 @@ class TestDataprocBatchSensor:
)
assert ret
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_cancelled(self, mock_hook):
+ def test_cancelled(self, mock_hook, soft_fail, expected_exception):
batch = self.create_batch(Batch.State.CANCELLED)
mock_hook.return_value.get_batch.return_value = batch
@@ -224,16 +239,20 @@ class TestDataprocBatchSensor:
batch_id="batch_id",
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="Batch was cancelled."):
+ with pytest.raises(expected_exception, match="Batch was cancelled."):
sensor.poke(context={})
mock_hook.return_value.get_batch.assert_called_once_with(
batch_id="batch_id", region=GCP_LOCATION, project_id=GCP_PROJECT
)
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
- def test_error(self, mock_hook):
+ def test_error(self, mock_hook, soft_fail, expected_exception):
batch = self.create_batch(Batch.State.FAILED)
mock_hook.return_value.get_batch.return_value = batch
@@ -244,9 +263,10 @@ class TestDataprocBatchSensor:
batch_id="batch_id",
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="Batch failed"):
+ with pytest.raises(expected_exception, match="Batch failed"):
sensor.poke(context={})
mock_hook.return_value.get_batch.assert_called_once_with(
diff --git a/tests/providers/google/cloud/sensors/test_dataproc_metastore.py
b/tests/providers/google/cloud/sensors/test_dataproc_metastore.py
index 117210e2ba..435ceac661 100644
--- a/tests/providers/google/cloud/sensors/test_dataproc_metastore.py
+++ b/tests/providers/google/cloud/sensors/test_dataproc_metastore.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.sensors.dataproc_metastore import
MetastoreHivePartitionSensor
DATAPROC_METASTORE_SENSOR_PATH =
"airflow.providers.google.cloud.sensors.dataproc_metastore.{}"
@@ -106,14 +106,14 @@ class TestMetastoreHivePartitionSensor:
)
assert sensor.poke(context={}) == expected_result
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@pytest.mark.parametrize("empty_manifest", [dict(), list(), tuple(), None,
""])
@mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("DataprocMetastoreHook"))
@mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("parse_json_from_gcs"))
def test_poke_empty_manifest(
- self,
- mock_parse_json_from_gcs,
- mock_hook,
- empty_manifest,
+ self, mock_parse_json_from_gcs, mock_hook, empty_manifest, soft_fail,
expected_exception
):
mock_parse_json_from_gcs.return_value = empty_manifest
@@ -124,18 +124,18 @@ class TestMetastoreHivePartitionSensor:
table=TEST_TABLE,
partitions=[PARTITION_1],
gcp_conn_id=GCP_CONN_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
sensor.poke(context={})
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("DataprocMetastoreHook"))
@mock.patch(DATAPROC_METASTORE_SENSOR_PATH.format("parse_json_from_gcs"))
- def test_poke_wrong_status(
- self,
- mock_parse_json_from_gcs,
- mock_hook,
- ):
+ def test_poke_wrong_status(self, mock_parse_json_from_gcs, mock_hook,
soft_fail, expected_exception):
error_message = "Test error message"
mock_parse_json_from_gcs.return_value = {"code": 1, "message":
error_message}
@@ -146,7 +146,8 @@ class TestMetastoreHivePartitionSensor:
table=TEST_TABLE,
partitions=[PARTITION_1],
gcp_conn_id=GCP_CONN_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match=f"Request failed:
{error_message}"):
+ with pytest.raises(expected_exception, match=f"Request failed:
{error_message}"):
sensor.poke(context={})
diff --git a/tests/providers/google/cloud/sensors/test_gcs.py
b/tests/providers/google/cloud/sensors/test_gcs.py
index 4bb3646152..422cd8f71a 100644
--- a/tests/providers/google/cloud/sensors/test_gcs.py
+++ b/tests/providers/google/cloud/sensors/test_gcs.py
@@ -24,7 +24,12 @@ import pendulum
import pytest
from google.cloud.storage.retry import DEFAULT_RETRY
-from airflow.exceptions import AirflowProviderDeprecationWarning,
AirflowSensorTimeout, TaskDeferred
+from airflow.exceptions import (
+ AirflowProviderDeprecationWarning,
+ AirflowSensorTimeout,
+ AirflowSkipException,
+ TaskDeferred,
+)
from airflow.models.dag import DAG, AirflowException
from airflow.providers.google.cloud.sensors.gcs import (
GCSObjectExistenceAsyncSensor,
@@ -135,7 +140,10 @@ class TestGoogleCloudStorageObjectSensor:
task.execute(context)
assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not
a GCSBlobTrigger"
- def test_gcs_object_existence_sensor_deferred_execute_failure(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_gcs_object_existence_sensor_deferred_execute_failure(self,
soft_fail, expected_exception):
"""Tests that an AirflowException is raised in case of error event
when deferrable is set to True"""
task = GCSObjectExistenceSensor(
task_id="task-id",
@@ -143,8 +151,9 @@ class TestGoogleCloudStorageObjectSensor:
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
deferrable=True,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context=None, event={"status": "error",
"message": "test failure message"})
def test_gcs_object_existence_sensor_execute_complete(self):
@@ -185,7 +194,10 @@ class TestGoogleCloudStorageObjectAsyncSensor:
task.execute(context)
assert isinstance(exc.value.trigger, GCSBlobTrigger), "Trigger is not
a GCSBlobTrigger"
- def test_gcs_object_existence_async_sensor_execute_failure(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_gcs_object_existence_async_sensor_execute_failure(self,
soft_fail, expected_exception):
"""Tests that an AirflowException is raised in case of error event"""
with pytest.warns(AirflowProviderDeprecationWarning,
match=self.depcrecation_message):
task = GCSObjectExistenceAsyncSensor(
@@ -193,8 +205,9 @@ class TestGoogleCloudStorageObjectAsyncSensor:
bucket=TEST_BUCKET,
object=TEST_OBJECT,
google_cloud_conn_id=TEST_GCP_CONN_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
task.execute_complete(context=None, event={"status": "error",
"message": "test failure message"})
def test_gcs_object_existence_async_sensor_execute_complete(self):
@@ -289,10 +302,13 @@ class TestGCSObjectUpdateAsyncSensor:
exc.value.trigger, GCSCheckBlobUpdateTimeTrigger
), "Trigger is not a GCSCheckBlobUpdateTimeTrigger"
- def test_gcs_object_update_async_sensor_execute_failure(self, context):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_gcs_object_update_async_sensor_execute_failure(self, context,
soft_fail, expected_exception):
"""Tests that an AirflowException is raised in case of error event"""
-
- with pytest.raises(AirflowException):
+ self.OPERATOR.soft_fail = soft_fail
+ with pytest.raises(expected_exception):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test
failure message"}
)
@@ -364,13 +380,21 @@ class TestGoogleCloudStoragePrefixSensor:
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET,
prefix=TEST_PREFIX)
assert response == generated_messages
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowSensorTimeout),
(True, AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook")
- def test_execute_timeout(self, mock_hook):
+ def test_execute_timeout(self, mock_hook, soft_fail, expected_exception):
task = GCSObjectsWithPrefixExistenceSensor(
- task_id="task-id", bucket=TEST_BUCKET, prefix=TEST_PREFIX,
poke_interval=0, timeout=1
+ task_id="task-id",
+ bucket=TEST_BUCKET,
+ prefix=TEST_PREFIX,
+ poke_interval=0,
+ timeout=1,
+ soft_fail=soft_fail,
)
mock_hook.return_value.list.return_value = []
- with pytest.raises(AirflowSensorTimeout):
+ with pytest.raises(expected_exception):
task.execute(mock.MagicMock)
mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET,
prefix=TEST_PREFIX)
@@ -410,10 +434,15 @@ class TestGCSObjectsWithPrefixExistenceAsyncSensor:
self.OPERATOR.execute(mock.MagicMock())
assert isinstance(exc.value.trigger, GCSPrefixBlobTrigger), "Trigger
is not a GCSPrefixBlobTrigger"
- def
test_gcs_object_with_prefix_existence_async_sensor_execute_failure(self,
context):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_gcs_object_with_prefix_existence_async_sensor_execute_failure(
+ self, context, soft_fail, expected_exception
+ ):
"""Tests that an AirflowException is raised in case of error event"""
-
- with pytest.raises(AirflowException):
+ self.OPERATOR.soft_fail = soft_fail
+ with pytest.raises(expected_exception):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test
failure message"}
)
@@ -461,10 +490,14 @@ class TestGCSUploadSessionCompleteSensor:
)
assert mock_hook.return_value == self.sensor.hook
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.gcs.get_time",
mock_time)
- def test_files_deleted_between_pokes_throw_error(self):
+ def test_files_deleted_between_pokes_throw_error(self, soft_fail,
expected_exception):
+ self.sensor.soft_fail = soft_fail
self.sensor.is_bucket_updated({"a", "b"})
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
self.sensor.is_bucket_updated({"a"})
@mock.patch("airflow.providers.google.cloud.sensors.gcs.get_time",
mock_time)
@@ -549,10 +582,14 @@ class TestGCSUploadSessionCompleteAsyncSensor:
exc.value.trigger, GCSUploadSessionTrigger
), "Trigger is not a GCSUploadSessionTrigger"
- def test_gcs_upload_session_complete_sensor_execute_failure(self, context):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_gcs_upload_session_complete_sensor_execute_failure(self, context,
soft_fail, expected_exception):
"""Tests that an AirflowException is raised in case of error event"""
- with pytest.raises(AirflowException):
+ self.OPERATOR.soft_fail = soft_fail
+ with pytest.raises(expected_exception):
self.OPERATOR.execute_complete(
context=context, event={"status": "error", "message": "test
failure message"}
)
diff --git a/tests/providers/google/cloud/sensors/test_looker.py
b/tests/providers/google/cloud/sensors/test_looker.py
index 567340f086..8e35234055 100644
--- a/tests/providers/google/cloud/sensors/test_looker.py
+++ b/tests/providers/google/cloud/sensors/test_looker.py
@@ -20,7 +20,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.looker import JobStatus
from airflow.providers.google.cloud.sensors.looker import
LookerCheckPdtBuildSensor
@@ -51,8 +51,11 @@ class TestLookerCheckPdtBuildSensor:
# assert we got a response
assert ret
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(SENSOR_PATH.format("LookerHook"))
- def test_error(self, mock_hook):
+ def test_error(self, mock_hook, soft_fail, expected_exception):
mock_hook.return_value.pdt_build_status.return_value = {
"status": JobStatus.ERROR.value,
"message": "test",
@@ -63,9 +66,10 @@ class TestLookerCheckPdtBuildSensor:
task_id=TASK_ID,
looker_conn_id=LOOKER_CONN_ID,
materialization_id=TEST_JOB_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="PDT materialization job
failed"):
+ with pytest.raises(expected_exception, match="PDT materialization job
failed"):
sensor.poke(context={})
# assert hook.pdt_build_status called once
@@ -89,8 +93,11 @@ class TestLookerCheckPdtBuildSensor:
# assert we got NO response
assert not ret
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(SENSOR_PATH.format("LookerHook"))
- def test_cancelled(self, mock_hook):
+ def test_cancelled(self, mock_hook, soft_fail, expected_exception):
mock_hook.return_value.pdt_build_status.return_value = {"status":
JobStatus.CANCELLED.value}
# run task in mock context
@@ -98,22 +105,23 @@ class TestLookerCheckPdtBuildSensor:
task_id=TASK_ID,
looker_conn_id=LOOKER_CONN_ID,
materialization_id=TEST_JOB_ID,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException, match="PDT materialization job
was cancelled"):
+ with pytest.raises(expected_exception, match="PDT materialization job
was cancelled"):
sensor.poke(context={})
# assert hook.pdt_build_status called once
mock_hook.return_value.pdt_build_status.assert_called_once_with(materialization_id=TEST_JOB_ID)
- def test_empty_materialization_id(self):
-
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_empty_materialization_id(self, soft_fail, expected_exception):
# run task in mock context
sensor = LookerCheckPdtBuildSensor(
- task_id=TASK_ID,
- looker_conn_id=LOOKER_CONN_ID,
- materialization_id="",
+ task_id=TASK_ID, looker_conn_id=LOOKER_CONN_ID,
materialization_id="", soft_fail=soft_fail
)
- with pytest.raises(AirflowException, match="^Invalid
`materialization_id`.$"):
+ with pytest.raises(expected_exception, match="^Invalid
`materialization_id`.$"):
sensor.poke(context={})
diff --git a/tests/providers/google/cloud/sensors/test_pubsub.py
b/tests/providers/google/cloud/sensors/test_pubsub.py
index 88fa3e296f..dcecce218e 100644
--- a/tests/providers/google/cloud/sensors/test_pubsub.py
+++ b/tests/providers/google/cloud/sensors/test_pubsub.py
@@ -23,7 +23,7 @@ from unittest import mock
import pytest
from google.cloud.pubsub_v1.types import ReceivedMessage
-from airflow.exceptions import AirflowException, AirflowSensorTimeout,
TaskDeferred
+from airflow.exceptions import AirflowException, AirflowSensorTimeout,
AirflowSkipException, TaskDeferred
from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor
from airflow.providers.google.cloud.triggers.pubsub import PubsubPullTrigger
@@ -98,19 +98,23 @@ class TestPubSubPullSensor:
)
assert generated_dicts == response
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowSensorTimeout),
(True, AirflowSkipException))
+ )
@mock.patch("airflow.providers.google.cloud.sensors.pubsub.PubSubHook")
- def test_execute_timeout(self, mock_hook):
+ def test_execute_timeout(self, mock_hook, soft_fail, expected_exception):
operator = PubSubPullSensor(
task_id=TASK_ID,
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
poke_interval=0,
timeout=1,
+ soft_fail=soft_fail,
)
mock_hook.return_value.pull.return_value = []
- with pytest.raises(AirflowSensorTimeout):
+ with pytest.raises(expected_exception):
operator.execute({})
mock_hook.return_value.pull.assert_called_once_with(
project_id=TEST_PROJECT,
@@ -173,7 +177,10 @@ class TestPubSubPullSensor:
task.execute(context={})
assert isinstance(exc.value.trigger, PubsubPullTrigger), "Trigger is
not a PubsubPullTrigger"
- def test_pubsub_pull_sensor_async_execute_should_throw_exception(self):
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
+ def test_pubsub_pull_sensor_async_execute_should_throw_exception(self,
soft_fail, expected_exception):
"""Tests that an AirflowException is raised in case of error event"""
operator = PubSubPullSensor(
@@ -182,9 +189,10 @@ class TestPubSubPullSensor:
project_id=TEST_PROJECT,
subscription=TEST_SUBSCRIPTION,
deferrable=True,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
operator.execute_complete(
context=mock.MagicMock(), event={"status": "error", "message":
"test failure message"}
)
diff --git a/tests/providers/google/cloud/sensors/test_workflows.py
b/tests/providers/google/cloud/sensors/test_workflows.py
index 12d66ac62d..232c1db0e0 100644
--- a/tests/providers/google/cloud/sensors/test_workflows.py
+++ b/tests/providers/google/cloud/sensors/test_workflows.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
from google.cloud.workflows.executions_v1beta import Execution
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.sensors.workflows import
WorkflowExecutionSensor
BASE_PATH = "airflow.providers.google.cloud.sensors.workflows.{}"
@@ -90,8 +90,11 @@ class TestWorkflowExecutionSensor:
assert result is False
+ @pytest.mark.parametrize(
+ "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ )
@mock.patch(BASE_PATH.format("WorkflowsHook"))
- def test_poke_failure(self, mock_hook):
+ def test_poke_failure(self, mock_hook, soft_fail, expected_exception):
mock_hook.return_value.get_execution.return_value =
mock.MagicMock(state=Execution.State.FAILED)
op = WorkflowExecutionSensor(
task_id="test_task",
@@ -104,6 +107,7 @@ class TestWorkflowExecutionSensor:
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
+ soft_fail=soft_fail,
)
- with pytest.raises(AirflowException):
+ with pytest.raises(expected_exception):
op.poke({})