MrGeorgeOwl commented on code in PR #27833:
URL: https://github.com/apache/airflow/pull/27833#discussion_r1039384748


##########
airflow/providers/google/cloud/operators/bigquery_dts.py:
##########
@@ -307,5 +317,82 @@ def execute(self, context: Context):
         result = StartManualTransferRunsResponse.to_dict(response)
         run_id = get_object_id(result["runs"][0])
         self.xcom_push(context, key="run_id", value=run_id)
-        self.log.info("Transfer run %s submitted successfully.", run_id)
-        return result
+
+        if not self.deferrable:
+            result = self._wait_for_transfer_to_be_done(
+                run_id=run_id,
+                transfer_config_id=transfer_config["config_id"],
+            )
+            self.log.info("Transfer run %s submitted successfully.", run_id)
+            return result
+
+        self.defer(
+            trigger=BigQueryDataTransferRunTrigger(
+                project_id=self.project_id,
+                config_id=transfer_config["config_id"],
+                run_id=run_id,
+                gcp_conn_id=self.gcp_conn_id,
+                location=self.location,
+                impersonation_chain=self.impersonation_chain,
+            ),
+            method_name="execute_completed",
+        )
+
+    def _get_hook(self) -> BiqQueryDataTransferServiceHook:
+        if self._hook is None:
+            self._hook = BiqQueryDataTransferServiceHook(
+                gcp_conn_id=self.gcp_conn_id,
+                impersonation_chain=self.impersonation_chain,
+                location=self.location,
+            )
+        return self._hook
+
+    def _wait_for_transfer_to_be_done(self, run_id: str, transfer_config_id: 
str, interval: int = 10):
+        if interval < 0:
+            raise ValueError("Interval must be > 0")
+
+        while True:
+            transfer_run: TransferRun = self._get_hook().get_transfer_run(
+                run_id=run_id,
+                transfer_config_id=transfer_config_id,
+                project_id=self.project_id,
+                retry=self.retry,
+                timeout=self.timeout,
+                metadata=self.metadata,
+            )
+            state = transfer_run.state
+
+            if self._job_is_done(state):
+                if state == TransferState.FAILED or state == 
TransferState.CANCELLED:
+                    raise AirflowException(f"Transfer run was finished with 
{state} status.")
+
+                result = TransferRun.to_dict(transfer_run)
+                return result
+
+            self.log.info("Transfer run is still working, waiting for %s 
seconds...", interval)
+            self.log.info("Transfer run status: %s", state)
+            time.sleep(interval)
+
+    @staticmethod
+    def _job_is_done(state: TransferState) -> bool:
+        finished_job_statuses = [
+            state.SUCCEEDED,
+            state.CANCELLED,
+            state.FAILED,
+        ]
+
+        return state in finished_job_statuses
+
+    def execute_completed(self, context: Context, event: dict):
+        """Method to be executed after invoked trigger in defer method 
finishes its job."""
+        if event["status"] == "failed" or event["status"] == "cancelled":
+            self.log.error("Trigger finished its work with status: %s.", 
event["status"])
+            raise AirflowException(event["message"])
+
+        self.log.info(
+            "%s finished with message: %s",
+            event["run_id"],
+            event["message"],
+        )
+
+        return event["run_id"]

Review Comment:
   Yeah you are right. I will rework that logic



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to