This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ce5c00fbed5d7179c15059916758217f3c9d8f51
Author: Kamil BreguĊ‚a <[email protected]>
AuthorDate: Thu Dec 31 18:07:32 2020 +0100

    Support google-cloud-bigquery-datatransfer>=3.0.0 (#13337)
    
    (cherry picked from commit 9de71270838ad3cc59043f1ab0bb6ca97af13622)
---
 airflow/providers/google/ADDITIONAL_INFO.md        |  1 +
 .../cloud/example_dags/example_bigquery_dts.py     | 20 ++++------
 .../providers/google/cloud/hooks/bigquery_dts.py   | 45 ++++++++++++++--------
 .../google/cloud/operators/bigquery_dts.py         | 12 +++---
 .../providers/google/cloud/sensors/bigquery_dts.py | 35 ++++++++++++-----
 setup.py                                           |  2 +-
 .../google/cloud/hooks/test_bigquery_dts.py        | 39 ++++++++-----------
 .../google/cloud/operators/test_bigquery_dts.py    | 37 +++++++++++++-----
 .../google/cloud/sensors/test_bigquery_dts.py      | 39 ++++++++++++++++---
 9 files changed, 142 insertions(+), 88 deletions(-)

diff --git a/airflow/providers/google/ADDITIONAL_INFO.md 
b/airflow/providers/google/ADDITIONAL_INFO.md
index b54b240..eca05df 100644
--- a/airflow/providers/google/ADDITIONAL_INFO.md
+++ b/airflow/providers/google/ADDITIONAL_INFO.md
@@ -29,6 +29,7 @@ Details are covered in the UPDATING.md files for each 
library, but there are som
 
 | Library name | Previous constraints | Current constraints | |
 | --- | --- | --- | --- |
+| 
[``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/)
 | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0``  | 
[`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md)
 |
 | 
[``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/)
 | ``>=0.5.0,<0.8`` | ``>=1.0.0,<2.0.0``  | 
[`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md)
 |
 | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) 
| ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | 
[`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md)
 |
 | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | 
``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0``  | 
[`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md)
 |
diff --git 
a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py 
b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py
index 260dc5d..da13c9d 100644
--- a/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py
+++ b/airflow/providers/google/cloud/example_dags/example_bigquery_dts.py
@@ -22,9 +22,6 @@ Example Airflow DAG that creates and deletes Bigquery data 
transfer configuratio
 import os
 import time
 
