This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 25fd66c4b4 Ensure Airbyte Provider is Compatible with Cloud and Config 
APIs (#37943)
25fd66c4b4 is described below

commit 25fd66c4b48fd940b11b09d3b590ab9d002cda11
Author: Chris Hronek <31361051+chrishro...@users.noreply.github.com>
AuthorDate: Wed Mar 6 12:48:40 2024 -0700

    Ensure Airbyte Provider is Compatible with Cloud and Config APIs (#37943)
    
    * Add parameter to specify which airbyte api to use
    
    * Fix provider tests
    
    * Add documentation outlining api types
    
    * Add api_type to the airbytejobsensor
    
    * Make api_type values less redundant
---
 airflow/providers/airbyte/hooks/airbyte.py         | 134 +++++++++++++++------
 airflow/providers/airbyte/operators/airbyte.py     |  20 ++-
 airflow/providers/airbyte/sensors/airbyte.py       |  24 +++-
 airflow/providers/airbyte/triggers/airbyte.py      |   8 +-
 .../connections.rst                                |  12 +-
 .../operators/airbyte.rst                          |   5 +
 tests/providers/airbyte/triggers/test_airbyte.py   |  26 +++-
 7 files changed, 174 insertions(+), 55 deletions(-)

diff --git a/airflow/providers/airbyte/hooks/airbyte.py 
b/airflow/providers/airbyte/hooks/airbyte.py
index 4545eeb7ba..e0ae41ff13 100644
--- a/airflow/providers/airbyte/hooks/airbyte.py
+++ b/airflow/providers/airbyte/hooks/airbyte.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import base64
 import json
 import time
-from typing import TYPE_CHECKING, Any, TypeVar
+from typing import TYPE_CHECKING, Any, Literal, TypeVar
 
 import aiohttp
 from aiohttp import ClientResponseError
@@ -42,6 +42,7 @@ class AirbyteHook(HttpHook):
     :param airbyte_conn_id: Optional. The name of the Airflow connection to get
         connection information for Airbyte. Defaults to "airbyte_default".
     :param api_version: Optional. Airbyte API version. Defaults to "v1".
+    :param api_type: Optional. The type of Airbyte API to use. Either "config" 
or "cloud". Defaults to "config".
     """
 
     conn_name_attr = "airbyte_conn_id"
@@ -57,23 +58,35 @@ class AirbyteHook(HttpHook):
     ERROR = "error"
     INCOMPLETE = "incomplete"
 
-    def __init__(self, airbyte_conn_id: str = "airbyte_default", api_version: 
str = "v1") -> None:
+    def __init__(
+        self,
+        airbyte_conn_id: str = "airbyte_default",
+        api_version: str = "v1",
+        api_type: Literal["config", "cloud"] = "config",
+    ) -> None:
         super().__init__(http_conn_id=airbyte_conn_id)
         self.api_version: str = api_version
+        self.api_type: str = api_type
 
     async def get_headers_tenants_from_connection(self) -> tuple[dict[str, 
Any], str]:
         """Get Headers, tenants from the connection details."""
         connection: Connection = await 
sync_to_async(self.get_connection)(self.http_conn_id)
         base_url = connection.host
 
-        credentials = f"{connection.login}:{connection.password}"
-        credentials_base64 = 
base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
-
-        authorized_headers = {
-            "accept": "application/json",
-            "content-type": "application/json",
-            "authorization": f"Basic {credentials_base64}",
-        }
+        if self.api_type == "config":
+            credentials = f"{connection.login}:{connection.password}"
+            credentials_base64 = 
base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
+            authorized_headers = {
+                "accept": "application/json",
+                "content-type": "application/json",
+                "authorization": f"Basic {credentials_base64}",
+            }
+        else:
+            authorized_headers = {
+                "accept": "application/json",
+                "content-type": "application/json",
+                "authorization": f"Bearer {connection.password}",
+            }
 
         return authorized_headers, base_url
 
@@ -84,16 +97,28 @@ class AirbyteHook(HttpHook):
         :param job_id: The ID of an Airbyte Sync Job.
         """
         headers, base_url = await self.get_headers_tenants_from_connection()
-        url = f"{base_url}/api/{self.api_version}/jobs/get"
-        self.log.info("URL for api request: %s", url)
-        async with aiohttp.ClientSession(headers=headers) as session:
-            async with session.post(url=url, data=json.dumps({"id": job_id})) 
as response:
-                try:
-                    response.raise_for_status()
-                    return await response.json()
-                except ClientResponseError as e:
-                    msg = f"{e.status}: {e.message} - {e.request_info}"
-                    raise AirflowException(msg)
+        if self.api_type == "config":
+            url = f"{base_url}/api/{self.api_version}/jobs/get"
+            self.log.info("URL for api request: %s", url)
+            async with aiohttp.ClientSession(headers=headers) as session:
+                async with session.post(url=url, data=json.dumps({"id": 
job_id})) as response:
+                    try:
+                        response.raise_for_status()
+                        return await response.json()
+                    except ClientResponseError as e:
+                        msg = f"{e.status}: {e.message} - {e.request_info}"
+                        raise AirflowException(msg)
+        else:
+            url = f"{base_url}/{self.api_version}/jobs/{job_id}"
+            self.log.info("URL for api request: %s", url)
+            async with aiohttp.ClientSession(headers=headers) as session:
+                async with session.get(url=url) as response:
+                    try:
+                        response.raise_for_status()
+                        return await response.json()
+                    except ClientResponseError as e:
+                        msg = f"{e.status}: {e.message} - {e.request_info}"
+                        raise AirflowException(msg)
 
     async def get_job_status(self, job_id: int) -> str:
         """
@@ -103,8 +128,10 @@ class AirbyteHook(HttpHook):
         """
         self.log.info("Getting the status of job run %s.", job_id)
         response = await self.get_job_details(job_id=job_id)
-        job_run_status: str = response["job"]["status"]
-        return job_run_status
+        if self.api_type == "config":
+            return str(response["job"]["status"])
+        else:
+            return str(response["status"])
 
     def wait_for_job(self, job_id: str | int, wait_seconds: float = 3, 
timeout: float | None = 3600) -> None:
         """
@@ -124,7 +151,10 @@ class AirbyteHook(HttpHook):
             time.sleep(wait_seconds)
             try:
                 job = self.get_job(job_id=(int(job_id)))
-                state = job.json()["job"]["status"]
+                if self.api_type == "config":
+                    state = job.json()["job"]["status"]
+                else:
+                    state = job.json()["status"]
             except AirflowException as err:
                 self.log.info("Retrying. Airbyte API returned server error 
when waiting for job: %s", err)
                 continue
@@ -146,11 +176,23 @@ class AirbyteHook(HttpHook):
 
         :param connection_id: Required. The ConnectionId of the Airbyte 
Connection.
         """
-        return self.run(
-            endpoint=f"api/{self.api_version}/connections/sync",
-            json={"connectionId": connection_id},
-            headers={"accept": "application/json"},
-        )
+        if self.api_type == "config":
+            return self.run(
+                endpoint=f"api/{self.api_version}/connections/sync",
+                json={"connectionId": connection_id},
+                headers={"accept": "application/json"},
+            )
+        else:
+            conn = self.get_connection(self.http_conn_id)
+            self.method = "POST"
+            return self.run(
+                endpoint=f"{self.api_version}/jobs",
+                headers={"accept": "application/json", "authorization": 
f"Bearer {conn.password}"},
+                json={
+                    "jobType": "sync",
+                    "connectionId": connection_id,
+                },  # TODO: add an option to pass jobType = reset
+            )
 
     def get_job(self, job_id: int) -> Any:
         """
@@ -158,11 +200,19 @@ class AirbyteHook(HttpHook):
 
         :param job_id: Required. Id of the Airbyte job
         """
-        return self.run(
-            endpoint=f"api/{self.api_version}/jobs/get",
-            json={"id": job_id},
-            headers={"accept": "application/json"},
-        )
+        if self.api_type == "config":
+            return self.run(
+                endpoint=f"api/{self.api_version}/jobs/get",
+                json={"id": job_id},
+                headers={"accept": "application/json"},
+            )
+        else:
+            self.method = "GET"
+            conn = self.get_connection(self.http_conn_id)
+            return self.run(
+                endpoint=f"{self.api_version}/jobs/{job_id}",
+                headers={"accept": "application/json", "authorization": 
f"Bearer {conn.password}"},
+            )
 
     def cancel_job(self, job_id: int) -> Any:
         """
@@ -170,11 +220,19 @@ class AirbyteHook(HttpHook):
 
         :param job_id: Required. Id of the Airbyte job
         """
-        return self.run(
-            endpoint=f"api/{self.api_version}/jobs/cancel",
-            json={"id": job_id},
-            headers={"accept": "application/json"},
-        )
+        if self.api_type == "config":
+            return self.run(
+                endpoint=f"api/{self.api_version}/jobs/cancel",
+                json={"id": job_id},
+                headers={"accept": "application/json"},
+            )
+        else:
+            self.method = "DELETE"
+            conn = self.get_connection(self.http_conn_id)
+            return self.run(
+                endpoint=f"{self.api_version}/jobs/{job_id}",
+                headers={"accept": "application/json", "authorization": 
f"Bearer {conn.password}"},
+            )
 
     def test_connection(self):
         """Tests the Airbyte connection by hitting the health API."""
diff --git a/airflow/providers/airbyte/operators/airbyte.py 
b/airflow/providers/airbyte/operators/airbyte.py
index d8fdddb0e2..c2f0a56202 100644
--- a/airflow/providers/airbyte/operators/airbyte.py
+++ b/airflow/providers/airbyte/operators/airbyte.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import time
-from typing import TYPE_CHECKING, Any, Sequence
+from typing import TYPE_CHECKING, Any, Literal, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
@@ -46,6 +46,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
         waiting on them asynchronously using the AirbyteJobSensor. Defaults to 
False.
     :param deferrable: Run operator in the deferrable mode.
     :param api_version: Optional. Airbyte API version. Defaults to "v1".
+    :param api_type: Optional. The type of Airbyte API to use. Either "config" 
or "cloud". Defaults to "config".
     :param wait_seconds: Optional. Number of seconds between checks. Only used 
when ``asynchronous`` is False.
         Defaults to 3 seconds.
     :param timeout: Optional. The amount of time, in seconds, to wait for the 
request to complete.
@@ -62,6 +63,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
         asynchronous: bool = False,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         api_version: str = "v1",
+        api_type: Literal["config", "cloud"] = "config",
         wait_seconds: float = 3,
         timeout: float = 3600,
         **kwargs,
@@ -71,16 +73,23 @@ class AirbyteTriggerSyncOperator(BaseOperator):
         self.connection_id = connection_id
         self.timeout = timeout
         self.api_version = api_version
+        self.api_type = api_type
         self.wait_seconds = wait_seconds
         self.asynchronous = asynchronous
         self.deferrable = deferrable
 
     def execute(self, context: Context) -> None:
         """Create Airbyte Job and wait to finish."""
-        hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version)
+        hook = AirbyteHook(
+            airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version, api_type=self.api_type
+        )
         job_object = 
hook.submit_sync_connection(connection_id=self.connection_id)
-        self.job_id = job_object.json()["job"]["id"]
-        state = job_object.json()["job"]["status"]
+        if self.api_type == "config":
+            self.job_id = job_object.json()["job"]["id"]
+            state = job_object.json()["job"]["status"]
+        else:
+            self.job_id = job_object.json()["jobId"]
+            state = job_object.json()["status"]
         end_time = time.time() + self.timeout
 
         self.log.info("Job %s was submitted to Airbyte Server", self.job_id)
@@ -92,6 +101,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
                         timeout=self.execution_timeout,
                         trigger=AirbyteSyncTrigger(
                             conn_id=self.airbyte_conn_id,
+                            api_type=self.api_type,
                             job_id=self.job_id,
                             end_time=end_time,
                             poll_interval=60,
@@ -128,7 +138,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
 
     def on_kill(self):
         """Cancel the job if task is cancelled."""
-        hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id)
+        hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, 
api_type=self.api_type)
         if self.job_id:
             self.log.info("on_kill: cancel the airbyte Job %s", self.job_id)
             hook.cancel_job(self.job_id)
