This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c0eaa9b25d Add deferrable mode to CloudSQLExportInstanceOperator
(#30852)
c0eaa9b25d is described below
commit c0eaa9b25d11eeb6fba1b716323c4ff2c2dbd5e1
Author: Beata Kossakowska <[email protected]>
AuthorDate: Thu Jun 29 19:30:32 2023 +0200
Add deferrable mode to CloudSQLExportInstanceOperator (#30852)
Co-authored-by: Beata Kossakowska <[email protected]>
---
airflow/providers/google/cloud/hooks/cloud_sql.py | 50 +++++-
.../providers/google/cloud/operators/cloud_sql.py | 41 ++++-
.../providers/google/cloud/triggers/cloud_sql.py | 102 ++++++++++++
airflow/providers/google/provider.yaml | 3 +
.../operators/cloud/cloud_sql.rst | 8 +
.../providers/google/cloud/hooks/test_cloud_sql.py | 113 +++++++++++--
.../google/cloud/operators/test_cloud_sql.py | 37 ++++-
.../google/cloud/triggers/test_cloud_sql.py | 150 +++++++++++++++++
.../cloud_sql/example_cloud_sql_deferrable.py | 184 +++++++++++++++++++++
9 files changed, 666 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/cloud_sql.py
b/airflow/providers/google/cloud/hooks/cloud_sql.py
index 42abfc1c24..304a6f88fa 100644
--- a/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -39,16 +39,18 @@ from typing import Any, Sequence
from urllib.parse import quote_plus
import httpx
+from aiohttp import ClientSession
+from gcloud.aio.auth import AioSession, Token
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError
-
-from airflow.exceptions import AirflowException
+from requests import Session
# Number of retries - used by googleapiclient method calls to perform retries
# For requests that are "retriable"
+from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook,
get_field
+from airflow.providers.google.common.hooks.base_google import
GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -300,8 +302,7 @@ class CloudSQLHook(GoogleBaseHook):
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)
@GoogleBaseHook.fallback_to_default_project_id
- @GoogleBaseHook.operation_in_progress_retry()
- def export_instance(self, instance: str, body: dict, project_id: str) ->
None:
+ def export_instance(self, instance: str, body: dict, project_id: str):
"""
Exports data from a Cloud SQL instance to a Cloud Storage bucket as a
SQL dump
or CSV file.
@@ -321,7 +322,7 @@ class CloudSQLHook(GoogleBaseHook):
.execute(num_retries=self.num_retries)
)
operation_name = response["name"]
- self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)
+ return operation_name
@GoogleBaseHook.fallback_to_default_project_id
def import_instance(self, instance: str, body: dict, project_id: str) ->
None:
@@ -376,6 +377,7 @@ class CloudSQLHook(GoogleBaseHook):
except HttpError as ex:
raise AirflowException(f"Cloning of instance {instance} failed:
{ex.content}")
+ @GoogleBaseHook.fallback_to_default_project_id
def _wait_for_operation_to_complete(
self, project_id: str, operation_name: str, time_to_sleep: int =
TIME_TO_SLEEP_IN_SECONDS
) -> None:
@@ -412,6 +414,42 @@ CLOUD_SQL_PROXY_VERSION_DOWNLOAD_URL = (
)
+class CloudSQLAsyncHook(GoogleBaseAsyncHook):
+ """Class to get asynchronous hook for Google Cloud SQL."""
+
+ sync_hook_class = CloudSQLHook
+
+ async def _get_conn(self, session: Session, url: str):
+ scopes = [
+ "https://www.googleapis.com/auth/cloud-platform",
+ "https://www.googleapis.com/auth/sqlservice.admin",
+ ]
+
+ async with Token(scopes=scopes) as token:
+ session_aio = AioSession(session)
+ headers = {
+ "Authorization": f"Bearer {await token.get()}",
+ }
+ return await session_aio.get(url=url, headers=headers)
+
+ async def get_operation_name(self, project_id: str, operation_name: str,
session):
+ url =
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project_id}/operations/{operation_name}"
+ return await self._get_conn(url=str(url), session=session)
+
+ async def get_operation(self, project_id: str, operation_name: str):
+ async with ClientSession() as session:
+ try:
+ operation = await self.get_operation_name(
+ project_id=project_id,
+ operation_name=operation_name,
+ session=session,
+ )
+ operation = await operation.json(content_type=None)
+ except HttpError as e:
+ raise e
+ return operation
+
+
class CloudSqlProxyRunner(LoggingMixin):
"""
Downloads and runs cloud-sql-proxy as subprocess of the Python process.
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py
b/airflow/providers/google/cloud/operators/cloud_sql.py
index 20a254b954..5c77cbd86c 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -28,6 +28,7 @@ from airflow.models import Connection
from airflow.providers.google.cloud.hooks.cloud_sql import
CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.links.cloud_sql import
CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_sql import
CloudSQLExportTrigger
from airflow.providers.google.cloud.utils.field_validator import
GcpBodyFieldValidator
from airflow.providers.google.common.hooks.base_google import get_field
from airflow.providers.google.common.links.storage import FileDetailsLink
@@ -926,6 +927,9 @@ class CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
+ :param deferrable: Run operator in the deferrable mode.
+ :param poke_interval: (Deferrable mode only) Time (seconds) to wait
between calls
+ to check the run status.
"""
# [START gcp_sql_export_template_fields]
@@ -951,10 +955,14 @@ class
CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
api_version: str = "v1beta4",
validate_body: bool = True,
impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = False,
+ poke_interval: int = 10,
**kwargs,
) -> None:
self.body = body
self.validate_body = validate_body
+ self.deferrable = deferrable
+ self.poke_interval = poke_interval
super().__init__(
project_id=project_id,
instance=instance,
@@ -994,7 +1002,38 @@ class
CloudSQLExportInstanceOperator(CloudSQLBaseOperator):
uri=self.body["exportContext"]["uri"][5:],
project_id=self.project_id or hook.project_id,
)
- return hook.export_instance(project_id=self.project_id,
instance=self.instance, body=self.body)
+
+ operation_name = hook.export_instance(
+ project_id=self.project_id, instance=self.instance, body=self.body
+ )
+
+ if not self.deferrable:
+ return hook._wait_for_operation_to_complete(
+ project_id=self.project_id, operation_name=operation_name
+ )
+ else:
+ self.defer(
+ trigger=CloudSQLExportTrigger(
+ operation_name=operation_name,
+ project_id=self.project_id or hook.project_id,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ poke_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context, event=None) -> None:
+ """
+ Callback for when the trigger fires - returns immediately.
+ Relies on trigger to throw an exception, otherwise it assumes
execution was
+ successful.
+ """
+ if event["status"] == "success":
+ self.log.info("Operation %s completed successfully",
event["operation_name"])
+ else:
+ self.log.exception("Unexpected error in the operation.")
+ raise AirflowException(event["message"])
class CloudSQLImportInstanceOperator(CloudSQLBaseOperator):
diff --git a/airflow/providers/google/cloud/triggers/cloud_sql.py
b/airflow/providers/google/cloud/triggers/cloud_sql.py
new file mode 100644
index 0000000000..7d2cd5a323
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/cloud_sql.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""This module contains Google Cloud SQL triggers."""
+from __future__ import annotations
+
+import asyncio
+from typing import Sequence
+
+from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLAsyncHook,
CloudSqlOperationStatus
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class CloudSQLExportTrigger(BaseTrigger):
+ """
+ Trigger that periodically polls information from Cloud SQL API to verify
job status.
+ Implementation leverages asynchronous transport.
+ """
+
+ def __init__(
+ self,
+ operation_name: str,
+ project_id: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poke_interval: int = 20,
+ ):
+ super().__init__()
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.operation_name = operation_name
+ self.project_id = project_id
+ self.poke_interval = poke_interval
+ self.hook = CloudSQLAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ def serialize(self):
+ return (
+
"airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger",
+ {
+ "operation_name": self.operation_name,
+ "project_id": self.project_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poke_interval": self.poke_interval,
+ },
+ )
+
+ async def run(self):
+ while True:
+ try:
+ operation = await self.hook.get_operation(
+ project_id=self.project_id,
operation_name=self.operation_name
+ )
+ if operation["status"] == CloudSqlOperationStatus.DONE:
+ if "error" in operation:
+ yield TriggerEvent(
+ {
+ "operation_name": operation["name"],
+ "status": "error",
+ "message": operation["error"]["message"],
+ }
+ )
+ return
+ yield TriggerEvent(
+ {
+ "operation_name": operation["name"],
+ "status": "success",
+ }
+ )
+ return
+ else:
+ self.log.info(
+ "Operation status is %s, sleeping for %s seconds.",
+ operation["status"],
+ self.poke_interval,
+ )
+ await asyncio.sleep(self.poke_interval)
+ except Exception as e:
+ self.log.exception("Exception occurred while checking
operation status.")
+ yield TriggerEvent(
+ {
+ "status": "failed",
+ "message": str(e),
+ }
+ )
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 64fcba31bc..3fc2c90b9b 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -847,6 +847,9 @@ triggers:
- integration-name: Google Cloud Composer
python-modules:
- airflow.providers.google.cloud.triggers.cloud_composer
+ - integration-name: Google Cloud SQL
+ python-modules:
+ - airflow.providers.google.cloud.triggers.cloud_sql
- integration-name: Google Dataflow
python-modules:
- airflow.providers.google.cloud.triggers.dataflow
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
index 80076b95f5..4f7c9428c3 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_sql.rst
@@ -241,6 +241,14 @@ it will be retrieved from the Google Cloud connection
used. Both variants are sh
:start-after: [START howto_operator_cloudsql_export]
:end-before: [END howto_operator_cloudsql_export]
+Also for all this action you can use operator in the deferrable mode:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloudsql_export_async]
+ :end-before: [END howto_operator_cloudsql_export_async]
+
Templating
""""""""""
diff --git a/tests/providers/google/cloud/hooks/test_cloud_sql.py
b/tests/providers/google/cloud/hooks/test_cloud_sql.py
index 27eb176da4..0d90f8d0c9 100644
--- a/tests/providers/google/cloud/hooks/test_cloud_sql.py
+++ b/tests/providers/google/cloud/hooks/test_cloud_sql.py
@@ -24,13 +24,17 @@ import tempfile
from unittest import mock
from unittest.mock import PropertyMock
+import aiohttp
import httplib2
import pytest
+from aiohttp.helpers import TimerNoop
from googleapiclient.errors import HttpError
+from yarl import URL
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.google.cloud.hooks.cloud_sql import (
+ CloudSQLAsyncHook,
CloudSQLDatabaseHook,
CloudSQLHook,
CloudSqlProxyRunner,
@@ -40,6 +44,26 @@ from tests.providers.google.cloud.utils.base_gcp_mock import
(
mock_base_gcp_hook_no_default_project_id,
)
+HOOK_STR = "airflow.providers.google.cloud.hooks.cloud_sql.{}"
+PROJECT_ID = "test_project_id"
+OPERATION_NAME = "test_operation_name"
+OPERATION_URL = (
+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}"
+)
+
+
[email protected]
+def hook_async():
+ with mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseAsyncHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ ):
+ yield CloudSQLAsyncHook()
+
+
+def session():
+ return mock.Mock()
+
class TestGcpSqlHookDefaultProjectId:
def test_delegate_to_runtime_error(self):
@@ -116,9 +140,6 @@ class TestGcpSqlHookDefaultProjectId:
export_method.assert_called_once_with(body={}, instance="instance",
project="example-project")
execute_method.assert_called_once_with(num_retries=5)
- wait_for_operation_to_complete.assert_called_once_with(
- project_id="example-project", operation_name="operation_id"
- )
assert 1 == mock_get_credentials.call_count
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_conn")
@@ -133,14 +154,9 @@ class TestGcpSqlHookDefaultProjectId:
),
{"name": "operation_id"},
]
- wait_for_operation_to_complete.return_value = None
- self.cloudsql_hook.export_instance(project_id="example-project",
instance="instance", body={})
-
- assert 2 == export_method.call_count
- assert 2 == execute_method.call_count
- wait_for_operation_to_complete.assert_called_once_with(
- project_id="example-project", operation_name="operation_id"
- )
+ with pytest.raises(HttpError):
+ self.cloudsql_hook.export_instance(project_id="example-project",
instance="instance", body={})
+ wait_for_operation_to_complete.assert_not_called()
@mock.patch(
"airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook.get_credentials_and_project_id",
@@ -551,9 +567,6 @@ class TestGcpSqlHookNoDefaultProjectID:
)
export_method.assert_called_once_with(body={}, instance="instance",
project="example-project")
execute_method.assert_called_once_with(num_retries=5)
- wait_for_operation_to_complete.assert_called_once_with(
- project_id="example-project", operation_name="operation_id"
- )
@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id",
@@ -1238,3 +1251,75 @@ class TestCloudSqlProxyRunner:
)
with pytest.raises(ValueError, match="The sql_proxy_version should
match the regular expression"):
runner._get_sql_proxy_download_url()
+
+
+class TestCloudSQLAsyncHook:
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook._get_conn"))
+ async def test_async_get_operation_name_should_execute_successfully(self,
mocked_conn, hook_async):
+ await hook_async.get_operation_name(
+ operation_name=OPERATION_NAME,
+ project_id=PROJECT_ID,
+ session=session,
+ )
+
+ mocked_conn.assert_awaited_once_with(url=OPERATION_URL,
session=session)
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation_name"))
+ async def
test_async_get_operation_completed_should_execute_successfully(self,
mocked_get, hook_async):
+ response = aiohttp.ClientResponse(
+ "get",
+ URL(OPERATION_URL),
+ request_info=mock.Mock(),
+ writer=mock.Mock(),
+ continue100=None,
+ timer=TimerNoop(),
+ traces=[],
+ loop=mock.Mock(),
+ session=session,
+ )
+ response.status = 200
+ mocked_get.return_value = response
+ mocked_get.return_value._headers = {"Authorization": "test-token"}
+ mocked_get.return_value._body = b'{"status": "DONE"}'
+
+ operation = await
hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID)
+ mocked_get.assert_awaited_once()
+ assert operation["status"] == "DONE"
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation_name"))
+ async def
test_async_get_operation_running_should_execute_successfully(self, mocked_get,
hook_async):
+ response = aiohttp.ClientResponse(
+ "get",
+ URL(OPERATION_URL),
+ request_info=mock.Mock(),
+ writer=mock.Mock(),
+ continue100=None,
+ timer=TimerNoop(),
+ traces=[],
+ loop=mock.Mock(),
+ session=session,
+ )
+ response.status = 200
+ mocked_get.return_value = response
+ mocked_get.return_value._headers = {"Authorization": "test-token"}
+ mocked_get.return_value._body = b'{"status": "RUNNING"}'
+
+ operation = await
hook_async.get_operation(operation_name=OPERATION_NAME, project_id=PROJECT_ID)
+ mocked_get.assert_awaited_once()
+ assert operation["status"] == "RUNNING"
+
+ @pytest.mark.asyncio
+ @mock.patch(HOOK_STR.format("CloudSQLAsyncHook._get_conn"))
+ async def test_async_get_operation_exception_should_execute_successfully(
+ self, mocked_get_conn, hook_async
+ ):
+ """Assets that the logging is done correctly when CloudSQLAsyncHook
raises HttpError"""
+
+ mocked_get_conn.side_effect = HttpError(
+ resp=mock.MagicMock(status=409), content=b"Operation already
exists"
+ )
+ with pytest.raises(HttpError):
+ await hook_async.get_operation(operation_name=OPERATION_NAME,
project_id=PROJECT_ID)
diff --git a/tests/providers/google/cloud/operators/test_cloud_sql.py
b/tests/providers/google/cloud/operators/test_cloud_sql.py
index 888dc11ebb..903c5e3c41 100644
--- a/tests/providers/google/cloud/operators/test_cloud_sql.py
+++ b/tests/providers/google/cloud/operators/test_cloud_sql.py
@@ -22,7 +22,7 @@ from unittest import mock
import pytest
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import Connection
from airflow.providers.google.cloud.operators.cloud_sql import (
CloudSQLCloneInstanceOperator,
@@ -36,6 +36,8 @@ from airflow.providers.google.cloud.operators.cloud_sql
import (
CloudSQLInstancePatchOperator,
CloudSQLPatchInstanceDatabaseOperator,
)
+from airflow.providers.google.cloud.triggers.cloud_sql import
CloudSQLExportTrigger
+from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
PROJECT_ID = os.environ.get("PROJECT_ID", "project-id")
INSTANCE_NAME = os.environ.get("INSTANCE_NAME", "test-name")
@@ -669,6 +671,39 @@ class TestCloudSql:
)
assert result
+
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLAsyncHook")
+ def test_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
+ operator = CloudSQLExportInstanceOperator(
+ task_id="test_task",
+ instance=INSTANCE_NAME,
+ body=EXPORT_BODY,
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ operator.execute(mock.MagicMock())
+
+ mock_hook.return_value.export_instance.assert_called_once()
+
+ mock_hook.return_value.get_operation.assert_not_called()
+ assert isinstance(exc.value.trigger, CloudSQLExportTrigger)
+ assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
+
+ def test_async_execute_should_should_throw_exception(self):
+ """Tests that an AirflowException is raised in case of error event"""
+
+ op = CloudSQLExportInstanceOperator(
+ task_id="test_task",
+ instance=INSTANCE_NAME,
+ body=EXPORT_BODY,
+ deferrable=True,
+ )
+ with pytest.raises(AirflowException):
+ op.execute_complete(
+ context=mock.MagicMock(), event={"status": "error", "message":
"test failure message"}
+ )
+
@mock.patch("airflow.providers.google.cloud.operators.cloud_sql.CloudSQLHook")
def test_instance_import(self, mock_hook):
mock_hook.return_value.export_instance.return_value = True
diff --git a/tests/providers/google/cloud/triggers/test_cloud_sql.py
b/tests/providers/google/cloud/triggers/test_cloud_sql.py
new file mode 100644
index 0000000000..c7cbb2046c
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_cloud_sql.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+from unittest import mock as async_mock
+
+import pytest
+
+from airflow.providers.google.cloud.triggers.cloud_sql import
CloudSQLExportTrigger
+from airflow.triggers.base import TriggerEvent
+
+CLASSPATH =
"airflow.providers.google.cloud.triggers.cloud_sql.CloudSQLExportTrigger"
+TASK_ID = "test_task"
+TEST_POLL_INTERVAL = 10
+TEST_GCP_CONN_ID = "test-project"
+HOOK_STR = "airflow.providers.google.cloud.hooks.cloud_sql.{}"
+PROJECT_ID = "test_project_id"
+OPERATION_NAME = "test_operation_name"
+OPERATION_URL = (
+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{PROJECT_ID}/operations/{OPERATION_NAME}"
+)
+
+
[email protected]
+def trigger():
+ return CloudSQLExportTrigger(
+ operation_name=OPERATION_NAME,
+ project_id=PROJECT_ID,
+ impersonation_chain=None,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ poke_interval=TEST_POLL_INTERVAL,
+ )
+
+
+class TestCloudSQLExportTrigger:
+ def
test_async_export_trigger_serialization_should_execute_successfully(self,
trigger):
+ """
+ Asserts that the CloudSQLExportTrigger correctly serializes its
arguments
+ and classpath.
+ """
+ classpath, kwargs = trigger.serialize()
+ assert classpath == CLASSPATH
+ assert kwargs == {
+ "operation_name": OPERATION_NAME,
+ "project_id": PROJECT_ID,
+ "impersonation_chain": None,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "poke_interval": TEST_POLL_INTERVAL,
+ }
+
+ @pytest.mark.asyncio
+ @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ async def test_async_export_trigger_on_success_should_execute_successfully(
+ self, mock_get_operation, trigger
+ ):
+ """
+ Tests the CloudSQLExportTrigger only fires once the job execution
reaches a successful state.
+ """
+ mock_get_operation.return_value = {
+ "status": "DONE",
+ "name": OPERATION_NAME,
+ }
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {
+ "operation_name": OPERATION_NAME,
+ "status": "success",
+ }
+ )
+ == actual
+ )
+
+ @pytest.mark.asyncio
+ @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ async def test_async_export_trigger_running_should_execute_successfully(
+ self, mock_get_operation, trigger, caplog
+ ):
+ """
+ Test that CloudSQLExportTrigger does not fire while a job is still
running.
+ """
+
+ mock_get_operation.return_value = {
+ "status": "RUNNING",
+ "name": OPERATION_NAME,
+ }
+ caplog.set_level(logging.INFO)
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+
+ # TriggerEvent was not returned
+ assert task.done() is False
+
+ assert f"Operation status is RUNNING, sleeping for
{TEST_POLL_INTERVAL} seconds." in caplog.text
+
+ # Prevents error when task is destroyed while in "pending" state
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+ @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ async def
test_async_export_trigger_error_should_execute_successfully(self,
mock_get_operation, trigger):
+ """
+ Test that CloudSQLExportTrigger fires the correct event in case of an
error.
+ """
+ mock_get_operation.return_value = {
+ "status": "DONE",
+ "name": OPERATION_NAME,
+ "error": {"message": "test_error"},
+ }
+
+ expected_event = {
+ "operation_name": OPERATION_NAME,
+ "status": "error",
+ "message": "test_error",
+ }
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent(expected_event) == actual
+
+ @pytest.mark.asyncio
+ @async_mock.patch(HOOK_STR.format("CloudSQLAsyncHook.get_operation"))
+ async def test_async_export_trigger_exception_should_execute_successfully(
+ self, mock_get_operation, trigger
+ ):
+ """
+ Test that CloudSQLExportTrigger fires the correct event in case of an
error.
+ """
+ mock_get_operation.side_effect = Exception("Test exception")
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert TriggerEvent({"status": "failed", "message": "Test exception"})
== actual
diff --git
a/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
new file mode 100644
index 0000000000..7859d93870
--- /dev/null
+++
b/tests/system/providers/google/cloud/cloud_sql/example_cloud_sql_deferrable.py
@@ -0,0 +1,184 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that creates, patches and deletes a Cloud SQL instance,
and also
+creates, patches and deletes a database inside the instance, in Google Cloud.
+
+This DAG relies on the following OS environment variables
+https://airflow.apache.org/concepts.html#variables
+* GCP_PROJECT_ID - Google Cloud project for the Cloud SQL instance.
+* INSTANCE_NAME - Name of the Cloud SQL instance.
+* DB_NAME - Name of the database inside a Cloud SQL instance.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+from urllib.parse import urlsplit
+
+from airflow import models
+from airflow.models.xcom_arg import XComArg
+from airflow.providers.google.cloud.operators.cloud_sql import (
+ CloudSQLCreateInstanceDatabaseOperator,
+ CloudSQLCreateInstanceOperator,
+ CloudSQLDeleteInstanceDatabaseOperator,
+ CloudSQLDeleteInstanceOperator,
+ CloudSQLExportInstanceOperator,
+)
+from airflow.providers.google.cloud.operators.gcs import (
+ GCSBucketCreateAclEntryOperator,
+ GCSCreateBucketOperator,
+ GCSDeleteBucketOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "cloudsql-def"
+
+INSTANCE_NAME = f"{DAG_ID}-{ENV_ID}-instance"
+DB_NAME = f"{DAG_ID}-{ENV_ID}-db"
+
+BUCKET_NAME = f"{DAG_ID}_{ENV_ID}_bucket"
+FILE_NAME = f"{DAG_ID}_{ENV_ID}_exportImportTestFile"
+FILE_URI = f"gs://{BUCKET_NAME}/{FILE_NAME}"
+
+# Bodies below represent Cloud SQL instance resources:
+# https://cloud.google.com/sql/docs/mysql/admin-api/v1beta4/instances
+
+body = {
+ "name": INSTANCE_NAME,
+ "settings": {
+ "tier": "db-n1-standard-1",
+ "backupConfiguration": {"binaryLogEnabled": True, "enabled": True,
"startTime": "05:00"},
+ "activationPolicy": "ALWAYS",
+ "dataDiskSizeGb": 30,
+ "dataDiskType": "PD_SSD",
+ "databaseFlags": [],
+ "ipConfiguration": {
+ "ipv4Enabled": True,
+ "requireSsl": True,
+ },
+ "locationPreference": {"zone": "europe-west4-a"},
+ "maintenanceWindow": {"hour": 5, "day": 7, "updateTrack": "canary"},
+ "pricingPlan": "PER_USE",
+ "replicationType": "ASYNCHRONOUS",
+ "storageAutoResize": True,
+ "storageAutoResizeLimit": 0,
+ "userLabels": {"my-key": "my-value"},
+ },
+ "databaseVersion": "MYSQL_5_7",
+ "region": "europe-west4",
+}
+
+export_body = {
+ "exportContext": {
+ "fileType": "sql",
+ "uri": FILE_URI,
+ "sqlExportOptions": {"schemaOnly": False},
+ "offload": True,
+ }
+}
+
+db_create_body = {"instance": INSTANCE_NAME, "name": DB_NAME, "project":
PROJECT_ID}
+
+
+with models.DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "cloud_sql"],
+) as dag:
+ create_bucket = GCSCreateBucketOperator(task_id="create_bucket",
bucket_name=BUCKET_NAME)
+
+ sql_instance_create_task = CloudSQLCreateInstanceOperator(
+ body=body, instance=INSTANCE_NAME, task_id="sql_instance_create_task"
+ )
+
+ sql_db_create_task = CloudSQLCreateInstanceDatabaseOperator(
+ body=db_create_body, instance=INSTANCE_NAME,
task_id="sql_db_create_task"
+ )
+
+ file_url_split = urlsplit(FILE_URI)
+
+ # For export & import to work we need to add the Cloud SQL instance's
Service Account
+ # write access to the destination GCS bucket.
+ service_account_email = XComArg(sql_instance_create_task,
key="service_account_email")
+
+ sql_gcp_add_bucket_permission_task = GCSBucketCreateAclEntryOperator(
+ entity=f"user-{service_account_email}",
+ role="WRITER",
+ bucket=file_url_split[1], # netloc (bucket)
+ task_id="sql_gcp_add_bucket_permission_task",
+ )
+
+ # [START howto_operator_cloudsql_export_async]
+ sql_export_task = CloudSQLExportInstanceOperator(
+ body=export_body,
+ instance=INSTANCE_NAME,
+ task_id="sql_export_task",
+ deferrable=True,
+ )
+ # [END howto_operator_cloudsql_export_async]
+
+ sql_db_delete_task = CloudSQLDeleteInstanceDatabaseOperator(
+ instance=INSTANCE_NAME, database=DB_NAME, task_id="sql_db_delete_task"
+ )
+ sql_db_delete_task.trigger_rule = TriggerRule.ALL_DONE
+
+ sql_instance_delete_task = CloudSQLDeleteInstanceOperator(
+ instance=INSTANCE_NAME, task_id="sql_instance_delete_task"
+ )
+ sql_instance_delete_task.trigger_rule = TriggerRule.ALL_DONE
+
+ delete_bucket = GCSDeleteBucketOperator(
+ task_id="delete_bucket", bucket_name=BUCKET_NAME,
trigger_rule=TriggerRule.ALL_DONE
+ )
+
+ (
+ # TEST SETUP
+ create_bucket
+ # TEST BODY
+ >> sql_instance_create_task
+ >> sql_db_create_task
+ >> sql_gcp_add_bucket_permission_task
+ >> sql_export_task
+ >> sql_db_delete_task
+ >> sql_instance_delete_task
+ # TEST TEARDOWN
+ >> delete_bucket
+ )
+
+ # Task dependencies created via `XComArgs`:
+ # sql_instance_create_task >> sql_gcp_add_bucket_permission_task
+ # sql_instance_create_task >> sql_gcp_add_object_permission_task
+
+ # ### Everything below this line is not part of example ###
+ # ### Just for system tests purpose ###
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)