This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 26964f8a8e feat(providers/dbt): add reuse_existing_run for allowing
DbtCloudRunJobOperator to reuse existing run (#37474)
26964f8a8e is described below
commit 26964f8a8e740115d40c608b153fa28d6f5979bf
Author: Wei Lee <[email protected]>
AuthorDate: Tue Feb 20 09:37:57 2024 +0800
feat(providers/dbt): add reuse_existing_run for allowing
DbtCloudRunJobOperator to reuse existing run (#37474)
---
airflow/providers/dbt/cloud/hooks/dbt.py | 15 +++++++
airflow/providers/dbt/cloud/operators/dbt.py | 40 ++++++++++++-----
tests/providers/dbt/cloud/hooks/test_dbt.py | 23 ++++++++--
tests/providers/dbt/cloud/operators/test_dbt.py | 57 +++++++++++++++++++++++++
4 files changed, 122 insertions(+), 13 deletions(-)
diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py
b/airflow/providers/dbt/cloud/hooks/dbt.py
index a375aa27c7..85eba8da04 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -109,6 +109,7 @@ class DbtCloudJobRunStatus(Enum):
SUCCESS = 10
ERROR = 20
CANCELLED = 30
+ NON_TERMINAL_STATUSES = (QUEUED, STARTING, RUNNING)
TERMINAL_STATUSES = (SUCCESS, ERROR, CANCELLED)
@classmethod
@@ -460,6 +461,20 @@ class DbtCloudHook(HttpHook):
paginate=True,
)
+ @fallback_to_default_account
+ def get_job_runs(self, account_id: int | None = None, payload: dict[str,
Any] | None = None) -> Response:
+ """
+ Retrieve metadata for a specific run of a dbt Cloud job.
+
+ :param account_id: Optional. The ID of a dbt Cloud account.
+ :param paylod: Optional. Query Parameters
+ :return: The request response.
+ """
+ return self._run_and_get_response(
+ endpoint=f"{account_id}/runs/",
+ payload=payload,
+ )
+
@fallback_to_default_account
def get_job_run(
self, run_id: int, account_id: int | None = None, include_related:
list[str] | None = None
diff --git a/airflow/providers/dbt/cloud/operators/dbt.py
b/airflow/providers/dbt/cloud/operators/dbt.py
index 0b56e88e01..0b31ba3014 100644
--- a/airflow/providers/dbt/cloud/operators/dbt.py
+++ b/airflow/providers/dbt/cloud/operators/dbt.py
@@ -73,6 +73,8 @@ class DbtCloudRunJobOperator(BaseOperator):
Used only if ``wait_for_termination`` is True. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that
should be included in the API
request when triggering the job.
+ :param reuse_existing_run: Flag to determine whether to reuse existing non
terminal job run. If set to
+ true and non terminal job runs found, it use the latest run without
triggering a new job run.
:param deferrable: Run operator in the deferrable mode
:return: The ID of the triggered dbt Cloud job run.
"""
@@ -102,6 +104,7 @@ class DbtCloudRunJobOperator(BaseOperator):
timeout: int = 60 * 60 * 24 * 7,
check_interval: int = 60,
additional_run_config: dict[str, Any] | None = None,
+ reuse_existing_run: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
**kwargs,
) -> None:
@@ -117,6 +120,7 @@ class DbtCloudRunJobOperator(BaseOperator):
self.check_interval = check_interval
self.additional_run_config = additional_run_config or {}
self.run_id: int | None = None
+ self.reuse_existing_run = reuse_existing_run
self.deferrable = deferrable
def execute(self, context: Context):
@@ -125,16 +129,32 @@ class DbtCloudRunJobOperator(BaseOperator):
f"Triggered via Apache Airflow by task {self.task_id!r} in the
{self.dag.dag_id} DAG."
)
- trigger_job_response = self.hook.trigger_job_run(
- account_id=self.account_id,
- job_id=self.job_id,
- cause=self.trigger_reason,
- steps_override=self.steps_override,
- schema_override=self.schema_override,
- additional_run_config=self.additional_run_config,
- )
- self.run_id = trigger_job_response.json()["data"]["id"]
- job_run_url = trigger_job_response.json()["data"]["href"]
+ non_terminal_runs = None
+ if self.reuse_existing_run:
+ non_terminal_runs = self.hook.get_job_runs(
+ account_id=self.account_id,
+ payload={
+ "job_definition_id": self.job_id,
+ "status": DbtCloudJobRunStatus.NON_TERMINAL_STATUSES,
+ "order_by": "-created_at",
+ },
+ ).json()["data"]
+ if non_terminal_runs:
+ self.run_id = non_terminal_runs[0]["id"]
+ job_run_url = non_terminal_runs[0]["href"]
+
+ if not self.reuse_existing_run or not non_terminal_runs:
+ trigger_job_response = self.hook.trigger_job_run(
+ account_id=self.account_id,
+ job_id=self.job_id,
+ cause=self.trigger_reason,
+ steps_override=self.steps_override,
+ schema_override=self.schema_override,
+ additional_run_config=self.additional_run_config,
+ )
+ self.run_id = trigger_job_response.json()["data"]["id"]
+ job_run_url = trigger_job_response.json()["data"]["href"]
+
# Push the ``job_run_url`` value to XCom regardless of what happens
during execution so that the job
# run can be monitored via the operator link.
context["ti"].xcom_push(key="job_run_url", value=job_run_url)
diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py
b/tests/providers/dbt/cloud/hooks/test_dbt.py
index 20e1313965..39d31a444b 100644
--- a/tests/providers/dbt/cloud/hooks/test_dbt.py
+++ b/tests/providers/dbt/cloud/hooks/test_dbt.py
@@ -439,6 +439,21 @@ class TestDbtCloudHook:
},
)
+ @pytest.mark.parametrize(
+ argnames="conn_id, account_id",
+ argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+ ids=["default_account", "explicit_account"],
+ )
+ @patch.object(DbtCloudHook, "run")
+ def test_get_job_runs(self, mock_http_run, conn_id, account_id):
+ hook = DbtCloudHook(conn_id)
+ hook.get_job_runs(account_id=account_id)
+
+ assert hook.method == "GET"
+
+ _account_id = account_id or DEFAULT_ACCOUNT_ID
+ hook.run.assert_called_once_with(endpoint=f"{_account_id}/runs/",
data=None)
+
@pytest.mark.parametrize(
argnames="conn_id, account_id",
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
@@ -493,9 +508,11 @@ class TestDbtCloudHook:
argnames=("job_run_status", "expected_status", "expected_output"),
argvalues=wait_for_job_run_status_test_args,
ids=[
- f"run_status_{argval[0]}_expected_{argval[1]}"
- if isinstance(argval[1], int)
- else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+ (
+ f"run_status_{argval[0]}_expected_{argval[1]}"
+ if isinstance(argval[1], int)
+ else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+ )
for argval in wait_for_job_run_status_test_args
],
)
diff --git a/tests/providers/dbt/cloud/operators/test_dbt.py
b/tests/providers/dbt/cloud/operators/test_dbt.py
index b4c1aa89e7..90465602dc 100644
--- a/tests/providers/dbt/cloud/operators/test_dbt.py
+++ b/tests/providers/dbt/cloud/operators/test_dbt.py
@@ -307,6 +307,63 @@ class TestDbtCloudRunJobOperator:
mock_get_job_run.assert_not_called()
+ @patch.object(DbtCloudHook, "get_job_runs")
+ @patch.object(DbtCloudHook, "trigger_job_run")
+ @pytest.mark.parametrize(
+ "conn_id, account_id",
+ [(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
+ ids=["default_account", "explicit_account"],
+ )
+ def test_execute_no_wait_for_termination_and_reuse_existing_run(
+ self, mock_run_job, mock_get_jobs_run, conn_id, account_id
+ ):
+ mock_get_jobs_run.return_value.json.return_value = {
+ "data": [
+ {
+ "id": 10000,
+ "status": 1,
+ "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
+ account_id=DEFAULT_ACCOUNT_ID, project_id=PROJECT_ID,
run_id=RUN_ID
+ ),
+ },
+ {
+ "id": 10001,
+ "status": 2,
+ "href": EXPECTED_JOB_RUN_OP_EXTRA_LINK.format(
+ account_id=DEFAULT_ACCOUNT_ID, project_id=PROJECT_ID,
run_id=RUN_ID
+ ),
+ },
+ ]
+ }
+
+ operator = DbtCloudRunJobOperator(
+ task_id=TASK_ID,
+ dbt_cloud_conn_id=conn_id,
+ account_id=account_id,
+ trigger_reason=None,
+ dag=self.dag,
+ wait_for_termination=False,
+ reuse_existing_run=True,
+ **self.config,
+ )
+
+ assert operator.dbt_cloud_conn_id == conn_id
+ assert operator.job_id == self.config["job_id"]
+ assert operator.account_id == account_id
+ assert operator.check_interval == self.config["check_interval"]
+ assert operator.timeout == self.config["timeout"]
+ assert not operator.wait_for_termination
+ assert operator.steps_override == self.config["steps_override"]
+ assert operator.schema_override == self.config["schema_override"]
+ assert operator.additional_run_config ==
self.config["additional_run_config"]
+
+ with patch.object(DbtCloudHook, "get_job_run") as mock_get_job_run:
+ operator.execute(context=self.mock_context)
+
+ mock_run_job.assert_not_called()
+
+ mock_get_job_run.assert_not_called()
+
@patch.object(DbtCloudHook, "trigger_job_run")
@pytest.mark.parametrize(
"conn_id, account_id",