This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d04aa13  [Airflow 13779] use provided parameters in the 
wait_for_pipeline_state hook (#17137)
d04aa13 is described below

commit d04aa135268b8e0230be3af6598a3b18e8614c3c
Author: Ɓukasz Wyszomirski <lwyszomir...@gmail.com>
AuthorDate: Fri Aug 20 22:10:46 2021 +0200

    [Airflow 13779] use provided parameters in the wait_for_pipeline_state hook 
(#17137)
    
    I removed wait_for_pipeline_state from start_pipeline hook. By this call, I 
think we have a bug in this operator, for example when we have pipeline which 
starting more than 300 seconds, so it have a starting status, we get the error 
because this pipepline is not in correct state after 300 seconds. Even when we 
pass our parameters sucess_states and pipeline_timeout we get this error in 
this case, so I think when I pass both parameters the logic should use them not 
default. Why we have  [...]
---
 airflow/providers/google/cloud/hooks/datafusion.py | 10 +-------
 .../providers/google/cloud/operators/datafusion.py | 29 ++++++++++++----------
 .../google/cloud/hooks/test_datafusion.py          | 12 ++-------
 .../google/cloud/operators/test_datafusion.py      | 12 +++++++++
 4 files changed, 31 insertions(+), 32 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/datafusion.py 
b/airflow/providers/google/cloud/hooks/datafusion.py
index 4cc8f50..0058223 100644
--- a/airflow/providers/google/cloud/hooks/datafusion.py
+++ b/airflow/providers/google/cloud/hooks/datafusion.py
@@ -464,15 +464,7 @@ class DataFusionHook(GoogleBaseHook):
             raise AirflowException(f"Starting a pipeline failed with code 
{response.status}")
 
         response_json = json.loads(response.data)
-        pipeline_id = response_json[0]["runId"]
-        self.wait_for_pipeline_state(
-            success_states=SUCCESS_STATES + [PipelineStates.RUNNING],
-            pipeline_name=pipeline_name,
-            pipeline_id=pipeline_id,
-            namespace=namespace,
-            instance_url=instance_url,
-        )
-        return pipeline_id
+        return response_json[0]["runId"]
 
     def stop_pipeline(self, pipeline_name: str, instance_url: str, namespace: 
str = "default") -> None:
         """
diff --git a/airflow/providers/google/cloud/operators/datafusion.py 
b/airflow/providers/google/cloud/operators/datafusion.py
index 0ba9673..b115437 100644
--- a/airflow/providers/google/cloud/operators/datafusion.py
+++ b/airflow/providers/google/cloud/operators/datafusion.py
@@ -23,7 +23,7 @@ from google.api_core.retry import exponential_sleep_generator
 from googleapiclient.errors import HttpError
 
 from airflow.models import BaseOperator
-from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
+from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, 
DataFusionHook, PipelineStates
 
 
 class CloudDataFusionRestartInstanceOperator(BaseOperator):
@@ -808,9 +808,7 @@ class CloudDataFusionStartPipelineOperator(BaseOperator):
     ) -> None:
         super().__init__(**kwargs)
         self.pipeline_name = pipeline_name
-        self.success_states = success_states
         self.runtime_args = runtime_args
-        self.pipeline_timeout = pipeline_timeout
         self.namespace = namespace
         self.instance_name = instance_name
         self.location = location
@@ -820,6 +818,13 @@ class CloudDataFusionStartPipelineOperator(BaseOperator):
         self.delegate_to = delegate_to
         self.impersonation_chain = impersonation_chain
 
+        if success_states:
+            self.success_states = success_states
+            self.pipeline_timeout = pipeline_timeout
+        else:
+            self.success_states = SUCCESS_STATES + [PipelineStates.RUNNING]
+            self.pipeline_timeout = 5 * 60
+
     def execute(self, context: dict) -> None:
         hook = DataFusionHook(
             gcp_conn_id=self.gcp_conn_id,
@@ -840,17 +845,15 @@ class CloudDataFusionStartPipelineOperator(BaseOperator):
             namespace=self.namespace,
             runtime_args=self.runtime_args,
         )
-
+        hook.wait_for_pipeline_state(
+            success_states=self.success_states,
+            pipeline_id=pipeline_id,
+            pipeline_name=self.pipeline_name,
+            namespace=self.namespace,
+            instance_url=api_url,
+            timeout=self.pipeline_timeout,
+        )
         self.log.info("Pipeline started")
-        if self.success_states:
-            hook.wait_for_pipeline_state(
-                success_states=self.success_states,
-                pipeline_id=pipeline_id,
-                pipeline_name=self.pipeline_name,
-                namespace=self.namespace,
-                instance_url=api_url,
-                timeout=self.pipeline_timeout,
-            )
 
 
 class CloudDataFusionStopPipelineOperator(BaseOperator):
diff --git a/tests/providers/google/cloud/hooks/test_datafusion.py 
b/tests/providers/google/cloud/hooks/test_datafusion.py
index 29199a2..79cbfbf 100644
--- a/tests/providers/google/cloud/hooks/test_datafusion.py
+++ b/tests/providers/google/cloud/hooks/test_datafusion.py
@@ -20,7 +20,7 @@ from unittest import mock
 
 import pytest
 
-from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, 
DataFusionHook, PipelineStates
+from airflow.providers.google.cloud.hooks.datafusion import DataFusionHook
 from tests.providers.google.cloud.utils.base_gcp_mock import 
mock_base_gcp_hook_default_project_id
 
 API_VERSION = "v1beta1"
@@ -180,8 +180,7 @@ class TestDataFusionHook:
         assert result == data
 
     @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))
-    @mock.patch(HOOK_STR.format("DataFusionHook.wait_for_pipeline_state"))
-    def test_start_pipeline(self, mock_wait_for_pipeline_state, mock_request, 
hook):
+    def test_start_pipeline(self, mock_request, hook):
         run_id = 1234
         mock_request.return_value = mock.MagicMock(status=200, 
data=f'[{{"runId":{run_id}}}]')
 
@@ -197,13 +196,6 @@ class TestDataFusionHook:
         mock_request.assert_called_once_with(
             url=f"{INSTANCE_URL}/v3/namespaces/default/start", method="POST", 
body=body
         )
-        mock_wait_for_pipeline_state.assert_called_once_with(
-            instance_url=INSTANCE_URL,
-            namespace="default",
-            pipeline_name=PIPELINE_NAME,
-            pipeline_id=run_id,
-            success_states=SUCCESS_STATES + [PipelineStates.RUNNING],
-        )
 
     @mock.patch(HOOK_STR.format("DataFusionHook._cdap_request"))
     def test_stop_pipeline(self, mock_request, hook):
diff --git a/tests/providers/google/cloud/operators/test_datafusion.py 
b/tests/providers/google/cloud/operators/test_datafusion.py
index 81ceaec..466f670 100644
--- a/tests/providers/google/cloud/operators/test_datafusion.py
+++ b/tests/providers/google/cloud/operators/test_datafusion.py
@@ -18,6 +18,7 @@
 from unittest import mock
 
 from airflow import DAG
+from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, 
PipelineStates
 from airflow.providers.google.cloud.operators.datafusion import (
     CloudDataFusionCreateInstanceOperator,
     CloudDataFusionCreatePipelineOperator,
@@ -194,7 +195,9 @@ class TestCloudDataFusionDeletePipelineOperator:
 class TestCloudDataFusionStartPipelineOperator:
     @mock.patch(HOOK_STR)
     def test_execute(self, mock_hook):
+        PIPELINE_ID = "test_pipeline_id"
         mock_hook.return_value.get_instance.return_value = {"apiEndpoint": 
INSTANCE_URL}
+        mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID
 
         op = CloudDataFusionStartPipelineOperator(
             task_id="test_task",
@@ -219,6 +222,15 @@ class TestCloudDataFusionStartPipelineOperator:
             runtime_args=RUNTIME_ARGS,
         )
 
+        mock_hook.return_value.wait_for_pipeline_state.assert_called_once_with(
+            success_states=SUCCESS_STATES + [PipelineStates.RUNNING],
+            pipeline_id=PIPELINE_ID,
+            pipeline_name=PIPELINE_NAME,
+            namespace=NAMESPACE,
+            instance_url=INSTANCE_URL,
+            timeout=300,
+        )
+
 
 class TestCloudDataFusionStopPipelineOperator:
     @mock.patch(HOOK_STR)

Reply via email to