pankajkoti commented on code in PR #40013:
URL: https://github.com/apache/airflow/pull/40013#discussion_r1624218143
##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -1178,10 +1118,250 @@ def execute(self, context: Context) -> None:
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id
else:
- self.launch_notebook_job()
+ self.launch_job()
if self.wait_for_termination:
self.monitor_databricks_job()
def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
self._handle_terminal_run_state(run_state)
+
+
+class DatabricksNotebookOperator(DatabricksTaskBaseOperator):
+ """
+ Runs a notebook on Databricks using an Airflow operator.
+
+ The DatabricksNotebookOperator allows users to launch and monitor notebook
job runs on Databricks as
+ Airflow tasks. It can be used as a part of a DatabricksWorkflowTaskGroup
to take advantage of job
+ clusters, which allows users to run their tasks on cheaper clusters that
can be shared between tasks.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:DatabricksNotebookOperator`
+
+ :param notebook_path: The path to the notebook in Databricks.
+ :param source: Optional location type of the notebook. When set to
WORKSPACE, the notebook will be retrieved
+ from the local Databricks workspace. When set to GIT, the notebook
will be retrieved from a Git repository
+ defined in git_source. If the value is empty, the task will use
GIT if git_source is defined
+ and WORKSPACE otherwise. For more information please visit
+
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+ :param databricks_conn_id: The name of the Airflow connection to use.
+ :param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
+ :param databricks_retry_delay: Number of seconds to wait between retries.
+ :param databricks_retry_limit: Amount of times to retry if the Databricks
backend is unreachable.
+ :param deferrable: Whether to run the operator in the deferrable mode.
+ :param existing_cluster_id: ID for existing cluster on which to run this
task.
+ :param job_cluster_key: The key for the job cluster.
+ :param new_cluster: Specs for a new cluster on which this task will be run.
+ :param notebook_packages: A list of the Python libraries to be installed
on the cluster running the
+ notebook.
+ :param notebook_params: A dict of key-value pairs to be passed as optional
params to the notebook task.
+ :param polling_period_seconds: Controls the rate which we poll for the
result of this notebook job run.
+ :param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
+ """
+
+ template_fields = (
+ "notebook_params",
+ "workflow_run_metadata",
+ )
+ CALLER = "DatabricksNotebookOperator"
+
+ def __init__(
+ self,
+ notebook_path: str,
+ source: str,
+ databricks_conn_id: str = "databricks_default",
+ databricks_retry_args: dict[Any, Any] | None = None,
+ databricks_retry_delay: int = 1,
+ databricks_retry_limit: int = 3,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ existing_cluster_id: str = "",
+ job_cluster_key: str = "",
+ new_cluster: dict[str, Any] | None = None,
+ notebook_packages: list[dict[str, Any]] | None = None,
+ notebook_params: dict | None = None,
+ polling_period_seconds: int = 5,
+ wait_for_termination: bool = True,
+ workflow_run_metadata: dict | None = None,
+ **kwargs: Any,
+ ):
+ self.notebook_path = notebook_path
+ self.source = source
+ self.databricks_conn_id = databricks_conn_id
+ self.databricks_retry_args = databricks_retry_args
+ self.databricks_retry_delay = databricks_retry_delay
+ self.databricks_retry_limit = databricks_retry_limit
+ self.databricks_run_id: int | None = None
+ self.deferrable = deferrable
+ self.existing_cluster_id = existing_cluster_id
+ self.job_cluster_key = job_cluster_key
+ self.new_cluster = new_cluster or {}
+ self.notebook_packages = notebook_packages or []
+ self.notebook_params = notebook_params or {}
+ self.polling_period_seconds = polling_period_seconds
+ self.wait_for_termination = wait_for_termination
+
+ # This is used to store the metadata of the Databricks job run when
the job is launched from within
+ # DatabricksWorkflowTaskGroup.
+ self.workflow_run_metadata: dict | None = workflow_run_metadata
+
+ super().__init__(
+ caller=self.CALLER,
+ databricks_conn_id=self.databricks_conn_id,
+ databricks_retry_args=self.databricks_retry_args,
+ databricks_retry_delay=self.databricks_retry_delay,
+ databricks_retry_limit=self.databricks_retry_limit,
+ deferrable=self.deferrable,
+ existing_cluster_id=self.existing_cluster_id,
+ job_cluster_key=self.job_cluster_key,
+ new_cluster=self.new_cluster,
+ polling_period_seconds=self.polling_period_seconds,
+ wait_for_termination=self.wait_for_termination,
+ **kwargs,
+ )
+
+ def _get_task_timeout_seconds(self) -> int:
+ """
+ Get the timeout seconds value for the Databricks job based on the
execution timeout value provided for the Airflow task.
+
+ By default, tasks in Airflow have an execution_timeout set to None. In
Airflow, when
+ execution_timeout is not defined, the task continues to run
indefinitely. Therefore,
+ to mirror this behavior in the Databricks Jobs API, we set the timeout
to 0, indicating
+ that the job should run indefinitely. This aligns with the default
behavior of Databricks jobs,
+ where a timeout seconds value of 0 signifies an indefinite run
duration.
+ More details can be found in the Databricks documentation:
+ See
https://docs.databricks.com/api/workspace/jobs/submit#timeout_seconds
+ """
+ if self.execution_timeout is None:
+ return 0
+ execution_timeout_seconds = int(self.execution_timeout.total_seconds())
+ if execution_timeout_seconds == 0:
+ raise ValueError(
+ "If you've set an `execution_timeout` for the task, ensure
it's not `0`. Set it instead to "
+ "`None` if you desire the task to run indefinitely."
+ )
+ return execution_timeout_seconds
+
+ def _get_task_base_json(self) -> dict[str, Any]:
+ """Get task base json to be used for task submissions."""
+ return {
+ "timeout_seconds": self._get_task_timeout_seconds(),
+ "email_notifications": {},
+ "notebook_task": {
+ "notebook_path": self.notebook_path,
+ "source": self.source,
+ "base_parameters": self.notebook_params,
+ },
+ "libraries": self.notebook_packages,
+ }
+
+ def _extend_workflow_notebook_packages(
+ self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
+ ) -> None:
+ """Extend the task group packages into the notebook's packages,
without adding any duplicates."""
+ for task_group_package in
databricks_workflow_task_group.notebook_packages:
+ exists = any(
+ task_group_package == existing_package for existing_package in
self.notebook_packages
+ )
+ if not exists:
+ self.notebook_packages.append(task_group_package)
+
+ def _convert_to_databricks_workflow_task(
+ self, relevant_upstreams: list[BaseOperator], context: Context | None
= None
+ ) -> dict[str, object]:
+ """Convert the operator to a Databricks workflow task that can be a
task in a workflow."""
+ databricks_workflow_task_group = self._databricks_workflow_task_group
+ if not databricks_workflow_task_group:
+ raise AirflowException(
+ "Calling `_convert_to_databricks_workflow_task` without a
parent TaskGroup."
+ )
+
+ if hasattr(databricks_workflow_task_group, "notebook_packages"):
+
self._extend_workflow_notebook_packages(databricks_workflow_task_group)
+
+ if hasattr(databricks_workflow_task_group, "notebook_params"):
+ self.notebook_params = {
+ **self.notebook_params,
+ **databricks_workflow_task_group.notebook_params,
+ }
+
+ return
super()._convert_to_databricks_workflow_task(relevant_upstreams,
context=context)
+
+
+class DatabricksTaskOperator(DatabricksTaskBaseOperator):
+ """
+ Runs a task on Databricks using an Airflow operator.
+
+ The DatabricksTaskOperator allows users to launch and monitor task job
runs on Databricks as Airflow
+ tasks. It can be used as a part of a DatabricksWorkflowTaskGroup to take
advantage of job clusters, which
+ allows users to run their tasks on cheaper clusters that can be shared
between tasks.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:DatabricksTaskOperator`
+
+ :param task_config: The configuration of the task to be run on Databricks.
+ :param databricks_conn_id: The name of the Airflow connection to use.
+ :param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
+ :param databricks_retry_delay: Number of seconds to wait between retries.
+ :param databricks_retry_limit: Amount of times to retry if the Databricks
backend is unreachable.
+ :param deferrable: Whether to run the operator in the deferrable mode.
+ :param existing_cluster_id: ID for existing cluster on which to run this
task.
+ :param job_cluster_key: The key for the job cluster.
+ :param new_cluster: Specs for a new cluster on which this task will be run.
+ :param polling_period_seconds: Controls the rate which we poll for the
result of this notebook job run.
+ :param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
+ """
+
+ CALLER = "DatabricksTaskOperator"
+ template_fields = ("workflow_run_metadata",)
+
+ def __init__(
+ self,
+ task_config: dict,
+ databricks_conn_id: str = "databricks_default",
+ databricks_retry_args: dict[Any, Any] | None = None,
+ databricks_retry_delay: int = 1,
+ databricks_retry_limit: int = 3,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ existing_cluster_id: str = "",
+ job_cluster_key: str = "",
+ new_cluster: dict[str, Any] | None = None,
+ polling_period_seconds: int = 5,
+ wait_for_termination: bool = True,
+ workflow_run_metadata: dict | None = None,
Review Comment:
added
##########
airflow/providers/databricks/operators/databricks.py:
##########
@@ -1178,10 +1118,250 @@ def execute(self, context: Context) -> None:
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id
else:
- self.launch_notebook_job()
+ self.launch_job()
if self.wait_for_termination:
self.monitor_databricks_job()
def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
self._handle_terminal_run_state(run_state)
+
+
+class DatabricksNotebookOperator(DatabricksTaskBaseOperator):
+ """
+ Runs a notebook on Databricks using an Airflow operator.
+
+ The DatabricksNotebookOperator allows users to launch and monitor notebook
job runs on Databricks as
+ Airflow tasks. It can be used as a part of a DatabricksWorkflowTaskGroup
to take advantage of job
+ clusters, which allows users to run their tasks on cheaper clusters that
can be shared between tasks.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:DatabricksNotebookOperator`
+
+ :param notebook_path: The path to the notebook in Databricks.
+ :param source: Optional location type of the notebook. When set to
WORKSPACE, the notebook will be retrieved
+ from the local Databricks workspace. When set to GIT, the notebook
will be retrieved from a Git repository
+ defined in git_source. If the value is empty, the task will use
GIT if git_source is defined
+ and WORKSPACE otherwise. For more information please visit
+
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsCreate
+ :param databricks_conn_id: The name of the Airflow connection to use.
+ :param databricks_retry_args: An optional dictionary with arguments passed
to ``tenacity.Retrying`` class.
+ :param databricks_retry_delay: Number of seconds to wait between retries.
+ :param databricks_retry_limit: Amount of times to retry if the Databricks
backend is unreachable.
+ :param deferrable: Whether to run the operator in the deferrable mode.
+ :param existing_cluster_id: ID for existing cluster on which to run this
task.
+ :param job_cluster_key: The key for the job cluster.
+ :param new_cluster: Specs for a new cluster on which this task will be run.
+ :param notebook_packages: A list of the Python libraries to be installed
on the cluster running the
+ notebook.
+ :param notebook_params: A dict of key-value pairs to be passed as optional
params to the notebook task.
+ :param polling_period_seconds: Controls the rate which we poll for the
result of this notebook job run.
+ :param wait_for_termination: if we should wait for termination of the job
run. ``True`` by default.
+ """
+
+ template_fields = (
+ "notebook_params",
+ "workflow_run_metadata",
+ )
+ CALLER = "DatabricksNotebookOperator"
+
+ def __init__(
+ self,
+ notebook_path: str,
+ source: str,
+ databricks_conn_id: str = "databricks_default",
+ databricks_retry_args: dict[Any, Any] | None = None,
+ databricks_retry_delay: int = 1,
+ databricks_retry_limit: int = 3,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ existing_cluster_id: str = "",
+ job_cluster_key: str = "",
+ new_cluster: dict[str, Any] | None = None,
+ notebook_packages: list[dict[str, Any]] | None = None,
+ notebook_params: dict | None = None,
+ polling_period_seconds: int = 5,
+ wait_for_termination: bool = True,
+ workflow_run_metadata: dict | None = None,
+ **kwargs: Any,
+ ):
+ self.notebook_path = notebook_path
+ self.source = source
+ self.databricks_conn_id = databricks_conn_id
+ self.databricks_retry_args = databricks_retry_args
+ self.databricks_retry_delay = databricks_retry_delay
+ self.databricks_retry_limit = databricks_retry_limit
+ self.databricks_run_id: int | None = None
+ self.deferrable = deferrable
Review Comment:
done so, please check.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]