This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 322aa649ed fix typos in DatabricksSubmitRunOperator (#36248)
322aa649ed is described below
commit 322aa649edce6655f4bddfb9813ff8cb38616b7a
Author: Adam B <[email protected]>
AuthorDate: Wed Dec 20 18:16:59 2023 -0600
fix typos in DatabricksSubmitRunOperator (#36248)
* fix typos in DatabricksSubmitRunOperator
* databricks find_pipeline_id_by_name tests and fixes
---
airflow/providers/databricks/hooks/databricks.py | 10 +-
.../providers/databricks/operators/databricks.py | 2 +-
.../providers/databricks/hooks/test_databricks.py | 108 +++++++++++++++++++++
.../databricks/operators/test_databricks.py | 29 ++++++
4 files changed, 143 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks.py
b/airflow/providers/databricks/hooks/databricks.py
index 6cb5b37e46..b39e3d622c 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -55,7 +55,7 @@ INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")
LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
-LIST_PIPELINES_ENDPOINT = ("GET", "/api/2.0/pipelines")
+LIST_PIPELINES_ENDPOINT = ("GET", "api/2.0/pipelines")
WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
@@ -322,8 +322,8 @@ class DatabricksHook(BaseDatabricksHook):
payload["filter"] = filter
while has_more:
- if next_token:
- payload["page_token"] = next_token
+ if next_token is not None:
+ payload = {**payload, "page_token": next_token}
response = self._do_api_call(LIST_PIPELINES_ENDPOINT, payload)
pipelines = response.get("statuses", [])
all_pipelines += pipelines
@@ -345,11 +345,11 @@ class DatabricksHook(BaseDatabricksHook):
if len(matching_pipelines) > 1:
raise AirflowException(
- f"There are more than one job with name {pipeline_name}. "
+ f"There are more than one pipelines with name {pipeline_name}.
"
"Please delete duplicated pipelines first"
)
- if not pipeline_name:
+ if not pipeline_name or len(matching_pipelines) == 0:
return None
else:
return matching_pipelines[0]["pipeline_id"]
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index c8e5180f9b..edea8b4e59 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -521,7 +521,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
):
# If pipeline_id is not provided, we need to fetch it from the
pipeline_name
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
- self.json["pipeline_task"]["pipeline_id"] =
self._hook.get_pipeline_id(pipeline_name)
+ self.json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
diff --git a/tests/providers/databricks/hooks/test_databricks.py
b/tests/providers/databricks/hooks/test_databricks.py
index e2f74d773e..1baaab1fea 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -59,6 +59,8 @@ CLUSTER_ID = "cluster_id"
RUN_ID = 1
JOB_ID = 42
JOB_NAME = "job-name"
+PIPELINE_NAME = "some pipeline name"
+PIPELINE_ID = "its-a-pipeline-id"
DEFAULT_RETRY_NUMBER = 3
DEFAULT_RETRY_ARGS = dict(
wait=tenacity.wait_none(),
@@ -100,6 +102,19 @@ LIST_JOBS_RESPONSE = {
],
"has_more": False,
}
+LIST_PIPELINES_RESPONSE = {
+ "statuses": [
+ {
+ "pipeline_id": PIPELINE_ID,
+ "state": "DEPLOYING",
+ "cluster_id": "string",
+ "name": PIPELINE_NAME,
+ "latest_updates": [{"update_id": "string", "state": "QUEUED",
"creation_time": "string"}],
+ "creator_user_name": "string",
+ "run_as_user_name": "string",
+ }
+ ]
+}
LIST_SPARK_VERSIONS_RESPONSE = {
"versions": [
{"key": "8.2.x-scala2.12", "name": "8.2 (includes Apache Spark 3.1.1,
Scala 2.12)"},
@@ -226,6 +241,13 @@ def list_jobs_endpoint(host):
return f"https://{host}/api/2.1/jobs/list"
+def list_pipelines_endpoint(host):
+ """
+ Utility function to generate the list jobs endpoint given the host
+ """
+ return f"https://{host}/api/2.0/pipelines"
+
+
def list_spark_versions_endpoint(host):
"""Utility function to generate the list spark versions endpoint given the
host"""
return f"https://{host}/api/2.0/clusters/spark-versions"
@@ -915,6 +937,92 @@ class TestDatabricksHook:
timeout=self.hook.timeout_seconds,
)
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_get_pipeline_id_by_name_success(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value =
LIST_PIPELINES_RESPONSE
+
+ pipeline_id = self.hook.find_pipeline_id_by_name(PIPELINE_NAME)
+
+ mock_requests.get.assert_called_once_with(
+ list_pipelines_endpoint(HOST),
+ json=None,
+ params={"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results":
25},
+ auth=HTTPBasicAuth(LOGIN, PASSWORD),
+ headers=self.hook.user_agent_header,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ assert pipeline_id == PIPELINE_ID
+
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_list_pipelines_success_multiple_pages(self, mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.side_effect = [
+ create_successful_response_mock({**LIST_PIPELINES_RESPONSE,
"next_page_token": "PAGETOKEN"}),
+ create_successful_response_mock(LIST_PIPELINES_RESPONSE),
+ ]
+
+ pipelines = self.hook.list_pipelines(pipeline_name=PIPELINE_NAME)
+
+ assert mock_requests.get.call_count == 2
+
+ first_call_args = mock_requests.method_calls[0]
+ assert first_call_args[1][0] == list_pipelines_endpoint(HOST)
+ assert first_call_args[2]["params"] == {"filter": f"name LIKE
'{PIPELINE_NAME}'", "max_results": 25}
+
+ second_call_args = mock_requests.method_calls[1]
+ assert second_call_args[1][0] == list_pipelines_endpoint(HOST)
+ assert second_call_args[2]["params"] == {
+ "filter": f"name LIKE '{PIPELINE_NAME}'",
+ "max_results": 25,
+ "page_token": "PAGETOKEN",
+ }
+
+ assert len(pipelines) == 2
+ assert pipelines == LIST_PIPELINES_RESPONSE["statuses"] * 2
+
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_get_pipeline_id_by_name_not_found(self, mock_requests):
+ empty_response = {"statuses": []}
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = empty_response
+
+ ne_pipeline_name = "Non existing pipeline"
+ pipeline_id = self.hook.find_pipeline_id_by_name(ne_pipeline_name)
+
+ mock_requests.get.assert_called_once_with(
+ list_pipelines_endpoint(HOST),
+ json=None,
+ params={"filter": f"name LIKE '{ne_pipeline_name}'",
"max_results": 25},
+ auth=HTTPBasicAuth(LOGIN, PASSWORD),
+ headers=self.hook.user_agent_header,
+ timeout=self.hook.timeout_seconds,
+ )
+
+ assert pipeline_id is None
+
+ @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+ def test_list_pipelines_raise_exception_with_duplicates(self,
mock_requests):
+ mock_requests.codes.ok = 200
+ mock_requests.get.return_value.json.return_value = {
+ **LIST_PIPELINES_RESPONSE,
+ "statuses": LIST_PIPELINES_RESPONSE["statuses"] * 2,
+ }
+
+ exception_message = f"There are more than one pipelines with name
{PIPELINE_NAME}."
+ with pytest.raises(AirflowException, match=exception_message):
+ self.hook.find_pipeline_id_by_name(pipeline_name=PIPELINE_NAME)
+
+ mock_requests.get.assert_called_once_with(
+ list_pipelines_endpoint(HOST),
+ json=None,
+ params={"filter": f"name LIKE '{PIPELINE_NAME}'", "max_results":
25},
+ auth=HTTPBasicAuth(LOGIN, PASSWORD),
+ headers=self.hook.user_agent_header,
+ timeout=self.hook.timeout_seconds,
+ )
+
@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_connection_success(self, mock_requests):
mock_requests.codes.ok = 200
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index 73cde92f8b..c196f51ee3 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -758,6 +758,35 @@ class TestDatabricksSubmitRunOperator:
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_exec_pipeline_name(self, db_mock_class):
+ """
+ Test the execute function when provided a pipeline name.
+ """
+ run = {"pipeline_task": {"pipeline_name": "This is a test pipeline"}}
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+ db_mock = db_mock_class.return_value
+ db_mock.find_pipeline_id_by_name.return_value =
PIPELINE_ID_TASK["pipeline_id"]
+ db_mock.submit_run.return_value = 1
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+ op.execute(None)
+
+ expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksSubmitRunOperator",
+ )
+ db_mock.find_pipeline_id_by_name.assert_called_once_with("This is a
test pipeline")
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run.assert_called_once_with(RUN_ID)
+ assert RUN_ID == op.run_id
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_failure(self, db_mock_class):
"""