pankajastro commented on code in PR #48507:
URL: https://github.com/apache/airflow/pull/48507#discussion_r2024748340


##########
providers/databricks/src/airflow/providers/databricks/hooks/databricks.py:
##########
@@ -709,6 +771,47 @@ def update_job_permission(self, job_id: int, json: 
dict[str, Any]) -> dict:
         """
         return self._do_api_call(("PATCH", 
f"api/2.0/permissions/jobs/{job_id}"), json)
 
+    def post_sql_statement(self, json: dict[str, Any]) -> str:

Review Comment:
   just for consistency, shall we add docstring for param json?



##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -969,6 +978,183 @@ def on_kill(self) -> None:
             self.log.error("Error: Task: %s with invalid run_id was requested 
to be cancelled.", self.task_id)
 
 
+class DatabricksSQLStatementsOperator(BaseOperator):
+    """
+    Submits a Databricks SQL Statement to Databricks using the 
api/2.0/sql/statements/ API endpoint.
+
+    See: https://docs.databricks.com/api/workspace/statementexecution
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksSQLStatementsOperator`
+
+    :param statement: The SQL statement to execute. The statement can 
optionally be parameterized, see parameters.
+    :param warehouse_id: Warehouse upon which to execute a statement.
+    :param catalog: Sets default catalog for statement execution, similar to 
USE CATALOG in SQL.
+    :param schema: Sets default schema for statement execution, similar to USE 
SCHEMA in SQL.
+    :param parameters: A list of parameters to pass into a SQL statement 
containing parameter markers.
+
+        .. seealso::
+            
https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters
+    :param wait_for_termination: if we should wait for termination of the 
statement execution. ``True`` by default.
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+        By default and in the common case this will be ``databricks_default``. 
To use
+        token based authentication, provide the key ``token`` in the extra 
field for the
+        connection and create the key ``host`` and leave the ``host`` field 
empty. (templated)
+    :param polling_period_seconds: Controls the rate which we poll for the 
result of
+        this statement. By default the operator will poll every 30 seconds.
+    :param databricks_retry_limit: Amount of times retry if the Databricks 
backend is
+        unreachable. Its value must be greater than or equal to 1.
+    :param databricks_retry_delay: Number of seconds to wait between retries 
(it
+            might be a floating point number).
+    :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
+    :param do_xcom_push: Whether we should push statement_id to xcom.
+    :param deferrable: Run operator in the deferrable mode.
+    """
+
+    # Used in airflow.models.BaseOperator
+    template_fields: Sequence[str] = ("databricks_conn_id",)
+    template_ext: Sequence[str] = (".json-tpl",)
+    # Databricks brand color (blue) under white text
+    ui_color = "#1CB1C2"
+    ui_fgcolor = "#fff"
+
+    def __init__(
+        self,
+        statement: str,
+        warehouse_id: str,
+        *,
+        catalog: str | None = None,
+        schema: str | None = None,
+        parameters: list[dict[str, Any]] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        polling_period_seconds: int = 30,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        do_xcom_push: bool = True,
+        wait_for_termination: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ) -> None:
+        """Create a new ``DatabricksSubmitRunOperator``."""
+        super().__init__(**kwargs)
+        self.statement = statement
+        self.warehouse_id = warehouse_id
+        self.catalog = catalog
+        self.schema = schema
+        self.parameters = parameters
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.wait_for_termination = wait_for_termination
+        self.deferrable = deferrable
+
+        # This variable will be used in case our task gets killed.
+        self.statement_id: str | None = None
+        self.do_xcom_push = do_xcom_push
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksSQLStatementsOperator")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def _handle_operator_execution(self) -> None:
+        while True:
+            statement_state = 
self._hook.get_sql_statement_state(self.statement_id)
+            if statement_state.is_terminal:
+                if statement_state.is_successful:
+                    self.log.info("%s completed successfully.", self.task_id)
+                    return
+                error_message = (
+                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
+                    f"and with the error code {statement_state.error_code} "
+                    f"and error message {statement_state.error_message}"
+                )
+                raise AirflowException(error_message)
+
+            self.log.info("%s in run state: %s", self.task_id, 
statement_state.state)
+            self.log.info("Sleeping for %s seconds.", 
self.polling_period_seconds)
+            time.sleep(self.polling_period_seconds)
+
+    def _handle_deferrable_operator_execution(self) -> None:
+        statement_state = self._hook.get_sql_statement_state(self.statement_id)
+        if not statement_state.is_terminal:
+            self.defer(
+                trigger=DatabricksSQLStatementExecutionTrigger(
+                    statement_id=self.statement_id,
+                    databricks_conn_id=self.databricks_conn_id,
+                    polling_period_seconds=self.polling_period_seconds,
+                    retry_limit=self.databricks_retry_limit,
+                    retry_delay=self.databricks_retry_delay,
+                    retry_args=self.databricks_retry_args,
+                ),
+                method_name=DEFER_METHOD_NAME,

Review Comment:
   what is `DEFER_METHOD_NAME`?



##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -969,6 +978,183 @@ def on_kill(self) -> None:
             self.log.error("Error: Task: %s with invalid run_id was requested 
to be cancelled.", self.task_id)
 
 
+class DatabricksSQLStatementsOperator(BaseOperator):
+    """
+    Submits a Databricks SQL Statement to Databricks using the 
api/2.0/sql/statements/ API endpoint.
+
+    See: https://docs.databricks.com/api/workspace/statementexecution
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksSQLStatementsOperator`
+
+    :param statement: The SQL statement to execute. The statement can 
optionally be parameterized, see parameters.
+    :param warehouse_id: Warehouse upon which to execute a statement.
+    :param catalog: Sets default catalog for statement execution, similar to 
USE CATALOG in SQL.
+    :param schema: Sets default schema for statement execution, similar to USE 
SCHEMA in SQL.
+    :param parameters: A list of parameters to pass into a SQL statement 
containing parameter markers.
+
+        .. seealso::
+            
https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters
+    :param wait_for_termination: if we should wait for termination of the 
statement execution. ``True`` by default.
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+        By default and in the common case this will be ``databricks_default``. 
To use
+        token based authentication, provide the key ``token`` in the extra 
field for the
+        connection and create the key ``host`` and leave the ``host`` field 
empty. (templated)
+    :param polling_period_seconds: Controls the rate which we poll for the 
result of
+        this statement. By default the operator will poll every 30 seconds.
+    :param databricks_retry_limit: Amount of times retry if the Databricks 
backend is
+        unreachable. Its value must be greater than or equal to 1.
+    :param databricks_retry_delay: Number of seconds to wait between retries 
(it
+            might be a floating point number).
+    :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
+    :param do_xcom_push: Whether we should push statement_id to xcom.
+    :param deferrable: Run operator in the deferrable mode.
+    """
+
+    # Used in airflow.models.BaseOperator
+    template_fields: Sequence[str] = ("databricks_conn_id",)
+    template_ext: Sequence[str] = (".json-tpl",)
+    # Databricks brand color (blue) under white text
+    ui_color = "#1CB1C2"
+    ui_fgcolor = "#fff"
+
+    def __init__(
+        self,
+        statement: str,
+        warehouse_id: str,
+        *,
+        catalog: str | None = None,
+        schema: str | None = None,
+        parameters: list[dict[str, Any]] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        polling_period_seconds: int = 30,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        do_xcom_push: bool = True,
+        wait_for_termination: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ) -> None:
+        """Create a new ``DatabricksSubmitRunOperator``."""
+        super().__init__(**kwargs)
+        self.statement = statement
+        self.warehouse_id = warehouse_id
+        self.catalog = catalog
+        self.schema = schema
+        self.parameters = parameters
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.wait_for_termination = wait_for_termination
+        self.deferrable = deferrable
+
+        # This variable will be used in case our task gets killed.

Review Comment:
   These variables are getting used in execute method too so unsure if this 
comment makes sense or I'm missing anything?



##########
providers/databricks/src/airflow/providers/databricks/triggers/databricks.py:
##########
@@ -119,3 +119,84 @@ async def run(self):
                     }
                 )
                 return
+
+
+class DatabricksSQLStatementExecutionTrigger(BaseTrigger):
+    """
+    The trigger handles the logic of async communication with DataBricks SQL 
Statements API.
+
+    :param statement_id: ID of the SQL statement.
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+    :param polling_period_seconds: Controls the rate of the poll for the 
result of this run.
+        By default, the trigger will poll every 30 seconds.
+    :param retry_limit: The number of times to retry the connection in case of 
service outages.
+    :param retry_delay: The number of seconds to wait between retries.
+    :param retry_args: An optional dictionary with arguments passed to 
``tenacity.Retrying`` class.
+    """
+
+    def __init__(
+        self,
+        statement_id: str,
+        databricks_conn_id: str,
+        polling_period_seconds: int = 30,
+        retry_limit: int = 3,
+        retry_delay: int = 10,
+        retry_args: dict[Any, Any] | None = None,
+        caller: str = "DatabricksSQLStatementExecutionTrigger",

Review Comment:
   Also, this param is not included in docstring



##########
providers/databricks/src/airflow/providers/databricks/triggers/databricks.py:
##########
@@ -119,3 +119,84 @@ async def run(self):
                     }
                 )
                 return
+
+
+class DatabricksSQLStatementExecutionTrigger(BaseTrigger):
+    """
+    The trigger handles the logic of async communication with DataBricks SQL 
Statements API.
+
+    :param statement_id: ID of the SQL statement.
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+    :param polling_period_seconds: Controls the rate of the poll for the 
result of this run.
+        By default, the trigger will poll every 30 seconds.
+    :param retry_limit: The number of times to retry the connection in case of 
service outages.
+    :param retry_delay: The number of seconds to wait between retries.
+    :param retry_args: An optional dictionary with arguments passed to 
``tenacity.Retrying`` class.
+    """
+
+    def __init__(
+        self,
+        statement_id: str,
+        databricks_conn_id: str,
+        polling_period_seconds: int = 30,
+        retry_limit: int = 3,
+        retry_delay: int = 10,
+        retry_args: dict[Any, Any] | None = None,
+        caller: str = "DatabricksSQLStatementExecutionTrigger",

Review Comment:
   Shall we remove this?
   
   I noticed that this parameter is not set during the initialization of 
DatabricksSQLStatementExecutionTrigger from the operator, nor is it included in 
the trigger's serialize method.



##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -969,6 +978,183 @@ def on_kill(self) -> None:
             self.log.error("Error: Task: %s with invalid run_id was requested 
to be cancelled.", self.task_id)
 
 
+class DatabricksSQLStatementsOperator(BaseOperator):
+    """
+    Submits a Databricks SQL Statement to Databricks using the 
api/2.0/sql/statements/ API endpoint.
+
+    See: https://docs.databricks.com/api/workspace/statementexecution
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksSQLStatementsOperator`
+
+    :param statement: The SQL statement to execute. The statement can 
optionally be parameterized, see parameters.
+    :param warehouse_id: Warehouse upon which to execute a statement.
+    :param catalog: Sets default catalog for statement execution, similar to 
USE CATALOG in SQL.
+    :param schema: Sets default schema for statement execution, similar to USE 
SCHEMA in SQL.
+    :param parameters: A list of parameters to pass into a SQL statement 
containing parameter markers.
+
+        .. seealso::
+            
https://docs.databricks.com/api/workspace/statementexecution/executestatement#parameters
+    :param wait_for_termination: if we should wait for termination of the 
statement execution. ``True`` by default.
+    :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
+        By default and in the common case this will be ``databricks_default``. 
To use
+        token based authentication, provide the key ``token`` in the extra 
field for the
+        connection and create the key ``host`` and leave the ``host`` field 
empty. (templated)
+    :param polling_period_seconds: Controls the rate which we poll for the 
result of
+        this statement. By default the operator will poll every 30 seconds.
+    :param databricks_retry_limit: Amount of times retry if the Databricks 
backend is
+        unreachable. Its value must be greater than or equal to 1.
+    :param databricks_retry_delay: Number of seconds to wait between retries 
(it
+            might be a floating point number).
+    :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
+    :param do_xcom_push: Whether we should push statement_id to xcom.
+    :param deferrable: Run operator in the deferrable mode.
+    """
+
+    # Used in airflow.models.BaseOperator
+    template_fields: Sequence[str] = ("databricks_conn_id",)
+    template_ext: Sequence[str] = (".json-tpl",)
+    # Databricks brand color (blue) under white text
+    ui_color = "#1CB1C2"
+    ui_fgcolor = "#fff"
+
+    def __init__(
+        self,
+        statement: str,
+        warehouse_id: str,
+        *,
+        catalog: str | None = None,
+        schema: str | None = None,
+        parameters: list[dict[str, Any]] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        polling_period_seconds: int = 30,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        do_xcom_push: bool = True,
+        wait_for_termination: bool = True,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ) -> None:
+        """Create a new ``DatabricksSubmitRunOperator``."""
+        super().__init__(**kwargs)
+        self.statement = statement
+        self.warehouse_id = warehouse_id
+        self.catalog = catalog
+        self.schema = schema
+        self.parameters = parameters
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.wait_for_termination = wait_for_termination
+        self.deferrable = deferrable
+
+        # This variable will be used in case our task gets killed.
+        self.statement_id: str | None = None
+        self.do_xcom_push = do_xcom_push
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksSQLStatementsOperator")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def _handle_operator_execution(self) -> None:
+        while True:
+            statement_state = 
self._hook.get_sql_statement_state(self.statement_id)
+            if statement_state.is_terminal:
+                if statement_state.is_successful:
+                    self.log.info("%s completed successfully.", self.task_id)
+                    return
+                error_message = (
+                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
+                    f"and with the error code {statement_state.error_code} "
+                    f"and error message {statement_state.error_message}"
+                )
+                raise AirflowException(error_message)
+
+            self.log.info("%s in run state: %s", self.task_id, 
statement_state.state)
+            self.log.info("Sleeping for %s seconds.", 
self.polling_period_seconds)
+            time.sleep(self.polling_period_seconds)
+
+    def _handle_deferrable_operator_execution(self) -> None:
+        statement_state = self._hook.get_sql_statement_state(self.statement_id)
+        if not statement_state.is_terminal:
+            self.defer(

Review Comment:
   shall we init timeout param too in self.defer



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to