This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3806a63bfe Add deferrable functionality to the AirbyteJobSensor and 
AirbyteTriggerSyncOperator (#36780)
3806a63bfe is described below

commit 3806a63bfe0c07d120a7181d751033c850f54997
Author: Chris Hronek <[email protected]>
AuthorDate: Wed Jan 17 12:29:57 2024 -0700

    Add deferrable functionality to the AirbyteJobSensor and 
AirbyteTriggerSyncOperator (#36780)
    
    * Add primary components
    
    * Create a drop in sensor replacement for the sensor
    
    * Add test for the async sensor
    
    * Add serialization test for airbyte trigger
    
    * Add test for is_still_running on airbyte trigger
    
    * Add test for terminal success status
    
    * Add test for cancelled status
    
    * Add additional trigger tests
    
    * Add tests for is_still_running
    
    * Add example to airbyte provider airflow docs
    
    * Update airflow docs to include async sensor
    
    * Resolve comments
    
    * Combine sensors and add deferrable parameter
    
    * Add class doc for deferrable parameter
    
    * Add deferrable parameter to the operator itself
    
    * Remove deprecated line from system tests
    
    * Add trigger to the provider yaml
    
    * Fix ci tests
---
 airflow/providers/airbyte/hooks/airbyte.py         |  58 ++++-
 airflow/providers/airbyte/operators/airbyte.py     |  56 ++++-
 airflow/providers/airbyte/provider.yaml            |   5 +
 airflow/providers/airbyte/sensors/airbyte.py       |  74 +++++-
 airflow/providers/airbyte/triggers/__init__.py     |  16 ++
 airflow/providers/airbyte/triggers/airbyte.py      | 117 ++++++++++
 .../operators/airbyte.rst                          |   7 +-
 tests/providers/airbyte/operators/test_airbyte.py  |   2 +-
 tests/providers/airbyte/triggers/__init__.py       |  16 ++
 tests/providers/airbyte/triggers/test_airbyte.py   | 253 +++++++++++++++++++++
 10 files changed, 591 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/airbyte/hooks/airbyte.py 
b/airflow/providers/airbyte/hooks/airbyte.py
index a9ce336022..b8ad957a9c 100644
--- a/airflow/providers/airbyte/hooks/airbyte.py
+++ b/airflow/providers/airbyte/hooks/airbyte.py
@@ -17,12 +17,23 @@
 # under the License.
 from __future__ import annotations
 
+import base64
+import json
 import time
-from typing import Any
+from typing import TYPE_CHECKING, Any, TypeVar
+
+import aiohttp
+from aiohttp import ClientResponseError
+from asgiref.sync import sync_to_async
 
 from airflow.exceptions import AirflowException
 from airflow.providers.http.hooks.http import HttpHook
 
+if TYPE_CHECKING:
+    from airflow.models import Connection
+
+T = TypeVar("T", bound=Any)
+
 
 class AirbyteHook(HttpHook):
     """
@@ -50,6 +61,51 @@ class AirbyteHook(HttpHook):
         super().__init__(http_conn_id=airbyte_conn_id)
         self.api_version: str = api_version
 
+    async def get_headers_tenants_from_connection(self) -> tuple[dict[str, 
Any], str]:
+        """Get Headers, tenants from the connection details."""
+        connection: Connection = await 
sync_to_async(self.get_connection)(self.http_conn_id)
+        base_url = connection.host
+
+        credentials = f"{connection.login}:{connection.password}"
+        credentials_base64 = 
base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
+
+        authorized_headers = {
+            "accept": "application/json",
+            "content-type": "application/json",
+            "authorization": f"Basic {credentials_base64}",
+        }
+
+        return authorized_headers, base_url
+
+    async def get_job_details(self, job_id: int) -> Any:
+        """
+        Uses Http async call to retrieve metadata for a specific job of an 
Airbyte Sync.
+
+        :param job_id: The ID of an Airbyte Sync Job.
+        """
+        headers, base_url = await self.get_headers_tenants_from_connection()
+        url = f"{base_url}/api/{self.api_version}/jobs/get"
+        self.log.info("URL for api request: %s", url)
+        async with aiohttp.ClientSession(headers=headers) as session:
+            async with session.post(url=url, data=json.dumps({"id": job_id})) 
as response:
+                try:
+                    response.raise_for_status()
+                    return await response.json()
+                except ClientResponseError as e:
+                    msg = f"{e.status}: {e.message} - {e.request_info}"
+                    raise AirflowException(msg)
+
+    async def get_job_status(self, job_id: int) -> str:
+        """
+        Retrieves the status for a specific job of an Airbyte Sync.
+
+        :param job_id: The ID of an Airbyte Sync Job.
+        """
+        self.log.info("Getting the status of job run %s.", job_id)
+        response = await self.get_job_details(job_id=job_id)
+        job_run_status: str = response["job"]["status"]
+        return job_run_status
+
     def wait_for_job(self, job_id: str | int, wait_seconds: float = 3, 
timeout: float | None = 3600) -> None:
         """
         Poll a job to check if it finishes.
diff --git a/airflow/providers/airbyte/operators/airbyte.py 
b/airflow/providers/airbyte/operators/airbyte.py
index 6d101662db..84a12dadfa 100644
--- a/airflow/providers/airbyte/operators/airbyte.py
+++ b/airflow/providers/airbyte/operators/airbyte.py
@@ -17,10 +17,14 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+import time
+from typing import TYPE_CHECKING, Any, Sequence
 
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
@@ -40,6 +44,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
     :param asynchronous: Optional. Flag to get job_id after submitting the job 
to the Airbyte API.
         This is useful for submitting long running jobs and
         waiting on them asynchronously using the AirbyteJobSensor. Defaults to 
False.
+    :param deferrable: Run operator in the deferrable mode.
     :param api_version: Optional. Airbyte API version. Defaults to "v1".
     :param wait_seconds: Optional. Number of seconds between checks. Only used 
when ``asynchronous`` is False.
         Defaults to 3 seconds.
@@ -48,12 +53,14 @@ class AirbyteTriggerSyncOperator(BaseOperator):
     """
 
     template_fields: Sequence[str] = ("connection_id",)
+    ui_color = "#6C51FD"
 
     def __init__(
         self,
         connection_id: str,
         airbyte_conn_id: str = "airbyte_default",
         asynchronous: bool = False,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         api_version: str = "v1",
         wait_seconds: float = 3,
         timeout: float = 3600,
@@ -66,23 +73,62 @@ class AirbyteTriggerSyncOperator(BaseOperator):
         self.api_version = api_version
         self.wait_seconds = wait_seconds
         self.asynchronous = asynchronous
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> None:
         """Create Airbyte Job and wait to finish."""
-        self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version)
-        job_object = 
self.hook.submit_sync_connection(connection_id=self.connection_id)
+        hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version)
+        job_object = 
hook.submit_sync_connection(connection_id=self.connection_id)
         self.job_id = job_object.json()["job"]["id"]
+        state = job_object.json()["job"]["status"]
+        end_time = time.time() + self.timeout
 
         self.log.info("Job %s was submitted to Airbyte Server", self.job_id)
         if not self.asynchronous:
             self.log.info("Waiting for job %s to complete", self.job_id)
-            self.hook.wait_for_job(job_id=self.job_id, 
wait_seconds=self.wait_seconds, timeout=self.timeout)
+            if self.deferrable:
+                if state in (hook.RUNNING, hook.PENDING, hook.INCOMPLETE):
+                    self.defer(
+                        timeout=self.execution_timeout,
+                        trigger=AirbyteSyncTrigger(
+                            conn_id=self.airbyte_conn_id,
+                            job_id=self.job_id,
+                            end_time=end_time,
+                            poll_interval=60,
+                        ),
+                        method_name="execute_complete",
+                    )
+                elif state == hook.SUCCEEDED:
+                    self.log.info("Job %s completed successfully", self.job_id)
+                    return
+                elif state == hook.ERROR:
+                    raise AirflowException(f"Job failed:\n{self.job_id}")
+                elif state == hook.CANCELLED:
+                    raise AirflowException(f"Job was 
cancelled:\n{self.job_id}")
+                else:
+                    raise Exception(f"Encountered unexpected state `{state}` 
for job_id `{self.job_id}")
+            else:
+                hook.wait_for_job(job_id=self.job_id, 
wait_seconds=self.wait_seconds, timeout=self.timeout)
             self.log.info("Job %s completed successfully", self.job_id)
 
         return self.job_id
 
+    def execute_complete(self, context: Context, event: Any = None) -> None:
+        """
+        Callback for when the trigger fires - returns immediately.
+
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event["status"] == "error":
+            raise AirflowException(event["message"])
+
+        self.log.info("%s completed successfully.", self.task_id)
+        return None
+
     def on_kill(self):
         """Cancel the job if task is cancelled."""
+        hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id)
         if self.job_id:
             self.log.info("on_kill: cancel the airbyte Job %s", self.job_id)
-            self.hook.cancel_job(self.job_id)
+            hook.cancel_job(self.job_id)
diff --git a/airflow/providers/airbyte/provider.yaml 
b/airflow/providers/airbyte/provider.yaml
index c973844dd7..4f163eedfc 100644
--- a/airflow/providers/airbyte/provider.yaml
+++ b/airflow/providers/airbyte/provider.yaml
@@ -69,6 +69,11 @@ sensors:
     python-modules:
       - airflow.providers.airbyte.sensors.airbyte
 
+triggers:
+  - integration-name: Airbyte
+    python-modules:
+      - airflow.providers.airbyte.triggers.airbyte
+
 connection-types:
   - hook-class-name: airflow.providers.airbyte.hooks.airbyte.AirbyteHook
     connection-type: airbyte
diff --git a/airflow/providers/airbyte/sensors/airbyte.py 
b/airflow/providers/airbyte/sensors/airbyte.py
index f38206246e..4556d55430 100644
--- a/airflow/providers/airbyte/sensors/airbyte.py
+++ b/airflow/providers/airbyte/sensors/airbyte.py
@@ -18,10 +18,14 @@
 """This module contains a Airbyte Job sensor."""
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Sequence
+import time
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
 
-from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -34,6 +38,7 @@ class AirbyteJobSensor(BaseSensorOperator):
 
     :param airbyte_job_id: Required. Id of the Airbyte job
     :param airbyte_conn_id: Optional. The name of the Airflow connection to get
+    :param deferrable: Run sensor in the deferrable mode.
         connection information for Airbyte. Defaults to "airbyte_default".
     :param api_version: Optional. Airbyte API version. Defaults to "v1".
     """
@@ -45,11 +50,30 @@ class AirbyteJobSensor(BaseSensorOperator):
         self,
         *,
         airbyte_job_id: int,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         airbyte_conn_id: str = "airbyte_default",
         api_version: str = "v1",
         **kwargs,
     ) -> None:
+        if deferrable:
+            if "poke_interval" not in kwargs:
+                # TODO: Remove once deprecated
+                if "polling_interval" in kwargs:
+                    kwargs["poke_interval"] = kwargs["polling_interval"]
+                    warnings.warn(
+                        "Argument `poll_interval` is deprecated and will be 
removed "
+                        "in a future release.  Please use `poke_interval` 
instead.",
+                        AirflowProviderDeprecationWarning,
+                        stacklevel=2,
+                    )
+                else:
+                    kwargs["poke_interval"] = 5
+
+                if "timeout" not in kwargs:
+                    kwargs["timeout"] = 60 * 60 * 24 * 7
+
         super().__init__(**kwargs)
+        self.deferrable = deferrable
         self.airbyte_conn_id = airbyte_conn_id
         self.airbyte_job_id = airbyte_job_id
         self.api_version = api_version
@@ -79,3 +103,49 @@ class AirbyteJobSensor(BaseSensorOperator):
 
         self.log.info("Waiting for job %s to complete.", self.airbyte_job_id)
         return False
+
+    def execute(self, context: Context) -> Any:
+        """Submits a job which generates a run_id and gets deferred."""
+        if not self.deferrable:
+            super().execute(context)
+        else:
+            hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id)
+            job = hook.get_job(job_id=(int(self.airbyte_job_id)))
+            state = job.json()["job"]["status"]
+            end_time = time.time() + self.timeout
+
+            self.log.info("Airbyte Job Id: Job %s", self.airbyte_job_id)
+
+            if state in (hook.RUNNING, hook.PENDING, hook.INCOMPLETE):
+                self.defer(
+                    timeout=self.execution_timeout,
+                    trigger=AirbyteSyncTrigger(
+                        conn_id=self.airbyte_conn_id,
+                        job_id=self.airbyte_job_id,
+                        end_time=end_time,
+                        poll_interval=60,
+                    ),
+                    method_name="execute_complete",
+                )
+            elif state == hook.SUCCEEDED:
+                self.log.info("%s completed successfully.", self.task_id)
+                return
+            elif state == hook.ERROR:
+                raise AirflowException(f"Job failed:\n{job}")
+            elif state == hook.CANCELLED:
+                raise AirflowException(f"Job was cancelled:\n{job}")
+            else:
+                raise Exception(f"Encountered unexpected state `{state}` for 
job_id `{self.airbyte_job_id}")
+
+    def execute_complete(self, context: Context, event: Any = None) -> None:
+        """
+        Callback for when the trigger fires - returns immediately.
+
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event["status"] == "error":
+            raise AirflowException(event["message"])
+
+        self.log.info("%s completed successfully.", self.task_id)
+        return None
diff --git a/airflow/providers/airbyte/triggers/__init__.py 
b/airflow/providers/airbyte/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/airbyte/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/airbyte/triggers/airbyte.py 
b/airflow/providers/airbyte/triggers/airbyte.py
new file mode 100644
index 0000000000..06c926d681
--- /dev/null
+++ b/airflow/providers/airbyte/triggers/airbyte.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import time
+from typing import Any, AsyncIterator
+
+from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AirbyteSyncTrigger(BaseTrigger):
+    """
+    Triggers Airbyte Sync, makes an asynchronous HTTP call to get the status 
via a job ID.
+
+    This trigger is designed to initiate and monitor the status of Airbyte 
Sync jobs. It
+    makes use of asynchronous communication to check the progress of a job run 
over time.
+
+    :param conn_id: The connection identifier for connecting to Airbyte.
+    :param job_id: The ID of an Airbyte Sync job.
+    :param end_time: Time in seconds to wait for a job run to reach a terminal 
status. Defaults to 7 days.
+    :param poll_interval:  polling period in seconds to check for the status.
+    """
+
+    def __init__(
+        self,
+        job_id: int,
+        conn_id: str,
+        end_time: float,
+        poll_interval: float,
+    ):
+        super().__init__()
+        self.job_id = job_id
+        self.conn_id = conn_id
+        self.end_time = end_time
+        self.poll_interval = poll_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes AirbyteSyncTrigger arguments and classpath."""
+        return (
+            "airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger",
+            {
+                "job_id": self.job_id,
+                "conn_id": self.conn_id,
+                "end_time": self.end_time,
+                "poll_interval": self.poll_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Make async connection to Airbyte, polls for the pipeline run 
status."""
+        hook = AirbyteHook(airbyte_conn_id=self.conn_id)
+        try:
+            while await self.is_still_running(hook):
+                if self.end_time < time.time():
+                    yield TriggerEvent(
+                        {
+                            "status": "error",
+                            "message": f"Job run {self.job_id} has not reached 
a terminal status after "
+                            f"{self.end_time} seconds.",
+                            "job_id": self.job_id,
+                        }
+                    )
+                await asyncio.sleep(self.poll_interval)
+            job_run_status = await hook.get_job_status(self.job_id)
+            if job_run_status == hook.SUCCEEDED:
+                yield TriggerEvent(
+                    {
+                        "status": "success",
+                        "message": f"Job run {self.job_id} has completed 
successfully.",
+                        "job_id": self.job_id,
+                    }
+                )
+            elif job_run_status == hook.CANCELLED:
+                yield TriggerEvent(
+                    {
+                        "status": "cancelled",
+                        "message": f"Job run {self.job_id} has been 
cancelled.",
+                        "job_id": self.job_id,
+                    }
+                )
+            else:
+                yield TriggerEvent(
+                    {
+                        "status": "error",
+                        "message": f"Job run {self.job_id} has failed.",
+                        "job_id": self.job_id,
+                    }
+                )
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e), 
"job_id": self.job_id})
+
+    async def is_still_running(self, hook: AirbyteHook) -> bool:
+        """
+        Async function to check whether the job is submitted via async API.
+
+        If job is in running state returns True if it is still running else 
return False
+        """
+        job_run_status = await hook.get_job_status(self.job_id)
+        if job_run_status in (AirbyteHook.RUNNING, AirbyteHook.PENDING, 
AirbyteHook.INCOMPLETE):
+            return True
+        return False
diff --git a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst 
b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
index 68fd8c44cb..60f47955dd 100644
--- a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
+++ b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
@@ -38,10 +38,9 @@ create in Airbyte between a source and destination 
synchronization job.
 Use the ``airbyte_conn_id`` parameter to specify the Airbyte connection to use 