-from google.cloud.bigquery_datatransfer_v1.types import TransferConfig
-from google.protobuf.json_format import ParseDict
-
 from airflow import models
 from airflow.providers.google.cloud.operators.bigquery_dts import (
     BigQueryCreateDataTransferOperator,
@@ -55,16 +52,13 @@ PARAMS = {
     "file_format": "CSV",
 }
 
-TRANSFER_CONFIG = ParseDict(
-    {
-        "destination_dataset_id": GCP_DTS_BQ_DATASET,
-        "display_name": "GCS Test Config",
-        "data_source_id": "google_cloud_storage",
-        "schedule_options": schedule_options,
-        "params": PARAMS,
-    },
-    TransferConfig(),
-)
+TRANSFER_CONFIG = {
+    "destination_dataset_id": GCP_DTS_BQ_DATASET,
+    "display_name": "GCS Test Config",
+    "data_source_id": "google_cloud_storage",
+    "schedule_options": schedule_options,
+    "params": PARAMS,
+}
 
 # [END howto_bigquery_dts_create_args]
 
diff --git a/airflow/providers/google/cloud/hooks/bigquery_dts.py 
b/airflow/providers/google/cloud/hooks/bigquery_dts.py
index 2d8d12b..37d42ef 100644
--- a/airflow/providers/google/cloud/hooks/bigquery_dts.py
+++ b/airflow/providers/google/cloud/hooks/bigquery_dts.py
@@ -27,7 +27,6 @@ from google.cloud.bigquery_datatransfer_v1.types import (
     TransferConfig,
     TransferRun,
 )
-from google.protobuf.json_format import MessageToDict, ParseDict
 from googleapiclient.discovery import Resource
 
 from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
@@ -71,7 +70,7 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :param config: Data transfer configuration to create.
         :type config: Union[dict, 
google.cloud.bigquery_datatransfer_v1.types.TransferConfig]
         """
-        config = MessageToDict(config) if isinstance(config, TransferConfig) 
else config
+        config = TransferConfig.to_dict(config) if isinstance(config, 
TransferConfig) else config
         new_config = copy(config)
         schedule_options = new_config.get("schedule_options")
         if schedule_options:
@@ -80,7 +79,11 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
                 schedule_options["disable_auto_scheduling"] = True
         else:
             new_config["schedule_options"] = {"disable_auto_scheduling": True}
-        return ParseDict(new_config, TransferConfig())
+        # HACK: TransferConfig.to_dict returns invalid representation
+        # See: 
https://github.com/googleapis/python-bigquery-datatransfer/issues/90
+        if isinstance(new_config.get('user_id'), str):
+            new_config['user_id'] = int(new_config['user_id'])
+        return TransferConfig(**new_config)
 
     def get_conn(self) -> DataTransferServiceClient:
         """
@@ -129,14 +132,16 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: A 
``google.cloud.bigquery_datatransfer_v1.types.TransferConfig`` instance.
         """
         client = self.get_conn()
-        parent = client.project_path(project_id)
+        parent = f"projects/{project_id}"
         return client.create_transfer_config(
-            parent=parent,
-            transfer_config=self._disable_auto_scheduling(transfer_config),
-            authorization_code=authorization_code,
+            request={
+                'parent': parent,
+                'transfer_config': 
self._disable_auto_scheduling(transfer_config),
+                'authorization_code': authorization_code,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -169,8 +174,10 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: None
         """
         client = self.get_conn()
-        name = client.project_transfer_config_path(project=project_id, 
transfer_config=transfer_config_id)
-        return client.delete_transfer_config(name=name, retry=retry, 
timeout=timeout, metadata=metadata)
+        name = f"projects/{project_id}/transferConfigs/{transfer_config_id}"
+        return client.delete_transfer_config(
+            request={'name': name}, retry=retry, timeout=timeout, 
metadata=metadata or ()
+        )
 
     @GoogleBaseHook.fallback_to_default_project_id
     def start_manual_transfer_runs(
@@ -216,14 +223,16 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: An 
``google.cloud.bigquery_datatransfer_v1.types.StartManualTransferRunsResponse`` 
instance.
         """
         client = self.get_conn()
-        parent = client.project_transfer_config_path(project=project_id, 
transfer_config=transfer_config_id)
+        parent = f"projects/{project_id}/transferConfigs/{transfer_config_id}"
         return client.start_manual_transfer_runs(
-            parent=parent,
-            requested_time_range=requested_time_range,
-            requested_run_time=requested_run_time,
+            request={
+                'parent': parent,
+                'requested_time_range': requested_time_range,
+                'requested_run_time': requested_run_time,
+            },
             retry=retry,
             timeout=timeout,
-            metadata=metadata,
+            metadata=metadata or (),
         )
 
     @GoogleBaseHook.fallback_to_default_project_id
@@ -259,5 +268,7 @@ class BiqQueryDataTransferServiceHook(GoogleBaseHook):
         :return: An 
``google.cloud.bigquery_datatransfer_v1.types.TransferRun`` instance.
         """
         client = self.get_conn()
-        name = client.project_run_path(project=project_id, 
transfer_config=transfer_config_id, run=run_id)
-        return client.get_transfer_run(name=name, retry=retry, 
timeout=timeout, metadata=metadata)
+        name = 
f"projects/{project_id}/transferConfigs/{transfer_config_id}/runs/{run_id}"
+        return client.get_transfer_run(
+            request={'name': name}, retry=retry, timeout=timeout, 
metadata=metadata or ()
+        )
diff --git a/airflow/providers/google/cloud/operators/bigquery_dts.py 
b/airflow/providers/google/cloud/operators/bigquery_dts.py
index e941bd4..656fc77 100644
--- a/airflow/providers/google/cloud/operators/bigquery_dts.py
+++ b/airflow/providers/google/cloud/operators/bigquery_dts.py
@@ -19,7 +19,7 @@
 from typing import Optional, Sequence, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.protobuf.json_format import MessageToDict
+from google.cloud.bigquery_datatransfer_v1 import 
StartManualTransferRunsResponse, TransferConfig
 
 from airflow.models import BaseOperator
 from airflow.providers.google.cloud.hooks.bigquery_dts import 
BiqQueryDataTransferServiceHook, get_object_id
@@ -110,7 +110,7 @@ class BigQueryCreateDataTransferOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(response)
+        result = TransferConfig.to_dict(response)
         self.log.info("Created DTS transfer config %s", get_object_id(result))
         self.xcom_push(context, key="transfer_config_id", 
value=get_object_id(result))
         return result
@@ -289,10 +289,8 @@ class 
BigQueryDataTransferServiceStartTransferRunsOperator(BaseOperator):
             timeout=self.timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(response)
-        run_id = None
-        if 'runs' in result:
-            run_id = get_object_id(result['runs'][0])
-            self.xcom_push(context, key="run_id", value=run_id)
+        result = StartManualTransferRunsResponse.to_dict(response)
+        run_id = get_object_id(result['runs'][0])
+        self.xcom_push(context, key="run_id", value=run_id)
         self.log.info('Transfer run %s submitted successfully.', run_id)
         return result
diff --git a/airflow/providers/google/cloud/sensors/bigquery_dts.py 
b/airflow/providers/google/cloud/sensors/bigquery_dts.py
index 5b851ed..49e124c 100644
--- a/airflow/providers/google/cloud/sensors/bigquery_dts.py
+++ b/airflow/providers/google/cloud/sensors/bigquery_dts.py
@@ -19,7 +19,7 @@
 from typing import Optional, Sequence, Set, Tuple, Union
 
 from google.api_core.retry import Retry
-from google.protobuf.json_format import MessageToDict
+from google.cloud.bigquery_datatransfer_v1 import TransferState
 
 from airflow.providers.google.cloud.hooks.bigquery_dts import 
BiqQueryDataTransferServiceHook
 from airflow.sensors.base import BaseSensorOperator
@@ -81,7 +81,9 @@ class 
BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
         *,
         run_id: str,
         transfer_config_id: str,
-        expected_statuses: Union[Set[str], str] = 'SUCCEEDED',
+        expected_statuses: Union[
+            Set[Union[str, TransferState, int]], str, TransferState, int
+        ] = TransferState.SUCCEEDED,
         project_id: Optional[str] = None,
         gcp_conn_id: str = "google_cloud_default",
         retry: Optional[Retry] = None,
@@ -96,13 +98,29 @@ class 
BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
         self.retry = retry
         self.request_timeout = request_timeout
         self.metadata = metadata
-        self.expected_statuses = (
-            {expected_statuses} if isinstance(expected_statuses, str) else 
expected_statuses
-        )
+        self.expected_statuses = self._normalize_state_list(expected_statuses)
         self.project_id = project_id
         self.gcp_cloud_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
 
+    def _normalize_state_list(self, states) -> Set[TransferState]:
+        states = {states} if isinstance(states, (str, TransferState, int)) 
else states
+        result = set()
+        for state in states:
+            if isinstance(state, str):
+                result.add(TransferState[state.upper()])
+            elif isinstance(state, int):
+                result.add(TransferState(state))
+            elif isinstance(state, TransferState):
+                result.add(state)
+            else:
+                raise TypeError(
+                    f"Unsupported type. "
+                    f"Expected: str, int, 
google.cloud.bigquery_datatransfer_v1.TransferState."
+                    f"Current type: {type(state)}"
+                )
+        return result
+
     def poke(self, context: dict) -> bool:
         hook = BiqQueryDataTransferServiceHook(
             gcp_conn_id=self.gcp_cloud_conn_id,
@@ -116,8 +134,5 @@ class 
BigQueryDataTransferServiceTransferRunSensor(BaseSensorOperator):
             timeout=self.request_timeout,
             metadata=self.metadata,
         )
-        result = MessageToDict(run)
-        state = result["state"]
-        self.log.info("Status of %s run: %s", self.run_id, state)
-
-        return state in self.expected_statuses
+        self.log.info("Status of %s run: %s", self.run_id, str(run.state))
+        return run.state in self.expected_statuses
diff --git a/setup.py b/setup.py
index 3df9e47..628ecd1 100644
--- a/setup.py
+++ b/setup.py
@@ -284,7 +284,7 @@ google = [
     'google-auth>=1.0.0,<2.0.0',
     'google-auth-httplib2>=0.0.1',
     'google-cloud-automl>=0.4.0,<2.0.0',
-    'google-cloud-bigquery-datatransfer>=0.4.0,<2.0.0',
+    'google-cloud-bigquery-datatransfer>=3.0.0,<4.0.0',
     'google-cloud-bigtable>=1.0.0,<2.0.0',
     'google-cloud-container>=0.1.1,<2.0.0',
     'google-cloud-datacatalog>=1.0.0,<2.0.0',
diff --git a/tests/providers/google/cloud/hooks/test_bigquery_dts.py 
b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
index 64ad79c..b53cb76 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery_dts.py
@@ -20,9 +20,7 @@ import unittest
 from copy import deepcopy
 from unittest import mock
 
-from google.cloud.bigquery_datatransfer_v1 import DataTransferServiceClient
 from google.cloud.bigquery_datatransfer_v1.types import TransferConfig
-from google.protobuf.json_format import ParseDict
 
 from airflow.providers.google.cloud.hooks.bigquery_dts import 
BiqQueryDataTransferServiceHook
 from airflow.version import version
@@ -33,21 +31,18 @@ PROJECT_ID = "id"
 
 PARAMS = {
     "field_delimiter": ",",
-    "max_bad_records": "0",
-    "skip_leading_rows": "1",
+    "max_bad_records": 0,
+    "skip_leading_rows": 1,
     "data_path_template": "bucket",
     "destination_table_name_template": "name",
     "file_format": "CSV",
 }
 
-TRANSFER_CONFIG = ParseDict(
-    {
-        "destination_dataset_id": "dataset",
-        "display_name": "GCS Test Config",
-        "data_source_id": "google_cloud_storage",
-        "params": PARAMS,
-    },
-    TransferConfig(),
+TRANSFER_CONFIG = TransferConfig(
+    destination_dataset_id="dataset",
+    display_name="GCS Test Config",
+    data_source_id="google_cloud_storage",
+    params=PARAMS,
 )
 
 TRANSFER_CONFIG_ID = "id1234"
@@ -77,14 +72,12 @@ class BigQueryDataTransferHookTestCase(unittest.TestCase):
     )
     def test_create_transfer_config(self, service_mock):
         self.hook.create_transfer_config(transfer_config=TRANSFER_CONFIG, 
project_id=PROJECT_ID)
-        parent = DataTransferServiceClient.project_path(PROJECT_ID)
+        parent = f"projects/{PROJECT_ID}"
         expected_config = deepcopy(TRANSFER_CONFIG)
         expected_config.schedule_options.disable_auto_scheduling = True
         service_mock.assert_called_once_with(
-            parent=parent,
-            transfer_config=expected_config,
-            authorization_code=None,
-            metadata=None,
+            request=dict(parent=parent, transfer_config=expected_config, 
authorization_code=None),
+            metadata=(),
             retry=None,
             timeout=None,
         )
@@ -96,8 +89,8 @@ class BigQueryDataTransferHookTestCase(unittest.TestCase):
     def test_delete_transfer_config(self, service_mock):
         
self.hook.delete_transfer_config(transfer_config_id=TRANSFER_CONFIG_ID, 
project_id=PROJECT_ID)
 
-        name = 
DataTransferServiceClient.project_transfer_config_path(PROJECT_ID, 
TRANSFER_CONFIG_ID)
-        service_mock.assert_called_once_with(name=name, metadata=None, 
retry=None, timeout=None)
+        name = f"projects/{PROJECT_ID}/transferConfigs/{TRANSFER_CONFIG_ID}"
+        service_mock.assert_called_once_with(request=dict(name=name), 
metadata=(), retry=None, timeout=None)
 
     @mock.patch(
         "airflow.providers.google.cloud.hooks.bigquery_dts."
@@ -106,12 +99,10 @@ class BigQueryDataTransferHookTestCase(unittest.TestCase):
     def test_start_manual_transfer_runs(self, service_mock):
         
self.hook.start_manual_transfer_runs(transfer_config_id=TRANSFER_CONFIG_ID, 
project_id=PROJECT_ID)
 
-        parent = 
DataTransferServiceClient.project_transfer_config_path(PROJECT_ID, 
TRANSFER_CONFIG_ID)
+        parent = f"projects/{PROJECT_ID}/transferConfigs/{TRANSFER_CONFIG_ID}"
         service_mock.assert_called_once_with(
-            parent=parent,
-            requested_time_range=None,
-            requested_run_time=None,
-            metadata=None,
+            request=dict(parent=parent, requested_time_range=None, 
requested_run_time=None),
+            metadata=(),
             retry=None,
             timeout=None,
         )
diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py 
b/tests/providers/google/cloud/operators/test_bigquery_dts.py
index 4d42352..d6071fa 100644
--- a/tests/providers/google/cloud/operators/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py
@@ -18,6 +18,8 @@
 import unittest
 from unittest import mock
 
+from google.cloud.bigquery_datatransfer_v1 import 
StartManualTransferRunsResponse, TransferConfig, TransferRun
+
 from airflow.providers.google.cloud.operators.bigquery_dts import (
     BigQueryCreateDataTransferOperator,
     BigQueryDataTransferServiceStartTransferRunsOperator,
@@ -39,20 +41,23 @@ TRANSFER_CONFIG = {
 
 TRANSFER_CONFIG_ID = "id1234"
 
-NAME = "projects/123abc/locations/321cba/transferConfig/1a2b3c"
+TRANSFER_CONFIG_NAME = "projects/123abc/locations/321cba/transferConfig/1a2b3c"
+RUN_NAME = "projects/123abc/locations/321cba/transferConfig/1a2b3c/runs/123"
 
 
 class BigQueryCreateDataTransferOperatorTestCase(unittest.TestCase):
-    
@mock.patch("airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook")
-    
@mock.patch("airflow.providers.google.cloud.operators.bigquery_dts.get_object_id")
-    def test_execute(self, mock_name, mock_hook):
-        mock_name.return_value = TRANSFER_CONFIG_ID
-        mock_xcom = mock.MagicMock()
+    @mock.patch(
+        
"airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{'return_value.create_transfer_config.return_value': 
TransferConfig(name=TRANSFER_CONFIG_NAME)},
+    )
+    def test_execute(self, mock_hook):
         op = BigQueryCreateDataTransferOperator(
             transfer_config=TRANSFER_CONFIG, project_id=PROJECT_ID, 
task_id="id"
         )
-        op.xcom_push = mock_xcom
-        op.execute(None)
+        ti = mock.MagicMock()
+
+        op.execute({'ti': ti})
+
         mock_hook.return_value.create_transfer_config.assert_called_once_with(
             authorization_code=None,
             metadata=None,
@@ -61,6 +66,7 @@ class 
BigQueryCreateDataTransferOperatorTestCase(unittest.TestCase):
             retry=None,
             timeout=None,
         )
+        ti.xcom_push.assert_called_once_with(execution_date=None, 
key='transfer_config_id', value='1a2b3c')
 
 
 class BigQueryDeleteDataTransferConfigOperatorTestCase(unittest.TestCase):
@@ -80,12 +86,22 @@ class 
BigQueryDeleteDataTransferConfigOperatorTestCase(unittest.TestCase):
 
 
 class 
BigQueryDataTransferServiceStartTransferRunsOperatorTestCase(unittest.TestCase):
-    
@mock.patch("airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook")
+    @mock.patch(
+        
"airflow.providers.google.cloud.operators.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{
+            'return_value.start_manual_transfer_runs.return_value': 
StartManualTransferRunsResponse(
+                runs=[TransferRun(name=RUN_NAME)]
+            )
+        },
+    )
     def test_execute(self, mock_hook):
         op = BigQueryDataTransferServiceStartTransferRunsOperator(
             transfer_config_id=TRANSFER_CONFIG_ID, task_id="id", 
project_id=PROJECT_ID
         )
-        op.execute(None)
+        ti = mock.MagicMock()
+
+        op.execute({'ti': ti})
+
         
mock_hook.return_value.start_manual_transfer_runs.assert_called_once_with(
             transfer_config_id=TRANSFER_CONFIG_ID,
             project_id=PROJECT_ID,
@@ -95,3 +111,4 @@ class 
BigQueryDataTransferServiceStartTransferRunsOperatorTestCase(unittest.Test
             retry=None,
             timeout=None,
         )
+        ti.xcom_push.assert_called_once_with(execution_date=None, 
key='run_id', value='123')
diff --git a/tests/providers/google/cloud/sensors/test_bigquery_dts.py 
b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
index 92a116e..c8a0548 100644
--- a/tests/providers/google/cloud/sensors/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/sensors/test_bigquery_dts.py
@@ -19,6 +19,8 @@
 import unittest
 from unittest import mock
 
+from google.cloud.bigquery_datatransfer_v1 import TransferState
+
 from airflow.providers.google.cloud.sensors.bigquery_dts import 
BigQueryDataTransferServiceTransferRunSensor
 
 TRANSFER_CONFIG_ID = "config_id"
@@ -27,20 +29,45 @@ PROJECT_ID = "project_id"
 
 
 class TestBigQueryDataTransferServiceTransferRunSensor(unittest.TestCase):
-    
@mock.patch("airflow.providers.google.cloud.sensors.bigquery_dts.BiqQueryDataTransferServiceHook")
     @mock.patch(
-        "airflow.providers.google.cloud.sensors.bigquery_dts.MessageToDict",
-        return_value={"state": "success"},
+        
"airflow.providers.google.cloud.sensors.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{'return_value.get_transfer_run.return_value.state': 
TransferState.FAILED},
+    )
+    def test_poke_returns_false(self, mock_hook):
+        op = BigQueryDataTransferServiceTransferRunSensor(
+            transfer_config_id=TRANSFER_CONFIG_ID,
+            run_id=RUN_ID,
+            task_id="id",
+            project_id=PROJECT_ID,
+            expected_statuses={"SUCCEEDED"},
+        )
+        result = op.poke({})
+
+        self.assertEqual(result, False)
+        mock_hook.return_value.get_transfer_run.assert_called_once_with(
+            transfer_config_id=TRANSFER_CONFIG_ID,
+            run_id=RUN_ID,
+            project_id=PROJECT_ID,
+            metadata=None,
+            retry=None,
+            timeout=None,
+        )
+
+    @mock.patch(
+        
"airflow.providers.google.cloud.sensors.bigquery_dts.BiqQueryDataTransferServiceHook",
+        **{'return_value.get_transfer_run.return_value.state': 
TransferState.SUCCEEDED},
     )
-    def test_poke(self, mock_msg_to_dict, mock_hook):
+    def test_poke_returns_true(self, mock_hook):
         op = BigQueryDataTransferServiceTransferRunSensor(
             transfer_config_id=TRANSFER_CONFIG_ID,
             run_id=RUN_ID,
             task_id="id",
             project_id=PROJECT_ID,
-            expected_statuses={"success"},
+            expected_statuses={"SUCCEEDED"},
         )
-        op.poke(None)
+        result = op.poke({})
+
+        self.assertEqual(result, True)
         mock_hook.return_value.get_transfer_run.assert_called_once_with(
             transfer_config_id=TRANSFER_CONFIG_ID,
             run_id=RUN_ID,

Reply via email to