diff --git a/airflow/providers/airbyte/sensors/airbyte.py 
b/airflow/providers/airbyte/sensors/airbyte.py
index 4cc280cfc4..2db010d3bb 100644
--- a/airflow/providers/airbyte/sensors/airbyte.py
+++ b/airflow/providers/airbyte/sensors/airbyte.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import time
 import warnings
-from typing import TYPE_CHECKING, Any, Sequence
+from typing import TYPE_CHECKING, Any, Literal, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
@@ -41,6 +41,7 @@ class AirbyteJobSensor(BaseSensorOperator):
     :param deferrable: Run sensor in the deferrable mode.
         connection information for Airbyte. Defaults to "airbyte_default".
     :param api_version: Optional. Airbyte API version. Defaults to "v1".
+    :param api_type: Optional. The type of Airbyte API to use. Either "config" 
or "cloud". Defaults to "config".
     """
 
     template_fields: Sequence[str] = ("airbyte_job_id",)
@@ -53,6 +54,7 @@ class AirbyteJobSensor(BaseSensorOperator):
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         airbyte_conn_id: str = "airbyte_default",
         api_version: str = "v1",
+        api_type: Literal["config", "cloud"] = "config",
         **kwargs,
     ) -> None:
         if deferrable:
@@ -77,11 +79,17 @@ class AirbyteJobSensor(BaseSensorOperator):
         self.airbyte_conn_id = airbyte_conn_id
         self.airbyte_job_id = airbyte_job_id
         self.api_version = api_version
+        self.api_type = api_type
 
     def poke(self, context: Context) -> bool:
-        hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version)
+        hook = AirbyteHook(
+            airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version, api_type=self.api_type
+        )
         job = hook.get_job(job_id=self.airbyte_job_id)
-        status = job.json()["job"]["status"]
+        if self.api_type == "config":
+            status = job.json()["job"]["status"]
+        else:
+            status = job.json()["status"]
 
         if status == hook.FAILED:
             # TODO: remove this if block when min_airflow_version is set to 
higher than 2.7.1
@@ -109,9 +117,14 @@ class AirbyteJobSensor(BaseSensorOperator):
         if not self.deferrable:
             super().execute(context)
         else:
-            hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id)
+            hook = AirbyteHook(
+                airbyte_conn_id=self.airbyte_conn_id, 
api_version=self.api_version, api_type=self.api_type
+            )
             job = hook.get_job(job_id=(int(self.airbyte_job_id)))
-            state = job.json()["job"]["status"]
+            if self.api_type == "config":
+                state = job.json()["job"]["status"]
+            else:
+                state = job.json()["status"]
             end_time = time.time() + self.timeout
 
             self.log.info("Airbyte Job Id: Job %s", self.airbyte_job_id)
@@ -120,6 +133,7 @@ class AirbyteJobSensor(BaseSensorOperator):
                 self.defer(
                     timeout=self.execution_timeout,
                     trigger=AirbyteSyncTrigger(
+                        api_type=self.api_type,
                         conn_id=self.airbyte_conn_id,
                         job_id=self.airbyte_job_id,
                         end_time=end_time,
diff --git a/airflow/providers/airbyte/triggers/airbyte.py 
b/airflow/providers/airbyte/triggers/airbyte.py
index cec032a5e0..67a9be69b8 100644
--- a/airflow/providers/airbyte/triggers/airbyte.py
+++ b/airflow/providers/airbyte/triggers/airbyte.py
@@ -18,7 +18,7 @@ from __future__ import annotations
 
 import asyncio
 import time
-from typing import Any, AsyncIterator
+from typing import Any, AsyncIterator, Literal
 
 from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
 from airflow.triggers.base import BaseTrigger, TriggerEvent
@@ -32,6 +32,7 @@ class AirbyteSyncTrigger(BaseTrigger):
     makes use of asynchronous communication to check the progress of a job run 
over time.
 
     :param conn_id: The connection identifier for connecting to Airbyte.
+    :param api_type: The type of Airbyte API to use. Either "config" or 
"cloud".
     :param job_id: The ID of an Airbyte Sync job.
     :param end_time: Time in seconds to wait for a job run to reach a terminal 
status. Defaults to 7 days.
     :param poll_interval:  polling period in seconds to check for the status.
@@ -41,12 +42,14 @@ class AirbyteSyncTrigger(BaseTrigger):
         self,
         job_id: int,
         conn_id: str,
+        api_type: Literal["config", "cloud"],
         end_time: float,
         poll_interval: float,
     ):
         super().__init__()
         self.job_id = job_id
         self.conn_id = conn_id
+        self.api_type = api_type
         self.end_time = end_time
         self.poll_interval = poll_interval
 
@@ -57,6 +60,7 @@ class AirbyteSyncTrigger(BaseTrigger):
             {
                 "job_id": self.job_id,
                 "conn_id": self.conn_id,
+                "api_type": self.api_type,
                 "end_time": self.end_time,
                 "poll_interval": self.poll_interval,
             },
@@ -64,7 +68,7 @@ class AirbyteSyncTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Make async connection to Airbyte, polls for the pipeline run 
status."""
-        hook = AirbyteHook(airbyte_conn_id=self.conn_id)
+        hook = AirbyteHook(airbyte_conn_id=self.conn_id, 
api_type=self.api_type)
         try:
             while await self.is_still_running(hook):
                 if self.end_time < time.time():