to
 connect to your account.
 
-You can trigger a synchronization job in Airflow in two ways with the 
Operator. The first one
-is a synchronous process. This will trigger the Airbyte job and the Operator 
manage the status
-of the job. Another way is use the flag ``async = True`` so the Operator only 
trigger the job and
-return the ``job_id`` that should be pass to the AirbyteSensor.
+You can trigger a synchronization job in Airflow in two ways with the 
Operator. The first one is a synchronous process.
+This Operator will initiate the Airbyte job, and the Operator manages the job 
status. Another way is to use the flag
+``async = True`` so the Operator only triggers the job and returns the 
``job_id``, passed to the AirbyteSensor.
 
 An example using the synchronous way:
 
diff --git a/tests/providers/airbyte/operators/test_airbyte.py 
b/tests/providers/airbyte/operators/test_airbyte.py
index f8ecd15615..2c0085f53d 100644
--- a/tests/providers/airbyte/operators/test_airbyte.py
+++ b/tests/providers/airbyte/operators/test_airbyte.py
@@ -37,7 +37,7 @@ class TestAirbyteTriggerSyncOp:
     
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.wait_for_job", 
return_value=None)
     def test_execute(self, mock_wait_for_job, mock_submit_sync_connection):
         mock_submit_sync_connection.return_value = mock.Mock(
-            **{"json.return_value": {"job": {"id": self.job_id}}}
+            **{"json.return_value": {"job": {"id": self.job_id, "status": 
"running"}}}
         )
 
         op = AirbyteTriggerSyncOperator(
diff --git a/tests/providers/airbyte/triggers/__init__.py 
b/tests/providers/airbyte/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/airbyte/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/airbyte/triggers/test_airbyte.py 
b/tests/providers/airbyte/triggers/test_airbyte.py
new file mode 100644
index 0000000000..103df7cf00
--- /dev/null
+++ b/tests/providers/airbyte/triggers/test_airbyte.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import time
+from unittest import mock
+
+import pytest
+
+from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger
+from airflow.triggers.base import TriggerEvent
+
+
+class TestAirbyteSyncTrigger:
+    DAG_ID = "airbyte_sync_run"
+    TASK_ID = "airbyte_sync_run_task_op"
+    JOB_ID = 1234
+    CONN_ID = "airbyte_default"
+    END_TIME = time.time() + 60 * 60 * 24 * 7
+    POLL_INTERVAL = 3.0
+
+    def test_serialization(self):
+        """Assert TestAirbyteSyncTrigger correctly serializes its arguments 
and classpath."""
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, 
end_time=self.END_TIME, job_id=self.JOB_ID
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger"
+        assert kwargs == {
+            "job_id": self.JOB_ID,
+            "conn_id": self.CONN_ID,
+            "end_time": self.END_TIME,
+            "poll_interval": self.POLL_INTERVAL,
+        }
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+    async def test_airbyte_run_sync_trigger(self, mocked_is_still_running):
+        """Test AirbyteSyncTrigger is triggered with mocked details and run 
successfully."""
+        mocked_is_still_running.return_value = True
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
+        )
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # TriggerEvent was not returned
+        assert task.done() is False
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mock_value, mock_status, mock_message",
+        [
+            (AirbyteHook.SUCCEEDED, "success", "Job run 1234 has completed 
successfully."),
+        ],
+    )
+    
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_job_for_terminal_status_success(
+        self, mock_get_job_status, mocked_is_still_running, mock_value, 
mock_status, mock_message
+    ):
+        """Assert that run trigger success message in case of job success"""
+        mocked_is_still_running.return_value = False
+        mock_get_job_status.return_value = mock_value
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
+        )
+        expected_result = {
+            "status": mock_status,
+            "message": mock_message,
+            "job_id": self.JOB_ID,
+        }
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+        assert TriggerEvent(expected_result) == task.result()
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mock_value, mock_status, mock_message",
+        [
+            (AirbyteHook.CANCELLED, "cancelled", "Job run 1234 has been 
cancelled."),
+        ],
+    )
+    
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_job_for_terminal_status_cancelled(
+        self, mock_get_job_status, mocked_is_still_running, mock_value, 
mock_status, mock_message
+    ):
+        """Assert that run trigger success message in case of job success"""
+        mocked_is_still_running.return_value = False
+        mock_get_job_status.return_value = mock_value
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, 
end_time=self.END_TIME, job_id=self.JOB_ID
+        )
+        expected_result = {
+            "status": mock_status,
+            "message": mock_message,
+            "job_id": self.JOB_ID,
+        }
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+        assert TriggerEvent(expected_result) == task.result()
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mock_value, mock_status, mock_message",
+        [
+            (AirbyteHook.ERROR, "error", "Job run 1234 has failed."),
+        ],
+    )
+    
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_job_for_terminal_status_error(
+        self, mock_get_job_status, mocked_is_still_running, mock_value, 
mock_status, mock_message
+    ):
+        """Assert that run trigger success message in case of job success"""
+        mocked_is_still_running.return_value = False
+        mock_get_job_status.return_value = mock_value
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
+        )
+        expected_result = {
+            "status": mock_status,
+            "message": mock_message,
+            "job_id": self.JOB_ID,
+        }
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+        assert TriggerEvent(expected_result) == task.result()
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_job_exception(self, mock_get_job_status, 
mocked_is_still_running):
+        """Assert that run catch exception if Airbyte Sync job API throw 
exception"""
+        mocked_is_still_running.return_value = False
+        mock_get_job_status.side_effect = Exception("Test exception")
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
+        )
+        task = [i async for i in trigger.run()]
+        response = TriggerEvent(
+            {
+                "status": "error",
+                "message": "Test exception",
+                "job_id": self.JOB_ID,
+            }
+        )
+        assert len(task) == 1
+        assert response in task
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_job_timeout(self, mock_get_job_status, 
mocked_is_still_running):
+        """Assert that run timeout after end_time elapsed"""
+        mocked_is_still_running.return_value = True
+        mock_get_job_status.side_effect = Exception("Test exception")
+        end_time = time.time()
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=end_time,
+            job_id=self.JOB_ID,
+        )
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        expected = TriggerEvent(
+            {
+                "status": "error",
+                "message": f"Job run {self.JOB_ID} has not reached a terminal 
status "
+                f"after {end_time} seconds.",
+                "job_id": self.JOB_ID,
+            }
+        )
+        assert expected == actual
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mock_response, expected_status",
+        [
+            (AirbyteHook.SUCCEEDED, False),
+        ],
+    )
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_job_is_still_running_success(
+        self, mock_get_job_status, mock_response, expected_status
+    ):
+        """Test is_still_running with mocked response job status and assert
+        the return response with expected value"""
+        hook = mock.AsyncMock(AirbyteHook)
+        hook.get_job_status.return_value = mock_response
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
+        )
+        response = await trigger.is_still_running(hook)
+        assert response == expected_status
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "mock_response, expected_status",
+        [
+            (AirbyteHook.RUNNING, True),
+        ],
+    )
+    
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+    async def test_airbyte_sync_run_is_still_running(
+        self, mock_get_job_status, mock_response, expected_status
+    ):
+        """Test is_still_running with mocked response job status and assert
+        the return response with expected value"""
+        airbyte_hook = mock.AsyncMock(AirbyteHook)
+        airbyte_hook.get_job_status.return_value = mock_response
+        trigger = AirbyteSyncTrigger(
+            conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, 
end_time=self.END_TIME, job_id=self.JOB_ID
+        )
+        response = await trigger.is_still_running(airbyte_hook)
+        assert response == expected_status


Reply via email to