This is an automated email from the ASF dual-hosted git repository.
pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3806a63bfe Add deferrable functionality to the AirbyteJobSensor and
AirbyteTriggerSyncOperator (#36780)
3806a63bfe is described below
commit 3806a63bfe0c07d120a7181d751033c850f54997
Author: Chris Hronek <[email protected]>
AuthorDate: Wed Jan 17 12:29:57 2024 -0700
Add deferrable functionality to the AirbyteJobSensor and
AirbyteTriggerSyncOperator (#36780)
* Add primary components
* Create a drop in sensor replacement for the sensor
* Add test for the async sensor
* Add serialization test for airbyte trigger
* Add test for is_still_running on airbyte trigger
* Add test for terminal success status
* Add test for cancelled status
* Add additional trigger tests
* Add tests for is_still_running
* Add example to airbyte provider airflow docs
* Update airflow docs to include async sensor
* Resolve comments
* Combine sensors and add deferrable parameter
* Add class doc for deferrable parameter
* Add deferrable parameter to the operator itself
* Remove deprecated line from system tests
* Add trigger to the provider yaml
* Fix ci tests
---
airflow/providers/airbyte/hooks/airbyte.py | 58 ++++-
airflow/providers/airbyte/operators/airbyte.py | 56 ++++-
airflow/providers/airbyte/provider.yaml | 5 +
airflow/providers/airbyte/sensors/airbyte.py | 74 +++++-
airflow/providers/airbyte/triggers/__init__.py | 16 ++
airflow/providers/airbyte/triggers/airbyte.py | 117 ++++++++++
.../operators/airbyte.rst | 7 +-
tests/providers/airbyte/operators/test_airbyte.py | 2 +-
tests/providers/airbyte/triggers/__init__.py | 16 ++
tests/providers/airbyte/triggers/test_airbyte.py | 253 +++++++++++++++++++++
10 files changed, 591 insertions(+), 13 deletions(-)
diff --git a/airflow/providers/airbyte/hooks/airbyte.py
b/airflow/providers/airbyte/hooks/airbyte.py
index a9ce336022..b8ad957a9c 100644
--- a/airflow/providers/airbyte/hooks/airbyte.py
+++ b/airflow/providers/airbyte/hooks/airbyte.py
@@ -17,12 +17,23 @@
# under the License.
from __future__ import annotations
+import base64
+import json
import time
-from typing import Any
+from typing import TYPE_CHECKING, Any, TypeVar
+
+import aiohttp
+from aiohttp import ClientResponseError
+from asgiref.sync import sync_to_async
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
+if TYPE_CHECKING:
+ from airflow.models import Connection
+
+T = TypeVar("T", bound=Any)
+
class AirbyteHook(HttpHook):
"""
@@ -50,6 +61,51 @@ class AirbyteHook(HttpHook):
super().__init__(http_conn_id=airbyte_conn_id)
self.api_version: str = api_version
+ async def get_headers_tenants_from_connection(self) -> tuple[dict[str,
Any], str]:
+ """Get Headers, tenants from the connection details."""
+ connection: Connection = await
sync_to_async(self.get_connection)(self.http_conn_id)
+ base_url = connection.host
+
+ credentials = f"{connection.login}:{connection.password}"
+ credentials_base64 =
base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
+
+ authorized_headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "authorization": f"Basic {credentials_base64}",
+ }
+
+ return authorized_headers, base_url
+
+ async def get_job_details(self, job_id: int) -> Any:
+ """
+ Uses Http async call to retrieve metadata for a specific job of an
Airbyte Sync.
+
+ :param job_id: The ID of an Airbyte Sync Job.
+ """
+ headers, base_url = await self.get_headers_tenants_from_connection()
+ url = f"{base_url}/api/{self.api_version}/jobs/get"
+ self.log.info("URL for api request: %s", url)
+ async with aiohttp.ClientSession(headers=headers) as session:
+ async with session.post(url=url, data=json.dumps({"id": job_id}))
as response:
+ try:
+ response.raise_for_status()
+ return await response.json()
+ except ClientResponseError as e:
+ msg = f"{e.status}: {e.message} - {e.request_info}"
+ raise AirflowException(msg)
+
+ async def get_job_status(self, job_id: int) -> str:
+ """
+ Retrieves the status for a specific job of an Airbyte Sync.
+
+ :param job_id: The ID of an Airbyte Sync Job.
+ """
+ self.log.info("Getting the status of job run %s.", job_id)
+ response = await self.get_job_details(job_id=job_id)
+ job_run_status: str = response["job"]["status"]
+ return job_run_status
+
def wait_for_job(self, job_id: str | int, wait_seconds: float = 3,
timeout: float | None = 3600) -> None:
"""
Poll a job to check if it finishes.
diff --git a/airflow/providers/airbyte/operators/airbyte.py
b/airflow/providers/airbyte/operators/airbyte.py
index 6d101662db..84a12dadfa 100644
--- a/airflow/providers/airbyte/operators/airbyte.py
+++ b/airflow/providers/airbyte/operators/airbyte.py
@@ -17,10 +17,14 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+import time
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -40,6 +44,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
:param asynchronous: Optional. Flag to get job_id after submitting the job
to the Airbyte API.
This is useful for submitting long running jobs and
waiting on them asynchronously using the AirbyteJobSensor. Defaults to
False.
+ :param deferrable: Run operator in the deferrable mode.
:param api_version: Optional. Airbyte API version. Defaults to "v1".
:param wait_seconds: Optional. Number of seconds between checks. Only used
when ``asynchronous`` is False.
Defaults to 3 seconds.
@@ -48,12 +53,14 @@ class AirbyteTriggerSyncOperator(BaseOperator):
"""
template_fields: Sequence[str] = ("connection_id",)
+ ui_color = "#6C51FD"
def __init__(
self,
connection_id: str,
airbyte_conn_id: str = "airbyte_default",
asynchronous: bool = False,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
api_version: str = "v1",
wait_seconds: float = 3,
timeout: float = 3600,
@@ -66,23 +73,62 @@ class AirbyteTriggerSyncOperator(BaseOperator):
self.api_version = api_version
self.wait_seconds = wait_seconds
self.asynchronous = asynchronous
+ self.deferrable = deferrable
def execute(self, context: Context) -> None:
"""Create Airbyte Job and wait to finish."""
- self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id,
api_version=self.api_version)
- job_object =
self.hook.submit_sync_connection(connection_id=self.connection_id)
+ hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id,
api_version=self.api_version)
+ job_object =
hook.submit_sync_connection(connection_id=self.connection_id)
self.job_id = job_object.json()["job"]["id"]
+ state = job_object.json()["job"]["status"]
+ end_time = time.time() + self.timeout
self.log.info("Job %s was submitted to Airbyte Server", self.job_id)
if not self.asynchronous:
self.log.info("Waiting for job %s to complete", self.job_id)
- self.hook.wait_for_job(job_id=self.job_id,
wait_seconds=self.wait_seconds, timeout=self.timeout)
+ if self.deferrable:
+ if state in (hook.RUNNING, hook.PENDING, hook.INCOMPLETE):
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AirbyteSyncTrigger(
+ conn_id=self.airbyte_conn_id,
+ job_id=self.job_id,
+ end_time=end_time,
+ poll_interval=60,
+ ),
+ method_name="execute_complete",
+ )
+ elif state == hook.SUCCEEDED:
+ self.log.info("Job %s completed successfully", self.job_id)
+ return
+ elif state == hook.ERROR:
+ raise AirflowException(f"Job failed:\n{self.job_id}")
+ elif state == hook.CANCELLED:
+ raise AirflowException(f"Job was
cancelled:\n{self.job_id}")
+ else:
+ raise Exception(f"Encountered unexpected state `{state}`
for job_id `{self.job_id}")
+ else:
+ hook.wait_for_job(job_id=self.job_id,
wait_seconds=self.wait_seconds, timeout=self.timeout)
self.log.info("Job %s completed successfully", self.job_id)
return self.job_id
+ def execute_complete(self, context: Context, event: Any = None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+
+ self.log.info("%s completed successfully.", self.task_id)
+ return None
+
def on_kill(self):
"""Cancel the job if task is cancelled."""
+ hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id)
if self.job_id:
self.log.info("on_kill: cancel the airbyte Job %s", self.job_id)
- self.hook.cancel_job(self.job_id)
+ hook.cancel_job(self.job_id)
diff --git a/airflow/providers/airbyte/provider.yaml
b/airflow/providers/airbyte/provider.yaml
index c973844dd7..4f163eedfc 100644
--- a/airflow/providers/airbyte/provider.yaml
+++ b/airflow/providers/airbyte/provider.yaml
@@ -69,6 +69,11 @@ sensors:
python-modules:
- airflow.providers.airbyte.sensors.airbyte
+triggers:
+ - integration-name: Airbyte
+ python-modules:
+ - airflow.providers.airbyte.triggers.airbyte
+
connection-types:
- hook-class-name: airflow.providers.airbyte.hooks.airbyte.AirbyteHook
connection-type: airbyte
diff --git a/airflow/providers/airbyte/sensors/airbyte.py
b/airflow/providers/airbyte/sensors/airbyte.py
index f38206246e..4556d55430 100644
--- a/airflow/providers/airbyte/sensors/airbyte.py
+++ b/airflow/providers/airbyte/sensors/airbyte.py
@@ -18,10 +18,14 @@
"""This module contains a Airbyte Job sensor."""
from __future__ import annotations
-from typing import TYPE_CHECKING, Sequence
+import time
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
-from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger
from airflow.sensors.base import BaseSensorOperator
if TYPE_CHECKING:
@@ -34,6 +38,7 @@ class AirbyteJobSensor(BaseSensorOperator):
:param airbyte_job_id: Required. Id of the Airbyte job
:param airbyte_conn_id: Optional. The name of the Airflow connection to get
+ :param deferrable: Run sensor in the deferrable mode.
connection information for Airbyte. Defaults to "airbyte_default".
:param api_version: Optional. Airbyte API version. Defaults to "v1".
"""
@@ -45,11 +50,30 @@ class AirbyteJobSensor(BaseSensorOperator):
self,
*,
airbyte_job_id: int,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
airbyte_conn_id: str = "airbyte_default",
api_version: str = "v1",
**kwargs,
) -> None:
+ if deferrable:
+ if "poke_interval" not in kwargs:
+ # TODO: Remove once deprecated
+ if "polling_interval" in kwargs:
+ kwargs["poke_interval"] = kwargs["polling_interval"]
+ warnings.warn(
+ "Argument `poll_interval` is deprecated and will be
removed "
+ "in a future release. Please use `poke_interval`
instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ else:
+ kwargs["poke_interval"] = 5
+
+ if "timeout" not in kwargs:
+ kwargs["timeout"] = 60 * 60 * 24 * 7
+
super().__init__(**kwargs)
+ self.deferrable = deferrable
self.airbyte_conn_id = airbyte_conn_id
self.airbyte_job_id = airbyte_job_id
self.api_version = api_version
@@ -79,3 +103,49 @@ class AirbyteJobSensor(BaseSensorOperator):
self.log.info("Waiting for job %s to complete.", self.airbyte_job_id)
return False
+
+ def execute(self, context: Context) -> Any:
+ """Submits a job which generates a run_id and gets deferred."""
+ if not self.deferrable:
+ super().execute(context)
+ else:
+ hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id)
+ job = hook.get_job(job_id=(int(self.airbyte_job_id)))
+ state = job.json()["job"]["status"]
+ end_time = time.time() + self.timeout
+
+ self.log.info("Airbyte Job Id: Job %s", self.airbyte_job_id)
+
+ if state in (hook.RUNNING, hook.PENDING, hook.INCOMPLETE):
+ self.defer(
+ timeout=self.execution_timeout,
+ trigger=AirbyteSyncTrigger(
+ conn_id=self.airbyte_conn_id,
+ job_id=self.airbyte_job_id,
+ end_time=end_time,
+ poll_interval=60,
+ ),
+ method_name="execute_complete",
+ )
+ elif state == hook.SUCCEEDED:
+ self.log.info("%s completed successfully.", self.task_id)
+ return
+ elif state == hook.ERROR:
+ raise AirflowException(f"Job failed:\n{job}")
+ elif state == hook.CANCELLED:
+ raise AirflowException(f"Job was cancelled:\n{job}")
+ else:
+ raise Exception(f"Encountered unexpected state `{state}` for
job_id `{self.airbyte_job_id}")
+
+ def execute_complete(self, context: Context, event: Any = None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event["status"] == "error":
+ raise AirflowException(event["message"])
+
+ self.log.info("%s completed successfully.", self.task_id)
+ return None
diff --git a/airflow/providers/airbyte/triggers/__init__.py
b/airflow/providers/airbyte/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/airbyte/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/airbyte/triggers/airbyte.py
b/airflow/providers/airbyte/triggers/airbyte.py
new file mode 100644
index 0000000000..06c926d681
--- /dev/null
+++ b/airflow/providers/airbyte/triggers/airbyte.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import time
+from typing import Any, AsyncIterator
+
+from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class AirbyteSyncTrigger(BaseTrigger):
+ """
+ Triggers Airbyte Sync, makes an asynchronous HTTP call to get the status
via a job ID.
+
+ This trigger is designed to initiate and monitor the status of Airbyte
Sync jobs. It
+ makes use of asynchronous communication to check the progress of a job run
over time.
+
+ :param conn_id: The connection identifier for connecting to Airbyte.
+ :param job_id: The ID of an Airbyte Sync job.
+ :param end_time: Time in seconds to wait for a job run to reach a terminal
status. Defaults to 7 days.
+ :param poll_interval: polling period in seconds to check for the status.
+ """
+
+ def __init__(
+ self,
+ job_id: int,
+ conn_id: str,
+ end_time: float,
+ poll_interval: float,
+ ):
+ super().__init__()
+ self.job_id = job_id
+ self.conn_id = conn_id
+ self.end_time = end_time
+ self.poll_interval = poll_interval
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes AirbyteSyncTrigger arguments and classpath."""
+ return (
+ "airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger",
+ {
+ "job_id": self.job_id,
+ "conn_id": self.conn_id,
+ "end_time": self.end_time,
+ "poll_interval": self.poll_interval,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """Make async connection to Airbyte, polls for the pipeline run
status."""
+ hook = AirbyteHook(airbyte_conn_id=self.conn_id)
+ try:
+ while await self.is_still_running(hook):
+ if self.end_time < time.time():
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Job run {self.job_id} has not reached
a terminal status after "
+ f"{self.end_time} seconds.",
+ "job_id": self.job_id,
+ }
+ )
+ await asyncio.sleep(self.poll_interval)
+ job_run_status = await hook.get_job_status(self.job_id)
+ if job_run_status == hook.SUCCEEDED:
+ yield TriggerEvent(
+ {
+ "status": "success",
+ "message": f"Job run {self.job_id} has completed
successfully.",
+ "job_id": self.job_id,
+ }
+ )
+ elif job_run_status == hook.CANCELLED:
+ yield TriggerEvent(
+ {
+ "status": "cancelled",
+ "message": f"Job run {self.job_id} has been
cancelled.",
+ "job_id": self.job_id,
+ }
+ )
+ else:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Job run {self.job_id} has failed.",
+ "job_id": self.job_id,
+ }
+ )
+ except Exception as e:
+ yield TriggerEvent({"status": "error", "message": str(e),
"job_id": self.job_id})
+
+ async def is_still_running(self, hook: AirbyteHook) -> bool:
+ """
+ Async function to check whether the job is submitted via async API.
+
+ If job is in running state returns True if it is still running else
return False
+ """
+ job_run_status = await hook.get_job_status(self.job_id)
+ if job_run_status in (AirbyteHook.RUNNING, AirbyteHook.PENDING,
AirbyteHook.INCOMPLETE):
+ return True
+ return False
diff --git a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
index 68fd8c44cb..60f47955dd 100644
--- a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
+++ b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
@@ -38,10 +38,9 @@ create in Airbyte between a source and destination
synchronization job.
Use the ``airbyte_conn_id`` parameter to specify the Airbyte connection to use
to
connect to your account.
-You can trigger a synchronization job in Airflow in two ways with the
Operator. The first one
-is a synchronous process. This will trigger the Airbyte job and the Operator
manage the status
-of the job. Another way is use the flag ``async = True`` so the Operator only
trigger the job and
-return the ``job_id`` that should be pass to the AirbyteSensor.
+You can trigger a synchronization job in Airflow in two ways with the
Operator. The first one is a synchronous process.
+This Operator will initiate the Airbyte job, and the Operator manages the job
status. Another way is to use the flag
+``async = True`` so the Operator only triggers the job and returns the
``job_id``, passed to the AirbyteSensor.
An example using the synchronous way:
diff --git a/tests/providers/airbyte/operators/test_airbyte.py
b/tests/providers/airbyte/operators/test_airbyte.py
index f8ecd15615..2c0085f53d 100644
--- a/tests/providers/airbyte/operators/test_airbyte.py
+++ b/tests/providers/airbyte/operators/test_airbyte.py
@@ -37,7 +37,7 @@ class TestAirbyteTriggerSyncOp:
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.wait_for_job",
return_value=None)
def test_execute(self, mock_wait_for_job, mock_submit_sync_connection):
mock_submit_sync_connection.return_value = mock.Mock(
- **{"json.return_value": {"job": {"id": self.job_id}}}
+ **{"json.return_value": {"job": {"id": self.job_id, "status":
"running"}}}
)
op = AirbyteTriggerSyncOperator(
diff --git a/tests/providers/airbyte/triggers/__init__.py
b/tests/providers/airbyte/triggers/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/airbyte/triggers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/airbyte/triggers/test_airbyte.py
b/tests/providers/airbyte/triggers/test_airbyte.py
new file mode 100644
index 0000000000..103df7cf00
--- /dev/null
+++ b/tests/providers/airbyte/triggers/test_airbyte.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import time
+from unittest import mock
+
+import pytest
+
+from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
+from airflow.providers.airbyte.triggers.airbyte import AirbyteSyncTrigger
+from airflow.triggers.base import TriggerEvent
+
+
+class TestAirbyteSyncTrigger:
+ DAG_ID = "airbyte_sync_run"
+ TASK_ID = "airbyte_sync_run_task_op"
+ JOB_ID = 1234
+ CONN_ID = "airbyte_default"
+ END_TIME = time.time() + 60 * 60 * 24 * 7
+ POLL_INTERVAL = 3.0
+
+ def test_serialization(self):
+ """Assert TestAirbyteSyncTrigger correctly serializes its arguments
and classpath."""
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL,
end_time=self.END_TIME, job_id=self.JOB_ID
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger"
+ assert kwargs == {
+ "job_id": self.JOB_ID,
+ "conn_id": self.CONN_ID,
+ "end_time": self.END_TIME,
+ "poll_interval": self.POLL_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+ async def test_airbyte_run_sync_trigger(self, mocked_is_still_running):
+ """Test AirbyteSyncTrigger is triggered with mocked details and run
successfully."""
+ mocked_is_still_running.return_value = True
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID,
+ poll_interval=self.POLL_INTERVAL,
+ end_time=self.END_TIME,
+ job_id=self.JOB_ID,
+ )
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mock_value, mock_status, mock_message",
+ [
+ (AirbyteHook.SUCCEEDED, "success", "Job run 1234 has completed
successfully."),
+ ],
+ )
+
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_job_for_terminal_status_success(
+ self, mock_get_job_status, mocked_is_still_running, mock_value,
mock_status, mock_message
+ ):
+ """Assert that run trigger success message in case of job success"""
+ mocked_is_still_running.return_value = False
+ mock_get_job_status.return_value = mock_value
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID,
+ poll_interval=self.POLL_INTERVAL,
+ end_time=self.END_TIME,
+ job_id=self.JOB_ID,
+ )
+ expected_result = {
+ "status": mock_status,
+ "message": mock_message,
+ "job_id": self.JOB_ID,
+ }
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+ assert TriggerEvent(expected_result) == task.result()
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mock_value, mock_status, mock_message",
+ [
+ (AirbyteHook.CANCELLED, "cancelled", "Job run 1234 has been
cancelled."),
+ ],
+ )
+
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_job_for_terminal_status_cancelled(
+ self, mock_get_job_status, mocked_is_still_running, mock_value,
mock_status, mock_message
+ ):
+ """Assert that run trigger success message in case of job success"""
+ mocked_is_still_running.return_value = False
+ mock_get_job_status.return_value = mock_value
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL,
end_time=self.END_TIME, job_id=self.JOB_ID
+ )
+ expected_result = {
+ "status": mock_status,
+ "message": mock_message,
+ "job_id": self.JOB_ID,
+ }
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+ assert TriggerEvent(expected_result) == task.result()
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mock_value, mock_status, mock_message",
+ [
+ (AirbyteHook.ERROR, "error", "Job run 1234 has failed."),
+ ],
+ )
+
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_job_for_terminal_status_error(
+ self, mock_get_job_status, mocked_is_still_running, mock_value,
mock_status, mock_message
+ ):
+ """Assert that run trigger success message in case of job success"""
+ mocked_is_still_running.return_value = False
+ mock_get_job_status.return_value = mock_value
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID,
+ poll_interval=self.POLL_INTERVAL,
+ end_time=self.END_TIME,
+ job_id=self.JOB_ID,
+ )
+ expected_result = {
+ "status": mock_status,
+ "message": mock_message,
+ "job_id": self.JOB_ID,
+ }
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+ assert TriggerEvent(expected_result) == task.result()
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_job_exception(self, mock_get_job_status,
mocked_is_still_running):
+ """Assert that run catch exception if Airbyte Sync job API throw
exception"""
+ mocked_is_still_running.return_value = False
+ mock_get_job_status.side_effect = Exception("Test exception")
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID,
+ poll_interval=self.POLL_INTERVAL,
+ end_time=self.END_TIME,
+ job_id=self.JOB_ID,
+ )
+ task = [i async for i in trigger.run()]
+ response = TriggerEvent(
+ {
+ "status": "error",
+ "message": "Test exception",
+ "job_id": self.JOB_ID,
+ }
+ )
+ assert len(task) == 1
+ assert response in task
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger.is_still_running")
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_job_timeout(self, mock_get_job_status,
mocked_is_still_running):
+ """Assert that run timeout after end_time elapsed"""
+ mocked_is_still_running.return_value = True
+ mock_get_job_status.side_effect = Exception("Test exception")
+ end_time = time.time()
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID,
+ poll_interval=self.POLL_INTERVAL,
+ end_time=end_time,
+ job_id=self.JOB_ID,
+ )
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ expected = TriggerEvent(
+ {
+ "status": "error",
+ "message": f"Job run {self.JOB_ID} has not reached a terminal
status "
+ f"after {end_time} seconds.",
+ "job_id": self.JOB_ID,
+ }
+ )
+ assert expected == actual
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mock_response, expected_status",
+ [
+ (AirbyteHook.SUCCEEDED, False),
+ ],
+ )
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_job_is_still_running_success(
+ self, mock_get_job_status, mock_response, expected_status
+ ):
+ """Test is_still_running with mocked response job status and assert
+ the return response with expected value"""
+ hook = mock.AsyncMock(AirbyteHook)
+ hook.get_job_status.return_value = mock_response
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID,
+ poll_interval=self.POLL_INTERVAL,
+ end_time=self.END_TIME,
+ job_id=self.JOB_ID,
+ )
+ response = await trigger.is_still_running(hook)
+ assert response == expected_status
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "mock_response, expected_status",
+ [
+ (AirbyteHook.RUNNING, True),
+ ],
+ )
+
@mock.patch("airflow.providers.airbyte.hooks.airbyte.AirbyteHook.get_job_status")
+ async def test_airbyte_sync_run_is_still_running(
+ self, mock_get_job_status, mock_response, expected_status
+ ):
+ """Test is_still_running with mocked response job status and assert
+ the return response with expected value"""
+ airbyte_hook = mock.AsyncMock(AirbyteHook)
+ airbyte_hook.get_job_status.return_value = mock_response
+ trigger = AirbyteSyncTrigger(
+ conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL,
end_time=self.END_TIME, job_id=self.JOB_ID
+ )
+ response = await trigger.is_still_running(airbyte_hook)
+ assert response == expected_status