This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fb2140f39 Make Databricks operators' json parameter compatible with 
XComs, Jinja expression values (#40471)
4fb2140f39 is described below

commit 4fb2140f393b6332903fb833151c2ce8a9c66fe2
Author: Bora Berke Sahin <[email protected]>
AuthorDate: Tue Jul 2 12:34:59 2024 +0300

    Make Databricks operators' json parameter compatible with XComs, Jinja 
expression values (#40471)
---
 .../providers/databricks/operators/databricks.py   | 184 +++----
 airflow/providers/databricks/utils/databricks.py   |   4 +-
 .../databricks/operators/test_databricks.py        | 600 ++++++++++++++++-----
 .../providers/databricks/utils/test_databricks.py  |   4 +-
 4 files changed, 571 insertions(+), 221 deletions(-)

diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index e461933ebc..d322519230 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -36,7 +36,7 @@ from 
airflow.providers.databricks.operators.databricks_workflow import (
     WorkflowRunMetadata,
 )
 from airflow.providers.databricks.triggers.databricks import 
DatabricksExecutionTrigger
-from airflow.providers.databricks.utils.databricks import 
normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import 
_normalise_json_content, validate_trigger_event
 
 if TYPE_CHECKING:
     from airflow.models.taskinstancekey import TaskInstanceKey
@@ -182,6 +182,17 @@ def 
_handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
     raise AirflowException(error_message)
 
 
+def _handle_overridden_json_params(operator):
+    for key, value in operator.overridden_json_params.items():
+        if value is not None:
+            operator.json[key] = value
+
+
+def normalise_json_content(operator):
+    if operator.json:
+        operator.json = _normalise_json_content(operator.json)
+
+
 class DatabricksJobRunLink(BaseOperatorLink):
     """Constructs a link to monitor a Databricks Job Run."""
 
@@ -285,34 +296,21 @@ class DatabricksCreateJobsOperator(BaseOperator):
         self.databricks_retry_limit = databricks_retry_limit
         self.databricks_retry_delay = databricks_retry_delay
         self.databricks_retry_args = databricks_retry_args
-        if name is not None:
-            self.json["name"] = name
-        if description is not None:
-            self.json["description"] = description
-        if tags is not None:
-            self.json["tags"] = tags
-        if tasks is not None:
-            self.json["tasks"] = tasks
-        if job_clusters is not None:
-            self.json["job_clusters"] = job_clusters
-        if email_notifications is not None:
-            self.json["email_notifications"] = email_notifications
-        if webhook_notifications is not None:
-            self.json["webhook_notifications"] = webhook_notifications
-        if notification_settings is not None:
-            self.json["notification_settings"] = notification_settings
-        if timeout_seconds is not None:
-            self.json["timeout_seconds"] = timeout_seconds
-        if schedule is not None:
-            self.json["schedule"] = schedule
-        if max_concurrent_runs is not None:
-            self.json["max_concurrent_runs"] = max_concurrent_runs
-        if git_source is not None:
-            self.json["git_source"] = git_source
-        if access_control_list is not None:
-            self.json["access_control_list"] = access_control_list
-        if self.json:
-            self.json = normalise_json_content(self.json)
+        self.overridden_json_params = {
+            "name": name,
+            "description": description,
+            "tags": tags,
+            "tasks": tasks,
+            "job_clusters": job_clusters,
+            "email_notifications": email_notifications,
+            "webhook_notifications": webhook_notifications,
+            "notification_settings": notification_settings,
+            "timeout_seconds": timeout_seconds,
+            "schedule": schedule,
+            "max_concurrent_runs": max_concurrent_runs,
+            "git_source": git_source,
+            "access_control_list": access_control_list,
+        }
 
     @cached_property
     def _hook(self):
@@ -324,16 +322,24 @@ class DatabricksCreateJobsOperator(BaseOperator):
             caller="DatabricksCreateJobsOperator",
         )
 
-    def execute(self, context: Context) -> int:
+    def _setup_and_validate_json(self):
+        _handle_overridden_json_params(self)
+
         if "name" not in self.json:
             raise AirflowException("Missing required parameter: name")
+
+        normalise_json_content(self)
+
+    def execute(self, context: Context) -> int:
+        self._setup_and_validate_json()
+
         job_id = self._hook.find_job_id_by_name(self.json["name"])
         if job_id is None:
             return self._hook.create_job(self.json)
         self._hook.reset_job(str(job_id), self.json)
         if (access_control_list := self.json.get("access_control_list")) is 
not None:
             acl_json = {"access_control_list": access_control_list}
-            self._hook.update_job_permission(job_id, 
normalise_json_content(acl_json))
+            self._hook.update_job_permission(job_id, 
_normalise_json_content(acl_json))
 
         return job_id
 
@@ -506,43 +512,23 @@ class DatabricksSubmitRunOperator(BaseOperator):
         self.databricks_retry_args = databricks_retry_args
         self.wait_for_termination = wait_for_termination
         self.deferrable = deferrable
-        if tasks is not None:
-            self.json["tasks"] = tasks
-        if spark_jar_task is not None:
-            self.json["spark_jar_task"] = spark_jar_task
-        if notebook_task is not None:
-            self.json["notebook_task"] = notebook_task
-        if spark_python_task is not None:
-            self.json["spark_python_task"] = spark_python_task
-        if spark_submit_task is not None:
-            self.json["spark_submit_task"] = spark_submit_task
-        if pipeline_task is not None:
-            self.json["pipeline_task"] = pipeline_task
-        if dbt_task is not None:
-            self.json["dbt_task"] = dbt_task
-        if new_cluster is not None:
-            self.json["new_cluster"] = new_cluster
-        if existing_cluster_id is not None:
-            self.json["existing_cluster_id"] = existing_cluster_id
-        if libraries is not None:
-            self.json["libraries"] = libraries
-        if run_name is not None:
-            self.json["run_name"] = run_name
-        if timeout_seconds is not None:
-            self.json["timeout_seconds"] = timeout_seconds
-        if "run_name" not in self.json:
-            self.json["run_name"] = run_name or kwargs["task_id"]
-        if idempotency_token is not None:
-            self.json["idempotency_token"] = idempotency_token
-        if access_control_list is not None:
-            self.json["access_control_list"] = access_control_list
-        if git_source is not None:
-            self.json["git_source"] = git_source
-
-        if "dbt_task" in self.json and "git_source" not in self.json:
-            raise AirflowException("git_source is required for dbt_task")
-        if pipeline_task is not None and "pipeline_id" in pipeline_task and 
"pipeline_name" in pipeline_task:
-            raise AirflowException("'pipeline_name' is not allowed in 
conjunction with 'pipeline_id'")
+        self.overridden_json_params = {
+            "tasks": tasks,
+            "spark_jar_task": spark_jar_task,
+            "notebook_task": notebook_task,
+            "spark_python_task": spark_python_task,
+            "spark_submit_task": spark_submit_task,
+            "pipeline_task": pipeline_task,
+            "dbt_task": dbt_task,
+            "new_cluster": new_cluster,
+            "existing_cluster_id": existing_cluster_id,
+            "libraries": libraries,
+            "run_name": run_name,
+            "timeout_seconds": timeout_seconds,
+            "idempotency_token": idempotency_token,
+            "access_control_list": access_control_list,
+            "git_source": git_source,
+        }
 
         # This variable will be used in case our task gets killed.
         self.run_id: int | None = None
@@ -561,7 +547,25 @@ class DatabricksSubmitRunOperator(BaseOperator):
             caller=caller,
         )
 
+    def _setup_and_validate_json(self):
+        _handle_overridden_json_params(self)
+
+        if "run_name" not in self.json or self.json["run_name"] is None:
+            self.json["run_name"] = self.task_id
+
+        if "dbt_task" in self.json and "git_source" not in self.json:
+            raise AirflowException("git_source is required for dbt_task")
+        if (
+            "pipeline_task" in self.json
+            and "pipeline_id" in self.json["pipeline_task"]
+            and "pipeline_name" in self.json["pipeline_task"]
+        ):
+            raise AirflowException("'pipeline_name' is not allowed in 
conjunction with 'pipeline_id'")
+
+        normalise_json_content(self)
+
     def execute(self, context: Context):
+        self._setup_and_validate_json()
         if (
             "pipeline_task" in self.json
             and self.json["pipeline_task"].get("pipeline_id") is None
@@ -571,7 +575,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
             pipeline_name = self.json["pipeline_task"]["pipeline_name"]
             self.json["pipeline_task"]["pipeline_id"] = 
self._hook.find_pipeline_id_by_name(pipeline_name)
             del self.json["pipeline_task"]["pipeline_name"]
-        json_normalised = normalise_json_content(self.json)
+        json_normalised = _normalise_json_content(self.json)
         self.run_id = self._hook.submit_run(json_normalised)
         if self.deferrable:
             _handle_deferrable_databricks_operator_execution(self, self._hook, 
self.log, context)
@@ -607,7 +611,7 @@ class 
DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
 
     def execute(self, context):
         hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
-        json_normalised = normalise_json_content(self.json)
+        json_normalised = _normalise_json_content(self.json)
         self.run_id = hook.submit_run(json_normalised)
         _handle_deferrable_databricks_operator_execution(self, hook, self.log, 
context)
 
@@ -807,27 +811,16 @@ class DatabricksRunNowOperator(BaseOperator):
         self.deferrable = deferrable
         self.repair_run = repair_run
         self.cancel_previous_runs = cancel_previous_runs
-
-        if job_id is not None:
-            self.json["job_id"] = job_id
-        if job_name is not None:
-            self.json["job_name"] = job_name
-        if "job_id" in self.json and "job_name" in self.json:
-            raise AirflowException("Argument 'job_name' is not allowed with 
argument 'job_id'")
-        if notebook_params is not None:
-            self.json["notebook_params"] = notebook_params
-        if python_params is not None:
-            self.json["python_params"] = python_params
-        if python_named_params is not None:
-            self.json["python_named_params"] = python_named_params
-        if jar_params is not None:
-            self.json["jar_params"] = jar_params
-        if spark_submit_params is not None:
-            self.json["spark_submit_params"] = spark_submit_params
-        if idempotency_token is not None:
-            self.json["idempotency_token"] = idempotency_token
-        if self.json:
-            self.json = normalise_json_content(self.json)
+        self.overridden_json_params = {
+            "job_id": job_id,
+            "job_name": job_name,
+            "notebook_params": notebook_params,
+            "python_params": python_params,
+            "python_named_params": python_named_params,
+            "jar_params": jar_params,
+            "spark_submit_params": spark_submit_params,
+            "idempotency_token": idempotency_token,
+        }
         # This variable will be used in case our task gets killed.
         self.run_id: int | None = None
         self.do_xcom_push = do_xcom_push
@@ -845,7 +838,16 @@ class DatabricksRunNowOperator(BaseOperator):
             caller=caller,
         )
 
+    def _setup_and_validate_json(self):
+        _handle_overridden_json_params(self)
+
+        if "job_id" in self.json and "job_name" in self.json:
+            raise AirflowException("Argument 'job_name' is not allowed with 
argument 'job_id'")
+
+        normalise_json_content(self)
+
     def execute(self, context: Context):
+        self._setup_and_validate_json()
         hook = self._hook
         if "job_name" in self.json:
             job_id = hook.find_job_id_by_name(self.json["job_name"])
diff --git a/airflow/providers/databricks/utils/databricks.py 
b/airflow/providers/databricks/utils/databricks.py
index 88d622c3bc..ec99bce178 100644
--- a/airflow/providers/databricks/utils/databricks.py
+++ b/airflow/providers/databricks/utils/databricks.py
@@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.databricks.hooks.databricks import RunState
 
 
-def normalise_json_content(content, json_path: str = "json") -> str | bool | 
list | dict:
+def _normalise_json_content(content, json_path: str = "json") -> str | bool | 
list | dict:
     """
     Normalize content or all values of content if it is a dict to a string.
 
@@ -33,7 +33,7 @@ def normalise_json_content(content, json_path: str = "json") 
-> str | bool | lis
     The only one exception is when we have boolean values, they can not be 
converted
     to string type because databricks does not understand 'True' or 'False' 
values.
     """
-    normalise = normalise_json_content
+    normalise = _normalise_json_content
     if isinstance(content, (str, bool)):
         return content
     elif isinstance(content, (int, float)):
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index 7ff2295eda..ae2bb49766 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -23,8 +23,10 @@ from unittest.mock import MagicMock
 
 import pytest
 
+from airflow.decorators import task
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.models import DAG
+from airflow.operators.python import PythonOperator
 from airflow.providers.databricks.hooks.databricks import RunState
 from airflow.providers.databricks.operators.databricks import (
     DatabricksCreateJobsOperator,
@@ -36,6 +38,7 @@ from airflow.providers.databricks.operators.databricks import 
(
 )
 from airflow.providers.databricks.triggers.databricks import 
DatabricksExecutionTrigger
 from airflow.providers.databricks.utils import databricks as utils
+from airflow.utils import timezone
 
 pytestmark = pytest.mark.db_test
 
@@ -248,9 +251,9 @@ def make_run_with_state_mock(
 
 
 class TestDatabricksCreateJobsOperator:
-    def test_init_with_named_parameters(self):
+    def test_validate_json_with_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with the named parameters.
         """
         op = DatabricksCreateJobsOperator(
             task_id=TASK_ID,
@@ -266,7 +269,9 @@ class TestDatabricksCreateJobsOperator:
             git_source=GIT_SOURCE,
             access_control_list=ACCESS_CONTROL_LIST,
         )
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {
                 "name": JOB_NAME,
                 "tags": TAGS,
@@ -284,9 +289,9 @@ class TestDatabricksCreateJobsOperator:
 
         assert expected == op.json
 
-    def test_init_with_json(self):
+    def test_validate_json_with_json(self):
         """
-        Test the initializer with json data.
+        Test the _setup_and_validate_json function with json data.
         """
         json = {
             "name": JOB_NAME,
@@ -302,8 +307,9 @@ class TestDatabricksCreateJobsOperator:
             "access_control_list": ACCESS_CONTROL_LIST,
         }
         op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+        op._setup_and_validate_json()
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "name": JOB_NAME,
                 "tags": TAGS,
@@ -321,9 +327,9 @@ class TestDatabricksCreateJobsOperator:
 
         assert expected == op.json
 
-    def test_init_with_merging(self):
+    def test_validate_json_with_merging(self):
         """
-        Test the initializer when json and other named parameters are both
+        Test the _setup_and_validate_json function when json and other named 
parameters are both
         provided. The named parameters should override top level keys in the
         json dict.
         """
@@ -367,8 +373,9 @@ class TestDatabricksCreateJobsOperator:
             git_source=override_git_source,
             access_control_list=override_access_control_list,
         )
+        op._setup_and_validate_json()
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "name": override_name,
                 "tags": override_tags,
@@ -386,24 +393,158 @@ class TestDatabricksCreateJobsOperator:
 
         assert expected == op.json
 
-    def test_init_with_templating(self):
+    def test_validate_json_with_templating(self):
         json = {"name": "test-{{ ds }}"}
 
         dag = DAG("test", start_date=datetime.now())
         op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json)
         op.render_template_fields(context={"ds": DATE})
-        expected = utils.normalise_json_content({"name": f"test-{DATE}"})
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content({"name": f"test-{DATE}"})
         assert expected == op.json
 
-    def test_init_with_bad_type(self):
-        json = {"test": datetime.now()}
+    def test_validate_json_with_bad_type(self):
+        json = {"test": datetime.now(), "name": "test"}
         # Looks a bit weird since we have to escape regex reserved symbols.
         exception_message = (
             r"Type \<(type|class) \'datetime.datetime\'\> used "
             r"for parameter json\[test\] is not a number or a string"
         )
         with pytest.raises(AirflowException, match=exception_message):
-            DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+            DatabricksCreateJobsOperator(task_id=TASK_ID, 
json=json)._setup_and_validate_json()
+
+    def test_validate_json_with_no_name(self):
+        json = {}
+        exception_message = "Missing required parameter: name"
+        with pytest.raises(AirflowException, match=exception_message):
+            DatabricksCreateJobsOperator(task_id=TASK_ID, 
json=json)._setup_and_validate_json()
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
+        json = "{{ ti.xcom_pull(task_ids='push_json') }}"
+        with dag_maker("test_templated", render_template_as_native_obj=True):
+            push_json = PythonOperator(
+                task_id="push_json",
+                python_callable=lambda: {
+                    "name": JOB_NAME,
+                    "description": JOB_DESCRIPTION,
+                    "tags": TAGS,
+                    "tasks": TASKS,
+                    "job_clusters": JOB_CLUSTERS,
+                    "email_notifications": EMAIL_NOTIFICATIONS,
+                    "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+                    "notification_settings": NOTIFICATION_SETTINGS,
+                    "timeout_seconds": TIMEOUT_SECONDS,
+                    "schedule": SCHEDULE,
+                    "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+                    "git_source": GIT_SOURCE,
+                    "access_control_list": ACCESS_CONTROL_LIST,
+                },
+            )
+            op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+            push_json >> op
+
+        db_mock = db_mock_class.return_value
+        db_mock.create_job.return_value = JOB_ID
+
+        db_mock.find_job_id_by_name.return_value = None
+
+        dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        tis = {ti.task_id: ti for ti in dagrun.task_instances}
+        tis["push_json"].run()
+        tis[TASK_ID].run()
+
+        expected = utils._normalise_json_content(
+            {
+                "name": JOB_NAME,
+                "description": JOB_DESCRIPTION,
+                "tags": TAGS,
+                "tasks": TASKS,
+                "job_clusters": JOB_CLUSTERS,
+                "email_notifications": EMAIL_NOTIFICATIONS,
+                "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+                "notification_settings": NOTIFICATION_SETTINGS,
+                "timeout_seconds": TIMEOUT_SECONDS,
+                "schedule": SCHEDULE,
+                "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+                "git_source": GIT_SOURCE,
+                "access_control_list": ACCESS_CONTROL_LIST,
+            }
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksCreateJobsOperator",
+        )
+
+        db_mock.create_job.assert_called_once_with(expected)
+        assert JOB_ID == tis[TASK_ID].xcom_pull()
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
+        with dag_maker("test_xcomarg", render_template_as_native_obj=True):
+
+            @task
+            def push_json() -> dict:
+                return {
+                    "name": JOB_NAME,
+                    "description": JOB_DESCRIPTION,
+                    "tags": TAGS,
+                    "tasks": TASKS,
+                    "job_clusters": JOB_CLUSTERS,
+                    "email_notifications": EMAIL_NOTIFICATIONS,
+                    "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+                    "notification_settings": NOTIFICATION_SETTINGS,
+                    "timeout_seconds": TIMEOUT_SECONDS,
+                    "schedule": SCHEDULE,
+                    "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+                    "git_source": GIT_SOURCE,
+                    "access_control_list": ACCESS_CONTROL_LIST,
+                }
+
+            json = push_json()
+            op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
+
+        db_mock = db_mock_class.return_value
+        db_mock.create_job.return_value = JOB_ID
+
+        db_mock.find_job_id_by_name.return_value = None
+
+        dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        tis = {ti.task_id: ti for ti in dagrun.task_instances}
+        tis["push_json"].run()
+        tis[TASK_ID].run()
+
+        expected = utils._normalise_json_content(
+            {
+                "name": JOB_NAME,
+                "description": JOB_DESCRIPTION,
+                "tags": TAGS,
+                "tasks": TASKS,
+                "job_clusters": JOB_CLUSTERS,
+                "email_notifications": EMAIL_NOTIFICATIONS,
+                "webhook_notifications": WEBHOOK_NOTIFICATIONS,
+                "notification_settings": NOTIFICATION_SETTINGS,
+                "timeout_seconds": TIMEOUT_SECONDS,
+                "schedule": SCHEDULE,
+                "max_concurrent_runs": MAX_CONCURRENT_RUNS,
+                "git_source": GIT_SOURCE,
+                "access_control_list": ACCESS_CONTROL_LIST,
+            }
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksCreateJobsOperator",
+        )
+
+        db_mock.create_job.assert_called_once_with(expected)
+        assert JOB_ID == tis[TASK_ID].xcom_pull()
 
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_exec_create(self, db_mock_class):
@@ -433,7 +574,7 @@ class TestDatabricksCreateJobsOperator:
 
         return_result = op.execute({})
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "name": JOB_NAME,
                 "description": JOB_DESCRIPTION,
@@ -487,7 +628,7 @@ class TestDatabricksCreateJobsOperator:
 
         return_result = op.execute({})
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "name": JOB_NAME,
                 "description": JOB_DESCRIPTION,
@@ -539,7 +680,7 @@ class TestDatabricksCreateJobsOperator:
 
         op.execute({})
 
-        expected = utils.normalise_json_content({"access_control_list": 
ACCESS_CONTROL_LIST})
+        expected = utils._normalise_json_content({"access_control_list": 
ACCESS_CONTROL_LIST})
 
         db_mock_class.assert_called_once_with(
             DEFAULT_CONN_ID,
@@ -586,66 +727,76 @@ class TestDatabricksCreateJobsOperator:
 
 
 class TestDatabricksSubmitRunOperator:
-    def test_init_with_notebook_task_named_parameters(self):
+    def test_validate_json_with_notebook_task_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with named parameters.
         """
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, 
notebook_task=NOTEBOOK_TASK
         )
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
 
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_spark_python_task_named_parameters(self):
+    def test_validate_json_with_spark_python_task_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with the named parameters.
         """
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, 
spark_python_task=SPARK_PYTHON_TASK
         )
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "spark_python_task": 
SPARK_PYTHON_TASK, "run_name": TASK_ID}
         )
 
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_pipeline_name_task_named_parameters(self):
+    def test_validate_json_with_pipeline_name_task_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with the named parameters.
         """
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, 
pipeline_task=PIPELINE_NAME_TASK)
-        expected = utils.normalise_json_content({"pipeline_task": 
PIPELINE_NAME_TASK, "run_name": TASK_ID})
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content({"pipeline_task": 
PIPELINE_NAME_TASK, "run_name": TASK_ID})
 
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_pipeline_id_task_named_parameters(self):
+    def test_validate_json_with_pipeline_id_task_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with the named parameters.
         """
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, 
pipeline_task=PIPELINE_ID_TASK)
-        expected = utils.normalise_json_content({"pipeline_task": 
PIPELINE_ID_TASK, "run_name": TASK_ID})
+        op._setup_and_validate_json()
 
-        assert expected == utils.normalise_json_content(op.json)
+        expected = utils._normalise_json_content({"pipeline_task": 
PIPELINE_ID_TASK, "run_name": TASK_ID})
 
-    def test_init_with_spark_submit_task_named_parameters(self):
+        assert expected == utils._normalise_json_content(op.json)
+
+    def test_validate_json_with_spark_submit_task_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with the named parameters.
         """
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, 
spark_submit_task=SPARK_SUBMIT_TASK
         )
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "spark_submit_task": 
SPARK_SUBMIT_TASK, "run_name": TASK_ID}
         )
 
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_dbt_task_named_parameters(self):
+    def test_validate_json_with_dbt_task_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with the named parameters.
         """
         git_source = {
             "git_url": "https://github.com/dbt-labs/jaffle_shop";,
@@ -655,15 +806,17 @@ class TestDatabricksSubmitRunOperator:
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK, 
git_source=git_source
         )
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": 
git_source, "run_name": TASK_ID}
         )
 
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_dbt_task_mixed_parameters(self):
+    def test_validate_json_with_dbt_task_mixed_parameters(self):
         """
-        Test the initializer with mixed parameters.
+        Test the _setup_and_validate_json function with mixed parameters.
         """
         git_source = {
             "git_url": "https://github.com/dbt-labs/jaffle_shop";,
@@ -674,73 +827,85 @@ class TestDatabricksSubmitRunOperator:
         op = DatabricksSubmitRunOperator(
             task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK, 
json=json
         )
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source": 
git_source, "run_name": TASK_ID}
         )
 
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_dbt_task_without_git_source_raises_error(self):
+    def test_validate_json_with_dbt_task_without_git_source_raises_error(self):
         """
-        Test the initializer without the necessary git_source for dbt_task 
raises error.
+        Test the _setup_and_validate_json function without the necessary 
git_source for dbt_task raises error.
         """
         exception_message = "git_source is required for dbt_task"
         with pytest.raises(AirflowException, match=exception_message):
-            DatabricksSubmitRunOperator(task_id=TASK_ID, 
new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK)
+            DatabricksSubmitRunOperator(
+                task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK
+            )._setup_and_validate_json()
 
-    def test_init_with_dbt_task_json_without_git_source_raises_error(self):
+    def 
test_validate_json_with_dbt_task_json_without_git_source_raises_error(self):
         """
-        Test the initializer without the necessary git_source for dbt_task 
raises error.
+        Test the _setup_and_validate_json function without the necessary 
git_source for dbt_task raises error.
         """
         json = {"dbt_task": DBT_TASK, "new_cluster": NEW_CLUSTER}
 
         exception_message = "git_source is required for dbt_task"
         with pytest.raises(AirflowException, match=exception_message):
-            DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+            DatabricksSubmitRunOperator(task_id=TASK_ID, 
json=json)._setup_and_validate_json()
 
-    def test_init_with_json(self):
+    def test_validate_json_with_json(self):
         """
-        Test the initializer with json data.
+        Test the _setup_and_validate_json function with json data.
         """
         json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_tasks(self):
+    def test_validate_json_with_tasks(self):
         tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task": 
NOTEBOOK_TASK}]
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks)
-        expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks": 
tasks})
-        assert expected == utils.normalise_json_content(op.json)
+        op._setup_and_validate_json()
 
-    def test_init_with_specified_run_name(self):
+        expected = utils._normalise_json_content({"run_name": TASK_ID, 
"tasks": tasks})
+        assert expected == utils._normalise_json_content(op.json)
+
+    def test_validate_json_with_specified_run_name(self):
         """
-        Test the initializer with a specified run_name.
+        Test the _setup_and_validate_json function with a specified run_name.
         """
         json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": RUN_NAME}
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": RUN_NAME}
         )
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_pipeline_task(self):
+    def test_validate_json_with_pipeline_task(self):
         """
-        Test the initializer with a pipeline task.
+        Test the _setup_and_validate_json function with a pipeline task.
         """
         pipeline_task = {"pipeline_id": "test-dlt"}
         json = {"new_cluster": NEW_CLUSTER, "run_name": RUN_NAME, 
"pipeline_task": pipeline_task}
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task, 
"run_name": RUN_NAME}
         )
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_merging(self):
+    def test_validate_json_with_merging(self):
         """
-        Test the initializer when json and other named parameters are both
+        Test the _setup_and_validate_json function when json and other named 
parameters are both
         provided. The named parameters should override top level keys in the
         json dict.
         """
@@ -750,34 +915,38 @@ class TestDatabricksSubmitRunOperator:
             "notebook_task": NOTEBOOK_TASK,
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, 
new_cluster=override_new_cluster)
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {
                 "new_cluster": override_new_cluster,
                 "notebook_task": NOTEBOOK_TASK,
                 "run_name": TASK_ID,
             }
         )
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
     @pytest.mark.db_test
-    def test_init_with_templating(self):
+    def test_validate_json_with_templating(self):
         json = {
             "new_cluster": NEW_CLUSTER,
             "notebook_task": TEMPLATED_NOTEBOOK_TASK,
         }
         dag = DAG("test", start_date=datetime.now())
         op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
+        op._setup_and_validate_json()
+
         op.render_template_fields(context={"ds": DATE})
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "new_cluster": NEW_CLUSTER,
                 "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK,
                 "run_name": TASK_ID,
             }
         )
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_git_source(self):
+    def test_validate_json_with_git_source(self):
         json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": RUN_NAME}
         git_source = {
             "git_url": "https://github.com/apache/airflow";,
@@ -785,7 +954,9 @@ class TestDatabricksSubmitRunOperator:
             "git_branch": "main",
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, 
git_source=git_source, json=json)
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content(
             {
                 "new_cluster": NEW_CLUSTER,
                 "notebook_task": NOTEBOOK_TASK,
@@ -793,18 +964,95 @@ class TestDatabricksSubmitRunOperator:
                 "git_source": git_source,
             }
         )
-        assert expected == utils.normalise_json_content(op.json)
+        assert expected == utils._normalise_json_content(op.json)
 
-    def test_init_with_bad_type(self):
+    def test_validate_json_with_bad_type(self):
         json = {"test": datetime.now()}
-        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
         # Looks a bit weird since we have to escape regex reserved symbols.
         exception_message = (
             r"Type \<(type|class) \'datetime.datetime\'\> used "
             r"for parameter json\[test\] is not a number or a string"
         )
         with pytest.raises(AirflowException, match=exception_message):
-            utils.normalise_json_content(op.json)
+            DatabricksSubmitRunOperator(task_id=TASK_ID, 
json=json)._setup_and_validate_json()
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
+        json = "{{ ti.xcom_pull(task_ids='push_json') }}"
+        with dag_maker("test_templated", render_template_as_native_obj=True):
+            push_json = PythonOperator(
+                task_id="push_json",
+                python_callable=lambda: {
+                    "new_cluster": NEW_CLUSTER,
+                    "notebook_task": NOTEBOOK_TASK,
+                },
+            )
+            op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+            push_json >> op
+
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = RUN_ID
+        db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+        dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        tis = {ti.task_id: ti for ti in dagrun.task_instances}
+        tis["push_json"].run()
+        tis[TASK_ID].run()
+
+        expected = utils._normalise_json_content(
+            {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSubmitRunOperator",
+        )
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run.assert_called_once_with(RUN_ID)
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
+        with dag_maker("test_xcomarg", render_template_as_native_obj=True):
+
+            @task
+            def push_json() -> dict:
+                return {
+                    "new_cluster": NEW_CLUSTER,
+                    "notebook_task": NOTEBOOK_TASK,
+                }
+
+            json = push_json()
+
+            op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = RUN_ID
+        db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+        dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        tis = {ti.task_id: ti for ti in dagrun.task_instances}
+        tis["push_json"].run()
+        tis[TASK_ID].run()
+
+        expected = utils._normalise_json_content(
+            {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
+        )
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSubmitRunOperator",
+        )
+
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run.assert_called_once_with(RUN_ID)
 
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_exec_success(self, db_mock_class):
@@ -822,7 +1070,7 @@ class TestDatabricksSubmitRunOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -852,7 +1100,7 @@ class TestDatabricksSubmitRunOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content({"pipeline_task": 
PIPELINE_ID_TASK, "run_name": TASK_ID})
+        expected = utils._normalise_json_content({"pipeline_task": 
PIPELINE_ID_TASK, "run_name": TASK_ID})
         db_mock_class.assert_called_once_with(
             DEFAULT_CONN_ID,
             retry_limit=op.databricks_retry_limit,
@@ -884,7 +1132,7 @@ class TestDatabricksSubmitRunOperator:
         with pytest.raises(AirflowException):
             op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "new_cluster": NEW_CLUSTER,
                 "notebook_task": NOTEBOOK_TASK,
@@ -932,7 +1180,7 @@ class TestDatabricksSubmitRunOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -961,7 +1209,7 @@ class TestDatabricksSubmitRunOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -995,7 +1243,7 @@ class TestDatabricksSubmitRunOperator:
         assert isinstance(exc.value.trigger, DatabricksExecutionTrigger)
         assert exc.value.method_name == "execute_complete"
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -1077,7 +1325,7 @@ class TestDatabricksSubmitRunOperator:
         db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -1107,7 +1355,7 @@ class TestDatabricksSubmitRunOperator:
         db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, 
"run_name": TASK_ID}
         )
         db_mock_class.assert_called_once_with(
@@ -1125,18 +1373,20 @@ class TestDatabricksSubmitRunOperator:
 
 
 class TestDatabricksRunNowOperator:
-    def test_init_with_named_parameters(self):
+    def test_validate_json_with_named_parameters(self):
         """
-        Test the initializer with the named parameters.
+        Test the _setup_and_validate_json function with named parameters.
         """
         op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
-        expected = utils.normalise_json_content({"job_id": 42})
+        op._setup_and_validate_json()
+
+        expected = utils._normalise_json_content({"job_id": 42})
 
         assert expected == op.json
 
-    def test_init_with_json(self):
+    def test_validate_json_with_json(self):
         """
-        Test the initializer with json data.
+        Test the _setup_and_validate_json function with json data.
         """
         json = {
             "notebook_params": NOTEBOOK_PARAMS,
@@ -1147,8 +1397,9 @@ class TestDatabricksRunNowOperator:
             "repair_run": False,
         }
         op = DatabricksRunNowOperator(task_id=TASK_ID, json=json)
+        op._setup_and_validate_json()
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "jar_params": JAR_PARAMS,
@@ -1161,9 +1412,9 @@ class TestDatabricksRunNowOperator:
 
         assert expected == op.json
 
-    def test_init_with_merging(self):
+    def test_validate_json_with_merging(self):
         """
-        Test the initializer when json and other named parameters are both
+        Test the _setup_and_validate_json function when json and other named 
parameters are both
         provided. The named parameters should override top level keys in the
         json dict.
         """
@@ -1180,8 +1431,9 @@ class TestDatabricksRunNowOperator:
             jar_params=override_jar_params,
             spark_submit_params=SPARK_SUBMIT_PARAMS,
         )
+        op._setup_and_validate_json()
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": override_notebook_params,
                 "jar_params": override_jar_params,
@@ -1194,13 +1446,14 @@ class TestDatabricksRunNowOperator:
         assert expected == op.json
 
     @pytest.mark.db_test
-    def test_init_with_templating(self):
+    def test_validate_json_with_templating(self):
         json = {"notebook_params": NOTEBOOK_PARAMS, "jar_params": 
TEMPLATED_JAR_PARAMS}
 
         dag = DAG("test", start_date=datetime.now())
         op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID, 
json=json)
         op.render_template_fields(context={"ds": DATE})
-        expected = utils.normalise_json_content(
+        op._setup_and_validate_json()
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "jar_params": RENDERED_TEMPLATED_JAR_PARAMS,
@@ -1209,7 +1462,7 @@ class TestDatabricksRunNowOperator:
         )
         assert expected == op.json
 
-    def test_init_with_bad_type(self):
+    def test_validate_json_with_bad_type(self):
         json = {"test": datetime.now()}
         # Looks a bit weird since we have to escape regex reserved symbols.
         exception_message = (
@@ -1217,7 +1470,116 @@ class TestDatabricksRunNowOperator:
             r"for parameter json\[test\] is not a number or a string"
         )
         with pytest.raises(AirflowException, match=exception_message):
-            DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)
+            DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, 
json=json)._setup_and_validate_json()
+
+    def test_validate_json_exception_with_job_name_and_job_id(self):
+        exception_message = "Argument 'job_name' is not allowed with argument 
'job_id'"
+
+        with pytest.raises(AirflowException, match=exception_message):
+            DatabricksRunNowOperator(
+                task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME
+            )._setup_and_validate_json()
+
+        run = {"job_id": JOB_ID, "job_name": JOB_NAME}
+        with pytest.raises(AirflowException, match=exception_message):
+            DatabricksRunNowOperator(task_id=TASK_ID, 
json=run)._setup_and_validate_json()
+
+        run = {"job_id": JOB_ID}
+        with pytest.raises(AirflowException, match=exception_message):
+            DatabricksRunNowOperator(task_id=TASK_ID, json=run, 
job_name=JOB_NAME)._setup_and_validate_json()
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
+        json = "{{ ti.xcom_pull(task_ids='push_json') }}"
+        with dag_maker("test_templated", render_template_as_native_obj=True):
+            push_json = PythonOperator(
+                task_id="push_json",
+                python_callable=lambda: {
+                    "notebook_params": NOTEBOOK_PARAMS,
+                    "notebook_task": NOTEBOOK_TASK,
+                    "jar_params": JAR_PARAMS,
+                    "job_id": JOB_ID,
+                },
+            )
+            op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, 
json=json)
+            push_json >> op
+
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = RUN_ID
+        db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+        dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        tis = {ti.task_id: ti for ti in dagrun.task_instances}
+        tis["push_json"].run()
+        tis[TASK_ID].run()
+
+        expected = utils._normalise_json_content(
+            {
+                "notebook_params": NOTEBOOK_PARAMS,
+                "notebook_task": NOTEBOOK_TASK,
+                "jar_params": JAR_PARAMS,
+                "job_id": JOB_ID,
+            }
+        )
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksRunNowOperator",
+        )
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run.assert_called_once_with(RUN_ID)
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
+        with dag_maker("test_xcomarg", render_template_as_native_obj=True):
+
+            @task
+            def push_json() -> dict:
+                return {
+                    "notebook_params": NOTEBOOK_PARAMS,
+                    "notebook_task": NOTEBOOK_TASK,
+                    "jar_params": JAR_PARAMS,
+                    "job_id": JOB_ID,
+                }
+
+            json = push_json()
+
+            op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, 
json=json)
+
+        db_mock = db_mock_class.return_value
+        db_mock.run_now.return_value = RUN_ID
+        db_mock.get_run_page_url.return_value = RUN_PAGE_URL
+        db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
+
+        dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
+        tis = {ti.task_id: ti for ti in dagrun.task_instances}
+        tis["push_json"].run()
+        tis[TASK_ID].run()
+
+        expected = utils._normalise_json_content(
+            {
+                "notebook_params": NOTEBOOK_PARAMS,
+                "notebook_task": NOTEBOOK_TASK,
+                "jar_params": JAR_PARAMS,
+                "job_id": JOB_ID,
+            }
+        )
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksRunNowOperator",
+        )
+        db_mock.run_now.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run.assert_called_once_with(RUN_ID)
 
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_exec_success(self, db_mock_class):
@@ -1232,7 +1594,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1267,7 +1629,7 @@ class TestDatabricksRunNowOperator:
         with pytest.raises(AirflowException):
             op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1323,7 +1685,7 @@ class TestDatabricksRunNowOperator:
         with pytest.raises(AirflowException, match="Exception: Something went 
wrong"):
             op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1391,7 +1753,7 @@ class TestDatabricksRunNowOperator:
         ):
             op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1435,7 +1797,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1466,7 +1828,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1486,20 +1848,6 @@ class TestDatabricksRunNowOperator:
         db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
         db_mock.get_run.assert_not_called()
 
-    def test_init_exception_with_job_name_and_job_id(self):
-        exception_message = "Argument 'job_name' is not allowed with argument 
'job_id'"
-
-        with pytest.raises(AirflowException, match=exception_message):
-            DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, 
job_name=JOB_NAME)
-
-        run = {"job_id": JOB_ID, "job_name": JOB_NAME}
-        with pytest.raises(AirflowException, match=exception_message):
-            DatabricksRunNowOperator(task_id=TASK_ID, json=run)
-
-        run = {"job_id": JOB_ID}
-        with pytest.raises(AirflowException, match=exception_message):
-            DatabricksRunNowOperator(task_id=TASK_ID, json=run, 
job_name=JOB_NAME)
-
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_exec_with_job_name(self, db_mock_class):
         run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": 
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
@@ -1511,7 +1859,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1559,7 +1907,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1593,7 +1941,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1630,7 +1978,7 @@ class TestDatabricksRunNowOperator:
         assert isinstance(exc.value.trigger, DatabricksExecutionTrigger)
         assert exc.value.method_name == "execute_complete"
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1740,7 +2088,7 @@ class TestDatabricksRunNowOperator:
         db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
 
         op.execute(None)
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
@@ -1774,7 +2122,7 @@ class TestDatabricksRunNowOperator:
 
         op.execute(None)
 
-        expected = utils.normalise_json_content(
+        expected = utils._normalise_json_content(
             {
                 "notebook_params": NOTEBOOK_PARAMS,
                 "notebook_task": NOTEBOOK_TASK,
diff --git a/tests/providers/databricks/utils/test_databricks.py 
b/tests/providers/databricks/utils/test_databricks.py
index 8c6ce8ce4b..4b57573253 100644
--- a/tests/providers/databricks/utils/test_databricks.py
+++ b/tests/providers/databricks/utils/test_databricks.py
@@ -21,7 +21,7 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.providers.databricks.hooks.databricks import RunState
-from airflow.providers.databricks.utils.databricks import 
normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import 
_normalise_json_content, validate_trigger_event
 
 RUN_ID = 1
 RUN_PAGE_URL = "run-page-url"
@@ -46,7 +46,7 @@ class TestDatabricksOperatorSharedFunctions:
             "test_list": ["1", "1.0", "a", "b"],
             "test_tuple": ["1", "1.0", "a", "b"],
         }
-        assert normalise_json_content(test_json) == expected
+        assert _normalise_json_content(test_json) == expected
 
     def test_validate_trigger_event_success(self):
         event = {

Reply via email to