This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new d04aa13 [Airflow 13779] use provided parameters in the wait_for_pipeline_state hook (#17137) d04aa13 is described below commit d04aa135268b8e0230be3af6598a3b18e8614c3c Author: Ćukasz Wyszomirski <lwyszomir...@gmail.com> AuthorDate: Fri Aug 20 22:10:46 2021 +0200 [Airflow 13779] use provided parameters in the wait_for_pipeline_state hook (#17137) I removed wait_for_pipeline_state from start_pipeline hook. By this call, I think we have a bug in this operator, for example when we have pipeline which starting more than 300 seconds, so it have a starting status, we get the error because this pipepline is not in correct state after 300 seconds. Even when we pass our parameters sucess_states and pipeline_timeout we get this error in this case, so I think when I pass both parameters the logic should use them not default. Why we have [...] --- airflow/providers/google/cloud/hooks/datafusion.py | 10 +------- .../providers/google/cloud/operators/datafusion.py | 29 ++++++++++++---------- .../google/cloud/hooks/test_datafusion.py | 12 ++------- .../google/cloud/operators/test_datafusion.py | 12 +++++++++ 4 files changed, 31 insertions(+), 32 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/datafusion.py b/airflow/providers/google/cloud/hooks/datafusion.py index 4cc8f50..0058223 100644 --- a/airflow/providers/google/cloud/hooks/datafusion.py +++ b/airflow/providers/google/cloud/hooks/datafusion.py @@ -464,15 +464,7 @@ class DataFusionHook(GoogleBaseHook): raise AirflowException(f"Starting a pipeline failed with code {response.status}") response_json = json.loads(response.data) - pipeline_id = response_json[0]["runId"] - self.wait_for_pipeline_state( - success_states=SUCCESS_STATES + [PipelineStates.RUNNING], - pipeline_name=pipeline_name, - pipeline_id=pipeline_id, - namespace=namespace, - instance_url=instance_url, - ) - return pipeline_id + return response_json[0]["runId"] def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: str = "default") -> None: """ diff --git a/airflow/providers/google/cloud/operators/datafusion.py b/airflow/providers/google/cloud/operators/datafusion.py index 0ba9673..b115437 100644 --- a/airflow/providers/google/cloud/operators/datafusion.py +++ b/airflow/providers/google/cloud/operators/datafusion.py @@ -23,7 +23,7 @@ from google.api_core.retry import exponential_sleep_generator from googleapiclient.errors import HttpError from airflow.models import BaseOperator -from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook +from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates class CloudDataFusionRestartInstanceOperator(BaseOperator): @@ -808,9 +808,7 @@ class CloudDataFusionStartPipelineOperator(BaseOperator): ) -> None: super().__init__(**kwargs) self.pipeline_name = pipeline_name - self.success_states = success_states self.runtime_args = runtime_args - self.pipeline_timeout = pipeline_timeout self.namespace = namespace self.instance_name = instance_name self.location = location @@ -820,6 +818,13 @@ class CloudDataFusionStartPipelineOperator(BaseOperator): self.delegate_to = delegate_to self.impersonation_chain = impersonation_chain + if success_states: + self.success_states = success_states + self.pipeline_timeout = pipeline_timeout + else: + self.success_states = SUCCESS_STATES + [PipelineStates.RUNNING] + self.pipeline_timeout = 5 * 60 + def execute(self, context: dict) -> None: hook = DataFusionHook( gcp_conn_id=self.gcp_conn_id, @@ -840,17 +845,15 @@ class CloudDataFusionStartPipelineOperator(BaseOperator): namespace=self.namespace, runtime_args=self.runtime_args, ) - + hook.wait_for_pipeline_state( + success_states=self.success_states, + pipeline_id=pipeline_id, + pipeline_name=self.pipeline_name, + namespace=self.namespace, + instance_url=api_url, + timeout=self.pipeline_timeout, + ) self.log.info("Pipeline started") - if self.success_states: - hook.wait_for_pipeline_state( - success_states=self.success_states, - pipeline_id=pipeline_id, - pipeline_name=self.pipeline_name, - namespace=self.namespace, - instance_url=api_url, - timeout=self.pipeline_timeout, - ) class CloudDataFusionStopPipelineOperator(BaseOperator): diff --git a/tests/providers/google/cloud/hooks/test_datafusion.py b/tests/providers/google/cloud/hooks/test_datafusion.py index 29199a2..79cbfbf 100644 --- a/tests/providers/google/cloud/hooks/test_datafusion.py +++ b/tests/providers/google/cloud/hooks/test_datafusion.py @@ -20,7 +20,7 @@ from unittest import mock import pytest -from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, DataFusionHook, PipelineStates +from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id API_VERSION = "v1beta1" @@ -180,8 +180,7 @@ class TestDataFusionHook: assert result == data @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) - @mock.patch(HOOK_STR.format("DataFusionHook.wait_for_pipeline_state")) - def test_start_pipeline(self, mock_wait_for_pipeline_state, mock_request, hook): + def test_start_pipeline(self, mock_request, hook): run_id = 1234 mock_request.return_value = mock.MagicMock(status=200, data=f'[{{"runId":{run_id}}}]') @@ -197,13 +196,6 @@ class TestDataFusionHook: mock_request.assert_called_once_with( url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", body=body ) - mock_wait_for_pipeline_state.assert_called_once_with( - instance_url=INSTANCE_URL, - namespace="default", - pipeline_name=PIPELINE_NAME, - pipeline_id=run_id, - success_states=SUCCESS_STATES + [PipelineStates.RUNNING], - ) @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request")) def test_stop_pipeline(self, mock_request, hook): diff --git a/tests/providers/google/cloud/operators/test_datafusion.py b/tests/providers/google/cloud/operators/test_datafusion.py index 81ceaec..466f670 100644 --- a/tests/providers/google/cloud/operators/test_datafusion.py +++ b/tests/providers/google/cloud/operators/test_datafusion.py @@ -18,6 +18,7 @@ from unittest import mock from airflow import DAG +from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, PipelineStates from airflow.providers.google.cloud.operators.datafusion import ( CloudDataFusionCreateInstanceOperator, CloudDataFusionCreatePipelineOperator, @@ -194,7 +195,9 @@ class TestCloudDataFusionDeletePipelineOperator: class TestCloudDataFusionStartPipelineOperator: @mock.patch(HOOK_STR) def test_execute(self, mock_hook): + PIPELINE_ID = "test_pipeline_id" mock_hook.return_value.get_instance.return_value = {"apiEndpoint": INSTANCE_URL} + mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID op = CloudDataFusionStartPipelineOperator( task_id="test_task", @@ -219,6 +222,15 @@ class TestCloudDataFusionStartPipelineOperator: runtime_args=RUNTIME_ARGS, ) + mock_hook.return_value.wait_for_pipeline_state.assert_called_once_with( + success_states=SUCCESS_STATES + [PipelineStates.RUNNING], + pipeline_id=PIPELINE_ID, + pipeline_name=PIPELINE_NAME, + namespace=NAMESPACE, + instance_url=INSTANCE_URL, + timeout=300, + ) + class TestCloudDataFusionStopPipelineOperator: @mock.patch(HOOK_STR)