This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f05beb7219 Add Alibaba Cloud AnalyticDB Spark Support (#31787)
f05beb7219 is described below
commit f05beb721940c36f2c501ea8657516262b99c01e
Author: Qian.Sun <[email protected]>
AuthorDate: Tue Jun 27 14:50:25 2023 +0800
Add Alibaba Cloud AnalyticDB Spark Support (#31787)
* Add Alibaba Cloud AnalyticDB Spark Support
---
.../alibaba/cloud/hooks/analyticdb_spark.py | 366 +++++++++++++++++++++
.../alibaba/cloud/operators/analyticdb_spark.py | 223 +++++++++++++
.../alibaba/cloud/sensors/analyticdb_spark.py | 68 ++++
airflow/providers/alibaba/provider.yaml | 18 +
.../connections/alibaba.rst | 2 +-
docs/apache-airflow-providers-alibaba/index.rst | 14 +-
.../operators/analyticdb_spark.rst | 45 +++
docs/conf.py | 2 +
docs/spelling_wordlist.txt | 2 +
generated/provider_dependencies.json | 2 +
.../alibaba/cloud/hooks/test_analyticdb_spark.py | 203 ++++++++++++
.../cloud/operators/test_analyticdb_spark.py | 176 ++++++++++
.../alibaba/cloud/sensors/test_analyticdb_spark.py | 72 ++++
.../alibaba/cloud/utils/analyticdb_spark_mock.py | 41 +++
.../providers/alibaba/example_adb_spark_batch.py | 62 ++++
.../providers/alibaba/example_adb_spark_sql.py | 54 +++
16 files changed, 1343 insertions(+), 7 deletions(-)
diff --git a/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
new file mode 100644
index 0000000000..bf3eca1722
--- /dev/null
+++ b/airflow/providers/alibaba/cloud/hooks/analyticdb_spark.py
@@ -0,0 +1,366 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from enum import Enum
+from typing import Any, Sequence
+
+from alibabacloud_adb20211201.client import Client
+from alibabacloud_adb20211201.models import (
+ GetSparkAppLogRequest,
+ GetSparkAppStateRequest,
+ GetSparkAppWebUiAddressRequest,
+ KillSparkAppRequest,
+ SubmitSparkAppRequest,
+ SubmitSparkAppResponse,
+)
+from alibabacloud_tea_openapi.models import Config
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class AppState(Enum):
+ """
+ AnalyticDB Spark application states doc:
+
https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/api-doc-adb-2021-12-01-api-struct
+ -sparkappinfo.
+
+ """
+
+ SUBMITTED = "SUBMITTED"
+ STARTING = "STARTING"
+ RUNNING = "RUNNING"
+ FAILING = "FAILING"
+ FAILED = "FAILED"
+ KILLING = "KILLING"
+ KILLED = "KILLED"
+ SUCCEEDING = "SUCCEEDING"
+ COMPLETED = "COMPLETED"
+ FATAL = "FATAL"
+ UNKNOWN = "UNKNOWN"
+
+
+class AnalyticDBSparkHook(BaseHook, LoggingMixin):
+ """
+ Hook for AnalyticDB MySQL Spark through the REST API.
+
+ :param adb_spark_conn_id: The Airflow connection used for AnalyticDB MySQL
Spark credentials.
+ :param region: AnalyticDB MySQL region you want to submit spark
application.
+ """
+
+ TERMINAL_STATES = {AppState.COMPLETED, AppState.FAILED, AppState.FATAL,
AppState.KILLED}
+
+ conn_name_attr = "alibabacloud_conn_id"
+ default_conn_name = "adb_spark_default"
+ conn_type = "adb_spark"
+ hook_name = "AnalyticDB Spark"
+
+ def __init__(
+ self, adb_spark_conn_id: str = "adb_spark_default", region: str | None
= None, *args, **kwargs
+ ) -> None:
+ self.adb_spark_conn_id = adb_spark_conn_id
+ self.adb_spark_conn = self.get_connection(adb_spark_conn_id)
+ self.region = self.get_default_region() if region is None else region
+ super().__init__(*args, **kwargs)
+
+ def submit_spark_app(
+ self, cluster_id: str, rg_name: str, *args: Any, **kwargs: Any
+ ) -> SubmitSparkAppResponse:
+ """
+ Perform request to submit spark application.
+
+ :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data
Lakehouse.
+ :param rg_name: The name of resource group in AnalyticDB MySQL 3.0
Data Lakehouse cluster.
+ """
+ self.log.info("Submitting application")
+ request = SubmitSparkAppRequest(
+ dbcluster_id=cluster_id,
+ resource_group_name=rg_name,
+ data=json.dumps(self.build_submit_app_data(*args, **kwargs)),
+ app_type="BATCH",
+ )
+ try:
+ return self.get_adb_spark_client().submit_spark_app(request)
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException("Errors when submit spark application")
from e
+
+ def submit_spark_sql(
+ self, cluster_id: str, rg_name: str, *args: Any, **kwargs: Any
+ ) -> SubmitSparkAppResponse:
+ """
+ Perform request to submit spark sql.
+
+ :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data
Lakehouse.
+ :param rg_name: The name of resource group in AnalyticDB MySQL 3.0
Data Lakehouse cluster.
+ """
+ self.log.info("Submitting Spark SQL")
+ request = SubmitSparkAppRequest(
+ dbcluster_id=cluster_id,
+ resource_group_name=rg_name,
+ data=self.build_submit_sql_data(*args, **kwargs),
+ app_type="SQL",
+ )
+ try:
+ return self.get_adb_spark_client().submit_spark_app(request)
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException("Errors when submit spark sql") from e
+
+ def get_spark_state(self, app_id: str) -> str:
+ """
+ Fetch the state of the specified spark application.
+
+ :param app_id: identifier of the spark application
+ """
+ self.log.debug("Fetching state for spark application %s", app_id)
+ try:
+ return (
+ self.get_adb_spark_client()
+ .get_spark_app_state(GetSparkAppStateRequest(app_id=app_id))
+ .body.data.state
+ )
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when fetching state for spark
application: {app_id}") from e
+
+ def get_spark_web_ui_address(self, app_id: str) -> str:
+ """
+ Fetch the web ui address of the specified spark application.
+
+ :param app_id: identifier of the spark application
+ """
+ self.log.debug("Fetching web ui address for spark application %s",
app_id)
+ try:
+ return (
+ self.get_adb_spark_client()
+
.get_spark_app_web_ui_address(GetSparkAppWebUiAddressRequest(app_id=app_id))
+ .body.data.web_ui_address
+ )
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(
+ f"Errors when fetching web ui address for spark application:
{app_id}"
+ ) from e
+
+ def get_spark_log(self, app_id: str) -> str:
+ """
+ Get the logs for a specified spark application.
+
+ :param app_id: identifier of the spark application
+ """
+ self.log.debug("Fetching log for spark application %s", app_id)
+ try:
+ return (
+ self.get_adb_spark_client()
+ .get_spark_app_log(GetSparkAppLogRequest(app_id=app_id))
+ .body.data.log_content
+ )
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when fetching log for spark
application: {app_id}") from e
+
+ def kill_spark_app(self, app_id: str) -> None:
+ """
+ Kill the specified spark application.
+
+ :param app_id: identifier of the spark application
+ """
+ self.log.info("Killing spark application %s", app_id)
+ try:
+
self.get_adb_spark_client().kill_spark_app(KillSparkAppRequest(app_id=app_id))
+ except Exception as e:
+ self.log.error(e)
+ raise AirflowException(f"Errors when killing spark application:
{app_id}") from e
+
+ @staticmethod
+ def build_submit_app_data(
+ file: str | None = None,
+ class_name: str | None = None,
+ args: Sequence[str | int | float] | None = None,
+ conf: dict[Any, Any] | None = None,
+ jars: Sequence[str] | None = None,
+ py_files: Sequence[str] | None = None,
+ files: Sequence[str] | None = None,
+ driver_resource_spec: str | None = None,
+ executor_resource_spec: str | None = None,
+ num_executors: int | str | None = None,
+ archives: Sequence[str] | None = None,
+ name: str | None = None,
+ ) -> dict:
+ """
+ Build the submit application request data.
+
+ :param file: path of the file containing the application to execute.
+ :param class_name: name of the application Java/Spark main class.
+ :param args: application command line arguments.
+ :param conf: Spark configuration properties.
+ :param jars: jars to be used in this application.
+ :param py_files: python files to be used in this application.
+ :param files: files to be used in this application.
+ :param driver_resource_spec: The resource specifications of the Spark
driver.
+ :param executor_resource_spec: The resource specifications of each
Spark executor.
+ :param num_executors: number of executors to launch for this
application.
+ :param archives: archives to be used in this application.
+ :param name: name of this application.
+ """
+ if file is None:
+ raise ValueError("Parameter file is need when submit spark
application.")
+
+ data: dict[str, Any] = {"file": file}
+ extra_conf: dict[str, str] = {}
+
+ if class_name:
+ data["className"] = class_name
+ if args and AnalyticDBSparkHook._validate_list_of_stringables(args):
+ data["args"] = [str(val) for val in args]
+ if driver_resource_spec:
+ extra_conf["spark.driver.resourceSpec"] = driver_resource_spec
+ if executor_resource_spec:
+ extra_conf["spark.executor.resourceSpec"] = executor_resource_spec
+ if num_executors:
+ extra_conf["spark.executor.instances"] = str(num_executors)
+ data["conf"] = extra_conf.copy()
+ if conf and AnalyticDBSparkHook._validate_extra_conf(conf):
+ data["conf"].update(conf)
+ if jars and AnalyticDBSparkHook._validate_list_of_stringables(jars):
+ data["jars"] = jars
+ if py_files and
AnalyticDBSparkHook._validate_list_of_stringables(py_files):
+ data["pyFiles"] = py_files
+ if files and AnalyticDBSparkHook._validate_list_of_stringables(files):
+ data["files"] = files
+ if archives and
AnalyticDBSparkHook._validate_list_of_stringables(archives):
+ data["archives"] = archives
+ if name:
+ data["name"] = name
+
+ return data
+
+ @staticmethod
+ def build_submit_sql_data(
+ sql: str | None = None,
+ conf: dict[Any, Any] | None = None,
+ driver_resource_spec: str | None = None,
+ executor_resource_spec: str | None = None,
+ num_executors: int | str | None = None,
+ name: str | None = None,
+ ) -> str:
+ """
+ Build the submit spark sql request data.
+
+ :param sql: The SQL query to execute. (templated)
+ :param conf: Spark configuration properties.
+ :param driver_resource_spec: The resource specifications of the Spark
driver.
+ :param executor_resource_spec: The resource specifications of each
Spark executor.
+ :param num_executors: number of executors to launch for this
application.
+ :param name: name of this application.
+ """
+ if sql is None:
+ raise ValueError("Parameter sql is need when submit spark sql.")
+
+ extra_conf: dict[str, str] = {}
+ formatted_conf = ""
+
+ if driver_resource_spec:
+ extra_conf["spark.driver.resourceSpec"] = driver_resource_spec
+ if executor_resource_spec:
+ extra_conf["spark.executor.resourceSpec"] = executor_resource_spec
+ if num_executors:
+ extra_conf["spark.executor.instances"] = str(num_executors)
+ if name:
+ extra_conf["spark.app.name"] = name
+ if conf and AnalyticDBSparkHook._validate_extra_conf(conf):
+ extra_conf.update(conf)
+ for key, value in extra_conf.items():
+ formatted_conf += f"set {key} = {value};"
+
+ return (formatted_conf + sql).strip()
+
+ @staticmethod
+ def _validate_list_of_stringables(vals: Sequence[str | int | float]) ->
bool:
+ """
+ Check the values in the provided list can be converted to strings.
+
+ :param vals: list to validate
+ """
+ if (
+ vals is None
+ or not isinstance(vals, (tuple, list))
+ or any(1 for val in vals if not isinstance(val, (str, int, float)))
+ ):
+ raise ValueError("List of strings expected")
+ return True
+
+ @staticmethod
+ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:
+ """
+ Check configuration values are either strings or ints.
+
+ :param conf: configuration variable
+ """
+ if conf:
+ if not isinstance(conf, dict):
+ raise ValueError("'conf' argument must be a dict")
+ if any(True for k, v in conf.items() if not (v and isinstance(v,
str) or isinstance(v, int))):
+ raise ValueError("'conf' values must be either strings or
ints")
+ return True
+
+ def get_adb_spark_client(self) -> Client:
+ """Get valid AnalyticDB MySQL Spark client."""
+ assert self.region is not None
+
+ extra_config = self.adb_spark_conn.extra_dejson
+ auth_type = extra_config.get("auth_type", None)
+ if not auth_type:
+ raise ValueError("No auth_type specified in extra_config.")
+
+ if auth_type != "AK":
+ raise ValueError(f"Unsupported auth_type: {auth_type}")
+ adb_spark_access_key_id = extra_config.get("access_key_id", None)
+ adb_spark_access_secret = extra_config.get("access_key_secret", None)
+ if not adb_spark_access_key_id:
+ raise ValueError(f"No access_key_id is specified for connection:
{self.adb_spark_conn_id}")
+
+ if not adb_spark_access_secret:
+ raise ValueError(f"No access_key_secret is specified for
connection: {self.adb_spark_conn_id}")
+
+ return Client(
+ Config(
+ access_key_id=adb_spark_access_key_id,
+ access_key_secret=adb_spark_access_secret,
+ endpoint=f"adb.{self.region}.aliyuncs.com",
+ )
+ )
+
+ def get_default_region(self) -> str | None:
+ """Get default region from connection."""
+ extra_config = self.adb_spark_conn.extra_dejson
+ auth_type = extra_config.get("auth_type", None)
+ if not auth_type:
+ raise ValueError("No auth_type specified in extra_config. ")
+
+ if auth_type != "AK":
+ raise ValueError(f"Unsupported auth_type: {auth_type}")
+
+ default_region = extra_config.get("region", None)
+ if not default_region:
+ raise ValueError(f"No region is specified for connection:
{self.adb_spark_conn}")
+ return default_region
diff --git a/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py
b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py
new file mode 100644
index 0000000000..6ddd47dab2
--- /dev/null
+++ b/airflow/providers/alibaba/cloud/operators/analyticdb_spark.py
@@ -0,0 +1,223 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from time import sleep
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import
AnalyticDBSparkHook, AppState
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class AnalyticDBSparkBaseOperator(BaseOperator):
+ """Abstract base class that defines how users develop AnalyticDB Spark."""
+
+ def __init__(
+ self,
+ *,
+ adb_spark_conn_id: str = "adb_spark_default",
+ region: str | None = None,
+ polling_interval: int = 0,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.app_id: str | None = None
+ self.polling_interval = polling_interval
+
+ self._adb_spark_conn_id = adb_spark_conn_id
+ self._region = region
+
+ self._adb_spark_hook: AnalyticDBSparkHook | None = None
+
+ @cached_property
+ def get_hook(self) -> AnalyticDBSparkHook:
+ """Get valid hook."""
+ if self._adb_spark_hook is None or not
isinstance(self._adb_spark_hook, AnalyticDBSparkHook):
+ self._adb_spark_hook = AnalyticDBSparkHook(
+ adb_spark_conn_id=self._adb_spark_conn_id, region=self._region
+ )
+ return self._adb_spark_hook
+
+ def execute(self, context: Context) -> Any:
+ ...
+
+ def monitor_application(self):
+ self.log.info("Monitoring application with %s", self.app_id)
+
+ if self.polling_interval > 0:
+ self.poll_for_termination(self.app_id)
+
+ def poll_for_termination(self, app_id: str) -> None:
+ """
+ Pool for spark application termination.
+
+ :param app_id: id of the spark application to monitor
+ """
+ hook = self.get_hook
+ state = hook.get_spark_state(app_id)
+ while AppState(state) not in AnalyticDBSparkHook.TERMINAL_STATES:
+ self.log.debug("Application with id %s is in state: %s", app_id,
state)
+ sleep(self.polling_interval)
+ state = hook.get_spark_state(app_id)
+ self.log.info("Application with id %s terminated with state: %s",
app_id, state)
+ self.log.info(
+ "Web ui address is %s for application with id %s",
hook.get_spark_web_ui_address(app_id), app_id
+ )
+ self.log.info(hook.get_spark_log(app_id))
+ if AppState(state) != AppState.COMPLETED:
+ raise AirflowException(f"Application {app_id} did not succeed")
+
+ def on_kill(self) -> None:
+ self.kill()
+
+ def kill(self) -> None:
+ """Delete the specified application."""
+ if self.app_id is not None:
+ self.get_hook.kill_spark_app(self.app_id)
+
+
+class AnalyticDBSparkSQLOperator(AnalyticDBSparkBaseOperator):
+ """
+ This operator warps the AnalyticDB Spark REST API, allowing to submit a
Spark sql
+ application to the underlying cluster.
+
+ :param sql: The SQL query to execute.
+ :param conf: Spark configuration properties.
+ :param driver_resource_spec: The resource specifications of the Spark
driver.
+ :param executor_resource_spec: The resource specifications of each Spark
executor.
+ :param num_executors: number of executors to launch for this application.
+ :param name: name of this application.
+ :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse.
+ :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data
Lakehouse cluster.
+ """
+
+ template_fields: Sequence[str] = ("spark_params",)
+ template_fields_renderers = {"spark_params": "json"}
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ conf: dict[Any, Any] | None = None,
+ driver_resource_spec: str | None = None,
+ executor_resource_spec: str | None = None,
+ num_executors: int | str | None = None,
+ name: str | None = None,
+ cluster_id: str,
+ rg_name: str,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.spark_params = {
+ "sql": sql,
+ "conf": conf,
+ "driver_resource_spec": driver_resource_spec,
+ "executor_resource_spec": executor_resource_spec,
+ "num_executors": num_executors,
+ "name": name,
+ }
+
+ self._cluster_id = cluster_id
+ self._rg_name = rg_name
+
+ def execute(self, context: Context) -> Any:
+ submit_response = self.get_hook.submit_spark_sql(
+ cluster_id=self._cluster_id, rg_name=self._rg_name,
**self.spark_params
+ )
+ self.app_id = submit_response.body.data.app_id
+ self.monitor_application()
+ return self.app_id
+
+
+class AnalyticDBSparkBatchOperator(AnalyticDBSparkBaseOperator):
+ """
+ This operator warps the AnalyticDB Spark REST API, allowing to submit a
Spark batch
+ application to the underlying cluster.
+
+ :param file: path of the file containing the application to execute.
+ :param class_name: name of the application Java/Spark main class.
+ :param args: application command line arguments.
+ :param conf: Spark configuration properties.
+ :param jars: jars to be used in this application.
+ :param py_files: python files to be used in this application.
+ :param files: files to be used in this application.
+ :param driver_resource_spec: The resource specifications of the Spark
driver.
+ :param executor_resource_spec: The resource specifications of each Spark
executor.
+ :param num_executors: number of executors to launch for this application.
+ :param archives: archives to be used in this application.
+ :param name: name of this application.
+ :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse.
+ :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data
Lakehouse cluster.
+ """
+
+ template_fields: Sequence[str] = ("spark_params",)
+ template_fields_renderers = {"spark_params": "json"}
+
+ def __init__(
+ self,
+ *,
+ file: str,
+ class_name: str | None = None,
+ args: Sequence[str | int | float] | None = None,
+ conf: dict[Any, Any] | None = None,
+ jars: Sequence[str] | None = None,
+ py_files: Sequence[str] | None = None,
+ files: Sequence[str] | None = None,
+ driver_resource_spec: str | None = None,
+ executor_resource_spec: str | None = None,
+ num_executors: int | str | None = None,
+ archives: Sequence[str] | None = None,
+ name: str | None = None,
+ cluster_id: str,
+ rg_name: str,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.spark_params = {
+ "file": file,
+ "class_name": class_name,
+ "args": args,
+ "conf": conf,
+ "jars": jars,
+ "py_files": py_files,
+ "files": files,
+ "driver_resource_spec": driver_resource_spec,
+ "executor_resource_spec": executor_resource_spec,
+ "num_executors": num_executors,
+ "archives": archives,
+ "name": name,
+ }
+
+ self._cluster_id = cluster_id
+ self._rg_name = rg_name
+
+ def execute(self, context: Context) -> Any:
+ submit_response = self.get_hook.submit_spark_app(
+ cluster_id=self._cluster_id, rg_name=self._rg_name,
**self.spark_params
+ )
+ self.app_id = submit_response.body.data.app_id
+ self.monitor_application()
+ return self.app_id
diff --git a/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py
b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py
new file mode 100644
index 0000000000..fb6a962d43
--- /dev/null
+++ b/airflow/providers/alibaba/cloud/sensors/analyticdb_spark.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import
AnalyticDBSparkHook, AppState
+from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class AnalyticDBSparkSensor(BaseSensorOperator):
+ """
+ Monitor a AnalyticDB Spark session for termination.
+
+ :param app_id: identifier of the monitored app depends on the option
that's being modified.
+ :param adb_spark_conn_id: reference to a pre-defined ADB Spark connection.
+ :param region: AnalyticDB MySQL region you want to submit spark
application.
+ """
+
+ template_fields: Sequence[str] = ("app_id",)
+
+ def __init__(
+ self,
+ *,
+ app_id: str,
+ adb_spark_conn_id: str = "adb_spark_default",
+ region: str | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.app_id = app_id
+ self._region = region
+ self._adb_spark_conn_id = adb_spark_conn_id
+ self._adb_spark_hook: AnalyticDBSparkHook | None = None
+
+ @cached_property
+ def get_hook(self) -> AnalyticDBSparkHook:
+ """Get valid hook."""
+ if self._adb_spark_hook is None or not
isinstance(self._adb_spark_hook, AnalyticDBSparkHook):
+ self._adb_spark_hook = AnalyticDBSparkHook(
+ adb_spark_conn_id=self._adb_spark_conn_id, region=self._region
+ )
+ return self._adb_spark_hook
+
+ def poke(self, context: Context) -> bool:
+ app_id = self.app_id
+
+ state = self.get_hook.get_spark_state(app_id)
+ return AppState(state) in AnalyticDBSparkHook.TERMINAL_STATES
diff --git a/airflow/providers/alibaba/provider.yaml
b/airflow/providers/alibaba/provider.yaml
index aa2d998971..660dccdcd6 100644
--- a/airflow/providers/alibaba/provider.yaml
+++ b/airflow/providers/alibaba/provider.yaml
@@ -38,6 +38,8 @@ versions:
dependencies:
- apache-airflow>=2.4.0
- oss2>=2.14.0
+ - alibabacloud_adb20211201>=1.0.0
+ - alibabacloud_tea_openapi>=0.3.7
integrations:
- integration-name: Alibaba Cloud OSS
@@ -46,26 +48,42 @@ integrations:
how-to-guide:
- /docs/apache-airflow-providers-alibaba/operators/oss.rst
tags: [alibaba]
+ - integration-name: Alibaba Cloud AnalyticDB Spark
+ external-doc-url:
https://www.alibabacloud.com/help/en/analyticdb-for-mysql/latest/spark-developerment
+ how-to-guide:
+ - /docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst
+ tags: [alibaba]
operators:
- integration-name: Alibaba Cloud OSS
python-modules:
- airflow.providers.alibaba.cloud.operators.oss
+ - integration-name: Alibaba Cloud AnalyticDB Spark
+ python-modules:
+ - airflow.providers.alibaba.cloud.operators.analyticdb_spark
sensors:
- integration-name: Alibaba Cloud OSS
python-modules:
- airflow.providers.alibaba.cloud.sensors.oss_key
+ - integration-name: Alibaba Cloud AnalyticDB Spark
+ python-modules:
+ - airflow.providers.alibaba.cloud.sensors.analyticdb_spark
hooks:
- integration-name: Alibaba Cloud OSS
python-modules:
- airflow.providers.alibaba.cloud.hooks.oss
+ - integration-name: Alibaba Cloud AnalyticDB Spark
+ python-modules:
+ - airflow.providers.alibaba.cloud.hooks.analyticdb_spark
connection-types:
- hook-class-name: airflow.providers.alibaba.cloud.hooks.oss.OSSHook
connection-type: oss
+ - hook-class-name:
airflow.providers.alibaba.cloud.hooks.analyticdb_spark.AnalyticDBSparkHook
+ connection-type: adb_spark
logging:
- airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler
diff --git a/docs/apache-airflow-providers-alibaba/connections/alibaba.rst
b/docs/apache-airflow-providers-alibaba/connections/alibaba.rst
index d697b4dc0c..4cf4747d7e 100644
--- a/docs/apache-airflow-providers-alibaba/connections/alibaba.rst
+++ b/docs/apache-airflow-providers-alibaba/connections/alibaba.rst
@@ -26,7 +26,7 @@ Authentication may be performed using `Security Token Service
(STS) or a signed
Default Connection IDs
----------------------
-The default connection ID is ``oss_default``.
+The default connection IDs are ``oss_default`` and ``adb_spark_default``.
Configuring the Connection
--------------------------
diff --git a/docs/apache-airflow-providers-alibaba/index.rst
b/docs/apache-airflow-providers-alibaba/index.rst
index 8e0da17c71..c050e7e9da 100644
--- a/docs/apache-airflow-providers-alibaba/index.rst
+++ b/docs/apache-airflow-providers-alibaba/index.rst
@@ -85,11 +85,13 @@ Requirements
The minimum Apache Airflow version supported by this provider package is
``2.4.0``.
-================== ==================
-PIP package Version required
-================== ==================
-``apache-airflow`` ``>=2.4.0``
-``oss2`` ``>=2.14.0``
-================== ==================
+============================ ==================
+PIP package Version required
+============================ ==================
+``apache-airflow`` ``>=2.4.0``
+``oss2`` ``>=2.14.0``
+``alibabacloud_adb20211201`` ``>=1.0.0``
+``alibabacloud_tea_openapi`` ``>=0.3.7``
+============================ ==================
.. include:: ../../airflow/providers/alibaba/CHANGELOG.rst
diff --git
a/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst
b/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst
new file mode 100644
index 0000000000..ac3f0638ad
--- /dev/null
+++ b/docs/apache-airflow-providers-alibaba/operators/analyticdb_spark.rst
@@ -0,0 +1,45 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Alibaba Cloud AnalyticDB Spark Operators
+========================================
+
+Overview
+--------
+
+Airflow to Alibaba Cloud AnalyticDB Spark integration provides several
operators to develop spark batch and sql applications.
+
+ -
:class:`~airflow.providers.alibaba.cloud.operators.analyticdb_spark.AnalyticDBSparkBatchOperator`
+ -
:class:`~airflow.providers.alibaba.cloud.operators.analyticdb_spark.AnalyticDBSparkSQLOperator`
+
+Develop Spark batch applications
+-------------------------------------------
+
+Purpose
+"""""""
+
+This example dag uses ``AnalyticDBSparkBatchOperator`` to submit Spark Pi and
Spark Logistic regression applications.
+
+Defining tasks
+""""""""""""""
+
+In the following code we submit Spark Pi and Spark Logistic regression
applications.
+
+.. exampleinclude::
/../../tests/system/providers/alibaba/example_adb_spark_batch.py
+ :language: python
+ :start-after: [START howto_operator_adb_spark_batch]
+ :end-before: [END howto_operator_adb_spark_batch]
diff --git a/docs/conf.py b/docs/conf.py
index 27b3e07384..f45956c9c0 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -588,6 +588,8 @@ elif PACKAGE_NAME == "helm-chart":
autodoc_mock_imports = [
"MySQLdb",
"adal",
+ "alibabacloud_adb20211201",
+ "alibabacloud_tea_openapi",
"analytics",
"azure",
"azure.cosmos",
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index d2a759e931..f51dc1215f 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -862,6 +862,7 @@ kwargs
KYLIN
Kylin
kylin
+Lakehouse
LanguageServiceClient
lastname
latencies
@@ -1357,6 +1358,7 @@ sourceArchiveUrl
sourceRepository
sourceUploadUrl
Spark
+sparkappinfo
sparkApplication
sparkcmd
SparkPi
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 2755e9f08b..5ce7ccd295 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -11,6 +11,8 @@
},
"alibaba": {
"deps": [
+ "alibabacloud_adb20211201>=1.0.0",
+ "alibabacloud_tea_openapi>=0.3.7",
"apache-airflow>=2.4.0",
"oss2>=2.14.0"
],
diff --git a/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py
b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py
new file mode 100644
index 0000000000..bf38a3f7ca
--- /dev/null
+++ b/tests/providers/alibaba/cloud/hooks/test_analyticdb_spark.py
@@ -0,0 +1,203 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from alibabacloud_adb20211201.models import (
+ GetSparkAppLogResponse,
+ GetSparkAppLogResponseBody,
+ GetSparkAppLogResponseBodyData,
+ GetSparkAppStateResponse,
+ GetSparkAppStateResponseBody,
+ GetSparkAppStateResponseBodyData,
+ GetSparkAppWebUiAddressResponse,
+ GetSparkAppWebUiAddressResponseBody,
+ GetSparkAppWebUiAddressResponseBodyData,
+ KillSparkAppResponse,
+ SubmitSparkAppResponse,
+)
+
+from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import
AnalyticDBSparkHook
+from tests.providers.alibaba.cloud.utils.analyticdb_spark_mock import
mock_adb_spark_hook_default_project_id
+
+ADB_SPARK_STRING = "airflow.providers.alibaba.cloud.hooks.analyticdb_spark.{}"
+MOCK_ADB_SPARK_CONN_ID = "mock_id"
+MOCK_ADB_CLUSTER_ID = "mock_adb_cluster_id"
+MOCK_ADB_RG_NAME = "mock_adb_rg_name"
+MOCK_ADB_SPARK_ID = "mock_adb_spark_id"
+
+
+class TestAnalyticDBSparkHook:
+ def setup_method(self):
+ with mock.patch(
+ ADB_SPARK_STRING.format("AnalyticDBSparkHook.__init__"),
+ new=mock_adb_spark_hook_default_project_id,
+ ):
+ self.hook =
AnalyticDBSparkHook(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID)
+
+ def test_build_submit_app_data(self):
+ """Test build submit application data for analyticDB spark as
expected."""
+ res_data = self.hook.build_submit_app_data(
+ file="oss://test_file",
+ class_name="com.aliyun.spark.SparkPi",
+ args=[1000, "test-args"],
+ conf={"spark.executor.instances": 1, "spark.eventLog.enabled":
"true"},
+ jars=["oss://1.jar", "oss://2.jar"],
+ py_files=["oss://1.py", "oss://2.py"],
+ files=["oss://1.file", "oss://2.file"],
+ driver_resource_spec="medium",
+ executor_resource_spec="medium",
+ num_executors=2,
+ archives=["oss://1.zip", "oss://2.zip"],
+ name="test",
+ )
+ except_data = {
+ "file": "oss://test_file",
+ "className": "com.aliyun.spark.SparkPi",
+ "args": ["1000", "test-args"],
+ "conf": {
+ "spark.executor.instances": 1,
+ "spark.eventLog.enabled": "true",
+ "spark.driver.resourceSpec": "medium",
+ "spark.executor.resourceSpec": "medium",
+ },
+ "jars": ["oss://1.jar", "oss://2.jar"],
+ "pyFiles": ["oss://1.py", "oss://2.py"],
+ "files": ["oss://1.file", "oss://2.file"],
+ "archives": ["oss://1.zip", "oss://2.zip"],
+ "name": "test",
+ }
+ assert res_data == except_data
+
+ def test_build_submit_sql_data(self):
+ """Test build submit sql data for analyticDB spark as expected."""
+ res_data = self.hook.build_submit_sql_data(
+ sql="""
+ set spark.executor.instances=1;
+ show databases;
+ """,
+ conf={"spark.executor.instances": 2},
+ driver_resource_spec="medium",
+ executor_resource_spec="medium",
+ num_executors=3,
+ name="test",
+ )
+ except_data = (
+ "set spark.driver.resourceSpec = medium;set
spark.executor.resourceSpec = medium;set "
+ "spark.executor.instances = 2;set spark.app.name = test;\n
set "
+ "spark.executor.instances=1;\n show databases;"
+ )
+ assert res_data == except_data
+
+
@mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client"))
+ def test_submit_spark_app(self, mock_service):
+ """Test submit_spark_app function works as expected."""
+ # Given
+ mock_client = mock_service.return_value
+ exists_method = mock_client.submit_spark_app
+ exists_method.return_value = SubmitSparkAppResponse(status_code=200)
+
+ # When
+ res = self.hook.submit_spark_app(MOCK_ADB_CLUSTER_ID,
MOCK_ADB_RG_NAME, "oss://test.py")
+
+ # Then
+ assert isinstance(res, SubmitSparkAppResponse)
+ mock_service.assert_called_once_with()
+
+
@mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client"))
+ def test_submit_spark_sql(self, mock_service):
+ """Test submit_spark_app function works as expected."""
+ # Given
+ mock_client = mock_service.return_value
+ exists_method = mock_client.submit_spark_app
+ exists_method.return_value = SubmitSparkAppResponse(status_code=200)
+
+ # When
+ res = self.hook.submit_spark_sql(MOCK_ADB_CLUSTER_ID,
MOCK_ADB_RG_NAME, "SELECT 1")
+
+ # Then
+ assert isinstance(res, SubmitSparkAppResponse)
+ mock_service.assert_called_once_with()
+
+
@mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client"))
+ def test_get_spark_state(self, mock_service):
+ """Test get_spark_state function works as expected."""
+ # Given
+ mock_client = mock_service.return_value
+ exists_method = mock_client.get_spark_app_state
+ exists_method.return_value = GetSparkAppStateResponse(
+
body=GetSparkAppStateResponseBody(data=GetSparkAppStateResponseBodyData(state="RUNNING"))
+ )
+
+ # When
+ res = self.hook.get_spark_state(MOCK_ADB_SPARK_ID)
+
+ # Then
+ assert res == "RUNNING"
+ mock_service.assert_called_once_with()
+
+
@mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client"))
+ def test_get_spark_web_ui_address(self, mock_service):
+ """Test get_spark_web_ui_address function works as expected."""
+ # Given
+ mock_client = mock_service.return_value
+ exists_method = mock_client.get_spark_app_web_ui_address
+ exists_method.return_value = GetSparkAppWebUiAddressResponse(
+ body=GetSparkAppWebUiAddressResponseBody(
+
data=GetSparkAppWebUiAddressResponseBodyData(web_ui_address="https://mock-web-ui-address.com")
+ )
+ )
+
+ # When
+ res = self.hook.get_spark_web_ui_address(MOCK_ADB_SPARK_ID)
+
+ # Then
+ assert res == "https://mock-web-ui-address.com"
+ mock_service.assert_called_once_with()
+
+
@mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client"))
+ def test_get_spark_log(self, mock_service):
+ """Test get_spark_log function works as expected."""
+ # Given
+ mock_client = mock_service.return_value
+ exists_method = mock_client.get_spark_app_log
+ exists_method.return_value = GetSparkAppLogResponse(
+
body=GetSparkAppLogResponseBody(data=GetSparkAppLogResponseBodyData(log_content="Pi
is 3.14"))
+ )
+
+ # When
+ res = self.hook.get_spark_log(MOCK_ADB_SPARK_ID)
+
+ # Then
+ assert res == "Pi is 3.14"
+ mock_service.assert_called_once_with()
+
+
@mock.patch(ADB_SPARK_STRING.format("AnalyticDBSparkHook.get_adb_spark_client"))
+ def test_kill_spark_app(self, mock_service):
+ """Test kill_spark_app function works as expected."""
+ # Given
+ mock_client = mock_service.return_value
+ exists_method = mock_client.kill_spark_app
+ exists_method.return_value = KillSparkAppResponse()
+
+ # When
+ self.hook.kill_spark_app(MOCK_ADB_SPARK_ID)
+
+ # Then
+ mock_service.assert_called_once_with()
diff --git a/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py
b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py
new file mode 100644
index 0000000000..eb2db3ff39
--- /dev/null
+++ b/tests/providers/alibaba/cloud/operators/test_analyticdb_spark.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow import AirflowException
+from airflow.providers.alibaba.cloud.operators.analyticdb_spark import (
+ AnalyticDBSparkBaseOperator,
+ AnalyticDBSparkBatchOperator,
+ AnalyticDBSparkSQLOperator,
+)
+
+ADB_SPARK_OPERATOR_STRING =
"airflow.providers.alibaba.cloud.operators.analyticdb_spark.{}"
+
+MOCK_FILE = "oss://test.py"
+MOCK_CLUSTER_ID = "mock_cluster_id"
+MOCK_RG_NAME = "mock_rg_name"
+MOCK_ADB_SPARK_CONN_ID = "mock_adb_spark_conn_id"
+MOCK_REGION = "mock_region"
+MOCK_TASK_ID = "mock_task_id"
+MOCK_APP_ID = "mock_app_id"
+MOCK_SQL = "SELECT 1;"
+
+
+class TestAnalyticDBSparkBaseOperator:
+ def setup_method(self):
+ self.operator = AnalyticDBSparkBaseOperator(
+ adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
+ region=MOCK_REGION,
+ task_id=MOCK_TASK_ID,
+ )
+
+ @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook"))
+ def test_get_hook(self, mock_hook):
+ """Test get_hook function works as expected."""
+ self.operator.get_hook()
+
mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
region=MOCK_REGION)
+
+
@mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook"))
+ def test_poll_for_termination(self, mock_hook):
+ """Test poll_for_termination works as expected with COMPLETED
application."""
+ # Given
+ mock_hook.get_spark_state.return_value = "COMPLETED"
+
+ # When
+ self.operator.poll_for_termination(MOCK_APP_ID)
+
+
@mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook"))
+ def test_poll_for_termination_with_exception(self, mock_hook):
+ """Test poll_for_termination raises AirflowException with FATAL
application."""
+ # Given
+ mock_hook.get_spark_state.return_value = "FATAL"
+
+ # When
+ with pytest.raises(AirflowException, match="Application mock_app_id
did not succeed"):
+ self.operator.poll_for_termination(MOCK_APP_ID)
+
+
+class TestAnalyticDBSparkBatchOperator:
+ @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook"))
+ def test_execute(self, mock_hook):
+ """Test submit AnalyticDB Spark Batch Application works as expected."""
+ operator = AnalyticDBSparkBatchOperator(
+ file=MOCK_FILE,
+ cluster_id=MOCK_CLUSTER_ID,
+ rg_name=MOCK_RG_NAME,
+ adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
+ region=MOCK_REGION,
+ task_id=MOCK_TASK_ID,
+ )
+
+ operator.execute(None)
+
+
mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
region=MOCK_REGION)
+ mock_hook.return_value.submit_spark_app.assert_called_once_with(
+ cluster_id=MOCK_CLUSTER_ID,
+ rg_name=MOCK_RG_NAME,
+ file=MOCK_FILE,
+ class_name=None,
+ args=None,
+ conf=None,
+ jars=None,
+ py_files=None,
+ files=None,
+ driver_resource_spec=None,
+ executor_resource_spec=None,
+ num_executors=None,
+ archives=None,
+ name=None,
+ )
+
+
@mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook"))
+ def test_execute_with_exception(self, mock_hook):
+ """Test submit AnalyticDB Spark Batch Application raises ValueError
with invalid parameter."""
+ # Given
+ mock_hook.submit_spark_app.side_effect = ValueError("List of strings
expected")
+
+ # When
+ operator = AnalyticDBSparkBatchOperator(
+ file=MOCK_FILE,
+ args=(True, False),
+ cluster_id=MOCK_CLUSTER_ID,
+ rg_name=MOCK_RG_NAME,
+ adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
+ region=MOCK_REGION,
+ task_id=MOCK_TASK_ID,
+ )
+
+ with pytest.raises(ValueError, match="List of strings expected"):
+ operator.execute(None)
+
+
+class TestAnalyticDBSparklSQLOperator:
+ @mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkHook"))
+ def test_execute(self, mock_hook):
+ """Test submit AnalyticDB Spark SQL Application works as expected."""
+ operator = AnalyticDBSparkSQLOperator(
+ sql=MOCK_SQL,
+ cluster_id=MOCK_CLUSTER_ID,
+ rg_name=MOCK_RG_NAME,
+ adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
+ region=MOCK_REGION,
+ task_id=MOCK_TASK_ID,
+ )
+
+ operator.execute(None)
+
+
mock_hook.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
region=MOCK_REGION)
+ mock_hook.return_value.submit_spark_sql.assert_called_once_with(
+ cluster_id=MOCK_CLUSTER_ID,
+ rg_name=MOCK_RG_NAME,
+ sql=MOCK_SQL,
+ conf=None,
+ driver_resource_spec=None,
+ executor_resource_spec=None,
+ num_executors=None,
+ name=None,
+ )
+
+
@mock.patch(ADB_SPARK_OPERATOR_STRING.format("AnalyticDBSparkBaseOperator.get_hook"))
+ def test_execute_with_exception(self, mock_hook):
+ """Test submit AnalyticDB Spark SQL Application raises ValueError with
invalid parameter."""
+ # Given
+ mock_hook.submit_spark_sql.side_effect = ValueError("List of strings
expected")
+
+ # When
+ operator = AnalyticDBSparkSQLOperator(
+ sql=MOCK_SQL,
+ conf={"spark.eventLog.enabled": True},
+ cluster_id=MOCK_CLUSTER_ID,
+ rg_name=MOCK_RG_NAME,
+ adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
+ region=MOCK_REGION,
+ task_id=MOCK_TASK_ID,
+ )
+
+ with pytest.raises(ValueError, match="List of strings expected"):
+ operator.execute(None)
diff --git a/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py
b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py
new file mode 100644
index 0000000000..8cef517500
--- /dev/null
+++ b/tests/providers/alibaba/cloud/sensors/test_analyticdb_spark.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+from airflow.providers.alibaba.cloud.sensors.analyticdb_spark import
AnalyticDBSparkSensor
+from airflow.utils import timezone
+
+ADB_SPARK_SENSOR_STRING =
"airflow.providers.alibaba.cloud.sensors.analyticdb_spark.{}"
+DEFAULT_DATE = timezone.datetime(2017, 1, 1)
+MOCK_ADB_SPARK_CONN_ID = "mock_adb_spark_default"
+MOCK_ADB_SPARK_ID = "mock_adb_spark_id"
+MOCK_SENSOR_TASK_ID = "test-adb-spark-operator"
+MOCK_REGION = "mock_region"
+
+
+class TestAnalyticDBSparkSensor:
+ def setup_method(self):
+ self.sensor = AnalyticDBSparkSensor(
+ app_id=MOCK_ADB_SPARK_ID,
+ adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
+ region=MOCK_REGION,
+ task_id=MOCK_SENSOR_TASK_ID,
+ )
+
+ @mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkHook"))
+ def test_get_hook(self, mock_service):
+ """Test get_hook function works as expected."""
+ self.sensor.get_hook()
+
mock_service.assert_called_once_with(adb_spark_conn_id=MOCK_ADB_SPARK_CONN_ID,
region=MOCK_REGION)
+
+
@mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook"))
+ def test_poke_terminal_state(self, mock_service):
+ """Test poke_terminal_state works as expected with COMPLETED
application."""
+ # Given
+ mock_service.get_spark_state.return_value = "COMPLETED"
+
+ # When
+ res = self.sensor.poke(None)
+
+ # Then
+ assert res is True
+ mock_service.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID)
+
+
@mock.patch(ADB_SPARK_SENSOR_STRING.format("AnalyticDBSparkSensor.get_hook"))
+ def test_poke_non_terminal_state(self, mock_service):
+ """Test poke_terminal_state works as expected with RUNNING
application."""
+ # Given
+ mock_service.get_spark_state.return_value = "RUNNING"
+
+ # When
+ res = self.sensor.poke(None)
+
+ # Then
+ assert res is False
+ mock_service.get_spark_state.assert_called_once_with(MOCK_ADB_SPARK_ID)
diff --git a/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py
b/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py
new file mode 100644
index 0000000000..b43cc1feb2
--- /dev/null
+++ b/tests/providers/alibaba/cloud/utils/analyticdb_spark_mock.py
@@ -0,0 +1,41 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+
+from airflow.models import Connection
+
+ANALYTICDB_SPARK_PROJECT_ID_HOOK_UNIT_TEST = "example-project"
+
+
+def mock_adb_spark_hook_default_project_id(
+ self, adb_spark_conn_id="mock_adb_spark_default", region="mock_region"
+):
+ self.adb_spark_conn_id = adb_spark_conn_id
+ self.adb_spark_conn = Connection(
+ extra=json.dumps(
+ {
+ "auth_type": "AK",
+ "access_key_id": "mock_access_key_id",
+ "access_key_secret": "mock_access_key_secret",
+ "region": "mock_region",
+ }
+ )
+ )
+ self.region = region
diff --git a/tests/system/providers/alibaba/example_adb_spark_batch.py
b/tests/system/providers/alibaba/example_adb_spark_batch.py
new file mode 100644
index 0000000000..b6945190bc
--- /dev/null
+++ b/tests/system/providers/alibaba/example_adb_spark_batch.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+# Ignore missing args provided by default_args
+# type: ignore[call-arg]
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.alibaba.cloud.operators.analyticdb_spark import
AnalyticDBSparkBatchOperator
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "adb_spark_batch_dag"
+# [START howto_operator_adb_spark_batch]
+with DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ default_args={"cluster_id": "your cluster", "rg_name": "your resource
group", "region": "your region"},
+ max_active_runs=1,
+ catchup=False,
+) as dag:
+
+ spark_pi = AnalyticDBSparkBatchOperator(
+ task_id="task1",
+ file="local:///tmp/spark-examples.jar",
+ class_name="org.apache.spark.examples.SparkPi",
+ )
+
+ spark_lr = AnalyticDBSparkBatchOperator(
+ task_id="task2",
+ file="local:///tmp/spark-examples.jar",
+ class_name="org.apache.spark.examples.SparkLR",
+ )
+
+ spark_pi >> spark_lr
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+# [END howto_operator_adb_spark_batch]
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git a/tests/system/providers/alibaba/example_adb_spark_sql.py
b/tests/system/providers/alibaba/example_adb_spark_sql.py
new file mode 100644
index 0000000000..851880fa73
--- /dev/null
+++ b/tests/system/providers/alibaba/example_adb_spark_sql.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+# Ignore missing args provided by default_args
+# type: ignore[call-arg]
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.alibaba.cloud.operators.analyticdb_spark import
AnalyticDBSparkSQLOperator
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+DAG_ID = "adb_spark_sql_dag"
+# [START howto_operator_adb_spark_sql]
+with DAG(
+ dag_id=DAG_ID,
+ start_date=datetime(2021, 1, 1),
+ default_args={"cluster_id": "your cluster", "rg_name": "your resource
group", "region": "your region"},
+ max_active_runs=1,
+ catchup=False,
+) as dag:
+
+ show_databases = AnalyticDBSparkSQLOperator(task_id="task1", sql="SHOE
DATABASES;")
+
+ show_tables = AnalyticDBSparkSQLOperator(task_id="task2", sql="SHOW
TABLES;")
+
+ show_databases >> show_tables
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+# [END howto_operator_adb_spark_sql]
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)