tatiana commented on code in PR #39178:
URL: https://github.com/apache/airflow/pull/39178#discussion_r1576027888


##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -884,3 +884,148 @@ class 
DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(deferrable=True, *args, **kwargs)
+
+
+class DatabricksNotebookOperator(BaseOperator):
+    """
+    Runs a notebook on Databricks using an Airflow operator.
+
+    The DatabricksNotebookOperator allows users to launch and monitor notebook
+    job runs on Databricks as Airflow tasks.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksNotebookOperator`
+
+    :param notebook_path: The path to the notebook in Databricks.
+    :param source: Optional location type of the notebook. When set to 
WORKSPACE, the notebook will be retrieved
+            from the local Databricks workspace. When set to GIT, the notebook 
will be retrieved from a Git repository
+            defined in git_source. If the value is empty, the task will use 
GIT if git_source is defined
+            and WORKSPACE otherwise. For more information please visit
+            
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+    :param notebook_params: A dict of key-value pairs to be passed as optional 
params to the notebook task.
+    :param notebook_packages: A list of the Python libraries to be installed 
on the cluster running the
+        notebook.
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+    :param existing_cluster_id: ID for existing cluster on which to run this 
task.
+    :param job_cluster_key: The key for the job cluster.
+    :param databricks_conn_id: The name of the Airflow connection to use.
+    """
+
+    template_fields = ("notebook_params",)
+
+    def __init__(
+        self,
+        notebook_path: str,
+        source: str,
+        notebook_params: dict | None = None,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        new_cluster: dict[str, Any] | None = None,
+        existing_cluster_id: str | None = None,
+        job_cluster_key: str | None = None,
+        polling_period_seconds: int = 5,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,

Review Comment:
   It's probably worth documenting these parameters



##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -884,3 +884,148 @@ class 
DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(deferrable=True, *args, **kwargs)
+
+
+class DatabricksNotebookOperator(BaseOperator):
+    """
+    Runs a notebook on Databricks using an Airflow operator.
+
+    The DatabricksNotebookOperator allows users to launch and monitor notebook
+    job runs on Databricks as Airflow tasks.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksNotebookOperator`
+
+    :param notebook_path: The path to the notebook in Databricks.
+    :param source: Optional location type of the notebook. When set to 
WORKSPACE, the notebook will be retrieved
+            from the local Databricks workspace. When set to GIT, the notebook 
will be retrieved from a Git repository
+            defined in git_source. If the value is empty, the task will use 
GIT if git_source is defined
+            and WORKSPACE otherwise. For more information please visit
+            
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+    :param notebook_params: A dict of key-value pairs to be passed as optional 
params to the notebook task.
+    :param notebook_packages: A list of the Python libraries to be installed 
on the cluster running the
+        notebook.
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+    :param existing_cluster_id: ID for existing cluster on which to run this 
task.
+    :param job_cluster_key: The key for the job cluster.
+    :param databricks_conn_id: The name of the Airflow connection to use.
+    """
+
+    template_fields = ("notebook_params",)
+
+    def __init__(
+        self,
+        notebook_path: str,
+        source: str,
+        notebook_params: dict | None = None,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        new_cluster: dict[str, Any] | None = None,
+        existing_cluster_id: str | None = None,
+        job_cluster_key: str | None = None,

Review Comment:
   ```suggestion
           existing_cluster_id: str = "",
           job_cluster_key: str = "",
   ```



##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -884,3 +884,148 @@ class 
DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(deferrable=True, *args, **kwargs)
+
+
+class DatabricksNotebookOperator(BaseOperator):
+    """
+    Runs a notebook on Databricks using an Airflow operator.
+
+    The DatabricksNotebookOperator allows users to launch and monitor notebook
+    job runs on Databricks as Airflow tasks.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksNotebookOperator`
+
+    :param notebook_path: The path to the notebook in Databricks.
+    :param source: Optional location type of the notebook. When set to 
WORKSPACE, the notebook will be retrieved
+            from the local Databricks workspace. When set to GIT, the notebook 
will be retrieved from a Git repository
+            defined in git_source. If the value is empty, the task will use 
GIT if git_source is defined
+            and WORKSPACE otherwise. For more information please visit
+            
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+    :param notebook_params: A dict of key-value pairs to be passed as optional 
params to the notebook task.
+    :param notebook_packages: A list of the Python libraries to be installed 
on the cluster running the
+        notebook.
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+    :param existing_cluster_id: ID for existing cluster on which to run this 
task.
+    :param job_cluster_key: The key for the job cluster.
+    :param databricks_conn_id: The name of the Airflow connection to use.
+    """
+
+    template_fields = ("notebook_params",)
+
+    def __init__(
+        self,
+        notebook_path: str,
+        source: str,
+        notebook_params: dict | None = None,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        new_cluster: dict[str, Any] | None = None,
+        existing_cluster_id: str | None = None,
+        job_cluster_key: str | None = None,
+        polling_period_seconds: int = 5,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        **kwargs: Any,
+    ):
+        self.notebook_path = notebook_path
+        self.source = source
+        self.notebook_params = notebook_params or {}
+        self.notebook_packages = notebook_packages or []
+        self.new_cluster = new_cluster or {}
+        self.existing_cluster_id = existing_cluster_id or ""
+        self.job_cluster_key = job_cluster_key or ""
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.databricks_conn_id = databricks_conn_id
+        self.databricks_run_id = ""
+        super().__init__(**kwargs)
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksNotebookOperator")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def _get_task_base_json(self) -> dict[str, Any]:
+        """Get task base json to be used for task submissions."""
+        return {
+            # Timeout seconds value of 0 for the Databricks Jobs API means the 
job runs forever.
+            # That is also the default behavior of Databricks jobs to run a 
job forever without a default
+            # timeout value.
+            "timeout_seconds": int(self.execution_timeout.total_seconds()) if 
self.execution_timeout else 0,
+            "email_notifications": {},
+            "notebook_task": {
+                "notebook_path": self.notebook_path,
+                "source": self.source,
+                "base_parameters": self.notebook_params,
+            },
+            "libraries": self.notebook_packages,
+        }
+
+    def _get_databricks_task_id(self, task_id: str):
+        """Get the databricks task ID using dag_id and task_id. Removes 
illegal characters."""
+        return f"{self.dag_id}__" + task_id.replace(".", "__")
+
+    def _get_run_json(self):
+        """Get run json to be used for task submissions."""
+        run_json = {
+            "run_name": self._get_databricks_task_id(self.task_id),
+            **self._get_task_base_json(),
+        }
+        if self.new_cluster and self.existing_cluster_id:
+            raise ValueError("Both new_cluster and existing_cluster_id are 
set. Only one should be set.")
+        if self.new_cluster:
+            run_json["new_cluster"] = self.new_cluster
+        elif self.existing_cluster_id:
+            run_json["existing_cluster_id"] = self.existing_cluster_id
+        else:
+            raise ValueError("Must specify either existing_cluster_id or 
new_cluster.")
+        return run_json
+
+    def launch_notebook_job(self):
+        run_json = self._get_run_json()
+        self.databricks_run_id = self._hook.submit_run(run_json)
+        url = self._hook.get_run_page_url(self.databricks_run_id)
+        self.log.info("Check the job run in Databricks: %s", url)
+        return self.databricks_run_id
+
+    def monitor_databricks_job(self):
+        run = self._hook.get_run(self.databricks_run_id)
+        run_state = RunState(**run["state"])
+        self.log.info("Current state of the job: %s", 
run_state.life_cycle_state)
+        while not run_state.is_terminal:
+            time.sleep(self.polling_period_seconds)
+            run = self._hook.get_run(self.databricks_run_id)
+            run_state = RunState(**run["state"])
+            self.log.info(
+                "task %s %s", self._get_databricks_task_id(self.task_id), 
run_state.life_cycle_state
+            )
+            self.log.info("Current state of the job: %s", 
run_state.life_cycle_state)
+        if run_state.life_cycle_state != "TERMINATED":

Review Comment:
   Would it be worth to have a constant with this value `"TERMINATED"`  on the 
top of the file?



##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -884,3 +884,148 @@ class 
DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(deferrable=True, *args, **kwargs)
+
+
+class DatabricksNotebookOperator(BaseOperator):
+    """
+    Runs a notebook on Databricks using an Airflow operator.
+
+    The DatabricksNotebookOperator allows users to launch and monitor notebook
+    job runs on Databricks as Airflow tasks.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksNotebookOperator`
+
+    :param notebook_path: The path to the notebook in Databricks.
+    :param source: Optional location type of the notebook. When set to 
WORKSPACE, the notebook will be retrieved
+            from the local Databricks workspace. When set to GIT, the notebook 
will be retrieved from a Git repository
+            defined in git_source. If the value is empty, the task will use 
GIT if git_source is defined
+            and WORKSPACE otherwise. For more information please visit
+            
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+    :param notebook_params: A dict of key-value pairs to be passed as optional 
params to the notebook task.
+    :param notebook_packages: A list of the Python libraries to be installed 
on the cluster running the
+        notebook.
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+    :param existing_cluster_id: ID for existing cluster on which to run this 
task.
+    :param job_cluster_key: The key for the job cluster.
+    :param databricks_conn_id: The name of the Airflow connection to use.
+    """
+
+    template_fields = ("notebook_params",)
+
+    def __init__(
+        self,
+        notebook_path: str,
+        source: str,
+        notebook_params: dict | None = None,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        new_cluster: dict[str, Any] | None = None,
+        existing_cluster_id: str | None = None,
+        job_cluster_key: str | None = None,
+        polling_period_seconds: int = 5,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        **kwargs: Any,
+    ):
+        self.notebook_path = notebook_path
+        self.source = source
+        self.notebook_params = notebook_params or {}
+        self.notebook_packages = notebook_packages or []
+        self.new_cluster = new_cluster or {}
+        self.existing_cluster_id = existing_cluster_id or ""
+        self.job_cluster_key = job_cluster_key or ""
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.databricks_conn_id = databricks_conn_id
+        self.databricks_run_id = ""
+        super().__init__(**kwargs)
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksNotebookOperator")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def _get_task_base_json(self) -> dict[str, Any]:
+        """Get task base json to be used for task submissions."""
+        return {
+            # Timeout seconds value of 0 for the Databricks Jobs API means the 
job runs forever.
+            # That is also the default behavior of Databricks jobs to run a 
job forever without a default
+            # timeout value.
+            "timeout_seconds": int(self.execution_timeout.total_seconds()) if 
self.execution_timeout else 0,
+            "email_notifications": {},
+            "notebook_task": {
+                "notebook_path": self.notebook_path,
+                "source": self.source,
+                "base_parameters": self.notebook_params,
+            },
+            "libraries": self.notebook_packages,
+        }
+
+    def _get_databricks_task_id(self, task_id: str):
+        """Get the databricks task ID using dag_id and task_id. Removes 
illegal characters."""
+        return f"{self.dag_id}__" + task_id.replace(".", "__")
+
+    def _get_run_json(self):
+        """Get run json to be used for task submissions."""
+        run_json = {
+            "run_name": self._get_databricks_task_id(self.task_id),
+            **self._get_task_base_json(),
+        }
+        if self.new_cluster and self.existing_cluster_id:
+            raise ValueError("Both new_cluster and existing_cluster_id are 
set. Only one should be set.")
+        if self.new_cluster:
+            run_json["new_cluster"] = self.new_cluster
+        elif self.existing_cluster_id:
+            run_json["existing_cluster_id"] = self.existing_cluster_id
+        else:
+            raise ValueError("Must specify either existing_cluster_id or 
new_cluster.")
+        return run_json
+
+    def launch_notebook_job(self):
+        run_json = self._get_run_json()
+        self.databricks_run_id = self._hook.submit_run(run_json)
+        url = self._hook.get_run_page_url(self.databricks_run_id)
+        self.log.info("Check the job run in Databricks: %s", url)
+        return self.databricks_run_id
+
+    def monitor_databricks_job(self):
+        run = self._hook.get_run(self.databricks_run_id)
+        run_state = RunState(**run["state"])
+        self.log.info("Current state of the job: %s", 
run_state.life_cycle_state)
+        while not run_state.is_terminal:
+            time.sleep(self.polling_period_seconds)

Review Comment:
   How will this logic change if we want to support this operator being 
deferrable in future?



##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -884,3 +884,148 @@ class 
DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(deferrable=True, *args, **kwargs)
+
+
+class DatabricksNotebookOperator(BaseOperator):
+    """
+    Runs a notebook on Databricks using an Airflow operator.
+
+    The DatabricksNotebookOperator allows users to launch and monitor notebook
+    job runs on Databricks as Airflow tasks.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksNotebookOperator`
+
+    :param notebook_path: The path to the notebook in Databricks.
+    :param source: Optional location type of the notebook. When set to 
WORKSPACE, the notebook will be retrieved
+            from the local Databricks workspace. When set to GIT, the notebook 
will be retrieved from a Git repository
+            defined in git_source. If the value is empty, the task will use 
GIT if git_source is defined
+            and WORKSPACE otherwise. For more information please visit
+            
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+    :param notebook_params: A dict of key-value pairs to be passed as optional 
params to the notebook task.
+    :param notebook_packages: A list of the Python libraries to be installed 
on the cluster running the
+        notebook.
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+    :param existing_cluster_id: ID for existing cluster on which to run this 
task.
+    :param job_cluster_key: The key for the job cluster.
+    :param databricks_conn_id: The name of the Airflow connection to use.
+    """
+
+    template_fields = ("notebook_params",)
+
+    def __init__(
+        self,
+        notebook_path: str,
+        source: str,
+        notebook_params: dict | None = None,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        new_cluster: dict[str, Any] | None = None,
+        existing_cluster_id: str | None = None,
+        job_cluster_key: str | None = None,
+        polling_period_seconds: int = 5,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        **kwargs: Any,
+    ):
+        self.notebook_path = notebook_path
+        self.source = source
+        self.notebook_params = notebook_params or {}
+        self.notebook_packages = notebook_packages or []
+        self.new_cluster = new_cluster or {}
+        self.existing_cluster_id = existing_cluster_id or ""
+        self.job_cluster_key = job_cluster_key or ""
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.databricks_conn_id = databricks_conn_id
+        self.databricks_run_id = ""
+        super().__init__(**kwargs)
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksNotebookOperator")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def _get_task_base_json(self) -> dict[str, Any]:
+        """Get task base json to be used for task submissions."""
+        return {
+            # Timeout seconds value of 0 for the Databricks Jobs API means the 
job runs forever.
+            # That is also the default behavior of Databricks jobs to run a 
job forever without a default
+            # timeout value.
+            "timeout_seconds": int(self.execution_timeout.total_seconds()) if 
self.execution_timeout else 0,

Review Comment:
   It is worth isolating the job timeout logic in a dedicated property/function 
within the operator, and giving more information.
   
   From a behavioural perspective, it may be worth highlighting this somewhere:
   - Airflow default `execution_timeout` == None -> Databricks = 0 (job runs 
forever)
   - Airflow default `execution_timeout` == 0 -> Databricks = 0 (job runs 
forever)
   - Airflow `execution_timeout` >= 0 -> Databricks >= 0 (user-defined setting)
   
   From a Databricks perspective, if we did not define this timeout, which 
would be the default timeout?



##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -884,3 +884,148 @@ class 
DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
 
     def __init__(self, *args, **kwargs):
         super().__init__(deferrable=True, *args, **kwargs)
+
+
+class DatabricksNotebookOperator(BaseOperator):
+    """
+    Runs a notebook on Databricks using an Airflow operator.
+
+    The DatabricksNotebookOperator allows users to launch and monitor notebook
+    job runs on Databricks as Airflow tasks.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:DatabricksNotebookOperator`
+
+    :param notebook_path: The path to the notebook in Databricks.
+    :param source: Optional location type of the notebook. When set to 
WORKSPACE, the notebook will be retrieved
+            from the local Databricks workspace. When set to GIT, the notebook 
will be retrieved from a Git repository
+            defined in git_source. If the value is empty, the task will use 
GIT if git_source is defined
+            and WORKSPACE otherwise. For more information please visit
+            
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+    :param notebook_params: A dict of key-value pairs to be passed as optional 
params to the notebook task.
+    :param notebook_packages: A list of the Python libraries to be installed 
on the cluster running the
+        notebook.
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+    :param existing_cluster_id: ID for existing cluster on which to run this 
task.
+    :param job_cluster_key: The key for the job cluster.
+    :param databricks_conn_id: The name of the Airflow connection to use.
+    """
+
+    template_fields = ("notebook_params",)
+
+    def __init__(
+        self,
+        notebook_path: str,
+        source: str,
+        notebook_params: dict | None = None,
+        notebook_packages: list[dict[str, Any]] | None = None,
+        new_cluster: dict[str, Any] | None = None,
+        existing_cluster_id: str | None = None,
+        job_cluster_key: str | None = None,
+        polling_period_seconds: int = 5,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        **kwargs: Any,
+    ):
+        self.notebook_path = notebook_path
+        self.source = source
+        self.notebook_params = notebook_params or {}
+        self.notebook_packages = notebook_packages or []
+        self.new_cluster = new_cluster or {}
+        self.existing_cluster_id = existing_cluster_id or ""
+        self.job_cluster_key = job_cluster_key or ""

Review Comment:
   ```suggestion
   ```
   Perhaps we could set the default to string in the method definition as 
suggested previously, so we don't need to have these lines



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to