This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c0eaa9b25d Add deferrable mode to CloudSQLExportInstanceOperator 
(#30852)
c0eaa9b25d is described below

commit c0eaa9b25d11eeb6fba1b716323c4ff2c2dbd5e1
Author: Beata Kossakowska <[email protected]>
AuthorDate: Thu Jun 29 19:30:32 2023 +0200

    Add deferrable mode to CloudSQLExportInstanceOperator (#30852)
    
    Co-authored-by: Beata Kossakowska <[email protected]>
---
 airflow/providers/google/cloud/hooks/cloud_sql.py  |  50 +++++-
 .../providers/google/cloud/operators/cloud_sql.py  |  41 ++++-
 .../providers/google/cloud/triggers/cloud_sql.py   | 102 ++++++++++++
 airflow/providers/google/provider.yaml             |   3 +
 .../operators/cloud/cloud_sql.rst                  |   8 +
 .../providers/google/cloud/hooks/test_cloud_sql.py | 113 +++++++++++--
 .../google/cloud/operators/test_cloud_sql.py       |  37 ++++-
 .../google/cloud/triggers/test_cloud_sql.py        | 150 +++++++++++++++++
 .../cloud_sql/example_cloud_sql_deferrable.py      | 184 +++++++++++++++++++++
 9 files changed, 666 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py 
b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 42abfc1c24..304a6f88fa 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -39,16 +39,18 @@ from typing import Any, Sequence
 from urllib.parse import quote_plus
 
 import httpx
+from aiohttp import ClientSession
+from gcloud.aio.auth import AioSession, Token
 from googleapiclient.discovery import Resource, build
 from googleapiclient.errors import HttpError
-
-from airflow.exceptions import AirflowException
+from requests import Session
 
 # Number of retries - used by googleapiclient method calls to perform retries
 # For requests that are "retriable"
+from airflow.exceptions import AirflowException
 from airflow.hooks.base import BaseHook
 from airflow.models import Connection
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, 
get_field
+from airflow.providers.google.common.hooks.base_google import 
GoogleBaseAsyncHook, GoogleBaseHook, get_field
 from airflow.providers.mysql.hooks.mysql import MySqlHook
 from airflow.providers.postgres.hooks.postgres import PostgresHook
 from airflow.utils.log.logging_mixin import LoggingMixin
@@ -300,8 +302,7 @@ class CloudSQLHook(GoogleBaseHook):
         self._wait_for_operation_to_complete(project_id=project_id, 
operation_name=operation_name)
 
     @GoogleBaseHook.fallback_to_default_project_id
-    @GoogleBaseHook.operation_in_progress_retry()
-    def export_instance(self, instance: str, body: dict, project_id: str) -> 
None:
+    def export_instance(self, instance: str, body: dict, project_id: str):
         """
         Exports data from a Cloud SQL instance to a Cloud Storage bucket as a 
SQL dump
         or CSV file.
@@ -321,7 +322,7 @@ class CloudSQLHook(GoogleBaseHook):
             .execute(num_retries=self.num_retries)
         )
         operation_name = response["name"]
-        self._wait_for_operation_to_complete(project_id=project_id, 
operation_name=operation_name)
+        return operation_name
 
     @GoogleBaseHook.fallback_to_default_project_id
     def import_instance(self, instance: str, body: dict, project_id: str) -> 
None:
@@ -376,6 +377,7 @@ class CloudSQLHook(GoogleBaseHook):
         except HttpError as ex:
             raise AirflowException(f"Cloning of instance {instance} failed: 
{ex.content}")
 
+    @GoogleBaseHook.fallback_to_default_project_id
     def _wait_for_operation_to_complete(
         self, project_id: str, operation_name: str, time_to_sleep: int = 
TIME_TO_SLEEP_IN_SECONDS
     ) -> None:
@@ -412,6 +414,42 @@ CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL = (
 )
 
 
+class CloudSQLAsyncHook(GoogleBaseAsyncHook):
+    """Class to get asynchronous hook for Google Cloud SQL."""
+
+    sync_hook_class = CloudSQLHook
+
+    async def _get_conn(self, session: Session, url: str):
+        scopes = [
+            "https://www.googleapis.com/auth/cloud-platform";,
+            "https://www.googleapis.com/auth/sqlservice.admin";,
+        ]
+
+        async with Token(scopes=scopes) as token:
+            session_aio = AioSession(session)
+            headers = {
+                "Authorization": f"Bearer {await token.get()}",
+            }
+        return await session_aio.get(url=url, headers=headers)
+
+    async def get_operation_name(self, project_id: str, operation_name: str, 
session):
+        url = 
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project_id}/operations/{operation_name}";
+        return await self._get_conn(url=str(url), session=session)
+
+    async def get_operation(self, project_id: str, operation_name: str):
+        async with ClientSession() as session:
+            try:
+                operation = await self.get_operation_name(
+                    project_id=project_id,
+                    operation_name=operation_name,
+                    session=session,
+                )
+                operation = await operation.json(content_type=None)
+            except HttpError as e:
+                raise e
+            return operation
+
+
 class CloudSqlProxyRunner(LoggingMixin):
     """
     Downloads and runs cloud-sql-proxy as subprocess of the Python process.
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py 
b/airflow/providers/google/cloud/operators/cloud_sql.py
index 20a254b954..5c77cbd86c 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -28,6 +28,7 @@ from airflow.models import Connection
 from airflow.providers.google.cloud.hooks.cloud_sql import 
CloudSQLDatabaseHook, CloudSQLHook
 from airflow.providers.google.cloud.links.cloud_sql import 
CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink
 from airflow.providers.google.cloud.operators.cloud_base import 
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_sql import 
CloudSQLExportTrigger
 from airflow.providers.google.cloud.utils.field_validator import 
GcpBodyFieldValidator
 from airflow.providers.google.common.hooks.base_google import get_field
 from airflow.providers.google.common.links.storage import FileDetailsLink
@@ -926,6 +927,9 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
         If set as a sequence, the identities from the list must grant
         Service Account Token Creator IAM role to the directly preceding 
identity, with first
         account from the list granting this role to the originating account 
(templated).
+    :param deferrable: Run operator in the deferrable mode.
+    :param poke_interval: (Deferrable mode only) Time (seconds) to wait 
between calls
+        to check the run status.
     """
 
     # [START gcp_sql_export_template_fields]
@@ -951,10 +955,14 @@ class 
CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
         api_version: str = "v1beta4",
         validate_body: bool = True,
         impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = False,
+        poke_interval: int = 10,
         **kwargs,
     ) -> None:
         self.body = body
         self.validate_body = validate_body
+        self.deferrable = deferrable
+        self.poke_interval = poke_interval
         super().__init__(
             project_id=project_id,
             instance=instance,
@@ -994,7 +1002,38 @@ class 
CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
             uri=self.body["exportContext"]["uri"][5:],
             project_id=self.project_id or hook.project_id,
         )
-        return hook.export_instance(project_id=self.project_id, 
instance=self.instance, body=self.body)
+
+        operation_name = hook.export_instance(
+            project_id=self.project_id, instance=self.instance, body=self.body
+        )
+
+        if not self.deferrable:
+            return hook._wait_for_operation_to_complete(
+                project_id=self.project_id, operation_name=operation_name
+            )
+        else:
+            self.defer(
+                trigger=CloudSQLExportTrigger(
+                    operation_name=operation_name,
+                    project_id=self.project_id or hook.project_id,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    poke_interval=self.poke_interval,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context, event=None) -> None:
+        """
+        Callback for when the trigger fires - returns immediately.
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was
+        successful.
+        """
+        if event["status"] == "success":
+            self.log.info("Operation %s completed successfully", 
event["operation_name"])
+        else:
+            self.log.exception("Unexpected error in the operation.")
+            raise AirflowException(event["message"])
 
 
 class CloudSQLImportInstanceOperator(CloudSQLBaseOperator):
diff --git a/airflow/providers/google/cloud/triggers/cloud_sql.py 
b/airflow/providers/google/cloud/triggers/cloud_sql.py
new file mode 100644
index 0000000000..7d2cd5a323
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/cloud_sql.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains Google Cloud SQL triggers."""
+from __future__ import annotations
+
+import asyncio
+from typing import Sequence
+
+from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLAsyncHook, 
CloudSqlOperationStatus
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class CloudSQLExportTrigger(BaseTrigger):
+    """
+    Trigger that periodically polls information from Cloud SQL API to verify 
job status.
+    Implementation leverages asynchronous transport.
+    """
+
+    def __init__(
+        self,
+        operation_name: str,
+        project_id: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        poke_interval: int = 20,
+    ):
+        super().__init__()
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.operation_name = operation_name
+        self.project_id = project_id
+        self.poke_interval = poke_interval
+        self.hook = CloudSQLAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+    def serialize(self):
+        return (
+            
"airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger",
+            {
+                "operation_name": self.operation_name,
+                "project_id": self.project_id,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    async def run(self):
+        while True:
+            try:
+                operation = await self.hook.get_operation(
+                    project_id=self.project_id, 
operation_name=self.operation_name
+                )
+                if operation["status"] == CloudSqlOperationStatus.DONE:
+                    if "error" in operation:
+                        yield TriggerEvent(
+                            {
+                                "operation_name": operation["name"],
+                                "status": "error",
+                                "message": operation["error"]["message"],
+                            }
+                        )
+                        return
+                    yield TriggerEvent(
+                        {
+                            "operation_name": operation["name"],
+                            "status": "success",
+                        }
+                    )
+                    return
+                else:
+                    self.log.info(
+                        "Operation status is %s, sleeping for %s seconds.",
+                        operation["status"],
+                        self.poke_interval,
+                    )
+                    await asyncio.sleep(self.poke_interval)
+            except Exception as e:
+                self.log.exception("Exception occurred while checking 
operation status.")
+                yield TriggerEvent(
+                    {
+                        "status": "failed",
+                        "message": str(e),
+                    }
+                )
diff --git a/airflow/providers/google/provider.yaml 
b/airflow/providers/google/provider.yaml
index 64fcba31bc..3fc2c90b9b 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -847,6 +847,9 @@ triggers:
   - integration-name: Google Cloud Composer
     python-modules:
       - airflow.providers.google.cloud.triggers.cloud_composer
+  - integration-name: Google Cloud SQL
+    python-modules:
+      - airflow.providers.google.cloud.triggers.cloud_sql
   - integration-name: Google Dataflow
     python-modules:
       - airflow.providers.google.cloud.triggers.dataflow
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst 
b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
index 80076b95f5..4f7c9428c3 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
@@ -241,6 +241,14 @@ it will be retrieved from the Google Cloud connection 
used. Both variants are sh
     :start-after: [START howto_operator_cloudsql_export]
     :end-before: [END howto_operator_cloudsql_export]
 
+Also for all this action you can use operator in the deferrable mode:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_cloudsql_export_async]
+    :end-before: [END howto_operator_cloudsql_export_async]
+
 Templating
 """"""""""
 
diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py 
b/tests/providers/google/cloud/hooks/test_cloud_sql.py
index 27eb176da4..0d90f8d0c9 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_sql.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py
@@ -24,13 +24,17 @@ import tempfile
 from unittest import mock
 from unittest.mock import PropertyMock
 
+import aiohttp
 import httplib2
 import pytest
+from aiohttp.helpers import TimerNoop
 from googleapiclient.errors import HttpError
+from yarl import URL
 
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.google.cloud.hooks.cloud_sql import (
+    CloudSQLAsyncHook,
     CloudSQLDatabaseHook,
     CloudSQLHook,
     CloudSqlProxyRunner,
@@ -40,6 +44,26 @@ from tests.providers.google.cloud.utils.base_gcp_mock import 
(
     mock_base_gcp_hook_no_default_project_id,
 )
 
+HOOK_STR = "airflow.providers.google.cloud.hooks.cloud_sql.{}"
+PROJECT_ID = "test_project_id"
+OPERATION_NAME = "test_operation_name"
+OPERATION_URL = (
+    
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}";
+)
+
+
[email protected]
+def hook_async():
+    with mock.patch(
+        
"airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
+        new=mock_base_gcp_hook_default_project_id,
+    ):
+        yield CloudSQLAsyncHook()
+
+
+def session():
+    return mock.Mock()
+
 
 class TestGcpSqlHookDefaultProjectId:
     def test_delegate_to_runtime_error(self):
@@ -116,9 +140,6 @@ class TestGcpSqlHookDefaultProjectId:
 
         export_method.assert_called_once_with(body={}, instance="instance", 
project="example-project")
         execute_method.assert_called_once_with(num_retries=5)
-        wait_for_operation_to_complete.assert_called_once_with(
-            project_id="example-project", operation_name="operation_id"
-        )
         assert 1 == mock_get_credentials.call_count
 
     
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn")
@@ -133,14 +154,9 @@ class TestGcpSqlHookDefaultProjectId:
             ),
             {"name": "operation_id"},
         ]
-        wait_for_operation_to_complete.return_value = None
-        self.cloudsql_hook.export_instance(project_id="example-project", 
instance="instance", body={})
-
-        assert 2 == export_method.call_count
-        assert 2 == execute_method.call_count
-        wait_for_operation_to_complete.assert_called_once_with(
-            project_id="example-project", operation_name="operation_id"
-        )
+        with pytest.raises(HttpError):
+            self.cloudsql_hook.export_instance(project_id="example-project", 
instance="instance", body={})
+        wait_for_operation_to_complete.assert_not_called()
 
     @mock.patch(
         
"airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_credentials_and_project_id",
@@ -551,9 +567,6 @@ class TestGcpSqlHookNoDefaultProjectID:
         )
         export_method.assert_called_once_with(body={}, instance="instance", 
project="example-project")
         execute_method.assert_called_once_with(num_retries=5)
-        wait_for_operation_to_complete.assert_called_once_with(
-            project_id="example-project", operation_name="operation_id"
-        )
 
     @mock.patch(
         
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id",
@@ -1238,3 +1251,75 @@ class TestCloudSqlProxyRunner:
         )
         with pytest.raises(ValueError, match="The sql_proxy_version should 
match the regular expression"):
             runner._get_sql_proxy_download_url()
+
+
+class TestCloudSQLAsyncHook:
+    @pytest.mark.asyncio
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook._get_conn"))
+    async def test_async_get_operation_name_should_execute_successfully(self, 
mocked_conn, hook_async):
+        await hook_async.get_operation_name(
+            operation_name=OPERATION_NAME,
+            project_id=PROJECT_ID,
+            session=session,
+        )
+
+        mocked_conn.assert_awaited_once_with(url=OPERATION_URL, 
session=session)
+
+    @pytest.mark.asyncio
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation_name"))
+    async def 
test_async_get_operation_completed_should_execute_successfully(self, 
mocked_get, hook_async):
+        response = aiohttp.ClientResponse(
+            "get",
+            URL(OPERATION_URL),
+            request_info=mock.Mock(),
+            writer=mock.Mock(),
+            continue100=None,
+            timer=TimerNoop(),
+            traces=[],
+            loop=mock.Mock(),
+            session=session,
+        )
+        response.status = 200
+        mocked_get.return_value = response
+        mocked_get.return_value._headers = {"Authorization": "test-token"}
+        mocked_get.return_value._body = b'{"status": "DONE"}'
+
+        operation = await 
hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID)
+        mocked_get.assert_awaited_once()
+        assert operation["status"] == "DONE"
+
+    @pytest.mark.asyncio
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation_name"))
+    async def 
test_async_get_operation_running_should_execute_successfully(self, mocked_get, 
hook_async):
+        response = aiohttp.ClientResponse(
+            "get",
+            URL(OPERATION_URL),
+            request_info=mock.Mock(),
+            writer=mock.Mock(),
+            continue100=None,
+            timer=TimerNoop(),
+            traces=[],
+            loop=mock.Mock(),
+            session=session,
+        )
+        response.status = 200
+        mocked_get.return_value = response
+        mocked_get.return_value._headers = {"Authorization": "test-token"}
+        mocked_get.return_value._body = b'{"status": "RUNNING"}'
+
+        operation = await 
hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID)
+        mocked_get.assert_awaited_once()
+        assert operation["status"] == "RUNNING"
+
+    @pytest.mark.asyncio
+    @mock.patch(HOOK_STR.format("CloudSQLAsyncHook._get_conn"))
+    async def test_async_get_operation_exception_should_execute_successfully(
+        self, mocked_get_conn, hook_async
+    ):
+        """Assets that the logging is done correctly when CloudSQLAsyncHook 
raises HttpError"""
+
+        mocked_get_conn.side_effect = HttpError(
+            resp=mock.MagicMock(status=409), content=b"Operation already 
exists"
+        )
+        with pytest.raises(HttpError):
+            await hook_async.get_operation(operation_name=OPERATION_NAME, 
project_id=PROJECT_ID)
diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py 
b/tests/providers/google/cloud/operators/test_cloud_sql.py
index 888dc11ebb..903c5e3c41 100644
--- a/tests/providers/google/cloud/operators/test_cloud_sql.py
+++ b/tests/providers/google/cloud/operators/test_cloud_sql.py
@@ -22,7 +22,7 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.models import Connection
 from airflow.providers.google.cloud.operators.cloud_sql import (
     CloudSQLCloneInstanceOperator,
@@ -36,6 +36,8 @@ from airflow.providers.google.cloud.operators.cloud_sql 
import (
     CloudSQLInstancePatchOperator,
     CloudSQLPatchInstanceDatabaseOperator,
 )
+from airflow.providers.google.cloud.triggers.cloud_sql import 
CloudSQLExportTrigger
+from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
 PROJECT_ID = os.environ.get("PROJECT_ID", "project-id")
 INSTANCE_NAME = os.environ.get("INSTANCE_NAME", "test-name")
@@ -669,6 +671,39 @@ class TestCloudSql:
         )
         assert result
 
+    
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
+    
@mock.patch("airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLAsyncHook")
+    def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
+        operator = CloudSQLExportInstanceOperator(
+            task_id="test_task",
+            instance=INSTANCE_NAME,
+            body=EXPORT_BODY,
+            deferrable=True,
+        )
+
+        with pytest.raises(TaskDeferred) as exc:
+            operator.execute(mock.MagicMock())
+
+        mock_hook.return_value.export_instance.assert_called_once()
+
+        mock_hook.return_value.get_operation.assert_not_called()
+        assert isinstance(exc.value.trigger, CloudSQLExportTrigger)
+        assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
+    def test_async_execute_should_should_throw_exception(self):
+        """Tests that an AirflowException is raised in case of error event"""
+
+        op = CloudSQLExportInstanceOperator(
+            task_id="test_task",
+            instance=INSTANCE_NAME,
+            body=EXPORT_BODY,
+            deferrable=True,
+        )
+        with pytest.raises(AirflowException):
+            op.execute_complete(
+                context=mock.MagicMock(), event={"status": "error", "message": 
"test failure message"}
+            )
+
     
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
     def test_instance_import(self, mock_hook):
         mock_hook.return_value.export_instance.return_value = True
diff --git a/tests/providers/google/cloud/triggers/test_cloud_sql.py 
b/tests/providers/google/cloud/triggers/test_cloud_sql.py
new file mode 100644
index 0000000000..c7cbb2046c
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_cloud_sql.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+from unittest import mock as async_mock
+
+import pytest
+
+from airflow.providers.google.cloud.triggers.cloud_sql import 
CloudSQLExportTrigger
+from airflow.triggers.base import TriggerEvent
+
+CLASSPATH = 
"airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger"
+TASK_ID = "test_task"
+TEST_POLL_INTERVAL = 10
+TEST_GCP_CONN_ID = "test-project"
+HOOK_STR = "airflow.providers.google.cloud.hooks.cloud_sql.{}"
+PROJECT_ID = "test_project_id"
+OPERATION_NAME = "test_operation_name"
+OPERATION_URL = (
+    
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}";
+)
+
+
[email protected]
+def trigger():
+    return CloudSQLExportTrigger(
+        operation_name=OPERATION_NAME,
+        project_id=PROJECT_ID,
+        impersonation_chain=None,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        poke_interval=TEST_POLL_INTERVAL,
+    )
+
+
+class TestCloudSQLExportTrigger:
+    def 
test_async_export_trigger_serialization_should_execute_successfully(self, 
trigger):
+        """
+        Asserts that the CloudSQLExportTrigger correctly serializes its 
arguments
+        and classpath.
+        """
+        classpath, kwargs = trigger.serialize()
+        assert classpath == CLASSPATH
+        assert kwargs == {
+            "operation_name": OPERATION_NAME,
+            "project_id": PROJECT_ID,
+            "impersonation_chain": None,
+            "gcp_conn_id": TEST_GCP_CONN_ID,
+            "poke_interval": TEST_POLL_INTERVAL,
+        }
+
+    @pytest.mark.asyncio
+    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    async def test_async_export_trigger_on_success_should_execute_successfully(
+        self, mock_get_operation, trigger
+    ):
+        """
+        Tests the CloudSQLExportTrigger only fires once the job execution 
reaches a successful state.
+        """
+        mock_get_operation.return_value = {
+            "status": "DONE",
+            "name": OPERATION_NAME,
+        }
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert (
+            TriggerEvent(
+                {
+                    "operation_name": OPERATION_NAME,
+                    "status": "success",
+                }
+            )
+            == actual
+        )
+
+    @pytest.mark.asyncio
+    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    async def test_async_export_trigger_running_should_execute_successfully(
+        self, mock_get_operation, trigger, caplog
+    ):
+        """
+        Test that CloudSQLExportTrigger does not fire while a job is still 
running.
+        """
+
+        mock_get_operation.return_value = {
+            "status": "RUNNING",
+            "name": OPERATION_NAME,
+        }
+        caplog.set_level(logging.INFO)
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # TriggerEvent was not returned
+        assert task.done() is False
+
+        assert f"Operation status is RUNNING, sleeping for 
{TEST_POLL_INTERVAL} seconds." in caplog.text
+
+        # Prevents error when task is destroyed while in "pending" state
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    async def 
test_async_export_trigger_error_should_execute_successfully(self, 
mock_get_operation, trigger):
+        """
+        Test that CloudSQLExportTrigger fires the correct event in case of an 
error.
+        """
+        mock_get_operation.return_value = {
+            "status": "DONE",
+            "name": OPERATION_NAME,
+            "error": {"message": "test_error"},
+        }
+
+        expected_event = {
+            "operation_name": OPERATION_NAME,
+            "status": "error",
+            "message": "test_error",
+        }
+
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent(expected_event) == actual
+
+    @pytest.mark.asyncio
+    @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+    async def test_async_export_trigger_exception_should_execute_successfully(
+        self, mock_get_operation, trigger
+    ):
+        """
+        Test that CloudSQLExportTrigger fires the correct event in case of an 
error.
+        """
+        mock_get_operation.side_effect = Exception("Test exception")
+
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent({"status": "failed", "message": "Test exception"}) 
== actual
diff --git 
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py 
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
new file mode 100644
index 0000000000..7859d93870
--- /dev/null
+++ 
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
@@ -0,0 +1,184 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that creates, patches and deletes a Cloud SQL instance, 
and also
+creates, patches and deletes a database inside the instance, in Google Cloud.
+
+This DAG relies on the following OS environment variables
+https://airflow.apache.org/concepts.html#variables
+* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance.
+* INSTANCE_NAME - Name of the Cloud SQL instance.
+* DB_NAME - Name of the database inside a Cloud SQL instance.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+from urllib.parse import urlsplit
+
+from airflow import models
+from airflow.models.xcom_arg import XComArg
+from airflow.providers.google.cloud.operators.cloud_sql import (
+    CloudSQLCreateInstanceDatabaseOperator,
+    CloudSQLCreateInstanceOperator,
+    CloudSQLDeleteInstanceDatabaseOperator,
+    CloudSQLDeleteInstanceOperator,
+    CloudSQLExportInstanceOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import (
+    GCSBucketCreateAclEntryOperator,
+    GCSCreateBucketOperator,
+    GCSDeleteBucketOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "cloudsql-def"
+
+INSTANCE_NAME = f"{DAG_ID}-{ENV_ID}-instance"
+DB_NAME = f"{DAG_ID}-{ENV_ID}-db"
+
+BUCKET_NAME = f"{DAG_ID}_{ENV_ID}_bucket"
+FILE_NAME = f"{DAG_ID}_{ENV_ID}_exportImportTestFile"
+FILE_URI = f"gs://{BUCKET_NAME}/{FILE_NAME}"
+
+# Bodies below represent Cloud SQL instance resources:
+# https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances
+
+body = {
+    "name": INSTANCE_NAME,
+    "settings": {
+        "tier": "db-n1-standard-1",
+        "backupConfiguration": {"binaryLogEnabled": True, "enabled": True, 
"startTime": "05:00"},
+        "activationPolicy": "ALWAYS",
+        "dataDiskSizeGb": 30,
+        "dataDiskType": "PD_SSD",
+        "databaseFlags": [],
+        "ipConfiguration": {
+            "ipv4Enabled": True,
+            "requireSsl": True,
+        },
+        "locationPreference": {"zone": "europe-west4-a"},
+        "maintenanceWindow": {"hour": 5, "day": 7, "updateTrack": "canary"},
+        "pricingPlan": "PER_USE",
+        "replicationType": "ASYNCHRONOUS",
+        "storageAutoResize": True,
+        "storageAutoResizeLimit": 0,
+        "userLabels": {"my-key": "my-value"},
+    },
+    "databaseVersion": "MYSQL_5_7",
+    "region": "europe-west4",
+}
+
+export_body = {
+    "exportContext": {
+        "fileType": "sql",
+        "uri": FILE_URI,
+        "sqlExportOptions": {"schemaOnly": False},
+        "offload": True,
+    }
+}
+
+db_create_body = {"instance": INSTANCE_NAME, "name": DB_NAME, "project": 
PROJECT_ID}
+
+
+with models.DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+    tags=["example", "cloud_sql"],
+) as dag:
+    create_bucket = GCSCreateBucketOperator(task_id="create_bucket", 
bucket_name=BUCKET_NAME)
+
+    sql_instance_create_task = CloudSQLCreateInstanceOperator(
+        body=body, instance=INSTANCE_NAME, task_id="sql_instance_create_task"
+    )
+
+    sql_db_create_task = CloudSQLCreateInstanceDatabaseOperator(
+        body=db_create_body, instance=INSTANCE_NAME, 
task_id="sql_db_create_task"
+    )
+
+    file_url_split = urlsplit(FILE_URI)
+
+    # For export & import to work we need to add the Cloud SQL instance's 
Service Account
+    # write access to the destination GCS bucket.
+    service_account_email = XComArg(sql_instance_create_task, 
key="service_account_email")
+
+    sql_gcp_add_bucket_permission_task = GCSBucketCreateAclEntryOperator(
+        entity=f"user-{service_account_email}",
+        role="WRITER",
+        bucket=file_url_split[1],  # netloc (bucket)
+        task_id="sql_gcp_add_bucket_permission_task",
+    )
+
+    # [START howto_operator_cloudsql_export_async]
+    sql_export_task = CloudSQLExportInstanceOperator(
+        body=export_body,
+        instance=INSTANCE_NAME,
+        task_id="sql_export_task",
+        deferrable=True,
+    )
+    # [END howto_operator_cloudsql_export_async]
+
+    sql_db_delete_task = CloudSQLDeleteInstanceDatabaseOperator(
+        instance=INSTANCE_NAME, database=DB_NAME, task_id="sql_db_delete_task"
+    )
+    sql_db_delete_task.trigger_rule = TriggerRule.ALL_DONE
+
+    sql_instance_delete_task = CloudSQLDeleteInstanceOperator(
+        instance=INSTANCE_NAME, task_id="sql_instance_delete_task"
+    )
+    sql_instance_delete_task.trigger_rule = TriggerRule.ALL_DONE
+
+    delete_bucket = GCSDeleteBucketOperator(
+        task_id="delete_bucket", bucket_name=BUCKET_NAME, 
trigger_rule=TriggerRule.ALL_DONE
+    )
+
+    (
+        # TEST SETUP
+        create_bucket
+        # TEST BODY
+        >> sql_instance_create_task
+        >> sql_db_create_task
+        >> sql_gcp_add_bucket_permission_task
+        >> sql_export_task
+        >> sql_db_delete_task
+        >> sql_instance_delete_task
+        # TEST TEARDOWN
+        >> delete_bucket
+    )
+
+    # Task dependencies created via `XComArgs`:
+    #   sql_instance_create_task >> sql_gcp_add_bucket_permission_task
+    #   sql_instance_create_task >> sql_gcp_add_object_permission_task
+
+    # ### Everything below this line is not part of example ###
+    # ### Just for system tests purpose ###
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)


Reply via email to