This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c45617c4d5 allow DatabricksSubmitRunOperator to accept a pipeline name 
for a pipeline_task (#32903)
c45617c4d5 is described below

commit c45617c4d5988555f2f52684e082b96b65ca6c17
Author: Adam Best <[email protected]>
AuthorDate: Wed Sep 6 19:44:06 2023 -0500

    allow DatabricksSubmitRunOperator to accept a pipeline name for a 
pipeline_task (#32903)
    
    
    ---------
    
    Co-authored-by: Adam Best <[email protected]>
    Co-authored-by: Adam Best <>
    Co-authored-by: Adam Best <unknown>
    Co-authored-by: Hussein Awala <[email protected]>
---
 airflow/providers/databricks/hooks/databricks.py   | 62 ++++++++++++++++++++++
 .../providers/databricks/operators/databricks.py   | 11 ++++
 .../operators/submit_run.rst                       |  2 +
 .../databricks/operators/test_databricks.py        | 20 +++++++
 4 files changed, 95 insertions(+)

diff --git a/airflow/providers/databricks/hooks/databricks.py 
b/airflow/providers/databricks/hooks/databricks.py
index e8b124b99f..593367c543 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -52,6 +52,7 @@ INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
 UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")
 
 LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
+LIST_PIPELINES_ENDPOINT = ("GET", "/api/2.0/pipelines")
 
 WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
 
@@ -215,6 +216,67 @@ class DatabricksHook(BaseDatabricksHook):
         else:
             return matching_jobs[0]["job_id"]
 
+    def list_pipelines(
+        self, batch_size: int = 25, pipeline_name: str | None = None, 
notebook_path: str | None = None
+    ) -> list[dict[str, Any]]:
+        """
+        Lists the pipelines in Databricks Delta Live Tables.
+
+        :param batch_size: The limit/batch size used to retrieve pipelines.
+        :param pipeline_name: Optional name of a pipeline to search. Cannot be 
combined with path.
+        :param notebook_path: Optional notebook of a pipeline to search. 
Cannot be combined with name.
+        :return: A list of pipelines.
+        """
+        has_more = True
+        next_token = None
+        all_pipelines = []
+        filter = None
+        if pipeline_name and notebook_path:
+            raise AirflowException("Cannot combine pipeline_name and 
notebook_path in one request")
+
+        if notebook_path:
+            filter = f"notebook='{notebook_path}'"
+        elif pipeline_name:
+            filter = f"name LIKE '{pipeline_name}'"
+        payload: dict[str, Any] = {
+            "max_results": batch_size,
+        }
+        if filter:
+            payload["filter"] = filter
+
+        while has_more:
+            if next_token:
+                payload["page_token"] = next_token
+            response = self._do_api_call(LIST_PIPELINES_ENDPOINT, payload)
+            pipelines = response.get("statuses", [])
+            all_pipelines += pipelines
+            if "next_page_token" in response:
+                next_token = response["next_page_token"]
+            else:
+                has_more = False
+
+        return all_pipelines
+
+    def find_pipeline_id_by_name(self, pipeline_name: str) -> str | None:
+        """
+        Finds pipeline id by its name. If multiple pipelines with the same 
name, raises AirflowException.
+
+        :param pipeline_name: The name of the pipeline to look up.
+        :return: The pipeline_id as a GUID string or None if no pipeline was 
found.
+        """
+        matching_pipelines = self.list_pipelines(pipeline_name=pipeline_name)
+
+        if len(matching_pipelines) > 1:
+            raise AirflowException(
+                f"There are more than one job with name {pipeline_name}. "
+                "Please delete duplicated pipelines first"
+            )
+
+        if not pipeline_name:
+            return None
+        else:
+            return matching_pipelines[0]["pipeline_id"]
+
     def get_run_page_url(self, run_id: int) -> str:
         """
         Retrieves run_page_url.
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index 3cc24f1b1b..8551c8d43f 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -365,6 +365,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
 
         if "dbt_task" in self.json and "git_source" not in self.json:
             raise AirflowException("git_source is required for dbt_task")
+        if pipeline_task is not None and "pipeline_id" in pipeline_task and 
"pipeline_name" in pipeline_task:
+            raise AirflowException("'pipeline_name' is not allowed in 
conjunction with 'pipeline_id'")
 
         # This variable will be used in case our task gets killed.
         self.run_id: int | None = None
@@ -384,6 +386,15 @@ class DatabricksSubmitRunOperator(BaseOperator):
         )
 
     def execute(self, context: Context):
+        if (
+            "pipeline_task" in self.json
+            and self.json["pipeline_task"].get("pipeline_id") is None
+            and self.json["pipeline_task"].get("pipeline_name")
+        ):
+            # If pipeline_id is not provided, we need to fetch it from the 
pipeline_name
+            pipeline_name = self.json["pipeline_task"]["pipeline_name"]
+            self.json["pipeline_task"]["pipeline_id"] = 
self._hook.get_pipeline_id(pipeline_name)
+            del self.json["pipeline_task"]["pipeline_name"]
         json_normalised = normalise_json_content(self.json)
         self.run_id = self._hook.submit_run(json_normalised)
         if self.deferrable:
diff --git a/docs/apache-airflow-providers-databricks/operators/submit_run.rst 
b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
index 9aa3eb912c..706920458c 100644
--- a/docs/apache-airflow-providers-databricks/operators/submit_run.rst
+++ b/docs/apache-airflow-providers-databricks/operators/submit_run.rst
@@ -61,6 +61,8 @@ one named parameter for each top level parameter in the 
``runs/submit`` endpoint
   * ``new_cluster`` - specs for a new cluster on which this task will be run
   * ``existing_cluster_id`` - ID for existing cluster on which to run this task
 
+* ``pipeline_task`` - may refer to either a ``pipeline_id`` or 
``pipeline_name``
+
 In the case where both the json parameter **AND** the named parameters
 are provided, they will be merged together. If there are conflicts during the 
merge,
 the named parameters will take precedence and override the top level ``json`` 
keys.
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index 81379ea339..e03eb7dccc 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -43,6 +43,8 @@ TEMPLATED_NOTEBOOK_TASK = {"notebook_path": "/test-{{ ds }}"}
 RENDERED_TEMPLATED_NOTEBOOK_TASK = {"notebook_path": f"/test-{DATE}"}
 SPARK_JAR_TASK = {"main_class_name": "com.databricks.Test"}
 SPARK_PYTHON_TASK = {"python_file": "test.py", "parameters": ["--param", 
"123"]}
+PIPELINE_ID_TASK = {"pipeline_id": "1234abcd"}
+PIPELINE_NAME_TASK = {"pipeline_name": "This is a test pipeline"}
 SPARK_SUBMIT_TASK = {
     "parameters": ["--class", "org.apache.spark.examples.SparkPi", 
"dbfs:/path/to/examples.jar", "10"]
 }
@@ -120,6 +122,24 @@ class TestDatabricksSubmitRunOperator:
 
         assert expected == utils.normalise_json_content(op.json)
 
+    def test_init_with_pipeline_name_task_named_parameters(self):
+        """
+        Test the initializer with the named parameters.
+        """
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, 
pipeline_task=PIPELINE_NAME_TASK)
+        expected = utils.normalise_json_content({"pipeline_task": 
PIPELINE_NAME_TASK, "run_name": TASK_ID})
+
+        assert expected == utils.normalise_json_content(op.json)
+
+    def test_init_with_pipeline_id_task_named_parameters(self):
+        """
+        Test the initializer with the named parameters.
+        """
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, 
pipeline_task=PIPELINE_ID_TASK)
+        expected = utils.normalise_json_content({"pipeline_task": 
PIPELINE_ID_TASK, "run_name": TASK_ID})
+
+        assert expected == utils.normalise_json_content(op.json)
+
     def test_init_with_spark_submit_task_named_parameters(self):
         """
         Test the initializer with the named parameters.

Reply via email to