This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4fb2140f39 Make Databricks operators' json parameter compatible with
XComs, Jinja expression values (#40471)
4fb2140f39 is described below
commit 4fb2140f393b6332903fb833151c2ce8a9c66fe2
Author: Bora Berke Sahin <[email protected]>
AuthorDate: Tue Jul 2 12:34:59 2024 +0300
Make Databricks operators' json parameter compatible with XComs, Jinja
expression values (#40471)
---
.../providers/databricks/operators/databricks.py | 184 +++----
airflow/providers/databricks/utils/databricks.py | 4 +-
.../databricks/operators/test_databricks.py | 600 ++++++++++++++++-----
.../providers/databricks/utils/test_databricks.py | 4 +-
4 files changed, 571 insertions(+), 221 deletions(-)
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index e461933ebc..d322519230 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -36,7 +36,7 @@ from
airflow.providers.databricks.operators.databricks_workflow import (
WorkflowRunMetadata,
)
from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
-from airflow.providers.databricks.utils.databricks import
normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import
_normalise_json_content, validate_trigger_event
if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -182,6 +182,17 @@ def
_handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
raise AirflowException(error_message)
+def _handle_overridden_json_params(operator):
+ for key, value in operator.overridden_json_params.items():
+ if value is not None:
+ operator.json[key] = value
+
+
+def normalise_json_content(operator):
+ if operator.json:
+ operator.json = _normalise_json_content(operator.json)
+
+
class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""
@@ -285,34 +296,21 @@ class DatabricksCreateJobsOperator(BaseOperator):
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
- if name is not None:
- self.json["name"] = name
- if description is not None:
- self.json["description"] = description
- if tags is not None:
- self.json["tags"] = tags
- if tasks is not None:
- self.json["tasks"] = tasks
- if job_clusters is not None:
- self.json["job_clusters"] = job_clusters
- if email_notifications is not None:
- self.json["email_notifications"] = email_notifications
- if webhook_notifications is not None:
- self.json["webhook_notifications"] = webhook_notifications
- if notification_settings is not None:
- self.json["notification_settings"] = notification_settings
- if timeout_seconds is not None:
- self.json["timeout_seconds"] = timeout_seconds
- if schedule is not None:
- self.json["schedule"] = schedule
- if max_concurrent_runs is not None:
- self.json["max_concurrent_runs"] = max_concurrent_runs
- if git_source is not None:
- self.json["git_source"] = git_source
- if access_control_list is not None:
- self.json["access_control_list"] = access_control_list
- if self.json:
- self.json = normalise_json_content(self.json)
+ self.overridden_json_params = {
+ "name": name,
+ "description": description,
+ "tags": tags,
+ "tasks": tasks,
+ "job_clusters": job_clusters,
+ "email_notifications": email_notifications,
+ "webhook_notifications": webhook_notifications,
+ "notification_settings": notification_settings,
+ "timeout_seconds": timeout_seconds,
+ "schedule": schedule,
+ "max_concurrent_runs": max_concurrent_runs,
+ "git_source": git_source,
+ "access_control_list": access_control_list,
+ }
@cached_property
def _hook(self):
@@ -324,16 +322,24 @@ class DatabricksCreateJobsOperator(BaseOperator):
caller="DatabricksCreateJobsOperator",
)
- def execute(self, context: Context) -> int:
+ def _setup_and_validate_json(self):
+ _handle_overridden_json_params(self)
+
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")
+
+ normalise_json_content(self)
+
+ def execute(self, context: Context) -> int:
+ self._setup_and_validate_json()
+
job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is
not None:
acl_json = {"access_control_list": access_control_list}
- self._hook.update_job_permission(job_id,
normalise_json_content(acl_json))
+ self._hook.update_job_permission(job_id,
_normalise_json_content(acl_json))
return job_id
@@ -506,43 +512,23 @@ class DatabricksSubmitRunOperator(BaseOperator):
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
- if tasks is not None:
- self.json["tasks"] = tasks
- if spark_jar_task is not None:
- self.json["spark_jar_task"] = spark_jar_task
- if notebook_task is not None:
- self.json["notebook_task"] = notebook_task
- if spark_python_task is not None:
- self.json["spark_python_task"] = spark_python_task
- if spark_submit_task is not None:
- self.json["spark_submit_task"] = spark_submit_task
- if pipeline_task is not None:
- self.json["pipeline_task"] = pipeline_task
- if dbt_task is not None:
- self.json["dbt_task"] = dbt_task
- if new_cluster is not None:
- self.json["new_cluster"] = new_cluster
- if existing_cluster_id is not None:
- self.json["existing_cluster_id"] = existing_cluster_id
- if libraries is not None:
- self.json["libraries"] = libraries
- if run_name is not None:
- self.json["run_name"] = run_name
- if timeout_seconds is not None:
- self.json["timeout_seconds"] = timeout_seconds
- if "run_name" not in self.json:
- self.json["run_name"] = run_name or kwargs["task_id"]
- if idempotency_token is not None:
- self.json["idempotency_token"] = idempotency_token
- if access_control_list is not None:
- self.json["access_control_list"] = access_control_list
- if git_source is not None:
- self.json["git_source"] = git_source
-
- if "dbt_task" in self.json and "git_source" not in self.json:
- raise AirflowException("git_source is required for dbt_task")
- if pipeline_task is not None and "pipeline_id" in pipeline_task and
"pipeline_name" in pipeline_task:
- raise AirflowException("'pipeline_name' is not allowed in
conjunction with 'pipeline_id'")
+ self.overridden_json_params = {
+ "tasks": tasks,
+ "spark_jar_task": spark_jar_task,
+ "notebook_task": notebook_task,
+ "spark_python_task": spark_python_task,
+ "spark_submit_task": spark_submit_task,
+ "pipeline_task": pipeline_task,
+ "dbt_task": dbt_task,
+ "new_cluster": new_cluster,
+ "existing_cluster_id": existing_cluster_id,
+ "libraries": libraries,
+ "run_name": run_name,
+ "timeout_seconds": timeout_seconds,
+ "idempotency_token": idempotency_token,
+ "access_control_list": access_control_list,
+ "git_source": git_source,
+ }
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
@@ -561,7 +547,25 @@ class DatabricksSubmitRunOperator(BaseOperator):
caller=caller,
)
+ def _setup_and_validate_json(self):
+ _handle_overridden_json_params(self)
+
+ if "run_name" not in self.json or self.json["run_name"] is None:
+ self.json["run_name"] = self.task_id
+
+ if "dbt_task" in self.json and "git_source" not in self.json:
+ raise AirflowException("git_source is required for dbt_task")
+ if (
+ "pipeline_task" in self.json
+ and "pipeline_id" in self.json["pipeline_task"]
+ and "pipeline_name" in self.json["pipeline_task"]
+ ):
+ raise AirflowException("'pipeline_name' is not allowed in
conjunction with 'pipeline_id'")
+
+ normalise_json_content(self)
+
def execute(self, context: Context):
+ self._setup_and_validate_json()
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
@@ -571,7 +575,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
- json_normalised = normalise_json_content(self.json)
+ json_normalised = _normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook,
self.log, context)
@@ -607,7 +611,7 @@ class
DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
- json_normalised = normalise_json_content(self.json)
+ json_normalised = _normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log,
context)
@@ -807,27 +811,16 @@ class DatabricksRunNowOperator(BaseOperator):
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs
-
- if job_id is not None:
- self.json["job_id"] = job_id
- if job_name is not None:
- self.json["job_name"] = job_name
- if "job_id" in self.json and "job_name" in self.json:
- raise AirflowException("Argument 'job_name' is not allowed with
argument 'job_id'")
- if notebook_params is not None:
- self.json["notebook_params"] = notebook_params
- if python_params is not None:
- self.json["python_params"] = python_params
- if python_named_params is not None:
- self.json["python_named_params"] = python_named_params
- if jar_params is not None:
- self.json["jar_params"] = jar_params
- if spark_submit_params is not None:
- self.json["spark_submit_params"] = spark_submit_params
- if idempotency_token is not None:
- self.json["idempotency_token"] = idempotency_token
- if self.json:
- self.json = normalise_json_content(self.json)
+ self.overridden_json_params = {
+ "job_id": job_id,
+ "job_name": job_name,
+ "notebook_params": notebook_params,
+ "python_params": python_params,
+ "python_named_params": python_named_params,
+ "jar_params": jar_params,
+ "spark_submit_params": spark_submit_params,
+ "idempotency_token": idempotency_token,
+ }
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
@@ -845,7 +838,16 @@ class DatabricksRunNowOperator(BaseOperator):
caller=caller,
)
+ def _setup_and_validate_json(self):
+ _handle_overridden_json_params(self)
+
+ if "job_id" in self.json and "job_name" in self.json:
+ raise AirflowException("Argument 'job_name' is not allowed with
argument 'job_id'")
+
+ normalise_json_content(self)
+
def execute(self, context: Context):
+ self._setup_and_validate_json()
hook = self._hook
if "job_name" in self.json:
job_id = hook.find_job_id_by_name(self.json["job_name"])
diff --git a/airflow/providers/databricks/utils/databricks.py
b/airflow/providers/databricks/utils/databricks.py
index 88d622c3bc..ec99bce178 100644
--- a/airflow/providers/databricks/utils/databricks.py
+++ b/airflow/providers/databricks/utils/databricks.py
@@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import RunState
-def normalise_json_content(content, json_path: str = "json") -> str | bool |
list | dict:
+def _normalise_json_content(content, json_path: str = "json") -> str | bool |
list | dict:
"""
Normalize content or all values of content if it is a dict to a string.
@@ -33,7 +33,7 @@ def normalise_json_content(content, json_path: str = "json")
-> str | bool | lis
The only one exception is when we have boolean values, they can not be
converted
to string type because databricks does not understand 'True' or 'False'
values.
"""
- normalise = normalise_json_content
+ normalise = _normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(content, (int, float)):
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index 7ff2295eda..ae2bb49766 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -23,8 +23,10 @@ from unittest.mock import MagicMock
import pytest
+from airflow.decorators import task
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import DAG
+from airflow.operators.python import PythonOperator
from airflow.providers.databricks.hooks.databricks import RunState
from airflow.providers.databricks.operators.databricks import (
DatabricksCreateJobsOperator,
@@ -36,6 +38,7 @@ from airflow.providers.databricks.operators.databricks import
(
)
from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
from airflow.providers.databricks.utils import databricks as utils
+from airflow.utils import timezone
pytestmark = pytest.mark.db_test
@@ -248,9 +251,9 @@ def make_run_with_state_mock(
class TestDatabricksCreateJobsOperator:
- def test_init_with_named_parameters(self):
+ def test_validate_json_with_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with the named parameters.
"""
op = DatabricksCreateJobsOperator(
task_id=TASK_ID,
@@ -266,7 +269,9 @@ class TestDatabricksCreateJobsOperator:
git_source=GIT_SOURCE,
access_control_list=ACCESS_CONTROL_LIST,
)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{
"name": JOB_NAME,
"tags": TAGS,
@@ -284,9 +289,9 @@ class TestDatabricksCreateJobsOperator:
assert expected == op.json
- def test_init_with_json(self):
+ def test_validate_json_with_json(self):
"""
- Test the initializer with json data.
+ Test the _setup_and_validate_json function with json data.
"""
json = {
"name": JOB_NAME,
@@ -302,8 +307,9 @@ class TestDatabricksCreateJobsOperator:
"access_control_list": ACCESS_CONTROL_LIST,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+ op._setup_and_validate_json()
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"name": JOB_NAME,
"tags": TAGS,
@@ -321,9 +327,9 @@ class TestDatabricksCreateJobsOperator:
assert expected == op.json
- def test_init_with_merging(self):
+ def test_validate_json_with_merging(self):
"""
- Test the initializer when json and other named parameters are both
+ Test the _setup_and_validate_json function when json and other named
parameters are both
provided. The named parameters should override top level keys in the
json dict.
"""
@@ -367,8 +373,9 @@ class TestDatabricksCreateJobsOperator:
git_source=override_git_source,
access_control_list=override_access_control_list,
)
+ op._setup_and_validate_json()
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"name": override_name,
"tags": override_tags,
@@ -386,24 +393,158 @@ class TestDatabricksCreateJobsOperator:
assert expected == op.json
- def test_init_with_templating(self):
+ def test_validate_json_with_templating(self):
json = {"name": "test-{{ ds }}"}
dag = DAG("test", start_date=datetime.now())
op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json)
op.render_template_fields(context={"ds": DATE})
- expected = utils.normalise_json_content({"name": f"test-{DATE}"})
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content({"name": f"test-{DATE}"})
assert expected == op.json
- def test_init_with_bad_type(self):
- json = {"test": datetime.now()}
+ def test_validate_json_with_bad_type(self):
+ json = {"test": datetime.now(), "name": "test"}
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
r"Type \<(type|class) \'datetime.datetime\'\> used "
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
- DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+ DatabricksCreateJobsOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
+
+ def test_validate_json_with_no_name(self):
+ json = {}
+ exception_message = "Missing required parameter: name"
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksCreateJobsOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
+ json = "{{ ti.xcom_pull(task_ids='push_json') }}"
+ with dag_maker("test_templated", render_template_as_native_obj=True):
+ push_json = PythonOperator(
+ task_id="push_json",
+ python_callable=lambda: {
+ "name": JOB_NAME,
+ "description": JOB_DESCRIPTION,
+ "tags": TAGS,
+ "tasks": TASKS,
+ "job_clusters": JOB_CLUSTERS,
+ "email_notifications": EMAIL_NOTIFICATIONS,
+ "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+ "notification_settings": NOTIFICATION_SETTINGS,
+ "timeout_seconds": TIMEOUT_SECONDS,
+ "schedule": SCHEDULE,
+ "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+ "git_source": GIT_SOURCE,
+ "access_control_list": ACCESS_CONTROL_LIST,
+ },
+ )
+ op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+ push_json >> op
+
+ db_mock = db_mock_class.return_value
+ db_mock.create_job.return_value = JOB_ID
+
+ db_mock.find_job_id_by_name.return_value = None
+
+ dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ tis = {ti.task_id: ti for ti in dagrun.task_instances}
+ tis["push_json"].run()
+ tis[TASK_ID].run()
+
+ expected = utils._normalise_json_content(
+ {
+ "name": JOB_NAME,
+ "description": JOB_DESCRIPTION,
+ "tags": TAGS,
+ "tasks": TASKS,
+ "job_clusters": JOB_CLUSTERS,
+ "email_notifications": EMAIL_NOTIFICATIONS,
+ "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+ "notification_settings": NOTIFICATION_SETTINGS,
+ "timeout_seconds": TIMEOUT_SECONDS,
+ "schedule": SCHEDULE,
+ "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+ "git_source": GIT_SOURCE,
+ "access_control_list": ACCESS_CONTROL_LIST,
+ }
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksCreateJobsOperator",
+ )
+
+ db_mock.create_job.assert_called_once_with(expected)
+ assert JOB_ID == tis[TASK_ID].xcom_pull()
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
+ with dag_maker("test_xcomarg", render_template_as_native_obj=True):
+
+ @task
+ def push_json() -> dict:
+ return {
+ "name": JOB_NAME,
+ "description": JOB_DESCRIPTION,
+ "tags": TAGS,
+ "tasks": TASKS,
+ "job_clusters": JOB_CLUSTERS,
+ "email_notifications": EMAIL_NOTIFICATIONS,
+ "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+ "notification_settings": NOTIFICATION_SETTINGS,
+ "timeout_seconds": TIMEOUT_SECONDS,
+ "schedule": SCHEDULE,
+ "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+ "git_source": GIT_SOURCE,
+ "access_control_list": ACCESS_CONTROL_LIST,
+ }
+
+ json = push_json()
+ op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+
+ db_mock = db_mock_class.return_value
+ db_mock.create_job.return_value = JOB_ID
+
+ db_mock.find_job_id_by_name.return_value = None
+
+ dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ tis = {ti.task_id: ti for ti in dagrun.task_instances}
+ tis["push_json"].run()
+ tis[TASK_ID].run()
+
+ expected = utils._normalise_json_content(
+ {
+ "name": JOB_NAME,
+ "description": JOB_DESCRIPTION,
+ "tags": TAGS,
+ "tasks": TASKS,
+ "job_clusters": JOB_CLUSTERS,
+ "email_notifications": EMAIL_NOTIFICATIONS,
+ "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+ "notification_settings": NOTIFICATION_SETTINGS,
+ "timeout_seconds": TIMEOUT_SECONDS,
+ "schedule": SCHEDULE,
+ "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+ "git_source": GIT_SOURCE,
+ "access_control_list": ACCESS_CONTROL_LIST,
+ }
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksCreateJobsOperator",
+ )
+
+ db_mock.create_job.assert_called_once_with(expected)
+ assert JOB_ID == tis[TASK_ID].xcom_pull()
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_create(self, db_mock_class):
@@ -433,7 +574,7 @@ class TestDatabricksCreateJobsOperator:
return_result = op.execute({})
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"name": JOB_NAME,
"description": JOB_DESCRIPTION,
@@ -487,7 +628,7 @@ class TestDatabricksCreateJobsOperator:
return_result = op.execute({})
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"name": JOB_NAME,
"description": JOB_DESCRIPTION,
@@ -539,7 +680,7 @@ class TestDatabricksCreateJobsOperator:
op.execute({})
- expected = utils.normalise_json_content({"access_control_list":
ACCESS_CONTROL_LIST})
+ expected = utils._normalise_json_content({"access_control_list":
ACCESS_CONTROL_LIST})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
@@ -586,66 +727,76 @@ class TestDatabricksCreateJobsOperator:
class TestDatabricksSubmitRunOperator:
- def test_init_with_notebook_task_named_parameters(self):
+ def test_validate_json_with_notebook_task_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with named parameters.
"""
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK
)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_spark_python_task_named_parameters(self):
+ def test_validate_json_with_spark_python_task_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with the named parameters.
"""
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
spark_python_task=SPARK_PYTHON_TASK
)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "spark_python_task":
SPARK_PYTHON_TASK, "run_name": TASK_ID}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_pipeline_name_task_named_parameters(self):
+ def test_validate_json_with_pipeline_name_task_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with the named parameters.
"""
op = DatabricksSubmitRunOperator(task_id=TASK_ID,
pipeline_task=PIPELINE_NAME_TASK)
- expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_NAME_TASK, "run_name": TASK_ID})
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content({"pipeline_task":
PIPELINE_NAME_TASK, "run_name": TASK_ID})
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_pipeline_id_task_named_parameters(self):
+ def test_validate_json_with_pipeline_id_task_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with the named parameters.
"""
op = DatabricksSubmitRunOperator(task_id=TASK_ID,
pipeline_task=PIPELINE_ID_TASK)
- expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
+ op._setup_and_validate_json()
- assert expected == utils.normalise_json_content(op.json)
+ expected = utils._normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
- def test_init_with_spark_submit_task_named_parameters(self):
+ assert expected == utils._normalise_json_content(op.json)
+
+ def test_validate_json_with_spark_submit_task_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with the named parameters.
"""
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
spark_submit_task=SPARK_SUBMIT_TASK
)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "spark_submit_task":
SPARK_SUBMIT_TASK, "run_name": TASK_ID}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_dbt_task_named_parameters(self):
+ def test_validate_json_with_dbt_task_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with the named parameters.
"""
git_source = {
"git_url": "https://github.com/dbt-labs/jaffle_shop",
@@ -655,15 +806,17 @@ class TestDatabricksSubmitRunOperator:
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK,
git_source=git_source
)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source":
git_source, "run_name": TASK_ID}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_dbt_task_mixed_parameters(self):
+ def test_validate_json_with_dbt_task_mixed_parameters(self):
"""
- Test the initializer with mixed parameters.
+ Test the _setup_and_validate_json function with mixed parameters.
"""
git_source = {
"git_url": "https://github.com/dbt-labs/jaffle_shop",
@@ -674,73 +827,85 @@ class TestDatabricksSubmitRunOperator:
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK,
json=json
)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source":
git_source, "run_name": TASK_ID}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_dbt_task_without_git_source_raises_error(self):
+ def test_validate_json_with_dbt_task_without_git_source_raises_error(self):
"""
- Test the initializer without the necessary git_source for dbt_task
raises error.
+ Test the _setup_and_validate_json function without the necessary
git_source for dbt_task raises error.
"""
exception_message = "git_source is required for dbt_task"
with pytest.raises(AirflowException, match=exception_message):
- DatabricksSubmitRunOperator(task_id=TASK_ID,
new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK)
+ DatabricksSubmitRunOperator(
+ task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK
+ )._setup_and_validate_json()
- def test_init_with_dbt_task_json_without_git_source_raises_error(self):
+ def
test_validate_json_with_dbt_task_json_without_git_source_raises_error(self):
"""
- Test the initializer without the necessary git_source for dbt_task
raises error.
+ Test the _setup_and_validate_json function without the necessary
git_source for dbt_task raises error.
"""
json = {"dbt_task": DBT_TASK, "new_cluster": NEW_CLUSTER}
exception_message = "git_source is required for dbt_task"
with pytest.raises(AirflowException, match=exception_message):
- DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ DatabricksSubmitRunOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
- def test_init_with_json(self):
+ def test_validate_json_with_json(self):
"""
- Test the initializer with json data.
+ Test the _setup_and_validate_json function with json data.
"""
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_tasks(self):
+ def test_validate_json_with_tasks(self):
tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task":
NOTEBOOK_TASK}]
op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks)
- expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks":
tasks})
- assert expected == utils.normalise_json_content(op.json)
+ op._setup_and_validate_json()
- def test_init_with_specified_run_name(self):
+ expected = utils._normalise_json_content({"run_name": TASK_ID,
"tasks": tasks})
+ assert expected == utils._normalise_json_content(op.json)
+
+ def test_validate_json_with_specified_run_name(self):
"""
- Test the initializer with a specified run_name.
+ Test the _setup_and_validate_json function with a specified run_name.
"""
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": RUN_NAME}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": RUN_NAME}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_pipeline_task(self):
+ def test_validate_json_with_pipeline_task(self):
"""
- Test the initializer with a pipeline task.
+ Test the _setup_and_validate_json function with a pipeline task.
"""
pipeline_task = {"pipeline_id": "test-dlt"}
json = {"new_cluster": NEW_CLUSTER, "run_name": RUN_NAME,
"pipeline_task": pipeline_task}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task,
"run_name": RUN_NAME}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_merging(self):
+ def test_validate_json_with_merging(self):
"""
- Test the initializer when json and other named parameters are both
+ Test the _setup_and_validate_json function when json and other named
parameters are both
provided. The named parameters should override top level keys in the
json dict.
"""
@@ -750,34 +915,38 @@ class TestDatabricksSubmitRunOperator:
"notebook_task": NOTEBOOK_TASK,
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json,
new_cluster=override_new_cluster)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{
"new_cluster": override_new_cluster,
"notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID,
}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
@pytest.mark.db_test
- def test_init_with_templating(self):
+ def test_validate_json_with_templating(self):
json = {
"new_cluster": NEW_CLUSTER,
"notebook_task": TEMPLATED_NOTEBOOK_TASK,
}
dag = DAG("test", start_date=datetime.now())
op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
+ op._setup_and_validate_json()
+
op.render_template_fields(context={"ds": DATE})
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"new_cluster": NEW_CLUSTER,
"notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK,
"run_name": TASK_ID,
}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_git_source(self):
+ def test_validate_json_with_git_source(self):
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": RUN_NAME}
git_source = {
"git_url": "https://github.com/apache/airflow",
@@ -785,7 +954,9 @@ class TestDatabricksSubmitRunOperator:
"git_branch": "main",
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID,
git_source=git_source, json=json)
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content(
{
"new_cluster": NEW_CLUSTER,
"notebook_task": NOTEBOOK_TASK,
@@ -793,18 +964,95 @@ class TestDatabricksSubmitRunOperator:
"git_source": git_source,
}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils._normalise_json_content(op.json)
- def test_init_with_bad_type(self):
+ def test_validate_json_with_bad_type(self):
json = {"test": datetime.now()}
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
r"Type \<(type|class) \'datetime.datetime\'\> used "
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
- utils.normalise_json_content(op.json)
+ DatabricksSubmitRunOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
+ json = "{{ ti.xcom_pull(task_ids='push_json') }}"
+ with dag_maker("test_templated", render_template_as_native_obj=True):
+ push_json = PythonOperator(
+ task_id="push_json",
+ python_callable=lambda: {
+ "new_cluster": NEW_CLUSTER,
+ "notebook_task": NOTEBOOK_TASK,
+ },
+ )
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ push_json >> op
+
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = RUN_ID
+ db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+ dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ tis = {ti.task_id: ti for ti in dagrun.task_instances}
+ tis["push_json"].run()
+ tis[TASK_ID].run()
+
+ expected = utils._normalise_json_content(
+ {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksSubmitRunOperator",
+ )
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run.assert_called_once_with(RUN_ID)
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
+ with dag_maker("test_xcomarg", render_template_as_native_obj=True):
+
+ @task
+ def push_json() -> dict:
+ return {
+ "new_cluster": NEW_CLUSTER,
+ "notebook_task": NOTEBOOK_TASK,
+ }
+
+ json = push_json()
+
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+ db_mock = db_mock_class.return_value
+ db_mock.submit_run.return_value = RUN_ID
+ db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+ dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ tis = {ti.task_id: ti for ti in dagrun.task_instances}
+ tis["push_json"].run()
+ tis[TASK_ID].run()
+
+ expected = utils._normalise_json_content(
+ {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
+ )
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksSubmitRunOperator",
+ )
+
+ db_mock.submit_run.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run.assert_called_once_with(RUN_ID)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_success(self, db_mock_class):
@@ -822,7 +1070,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -852,7 +1100,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
+ expected = utils._normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
@@ -884,7 +1132,7 @@ class TestDatabricksSubmitRunOperator:
with pytest.raises(AirflowException):
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"new_cluster": NEW_CLUSTER,
"notebook_task": NOTEBOOK_TASK,
@@ -932,7 +1180,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -961,7 +1209,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -995,7 +1243,7 @@ class TestDatabricksSubmitRunOperator:
assert isinstance(exc.value.trigger, DatabricksExecutionTrigger)
assert exc.value.method_name == "execute_complete"
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1077,7 +1325,7 @@ class TestDatabricksSubmitRunOperator:
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1107,7 +1355,7 @@ class TestDatabricksSubmitRunOperator:
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1125,18 +1373,20 @@ class TestDatabricksSubmitRunOperator:
class TestDatabricksRunNowOperator:
- def test_init_with_named_parameters(self):
+ def test_validate_json_with_named_parameters(self):
"""
- Test the initializer with the named parameters.
+ Test the _setup_and_validate_json function with named parameters.
"""
op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
- expected = utils.normalise_json_content({"job_id": 42})
+ op._setup_and_validate_json()
+
+ expected = utils._normalise_json_content({"job_id": 42})
assert expected == op.json
- def test_init_with_json(self):
+ def test_validate_json_with_json(self):
"""
- Test the initializer with json data.
+ Test the _setup_and_validate_json function with json data.
"""
json = {
"notebook_params": NOTEBOOK_PARAMS,
@@ -1147,8 +1397,9 @@ class TestDatabricksRunNowOperator:
"repair_run": False,
}
op = DatabricksRunNowOperator(task_id=TASK_ID, json=json)
+ op._setup_and_validate_json()
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"jar_params": JAR_PARAMS,
@@ -1161,9 +1412,9 @@ class TestDatabricksRunNowOperator:
assert expected == op.json
- def test_init_with_merging(self):
+ def test_validate_json_with_merging(self):
"""
- Test the initializer when json and other named parameters are both
+ Test the _setup_and_validate_json function when json and other named
parameters are both
provided. The named parameters should override top level keys in the
json dict.
"""
@@ -1180,8 +1431,9 @@ class TestDatabricksRunNowOperator:
jar_params=override_jar_params,
spark_submit_params=SPARK_SUBMIT_PARAMS,
)
+ op._setup_and_validate_json()
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": override_notebook_params,
"jar_params": override_jar_params,
@@ -1194,13 +1446,14 @@ class TestDatabricksRunNowOperator:
assert expected == op.json
@pytest.mark.db_test
- def test_init_with_templating(self):
+ def test_validate_json_with_templating(self):
json = {"notebook_params": NOTEBOOK_PARAMS, "jar_params":
TEMPLATED_JAR_PARAMS}
dag = DAG("test", start_date=datetime.now())
op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID,
json=json)
op.render_template_fields(context={"ds": DATE})
- expected = utils.normalise_json_content(
+ op._setup_and_validate_json()
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"jar_params": RENDERED_TEMPLATED_JAR_PARAMS,
@@ -1209,7 +1462,7 @@ class TestDatabricksRunNowOperator:
)
assert expected == op.json
- def test_init_with_bad_type(self):
+ def test_validate_json_with_bad_type(self):
json = {"test": datetime.now()}
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
@@ -1217,7 +1470,116 @@ class TestDatabricksRunNowOperator:
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)
+ DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
json=json)._setup_and_validate_json()
+
+ def test_validate_json_exception_with_job_name_and_job_id(self):
+ exception_message = "Argument 'job_name' is not allowed with argument
'job_id'"
+
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(
+ task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME
+ )._setup_and_validate_json()
+
+ run = {"job_id": JOB_ID, "job_name": JOB_NAME}
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(task_id=TASK_ID,
json=run)._setup_and_validate_json()
+
+ run = {"job_id": JOB_ID}
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(task_id=TASK_ID, json=run,
job_name=JOB_NAME)._setup_and_validate_json()
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
+ json = "{{ ti.xcom_pull(task_ids='push_json') }}"
+ with dag_maker("test_templated", render_template_as_native_obj=True):
+ push_json = PythonOperator(
+ task_id="push_json",
+ python_callable=lambda: {
+ "notebook_params": NOTEBOOK_PARAMS,
+ "notebook_task": NOTEBOOK_TASK,
+ "jar_params": JAR_PARAMS,
+ "job_id": JOB_ID,
+ },
+ )
+ op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
json=json)
+ push_json >> op
+
+ db_mock = db_mock_class.return_value
+ db_mock.run_now.return_value = RUN_ID
+ db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+ dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ tis = {ti.task_id: ti for ti in dagrun.task_instances}
+ tis["push_json"].run()
+ tis[TASK_ID].run()
+
+ expected = utils._normalise_json_content(
+ {
+ "notebook_params": NOTEBOOK_PARAMS,
+ "notebook_task": NOTEBOOK_TASK,
+ "jar_params": JAR_PARAMS,
+ "job_id": JOB_ID,
+ }
+ )
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksRunNowOperator",
+ )
+ db_mock.run_now.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run.assert_called_once_with(RUN_ID)
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
+ with dag_maker("test_xcomarg", render_template_as_native_obj=True):
+
+ @task
+ def push_json() -> dict:
+ return {
+ "notebook_params": NOTEBOOK_PARAMS,
+ "notebook_task": NOTEBOOK_TASK,
+ "jar_params": JAR_PARAMS,
+ "job_id": JOB_ID,
+ }
+
+ json = push_json()
+
+ op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
json=json)
+
+ db_mock = db_mock_class.return_value
+ db_mock.run_now.return_value = RUN_ID
+ db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+ db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+ dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+ tis = {ti.task_id: ti for ti in dagrun.task_instances}
+ tis["push_json"].run()
+ tis[TASK_ID].run()
+
+ expected = utils._normalise_json_content(
+ {
+ "notebook_params": NOTEBOOK_PARAMS,
+ "notebook_task": NOTEBOOK_TASK,
+ "jar_params": JAR_PARAMS,
+ "job_id": JOB_ID,
+ }
+ )
+
+ db_mock_class.assert_called_once_with(
+ DEFAULT_CONN_ID,
+ retry_limit=op.databricks_retry_limit,
+ retry_delay=op.databricks_retry_delay,
+ retry_args=None,
+ caller="DatabricksRunNowOperator",
+ )
+ db_mock.run_now.assert_called_once_with(expected)
+ db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+ db_mock.get_run.assert_called_once_with(RUN_ID)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_success(self, db_mock_class):
@@ -1232,7 +1594,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1267,7 +1629,7 @@ class TestDatabricksRunNowOperator:
with pytest.raises(AirflowException):
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1323,7 +1685,7 @@ class TestDatabricksRunNowOperator:
with pytest.raises(AirflowException, match="Exception: Something went
wrong"):
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1391,7 +1753,7 @@ class TestDatabricksRunNowOperator:
):
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1435,7 +1797,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1466,7 +1828,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1486,20 +1848,6 @@ class TestDatabricksRunNowOperator:
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()
- def test_init_exception_with_job_name_and_job_id(self):
- exception_message = "Argument 'job_name' is not allowed with argument
'job_id'"
-
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
job_name=JOB_NAME)
-
- run = {"job_id": JOB_ID, "job_name": JOB_NAME}
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID, json=run)
-
- run = {"job_id": JOB_ID}
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID, json=run,
job_name=JOB_NAME)
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_with_job_name(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task":
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
@@ -1511,7 +1859,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1559,7 +1907,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1593,7 +1941,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1630,7 +1978,7 @@ class TestDatabricksRunNowOperator:
assert isinstance(exc.value.trigger, DatabricksExecutionTrigger)
assert exc.value.method_name == "execute_complete"
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1740,7 +2088,7 @@ class TestDatabricksRunNowOperator:
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1774,7 +2122,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils.normalise_json_content(
+ expected = utils._normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
diff --git a/tests/providers/databricks/utils/test_databricks.py
b/tests/providers/databricks/utils/test_databricks.py
index 8c6ce8ce4b..4b57573253 100644
--- a/tests/providers/databricks/utils/test_databricks.py
+++ b/tests/providers/databricks/utils/test_databricks.py
@@ -21,7 +21,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import RunState
-from airflow.providers.databricks.utils.databricks import
normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import
_normalise_json_content, validate_trigger_event
RUN_ID = 1
RUN_PAGE_URL = "run-page-url"
@@ -46,7 +46,7 @@ class TestDatabricksOperatorSharedFunctions:
"test_list": ["1", "1.0", "a", "b"],
"test_tuple": ["1", "1.0", "a", "b"],
}
- assert normalise_json_content(test_json) == expected
+ assert _normalise_json_content(test_json) == expected
def test_validate_trigger_event_success(self):
event = {