This is an automated email from the ASF dual-hosted git repository.
joshfell pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bb4a0b39e8 Optimize deferred execution mode for DbtCloudRunJobOperator
(#31188)
bb4a0b39e8 is described below
commit bb4a0b39e8c021772830c9d44e72e492e0fef4bb
Author: Phani Kumar <[email protected]>
AuthorDate: Tue May 16 01:47:43 2023 +0530
Optimize deferred execution mode for DbtCloudRunJobOperator (#31188)
---
airflow/providers/dbt/cloud/operators/dbt.py | 40 +++++++----
.../dbt/cloud/operators/test_dbt_cloud.py | 79 ++++++++++++++++++++++
2 files changed, 107 insertions(+), 12 deletions(-)
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py
b/airflow/providers/dbt/cloud/operators/dbt.py
index de110e9865..d9177e6772 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook,
DbtCloudJobRunException, DbtCloudJobRunStatus
+from airflow.providers.dbt.cloud.hooks.dbt import (
+ DbtCloudHook,
+ DbtCloudJobRunException,
+ DbtCloudJobRunStatus,
+ JobRunInfo,
+)
from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
if TYPE_CHECKING:
@@ -154,17 +159,28 @@ class DbtCloudRunJobOperator(BaseOperator):
return self.run_id
else:
end_time = time.time() + self.timeout
- self.defer(
- timeout=self.execution_timeout,
- trigger=DbtCloudRunJobTrigger(
- conn_id=self.dbt_cloud_conn_id,
- run_id=self.run_id,
- end_time=end_time,
- account_id=self.account_id,
- poll_interval=self.check_interval,
- ),
- method_name="execute_complete",
- )
+ job_run_info = JobRunInfo(account_id=self.account_id,
run_id=self.run_id)
+ job_run_status = self.hook.get_job_run_status(**job_run_info)
+ if not DbtCloudJobRunStatus.is_terminal(job_run_status):
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=DbtCloudRunJobTrigger(
+ conn_id=self.dbt_cloud_conn_id,
+ run_id=self.run_id,
+ end_time=end_time,
+ account_id=self.account_id,
+ poll_interval=self.check_interval,
+ ),
+ method_name="execute_complete",
+ )
+ elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
+ self.log.info("Job run %s has completed successfully.",
str(self.run_id))
+ return self.run_id
+ elif job_run_status in (
+ DbtCloudJobRunStatus.CANCELLED.value,
+ DbtCloudJobRunStatus.ERROR.value,
+ ):
+ raise DbtCloudJobRunException(f"Job run {self.run_id} has
failed or has been cancelled.")
else:
if self.deferrable is True:
warnings.warn(
diff --git a/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
b/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
index 6aab53109f..4400ebd4ab 100644
--- a/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
+++ b/tests/providers/dbt/cloud/operators/test_dbt_cloud.py
@@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
import pytest
+from airflow.exceptions import TaskDeferred
from airflow.models import DAG, Connection
from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook,
DbtCloudJobRunException, DbtCloudJobRunStatus
from airflow.providers.dbt.cloud.operators.dbt import (
@@ -28,6 +29,7 @@ from airflow.providers.dbt.cloud.operators.dbt import (
DbtCloudListJobsOperator,
DbtCloudRunJobOperator,
)
+from airflow.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
from airflow.utils import db, timezone
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
@@ -95,6 +97,83 @@ class TestDbtCloudRunJobOperator:
"additional_run_config": {"threads_override": 8},
}
+ @patch(
+
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
+ return_value=DbtCloudJobRunStatus.SUCCESS.value,
+ )
+
@patch("airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator.defer")
+ @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
+
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
+ def test_execute_succeeded_before_getting_deferred(
+ self, mock_trigger_job_run, mock_dbt_hook, mock_defer,
mock_job_run_status
+ ):
+ dbt_op = DbtCloudRunJobOperator(
+ dbt_cloud_conn_id=ACCOUNT_ID_CONN,
+ task_id=TASK_ID,
+ job_id=JOB_ID,
+ check_interval=1,
+ timeout=3,
+ dag=self.dag,
+ deferrable=True,
+ )
+ dbt_op.execute(MagicMock())
+ assert not mock_defer.called
+
+ @patch(
+
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
+ return_value=DbtCloudJobRunStatus.ERROR.value,
+ )
+
@patch("airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator.defer")
+ @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
+
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
+ def test_execute_failed_before_getting_deferred(
+ self, mock_trigger_job_run, mock_dbt_hook, mock_defer,
mock_job_run_status
+ ):
+ dbt_op = DbtCloudRunJobOperator(
+ dbt_cloud_conn_id=ACCOUNT_ID_CONN,
+ task_id=TASK_ID,
+ job_id=JOB_ID,
+ check_interval=1,
+ timeout=3,
+ dag=self.dag,
+ deferrable=True,
+ )
+ with pytest.raises(DbtCloudJobRunException):
+ dbt_op.execute(MagicMock())
+ assert not mock_defer.called
+
+ @pytest.mark.parametrize(
+ "status",
+ (
+ DbtCloudJobRunStatus.QUEUED.value,
+ DbtCloudJobRunStatus.STARTING.value,
+ DbtCloudJobRunStatus.RUNNING.value,
+ ),
+ )
+ @patch(
+
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
+ )
+ @patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
+
@patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
+ def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook,
mock_job_run_status, status):
+ """
+ Asserts that a task is deferred and an DbtCloudRunJobTrigger will be
fired
+ when the DbtCloudRunJobOperator has deferrable param set to True
+ """
+ mock_job_run_status.return_value = status
+ dbt_op = DbtCloudRunJobOperator(
+ dbt_cloud_conn_id=ACCOUNT_ID_CONN,
+ task_id=TASK_ID,
+ job_id=JOB_ID,
+ check_interval=1,
+ timeout=3,
+ dag=self.dag,
+ deferrable=True,
+ )
+ with pytest.raises(TaskDeferred) as exc:
+ dbt_op.execute(MagicMock())
+ assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger
is not a DbtCloudRunJobTrigger"
+
@patch.object(DbtCloudHook, "trigger_job_run",
return_value=MagicMock(**DEFAULT_ACCOUNT_JOB_RUN_RESPONSE))
@pytest.mark.parametrize(
"job_run_status, expected_output",