pankajkoti commented on code in PR #48507:
URL: https://github.com/apache/airflow/pull/48507#discussion_r2025311741


##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -969,6 +978,183 @@ def on_kill(self) -> None:
             self.log.error("Error: Task: %s with invalid run_id was requested 
to be cancelled.", self.task_id)
 
 
+class DatabricksSQLStatementsOperator(BaseOperator):
+    """
+    Submits a Databricks SQL Statement to Databricks using the 
api/2.0/sql/statements/ API endpoint.
+
+    See: https://docs.databricks.com/api/workspace/statementexecution
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksSQLStatementsOperator`
+
+    :param statement: The SQL statement to execute. The statement can 
optionally be parameterized, see parameters.
+    :param warehouse_id: Warehouse upon which to execute a statement.
+    :param catalog: Sets default catalog for statement execution, similar to 
USE CATALOG in SQL.
+    :param schema: Sets default schema for statement execution, similar to USE 
SCHEMA in SQL.
+    :param parameters: A list of parameters to pass into a SQL statement 
containing parameter markers.
+
+        .. seealso::
+            
https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters
+    :param wait_for_termination: if we should wait for termination of the 
statement execution. ``True`` by default.
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+        By default and in the common case this will be ``databricks_default``. 
To use
+        token based authentication, provide the key ``token`` in the extra 
field for the
+        connection and create the key ``host`` and leave the ``host`` field 
empty. (templated)
+    :param polling_period_seconds: Controls the rate which we poll for the 
result of
+        this statement. By default the operator will poll every 30 seconds.
+    :param databricks_retry_limit: Amount of times retry if the Databricks 
backend is
+        unreachable. Its value must be greater than or equal to 1.
+    :param databricks_retry_delay: Number of seconds to wait between retries 
(it
+            might be a floating point number).
+    :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
+    :param do_xcom_push: Whether we should push statement_id to xcom.
+    :param deferrable: Run operator in the deferrable mode.
+    """
+
+    # Used in airflow.models.BaseOperator
+    template_fields: Sequence[str] = ("databricks_conn_id",)
+    template_ext: Sequence[str] = (".json-tpl",)
+    # Databricks brand color (blue) under white text
+    ui_color = "#1CB1C2"
+    ui_fgcolor = "#fff"
+
+    def __init__(
+        self,
+        statement: str,
+        warehouse_id: str,
+        *,
+        catalog: str | None = None,
+        schema: str | None = None,
+        parameters: list[dict[str, Any]] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        polling_period_seconds: int = 30,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        do_xcom_push: bool = True,
+        wait_for_termination: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ) -> None:
+        """Create a new ``DatabricksSubmitRunOperator``."""
+        super().__init__(**kwargs)
+        self.statement = statement
+        self.warehouse_id = warehouse_id
+        self.catalog = catalog
+        self.schema = schema
+        self.parameters = parameters
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.wait_for_termination = wait_for_termination
+        self.deferrable = deferrable
+
+        # This variable will be used in case our task gets killed.
+        self.statement_id: str | None = None
+        self.do_xcom_push = do_xcom_push
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksSQLStatementsOperator")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def _handle_operator_execution(self) -> None:
+        while True:
+            statement_state = 
self._hook.get_sql_statement_state(self.statement_id)
+            if statement_state.is_terminal:
+                if statement_state.is_successful:
+                    self.log.info("%s completed successfully.", self.task_id)
+                    return
+                error_message = (
+                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
+                    f"and with the error code {statement_state.error_code} "
+                    f"and error message {statement_state.error_message}"
+                )
+                raise AirflowException(error_message)
+
+            self.log.info("%s in run state: %s", self.task_id, 
statement_state.state)
+            self.log.info("Sleeping for %s seconds.", 
self.polling_period_seconds)
+            time.sleep(self.polling_period_seconds)
+
+    def _handle_deferrable_operator_execution(self) -> None:
+        statement_state = self._hook.get_sql_statement_state(self.statement_id)
+        if not statement_state.is_terminal:
+            self.defer(
+                trigger=DatabricksSQLStatementExecutionTrigger(
+                    statement_id=self.statement_id,
+                    databricks_conn_id=self.databricks_conn_id,
+                    polling_period_seconds=self.polling_period_seconds,
+                    retry_limit=self.databricks_retry_limit,
+                    retry_delay=self.databricks_retry_delay,
+                    retry_args=self.databricks_retry_args,
+                ),
+                method_name=DEFER_METHOD_NAME,

Review Comment:
   I observed that the module had a constant defined for this 
https://github.com/apache/airflow/blob/e85b61dd40ef6a1a16fd9260aff4fe6f228be6e8/providers/databricks/src/airflow/providers/databricks/operators/databricks.py#L58
 whose value is basically execute_complete. I am simply re-using that constant 
here :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to