kaxil closed pull request #4314: [AIRFLOW-3398] Google Cloud Spanner instance database query operator URL: https://github.com/apache/incubator-airflow/pull/4314
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/contrib/example_dags/example_gcp_spanner.py b/airflow/contrib/example_dags/example_gcp_spanner.py index dd8b8c52b9..cec3dcb855 100644 --- a/airflow/contrib/example_dags/example_gcp_spanner.py +++ b/airflow/contrib/example_dags/example_gcp_spanner.py @@ -18,18 +18,18 @@ # under the License. """ -Example Airflow DAG that creates, updates and deletes a Cloud Spanner instance. +Example Airflow DAG that creates, updates, queries and deletes a Cloud Spanner instance. This DAG relies on the following environment variables -* PROJECT_ID - Google Cloud Platform project for the Cloud Spanner instance. -* INSTANCE_ID - Cloud Spanner instance ID. -* CONFIG_NAME - The name of the instance's configuration. Values are of the form +* SPANNER_PROJECT_ID - Google Cloud Platform project for the Cloud Spanner instance. +* SPANNER_INSTANCE_ID - Cloud Spanner instance ID. +* SPANNER_CONFIG_NAME - The name of the instance's configuration. Values are of the form projects/<project>/instanceConfigs/<configuration>. See also: https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs#InstanceConfig https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list#google.spanner.admin.instance.v1.InstanceAdmin.ListInstanceConfigs -* NODE_COUNT - Number of nodes allocated to the instance. -* DISPLAY_NAME - The descriptive name for this instance as it appears in UIs. +* SPANNER_NODE_COUNT - Number of nodes allocated to the instance. +* SPANNER_DISPLAY_NAME - The descriptive name for this instance as it appears in UIs. Must be unique per project and between 4 and 30 characters in length. """ @@ -38,15 +38,17 @@ import airflow from airflow import models from airflow.contrib.operators.gcp_spanner_operator import \ - CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDeleteOperator + CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDatabaseQueryOperator, \ + CloudSpannerInstanceDeleteOperator # [START howto_operator_spanner_arguments] -PROJECT_ID = os.environ.get('PROJECT_ID', 'example-project') -INSTANCE_ID = os.environ.get('INSTANCE_ID', 'testinstance') -CONFIG_NAME = os.environ.get('CONFIG_NAME', +PROJECT_ID = os.environ.get('SPANNER_PROJECT_ID', 'example-project') +INSTANCE_ID = os.environ.get('SPANNER_INSTANCE_ID', 'testinstance') +DB_ID = os.environ.get('SPANNER_DB_ID', 'db1') +CONFIG_NAME = os.environ.get('SPANNER_CONFIG_NAME', 'projects/example-project/instanceConfigs/eur3') -NODE_COUNT = os.environ.get('NODE_COUNT', '1') -DISPLAY_NAME = os.environ.get('DISPLAY_NAME', 'Test Instance') +NODE_COUNT = os.environ.get('SPANNER_NODE_COUNT', '1') +DISPLAY_NAME = os.environ.get('SPANNER_DISPLAY_NAME', 'Test Instance') # [END howto_operator_spanner_arguments] default_args = { @@ -80,6 +82,24 @@ task_id='spanner_instance_update_task' ) + # [START howto_operator_spanner_query] + spanner_instance_query = CloudSpannerInstanceDatabaseQueryOperator( + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id='db1', + query="DELETE FROM my_table2 WHERE true", + task_id='spanner_instance_query' + ) + # [END howto_operator_spanner_query] + + spanner_instance_query2 = CloudSpannerInstanceDatabaseQueryOperator( + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id='db1', + query="example_gcp_spanner.sql", + task_id='spanner_instance_query2' + ) + # [START howto_operator_spanner_delete] spanner_instance_delete_task = CloudSpannerInstanceDeleteOperator( project_id=PROJECT_ID, @@ -89,4 +109,5 @@ # [END howto_operator_spanner_delete] spanner_instance_create_task >> spanner_instance_update_task \ + >> spanner_instance_query >> spanner_instance_query2 \ >> spanner_instance_delete_task diff --git a/airflow/contrib/example_dags/example_gcp_spanner.sql b/airflow/contrib/example_dags/example_gcp_spanner.sql new file mode 100644 index 0000000000..5d5f238022 --- /dev/null +++ b/airflow/contrib/example_dags/example_gcp_spanner.sql @@ -0,0 +1,3 @@ +INSERT my_table2 (id, name) VALUES (7, 'Seven'); +INSERT my_table2 (id, name) + VALUES (8, 'Eight'); diff --git a/airflow/contrib/hooks/gcp_spanner_hook.py b/airflow/contrib/hooks/gcp_spanner_hook.py index fc73562e8b..96e8bcb71c 100644 --- a/airflow/contrib/hooks/gcp_spanner_hook.py +++ b/airflow/contrib/hooks/gcp_spanner_hook.py @@ -16,12 +16,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from google.longrunning.operations_grpc_pb2 import Operation # noqa: F401 -from typing import Optional, Callable # noqa: F401 - from google.api_core.exceptions import GoogleAPICallError from google.cloud.spanner_v1.client import Client +from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance # noqa: F401 +from google.longrunning.operations_grpc_pb2 import Operation # noqa: F401 +from typing import Optional, Callable # noqa: F401 from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook @@ -181,3 +181,28 @@ def delete_instance(self, project_id, instance_id): except GoogleAPICallError as e: self.log.error('An error occurred: %s. Aborting.', e.message) raise e + + def execute_dml(self, project_id, instance_id, database_id, queries): + # type: (str, str, str, str) -> None + """ + Executes an arbitrary DML query (INSERT, UPDATE, DELETE). + + :param project_id: The ID of the project which owns the instances, tables and data. + :type project_id: str + :param instance_id: The ID of the instance. + :type instance_id: str + :param database_id: The ID of the database. + :type database_id: str + :param queries: The queries to be executed. + :type queries: str + """ + client = self.get_client(project_id) + instance = client.instance(instance_id) + database = Database(database_id, instance) + database.run_in_transaction(lambda transaction: + self._execute_sql_in_transaction(transaction, queries)) + + @staticmethod + def _execute_sql_in_transaction(transaction, queries): + for sql in queries: + transaction.execute_update(sql) diff --git a/airflow/contrib/operators/gcp_spanner_operator.py b/airflow/contrib/operators/gcp_spanner_operator.py index 7b329a3849..b803fcc30a 100644 --- a/airflow/contrib/operators/gcp_spanner_operator.py +++ b/airflow/contrib/operators/gcp_spanner_operator.py @@ -16,6 +16,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import six + from airflow import AirflowException from airflow.contrib.hooks.gcp_spanner_hook import CloudSpannerHook from airflow.models import BaseOperator @@ -130,3 +132,68 @@ def execute(self, context): self.log.info("Instance '%s' does not exist in project '%s'. " "Aborting delete.", self.instance_id, self.project_id) return True + + +class CloudSpannerInstanceDatabaseQueryOperator(BaseOperator): + """ + Executes an arbitrary DML query (INSERT, UPDATE, DELETE). + + :param project_id: The ID of the project which owns the instances, tables and data. + :type project_id: str + :param instance_id: The ID of the instance. + :type instance_id: str + :param database_id: The ID of the database. + :type database_id: str + :param query: The query or list of queries to be executed. Can be a path to a SQL file. + :type query: str or list + :param gcp_conn_id: The connection ID used to connect to Google Cloud Platform. + :type gcp_conn_id: str + """ + # [START gcp_spanner_query_template_fields] + template_fields = ('project_id', 'instance_id', 'database_id', 'query', 'gcp_conn_id') + template_ext = ('.sql',) + # [END gcp_spanner_query_template_fields] + + @apply_defaults + def __init__(self, + project_id, + instance_id, + database_id, + query, + gcp_conn_id='google_cloud_default', + *args, **kwargs): + self.instance_id = instance_id + self.project_id = project_id + self.database_id = database_id + self.query = query + self.gcp_conn_id = gcp_conn_id + self._validate_inputs() + self._hook = CloudSpannerHook(gcp_conn_id=gcp_conn_id) + super(CloudSpannerInstanceDatabaseQueryOperator, self).__init__(*args, **kwargs) + + def _validate_inputs(self): + if not self.project_id: + raise AirflowException("The required parameter 'project_id' is empty") + if not self.instance_id: + raise AirflowException("The required parameter 'instance_id' is empty") + if not self.database_id: + raise AirflowException("The required parameter 'database_id' is empty") + if not self.query: + raise AirflowException("The required parameter 'query' is empty") + + def execute(self, context): + queries = self.query + if isinstance(self.query, six.string_types): + queries = [x.strip() for x in self.query.split(';')] + self.sanitize_queries(queries) + self.log.info("Executing DML query(-ies) on " + "projects/%s/instances/%s/databases/%s", + self.project_id, self.instance_id, self.database_id) + self.log.info(queries) + self._hook.execute_dml(self.project_id, self.instance_id, + self.database_id, queries) + + @staticmethod + def sanitize_queries(queries): + if len(queries) and queries[-1] == '': + del queries[-1] diff --git a/docs/howto/operator.rst b/docs/howto/operator.rst index 095553b3ac..221913dec0 100644 --- a/docs/howto/operator.rst +++ b/docs/howto/operator.rst @@ -545,6 +545,48 @@ See `Google Cloud Functions API documentation Google Cloud Sql Operators -------------------------- +CloudSpannerInstanceDatabaseQueryOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Executes an arbitrary DML query (INSERT, UPDATE, DELETE). + +For parameter definition take a look at +:class:`~airflow.contrib.operators.gcp_spanner_operator.CloudSpannerInstanceDatabaseQueryOperator`. + +Arguments +""""""""" + +Some arguments in the example DAG are taken from environment variables: + +.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_spanner.py + :language: python + :start-after: [START howto_operator_spanner_arguments] + :end-before: [END howto_operator_spanner_arguments] + +Using the operator +"""""""""""""""""" + +.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_spanner.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_spanner_query] + :end-before: [END howto_operator_spanner_query] + +Templating +"""""""""" + +.. literalinclude:: ../../airflow/contrib/operators/gcp_spanner_operator.py + :language: python + :dedent: 4 + :start-after: [START gcp_spanner_query_template_fields] + :end-before: [END gcp_spanner_query_template_fields] + +More information +"""""""""""""""" + +See Google Cloud Spanner API documentation for `the DML syntax +<https://cloud.google.com/spanner/docs/dml-syntax>`_. + CloudSpannerInstanceDeployOperator ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/integration.rst b/docs/integration.rst index 0afe555309..e74d8f662b 100644 --- a/docs/integration.rst +++ b/docs/integration.rst @@ -642,10 +642,19 @@ Cloud Spanner Cloud Spanner Operators """"""""""""""""""""""" +- :ref:`CloudSpannerInstanceDatabaseQueryOperator` : executes an arbitrary DML query + (INSERT, UPDATE, DELETE). - :ref:`CloudSpannerInstanceDeployOperator` : creates a new Cloud Spanner instance or, if an instance with the same name exists, updates it. - :ref:`CloudSpannerInstanceDeleteOperator` : deletes a Cloud Spanner instance. +.. _CloudSpannerInstanceDatabaseQueryOperator: + +CloudSpannerInstanceDatabaseQueryOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: airflow.contrib.operators.gcp_spanner_operator.CloudSpannerInstanceDatabaseQueryOperator + .. _CloudSpannerInstanceDeployOperator: CloudSpannerInstanceDeployOperator diff --git a/tests/contrib/operators/test_gcp_spanner_operator.py b/tests/contrib/operators/test_gcp_spanner_operator.py index ff2b82fd16..38ae985f26 100644 --- a/tests/contrib/operators/test_gcp_spanner_operator.py +++ b/tests/contrib/operators/test_gcp_spanner_operator.py @@ -22,7 +22,8 @@ from airflow import AirflowException from airflow.contrib.operators.gcp_spanner_operator import \ - CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDeleteOperator + CloudSpannerInstanceDeployOperator, CloudSpannerInstanceDeleteOperator, \ + CloudSpannerInstanceDatabaseQueryOperator from tests.contrib.operators.test_gcp_base import BaseGcpIntegrationTestCase, \ SKIP_TEST_WARNING, GCP_SPANNER_KEY @@ -37,10 +38,15 @@ PROJECT_ID = 'project-id' INSTANCE_ID = 'instance-id' -DB_NAME = 'db1' +DB_ID = 'db1' CONFIG_NAME = 'projects/project-id/instanceConfigs/eur3' NODE_COUNT = '1' DISPLAY_NAME = 'Test Instance' +INSERT_QUERY = "INSERT my_table1 (id, name) VALUES (1, 'One')" +INSERT_QUERY_2 = "INSERT my_table2 (id, name) VALUES (1, 'One')" +CREATE_QUERY = "CREATE TABLE my_table1 (id INT64, name STRING(MAX)) PRIMARY KEY (id)" +CREATE_QUERY_2 = "CREATE TABLE my_table2 (id INT64, name STRING(MAX)) PRIMARY KEY (id)" +QUERY_TYPE = "DML" class CloudSpannerTest(unittest.TestCase): @@ -164,6 +170,76 @@ def test_instance_delete_ex_if_param_missing(self, project_id, instance_id, exp_ self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) mock_hook.assert_not_called() + @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook") + def test_instance_query(self, mock_hook): + mock_hook.return_value.execute_sql.return_value = None + op = CloudSpannerInstanceDatabaseQueryOperator( + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id=DB_ID, + query=INSERT_QUERY, + task_id="id" + ) + result = op.execute(None) + mock_hook.assert_called_once_with(gcp_conn_id="google_cloud_default") + mock_hook.return_value.execute_dml.assert_called_once_with( + PROJECT_ID, INSTANCE_ID, DB_ID, [INSERT_QUERY] + ) + self.assertIsNone(result) + + @parameterized.expand([ + ("", INSTANCE_ID, DB_ID, INSERT_QUERY, "project_id"), + (PROJECT_ID, "", DB_ID, INSERT_QUERY, "instance_id"), + (PROJECT_ID, INSTANCE_ID, "", INSERT_QUERY, "database_id"), + (PROJECT_ID, INSTANCE_ID, DB_ID, "", "query"), + ]) + @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook") + def test_instance_query_ex_if_param_missing(self, project_id, instance_id, + database_id, query, exp_msg, mock_hook): + with self.assertRaises(AirflowException) as cm: + CloudSpannerInstanceDatabaseQueryOperator( + project_id=project_id, + instance_id=instance_id, + database_id=database_id, + query=query, + task_id="id" + ) + err = cm.exception + self.assertIn("The required parameter '{}' is empty".format(exp_msg), str(err)) + mock_hook.assert_not_called() + + @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook") + def test_instance_query_dml(self, mock_hook): + mock_hook.return_value.execute_dml.return_value = None + op = CloudSpannerInstanceDatabaseQueryOperator( + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id=DB_ID, + query=INSERT_QUERY, + task_id="id" + ) + op.execute(None) + mock_hook.assert_called_once_with(gcp_conn_id="google_cloud_default") + mock_hook.return_value.execute_dml.assert_called_once_with( + PROJECT_ID, INSTANCE_ID, DB_ID, [INSERT_QUERY] + ) + + @mock.patch("airflow.contrib.operators.gcp_spanner_operator.CloudSpannerHook") + def test_instance_query_dml_list(self, mock_hook): + mock_hook.return_value.execute_dml.return_value = None + op = CloudSpannerInstanceDatabaseQueryOperator( + project_id=PROJECT_ID, + instance_id=INSTANCE_ID, + database_id=DB_ID, + query=[INSERT_QUERY, INSERT_QUERY_2], + task_id="id" + ) + op.execute(None) + mock_hook.assert_called_once_with(gcp_conn_id="google_cloud_default") + mock_hook.return_value.execute_dml.assert_called_once_with( + PROJECT_ID, INSTANCE_ID, DB_ID, [INSERT_QUERY, INSERT_QUERY_2] + ) + @unittest.skipIf( BaseGcpIntegrationTestCase.skip_check(GCP_SPANNER_KEY), SKIP_TEST_WARNING) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services