This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 322aa649ed fix typos in DatabricksSubmitRunOperator (#36248)
322aa649ed is described below

commit 322aa649edce6655f4bddfb9813ff8cb38616b7a
Author: Adam B <[email protected]>
AuthorDate: Wed Dec 20 18:16:59 2023 -0600

    fix typos in DatabricksSubmitRunOperator (#36248)
    
    * fix typos in DatabricksSubmitRunOperator
    
    * databricks find_pipeline_id_by_name tests and fixes
---
 airflow/providers/databricks/hooks/databricks.py   |  10 +-
 .../providers/databricks/operators/databricks.py   |   2 +-
 .../providers/databricks/hooks/test_databricks.py  | 108 +++++++++++++++++++++
 .../databricks/operators/test_databricks.py        |  29 ++++++
 4 files changed, 143 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks.py 
b/airflow/providers/databricks/hooks/databricks.py
index 6cb5b37e46..b39e3d622c 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -55,7 +55,7 @@ INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
 UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")
 
 LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
-LIST_PIPELINES_ENDPOINT = ("GET", "/api/2.0/pipelines")
+LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines")
 
 WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
 
@@ -322,8 +322,8 @@ class DatabricksHook(BaseDatabricksHook):
             payload["filter"] = filter
 
         while has_more:
-            if next_token:
-                payload["page_token"] = next_token
+            if next_token is not None:
+                payload = {**payload, "page_token": next_token}
             response = self._do_api_call(LIST_PIPELINES_ENDPOINT, payload)
             pipelines = response.get("statuses", [])
             all_pipelines += pipelines
@@ -345,11 +345,11 @@ class DatabricksHook(BaseDatabricksHook):
 
         if len(matching_pipelines) > 1:
             raise AirflowException(
-                f"There are more than one job with name {pipeline_name}. "
+                f"There are more than one pipelines with name {pipeline_name}. 
"
                 "Please delete duplicated pipelines first"
             )
 
-        if not pipeline_name:
+        if not pipeline_name or len(matching_pipelines) == 0:
             return None
         else:
             return matching_pipelines[0]["pipeline_id"]
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index c8e5180f9b..edea8b4e59 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -521,7 +521,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
         ):
             # If pipeline_id is not provided, we need to fetch it from the 
pipeline_name
             pipeline_name = self.json["pipeline_task"]["pipeline_name"]
-            self.json["pipeline_task"]["pipeline_id"] = 
self._hook.get_pipeline_id(pipeline_name)
+            self.json["pipeline_task"]["pipeline_id"] = 
self._hook.find_pipeline_id_by_name(pipeline_name)
             del self.json["pipeline_task"]["pipeline_name"]
         json_normalised = normalise_json_content(self.json)
         self.run_id = self._hook.submit_run(json_normalised)
diff --git a/tests/providers/databricks/hooks/test_databricks.py 
b/tests/providers/databricks/hooks/test_databricks.py
index e2f74d773e..1baaab1fea 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -59,6 +59,8 @@ CLUSTER_ID = "cluster_id"
 RUN_ID = 1
 JOB_ID = 42
 JOB_NAME = "job-name"
+PIPELINE_NAME = "some pipeline name"
+PIPELINE_ID = "its-a-pipeline-id"
 DEFAULT_RETRY_NUMBER = 3
 DEFAULT_RETRY_ARGS = dict(
     wait=tenacity.wait_none(),
@@ -100,6 +102,19 @@ LIST_JOBS_RESPONSE = {
     ],
     "has_more": False,
 }
+LIST_PIPELINES_RESPONSE = {
+    "statuses": [
+        {
+            "pipeline_id": PIPELINE_ID,
+            "state": "DEPLOYING",
+            "cluster_id": "string",
+            "name": PIPELINE_NAME,
+            "latest_updates": [{"update_id": "string", "state": "QUEUED", 
"creation_time": "string"}],
+            "creator_user_name": "string",
+            "run_as_user_name": "string",
+        }
+    ]
+}
 LIST_SPARK_VERSIONS_RESPONSE = {
     "versions": [
         {"key": "8.2.x-scala2.12", "name": "8.2 (includes Apache Spark 3.1.1, 
Scala 2.12)"},
@@ -226,6 +241,13 @@ def list_jobs_endpoint(host):
     return f"https://{host}/api/2.1/jobs/list";
 
 
+def list_pipelines_endpoint(host):
+    """
+    Utility function to generate the list jobs endpoint given the host
+    """
+    return f"https://{host}/api/2.0/pipelines";
+
+
 def list_spark_versions_endpoint(host):
     """Utility function to generate the list spark versions endpoint given the 
host"""
     return f"https://{host}/api/2.0/clusters/spark-versions";
@@ -915,6 +937,92 @@ class TestDatabricksHook:
             timeout=self.hook.timeout_seconds,
         )
 
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_get_pipeline_id_by_name_success(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = 
LIST_PIPELINES_RESPONSE
+
+        pipeline_id = self.hook.find_pipeline_id_by_name(PIPELINE_NAME)
+
+        mock_requests.get.assert_called_once_with(
+            list_pipelines_endpoint(HOST),
+            json=None,
+            params={"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results": 
25},
+            auth=HTTPBasicAuth(LOGIN, PASSWORD),
+            headers=self.hook.user_agent_header,
+            timeout=self.hook.timeout_seconds,
+        )
+
+        assert pipeline_id == PIPELINE_ID
+
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_list_pipelines_success_multiple_pages(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.side_effect = [
+            create_successful_response_mock({**LIST_PIPELINES_RESPONSE, 
"next_page_token": "PAGETOKEN"}),
+            create_successful_response_mock(LIST_PIPELINES_RESPONSE),
+        ]
+
+        pipelines = self.hook.list_pipelines(pipeline_name=PIPELINE_NAME)
+
+        assert mock_requests.get.call_count == 2
+
+        first_call_args = mock_requests.method_calls[0]
+        assert first_call_args[1][0] == list_pipelines_endpoint(HOST)
+        assert first_call_args[2]["params"] == {"filter": f"name LIKE 
'{PIPELINE_NAME}'", "max_results": 25}
+
+        second_call_args = mock_requests.method_calls[1]
+        assert second_call_args[1][0] == list_pipelines_endpoint(HOST)
+        assert second_call_args[2]["params"] == {
+            "filter": f"name LIKE '{PIPELINE_NAME}'",
+            "max_results": 25,
+            "page_token": "PAGETOKEN",
+        }
+
+        assert len(pipelines) == 2
+        assert pipelines == LIST_PIPELINES_RESPONSE["statuses"] * 2
+
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_get_pipeline_id_by_name_not_found(self, mock_requests):
+        empty_response = {"statuses": []}
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = empty_response
+
+        ne_pipeline_name = "Non existing pipeline"
+        pipeline_id = self.hook.find_pipeline_id_by_name(ne_pipeline_name)
+
+        mock_requests.get.assert_called_once_with(
+            list_pipelines_endpoint(HOST),
+            json=None,
+            params={"filter": f"name LIKE '{ne_pipeline_name}'", 
"max_results": 25},
+            auth=HTTPBasicAuth(LOGIN, PASSWORD),
+            headers=self.hook.user_agent_header,
+            timeout=self.hook.timeout_seconds,
+        )
+
+        assert pipeline_id is None
+
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_list_pipelines_raise_exception_with_duplicates(self, 
mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            **LIST_PIPELINES_RESPONSE,
+            "statuses": LIST_PIPELINES_RESPONSE["statuses"] * 2,
+        }
+
+        exception_message = f"There are more than one pipelines with name 
{PIPELINE_NAME}."
+        with pytest.raises(AirflowException, match=exception_message):
+            self.hook.find_pipeline_id_by_name(pipeline_name=PIPELINE_NAME)
+
+        mock_requests.get.assert_called_once_with(
+            list_pipelines_endpoint(HOST),
+            json=None,
+            params={"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results": 
25},
+            auth=HTTPBasicAuth(LOGIN, PASSWORD),
+            headers=self.hook.user_agent_header,
+            timeout=self.hook.timeout_seconds,
+        )
+
     @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
     def test_connection_success(self, mock_requests):
         mock_requests.codes.ok = 200
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index 73cde92f8b..c196f51ee3 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -758,6 +758,35 @@ class TestDatabricksSubmitRunOperator:
         db_mock.get_run.assert_called_once_with(RUN_ID)
         assert RUN_ID == op.run_id
 
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_exec_pipeline_name(self, db_mock_class):
+        """
+        Test the execute function when provided a pipeline name.
+        """
+        run = {"pipeline_task": {"pipeline_name": "This is a test pipeline"}}
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.find_pipeline_id_by_name.return_value = 
PIPELINE_ID_TASK["pipeline_id"]
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+        op.execute(None)
+
+        expected = utils.normalise_json_content({"pipeline_task": 
PIPELINE_ID_TASK, "run_name": TASK_ID})
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSubmitRunOperator",
+        )
+        db_mock.find_pipeline_id_by_name.assert_called_once_with("This is a 
test pipeline")
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run.assert_called_once_with(RUN_ID)
+        assert RUN_ID == op.run_id
+
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_exec_failure(self, db_mock_class):
         """

Reply via email to