diff --git a/docs/apache-airflow-providers-airbyte/connections.rst 
b/docs/apache-airflow-providers-airbyte/connections.rst
index 31b69c70a6..28f6a3a747 100644
--- a/docs/apache-airflow-providers-airbyte/connections.rst
+++ b/docs/apache-airflow-providers-airbyte/connections.rst
@@ -21,8 +21,8 @@ Airbyte Connection
 ==================
 The Airbyte connection type use the HTTP protocol.
 
-Configuring the Connection
---------------------------
+Configuring the Connection - Config API
+---------------------------------------
 Host(required)
     The host to connect to the Airbyte server.
 
@@ -34,3 +34,11 @@ Login (optional)
 
 Password (optional)
     Specify the password to connect.
+
+Configuring the Connection - Cloud API
+--------------------------------------
+Host(required)
+    The host to connect to the Airbyte Cloud. (Typically 
``https://api.airbyte.com``)
+
+Password (required)
+    Cloud API Key obtained from https://portal.airbyte.com/
diff --git a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst 
b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
index 60f47955dd..55eb110163 100644
--- a/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
+++ b/docs/apache-airflow-providers-airbyte/operators/airbyte.rst
@@ -38,6 +38,11 @@ create in Airbyte between a source and destination 
synchronization job.
 Use the ``airbyte_conn_id`` parameter to specify the Airbyte connection to use 
to
 connect to your account.
 
+Airbyte currently supports two different API's. The first one is the `Config 
API 
<https://airbyte-public-api-docs.s3.us-east-2.amazonaws.com/rapidoc-api-docs.html>`_
+which is specifically used for Open Source Airbyte Instances. The second is 
the `Cloud API <https://reference.airbyte.com/reference/start>`_
+which is used for the Airbyte Cloud Service. If you are using Airbyte's Cloud 
service,
+then you will need to specify ``api_type="cloud"`` as part of the Operator's 
parameters.
+
 You can trigger a synchronization job in Airflow in two ways with the 
Operator. The first one is a synchronous process.
 This Operator will initiate the Airbyte job, and the Operator manages the job 
status. Another way is to use the flag
 ``async = True`` so the Operator only triggers the job and returns the 
``job_id``, passed to the AirbyteSensor.
diff --git a/tests/providers/airbyte/triggers/test_airbyte.py 
b/tests/providers/airbyte/triggers/test_airbyte.py
index 103df7cf00..4de34f5053 100644
--- a/tests/providers/airbyte/triggers/test_airbyte.py
+++ b/tests/providers/airbyte/triggers/test_airbyte.py
@@ -31,6 +31,7 @@ class TestAirbyteSyncTrigger:
     DAG_ID = "airbyte_sync_run"
     TASK_ID = "airbyte_sync_run_task_op"
     JOB_ID = 1234
+    API_TYPE = "config"
     CONN_ID = "airbyte_default"
     END_TIME = time.time() + 60 * 60 * 24 * 7
     POLL_INTERVAL = 3.0
@@ -38,13 +39,18 @@ class TestAirbyteSyncTrigger:
     def test_serialization(self):
         """Assert TestAirbyteSyncTrigger correctly serializes its arguments 
and classpath."""
         trigger = AirbyteSyncTrigger(
-            conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, 
end_time=self.END_TIME, job_id=self.JOB_ID
+            api_type=self.API_TYPE,
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
         )
         classpath, kwargs = trigger.serialize()
         assert classpath == 
"airflow.providers.airbyte.triggers.airbyte.AirbyteSyncTrigger"
         assert kwargs == {
             "job_id": self.JOB_ID,
             "conn_id": self.CONN_ID,
+            "api_type": self.API_TYPE,
             "end_time": self.END_TIME,
             "poll_interval": self.POLL_INTERVAL,
         }
@@ -55,6 +61,7 @@ class TestAirbyteSyncTrigger:
         """Test AirbyteSyncTrigger is triggered with mocked details and run 
successfully."""
         mocked_is_still_running.return_value = True
         trigger = AirbyteSyncTrigger(
+            api_type=self.API_TYPE,
             conn_id=self.CONN_ID,
             poll_interval=self.POLL_INTERVAL,
             end_time=self.END_TIME,
@@ -83,6 +90,7 @@ class TestAirbyteSyncTrigger:
         mocked_is_still_running.return_value = False
         mock_get_job_status.return_value = mock_value
         trigger = AirbyteSyncTrigger(
+            api_type=self.API_TYPE,
             conn_id=self.CONN_ID,
             poll_interval=self.POLL_INTERVAL,
             end_time=self.END_TIME,
@@ -114,7 +122,11 @@ class TestAirbyteSyncTrigger:
         mocked_is_still_running.return_value = False
         mock_get_job_status.return_value = mock_value
         trigger = AirbyteSyncTrigger(
-            conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, 
end_time=self.END_TIME, job_id=self.JOB_ID
+            api_type=self.API_TYPE,
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
         )
         expected_result = {
             "status": mock_status,
@@ -142,6 +154,7 @@ class TestAirbyteSyncTrigger:
         mocked_is_still_running.return_value = False
         mock_get_job_status.return_value = mock_value
         trigger = AirbyteSyncTrigger(
+            api_type=self.API_TYPE,
             conn_id=self.CONN_ID,
             poll_interval=self.POLL_INTERVAL,
             end_time=self.END_TIME,
@@ -165,6 +178,7 @@ class TestAirbyteSyncTrigger:
         mocked_is_still_running.return_value = False
         mock_get_job_status.side_effect = Exception("Test exception")
         trigger = AirbyteSyncTrigger(
+            api_type=self.API_TYPE,
             conn_id=self.CONN_ID,
             poll_interval=self.POLL_INTERVAL,
             end_time=self.END_TIME,
@@ -190,6 +204,7 @@ class TestAirbyteSyncTrigger:
         mock_get_job_status.side_effect = Exception("Test exception")
         end_time = time.time()
         trigger = AirbyteSyncTrigger(
+            api_type=self.API_TYPE,
             conn_id=self.CONN_ID,
             poll_interval=self.POLL_INTERVAL,
             end_time=end_time,
@@ -223,6 +238,7 @@ class TestAirbyteSyncTrigger:
         hook = mock.AsyncMock(AirbyteHook)
         hook.get_job_status.return_value = mock_response
         trigger = AirbyteSyncTrigger(
+            api_type=self.API_TYPE,
             conn_id=self.CONN_ID,
             poll_interval=self.POLL_INTERVAL,
             end_time=self.END_TIME,
@@ -247,7 +263,11 @@ class TestAirbyteSyncTrigger:
         airbyte_hook = mock.AsyncMock(AirbyteHook)
         airbyte_hook.get_job_status.return_value = mock_response
         trigger = AirbyteSyncTrigger(
-            conn_id=self.CONN_ID, poll_interval=self.POLL_INTERVAL, 
end_time=self.END_TIME, job_id=self.JOB_ID
+            api_type=self.API_TYPE,
+            conn_id=self.CONN_ID,
+            poll_interval=self.POLL_INTERVAL,
+            end_time=self.END_TIME,
+            job_id=self.JOB_ID,
         )
         response = await trigger.is_still_running(airbyte_hook)
         assert response == expected_status

Reply via email to