pankajkoti commented on code in PR #48507:
URL: https://github.com/apache/airflow/pull/48507#discussion_r2025313956
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -969,6 +978,183 @@ def on_kill(self) -> None:
self.log.error("Error: Task: %s with invalid run_id was requested
to be cancelled.", self.task_id)
+class DatabricksSQLStatementsOperator(BaseOperator):
+ """
+ Submits a Databricks SQL Statement to Databricks using the
api/2.0/sql/statements/ API endpoint.
+
+ See: https://docs.databricks.com/api/workspace/statementexecution
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:DatabricksSQLStatementsOperator`
+
+ :param statement: The SQL statement to execute. The statement can
optionally be parameterized, see parameters.
+ :param warehouse_id: Warehouse upon which to execute a statement.
+ :param catalog: Sets default catalog for statement execution, similar to
USE CATALOG in SQL.
+ :param schema: Sets default schema for statement execution, similar to USE
SCHEMA in SQL.
+ :param parameters: A list of parameters to pass into a SQL statement
containing parameter markers.
+
+ .. seealso::
+
https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters
+ :param wait_for_termination: if we should wait for termination of the
statement execution. ``True`` by default.
+ :param databricks_conn_id: Reference to the :ref:`Databricks connection
<howto/connection:databricks>`.
+ By default and in the common case this will be ``databricks_default``.
To use
+ token based authentication, provide the key ``token`` in the extra
field for the
+ connection and create the key ``host`` and leave the ``host`` field
empty. (templated)
+ :param polling_period_seconds: Controls the rate which we poll for the
result of
+ this statement. By default the operator will poll every 30 seconds.
+ :param databricks_retry_limit: Amount of times retry if the Databricks
backend is
+ unreachable. Its value must be greater than or equal to 1.
+ :param databricks_retry_delay: Number of seconds to wait between retries
(it
+ might be a floating point number).
+ :param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
+ :param do_xcom_push: Whether we should push statement_id to xcom.
+ :param deferrable: Run operator in the deferrable mode.
+ """
+
+ # Used in airflow.models.BaseOperator
+ template_fields: Sequence[str] = ("databricks_conn_id",)
+ template_ext: Sequence[str] = (".json-tpl",)
+ # Databricks brand color (blue) under white text
+ ui_color = "#1CB1C2"
+ ui_fgcolor = "#fff"
+
+ def __init__(
+ self,
+ statement: str,
+ warehouse_id: str,
+ *,
+ catalog: str | None = None,
+ schema: str | None = None,
+ parameters: list[dict[str, Any]] | None = None,
+ databricks_conn_id: str = "databricks_default",
+ polling_period_seconds: int = 30,
+ databricks_retry_limit: int = 3,
+ databricks_retry_delay: int = 1,
+ databricks_retry_args: dict[Any, Any] | None = None,
+ do_xcom_push: bool = True,
+ wait_for_termination: bool = True,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ) -> None:
+ """Create a new ``DatabricksSubmitRunOperator``."""
+ super().__init__(**kwargs)
+ self.statement = statement
+ self.warehouse_id = warehouse_id
+ self.catalog = catalog
+ self.schema = schema
+ self.parameters = parameters
+ self.databricks_conn_id = databricks_conn_id
+ self.polling_period_seconds = polling_period_seconds
+ self.databricks_retry_limit = databricks_retry_limit
+ self.databricks_retry_delay = databricks_retry_delay
+ self.databricks_retry_args = databricks_retry_args
+ self.wait_for_termination = wait_for_termination
+ self.deferrable = deferrable
+
+ # This variable will be used in case our task gets killed.
+ self.statement_id: str | None = None
+ self.do_xcom_push = do_xcom_push
+
+ @cached_property
+ def _hook(self):
+ return self._get_hook(caller="DatabricksSQLStatementsOperator")
+
+ def _get_hook(self, caller: str) -> DatabricksHook:
+ return DatabricksHook(
+ self.databricks_conn_id,
+ retry_limit=self.databricks_retry_limit,
+ retry_delay=self.databricks_retry_delay,
+ retry_args=self.databricks_retry_args,
+ caller=caller,
+ )
+
+ def _handle_operator_execution(self) -> None:
+ while True:
+ statement_state =
self._hook.get_sql_statement_state(self.statement_id)
+ if statement_state.is_terminal:
+ if statement_state.is_successful:
+ self.log.info("%s completed successfully.", self.task_id)
+ return
+ error_message = (
+ f"{self.task_id} failed with terminal state:
{statement_state.state} "
+ f"and with the error code {statement_state.error_code} "
+ f"and error message {statement_state.error_message}"
+ )
+ raise AirflowException(error_message)
+
+ self.log.info("%s in run state: %s", self.task_id,
statement_state.state)
+ self.log.info("Sleeping for %s seconds.",
self.polling_period_seconds)
+ time.sleep(self.polling_period_seconds)
+
+ def _handle_deferrable_operator_execution(self) -> None:
+ statement_state = self._hook.get_sql_statement_state(self.statement_id)
+ if not statement_state.is_terminal:
+ self.defer(
Review Comment:
Good idea! I noticed that most providers including Databricks and operators
currently lack this feature. However, I’ve now introduced a timeout for both
synchronous and asynchronous execution in this newly added operator.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]