This is an automated email from the ASF dual-hosted git repository.
pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4535e08b86 [Databricks Provider] Revert PRs #40864 and #40471 (#41050)
4535e08b86 is described below
commit 4535e08b862e2b7ff4f2a76de7124983d4efe9db
Author: Pankaj Koti <[email protected]>
AuthorDate: Sat Jul 27 18:35:02 2024 +0530
[Databricks Provider] Revert PRs #40864 and #40471 (#41050)
* Revert "Fix named parameters templating in Databricks operators (#40864)"
This reverts commit cfe1d53ed041ea903292e3789e1a5238db5b5031.
* Revert "Make Databricks operators' json parameter compatible with XComs,
Jinja expression values (#40471)"
This reverts commit 4fb2140f393b6332903fb833151c2ce8a9c66fe2.
This reverts PR #40864 and PR #40471.
Previously, PR https://github.com/apache/airflow/pull/40471 was contributed
to address issue https://github.com/apache/airflow/issues/35433.
However, that contribution gave rise to another issue
https://github.com/apache/airflow/issues/40788.
Next https://github.com/apache/airflow/issues/40788 was being attempted to
be resolved in PR #40864.
However, with the second PR, it appears that the previous old
issue #35433 has
[resurfaced](https://github.com/apache/airflow/pull/40864#issuecomment-2239061933).
So, at the moment, the case is
that we have 2 PRs on top of the existing implementation
eventually having nil effect and the previous issues persists.
I believe it is better to revert those 2 PRs, reopen the earlier
issue #35433 and peacefully address it by taking the needed time.
---
.../providers/databricks/operators/databricks.py | 273 +++-----
airflow/providers/databricks/utils/databricks.py | 4 +-
.../databricks/operators/test_databricks.py | 756 ++++-----------------
.../providers/databricks/utils/test_databricks.py | 4 +-
4 files changed, 224 insertions(+), 813 deletions(-)
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index b1299b9d85..89aa6e9df9 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -40,7 +40,7 @@ from airflow.providers.databricks.plugins.databricks_workflow
import (
WorkflowJobRunLink,
)
from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
-from airflow.providers.databricks.utils.databricks import
_normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import
normalise_json_content, validate_trigger_event
if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -186,17 +186,6 @@ def
_handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
raise AirflowException(error_message)
-def _handle_overridden_json_params(operator):
- for key, value in operator.overridden_json_params.items():
- if value is not None:
- operator.json[key] = value
-
-
-def normalise_json_content(operator):
- if operator.json:
- operator.json = _normalise_json_content(operator.json)
-
-
class DatabricksJobRunLink(BaseOperatorLink):
"""Constructs a link to monitor a Databricks Job Run."""
@@ -263,23 +252,7 @@ class DatabricksCreateJobsOperator(BaseOperator):
"""
# Used in airflow.models.BaseOperator
- template_fields: Sequence[str] = (
- "json",
- "databricks_conn_id",
- "name",
- "description",
- "tags",
- "tasks",
- "job_clusters",
- "email_notifications",
- "webhook_notifications",
- "notification_settings",
- "timeout_seconds",
- "schedule",
- "max_concurrent_runs",
- "git_source",
- "access_control_list",
- )
+ template_fields: Sequence[str] = ("json", "databricks_conn_id")
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
ui_fgcolor = "#fff"
@@ -316,19 +289,34 @@ class DatabricksCreateJobsOperator(BaseOperator):
self.databricks_retry_limit = databricks_retry_limit
self.databricks_retry_delay = databricks_retry_delay
self.databricks_retry_args = databricks_retry_args
- self.name = name
- self.description = description
- self.tags = tags
- self.tasks = tasks
- self.job_clusters = job_clusters
- self.email_notifications = email_notifications
- self.webhook_notifications = webhook_notifications
- self.notification_settings = notification_settings
- self.timeout_seconds = timeout_seconds
- self.schedule = schedule
- self.max_concurrent_runs = max_concurrent_runs
- self.git_source = git_source
- self.access_control_list = access_control_list
+ if name is not None:
+ self.json["name"] = name
+ if description is not None:
+ self.json["description"] = description
+ if tags is not None:
+ self.json["tags"] = tags
+ if tasks is not None:
+ self.json["tasks"] = tasks
+ if job_clusters is not None:
+ self.json["job_clusters"] = job_clusters
+ if email_notifications is not None:
+ self.json["email_notifications"] = email_notifications
+ if webhook_notifications is not None:
+ self.json["webhook_notifications"] = webhook_notifications
+ if notification_settings is not None:
+ self.json["notification_settings"] = notification_settings
+ if timeout_seconds is not None:
+ self.json["timeout_seconds"] = timeout_seconds
+ if schedule is not None:
+ self.json["schedule"] = schedule
+ if max_concurrent_runs is not None:
+ self.json["max_concurrent_runs"] = max_concurrent_runs
+ if git_source is not None:
+ self.json["git_source"] = git_source
+ if access_control_list is not None:
+ self.json["access_control_list"] = access_control_list
+ if self.json:
+ self.json = normalise_json_content(self.json)
@cached_property
def _hook(self):
@@ -340,40 +328,16 @@ class DatabricksCreateJobsOperator(BaseOperator):
caller="DatabricksCreateJobsOperator",
)
- def _setup_and_validate_json(self):
- self.overridden_json_params = {
- "name": self.name,
- "description": self.description,
- "tags": self.tags,
- "tasks": self.tasks,
- "job_clusters": self.job_clusters,
- "email_notifications": self.email_notifications,
- "webhook_notifications": self.webhook_notifications,
- "notification_settings": self.notification_settings,
- "timeout_seconds": self.timeout_seconds,
- "schedule": self.schedule,
- "max_concurrent_runs": self.max_concurrent_runs,
- "git_source": self.git_source,
- "access_control_list": self.access_control_list,
- }
-
- _handle_overridden_json_params(self)
-
+ def execute(self, context: Context) -> int:
if "name" not in self.json:
raise AirflowException("Missing required parameter: name")
-
- normalise_json_content(self)
-
- def execute(self, context: Context) -> int:
- self._setup_and_validate_json()
-
job_id = self._hook.find_job_id_by_name(self.json["name"])
if job_id is None:
return self._hook.create_job(self.json)
self._hook.reset_job(str(job_id), self.json)
if (access_control_list := self.json.get("access_control_list")) is
not None:
acl_json = {"access_control_list": access_control_list}
- self._hook.update_job_permission(job_id,
_normalise_json_content(acl_json))
+ self._hook.update_job_permission(job_id,
normalise_json_content(acl_json))
return job_id
@@ -500,25 +464,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
"""
# Used in airflow.models.BaseOperator
- template_fields: Sequence[str] = (
- "json",
- "databricks_conn_id",
- "tasks",
- "spark_jar_task",
- "notebook_task",
- "spark_python_task",
- "spark_submit_task",
- "pipeline_task",
- "dbt_task",
- "new_cluster",
- "existing_cluster_id",
- "libraries",
- "run_name",
- "timeout_seconds",
- "idempotency_token",
- "access_control_list",
- "git_source",
- )
+ template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
@@ -564,21 +510,43 @@ class DatabricksSubmitRunOperator(BaseOperator):
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
- self.tasks = tasks
- self.spark_jar_task = spark_jar_task
- self.notebook_task = notebook_task
- self.spark_python_task = spark_python_task
- self.spark_submit_task = spark_submit_task
- self.pipeline_task = pipeline_task
- self.dbt_task = dbt_task
- self.new_cluster = new_cluster
- self.existing_cluster_id = existing_cluster_id
- self.libraries = libraries
- self.run_name = run_name
- self.timeout_seconds = timeout_seconds
- self.idempotency_token = idempotency_token
- self.access_control_list = access_control_list
- self.git_source = git_source
+ if tasks is not None:
+ self.json["tasks"] = tasks
+ if spark_jar_task is not None:
+ self.json["spark_jar_task"] = spark_jar_task
+ if notebook_task is not None:
+ self.json["notebook_task"] = notebook_task
+ if spark_python_task is not None:
+ self.json["spark_python_task"] = spark_python_task
+ if spark_submit_task is not None:
+ self.json["spark_submit_task"] = spark_submit_task
+ if pipeline_task is not None:
+ self.json["pipeline_task"] = pipeline_task
+ if dbt_task is not None:
+ self.json["dbt_task"] = dbt_task
+ if new_cluster is not None:
+ self.json["new_cluster"] = new_cluster
+ if existing_cluster_id is not None:
+ self.json["existing_cluster_id"] = existing_cluster_id
+ if libraries is not None:
+ self.json["libraries"] = libraries
+ if run_name is not None:
+ self.json["run_name"] = run_name
+ if timeout_seconds is not None:
+ self.json["timeout_seconds"] = timeout_seconds
+ if "run_name" not in self.json:
+ self.json["run_name"] = run_name or kwargs["task_id"]
+ if idempotency_token is not None:
+ self.json["idempotency_token"] = idempotency_token
+ if access_control_list is not None:
+ self.json["access_control_list"] = access_control_list
+ if git_source is not None:
+ self.json["git_source"] = git_source
+
+ if "dbt_task" in self.json and "git_source" not in self.json:
+ raise AirflowException("git_source is required for dbt_task")
+ if pipeline_task is not None and "pipeline_id" in pipeline_task and
"pipeline_name" in pipeline_task:
+ raise AirflowException("'pipeline_name' is not allowed in
conjunction with 'pipeline_id'")
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
@@ -597,43 +565,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
caller=caller,
)
- def _setup_and_validate_json(self):
- self.overridden_json_params = {
- "tasks": self.tasks,
- "spark_jar_task": self.spark_jar_task,
- "notebook_task": self.notebook_task,
- "spark_python_task": self.spark_python_task,
- "spark_submit_task": self.spark_submit_task,
- "pipeline_task": self.pipeline_task,
- "dbt_task": self.dbt_task,
- "new_cluster": self.new_cluster,
- "existing_cluster_id": self.existing_cluster_id,
- "libraries": self.libraries,
- "run_name": self.run_name,
- "timeout_seconds": self.timeout_seconds,
- "idempotency_token": self.idempotency_token,
- "access_control_list": self.access_control_list,
- "git_source": self.git_source,
- }
-
- _handle_overridden_json_params(self)
-
- if "run_name" not in self.json or self.json["run_name"] is None:
- self.json["run_name"] = self.task_id
-
- if "dbt_task" in self.json and "git_source" not in self.json:
- raise AirflowException("git_source is required for dbt_task")
- if (
- "pipeline_task" in self.json
- and "pipeline_id" in self.json["pipeline_task"]
- and "pipeline_name" in self.json["pipeline_task"]
- ):
- raise AirflowException("'pipeline_name' is not allowed in
conjunction with 'pipeline_id'")
-
- normalise_json_content(self)
-
def execute(self, context: Context):
- self._setup_and_validate_json()
if (
"pipeline_task" in self.json
and self.json["pipeline_task"].get("pipeline_id") is None
@@ -643,7 +575,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
pipeline_name = self.json["pipeline_task"]["pipeline_name"]
self.json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
del self.json["pipeline_task"]["pipeline_name"]
- json_normalised = _normalise_json_content(self.json)
+ json_normalised = normalise_json_content(self.json)
self.run_id = self._hook.submit_run(json_normalised)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, self._hook,
self.log, context)
@@ -679,7 +611,7 @@ class
DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
def execute(self, context):
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
- json_normalised = _normalise_json_content(self.json)
+ json_normalised = normalise_json_content(self.json)
self.run_id = hook.submit_run(json_normalised)
_handle_deferrable_databricks_operator_execution(self, hook, self.log,
context)
@@ -836,18 +768,7 @@ class DatabricksRunNowOperator(BaseOperator):
"""
# Used in airflow.models.BaseOperator
- template_fields: Sequence[str] = (
- "json",
- "databricks_conn_id",
- "job_id",
- "job_name",
- "notebook_params",
- "python_params",
- "python_named_params",
- "jar_params",
- "spark_submit_params",
- "idempotency_token",
- )
+ template_fields: Sequence[str] = ("json", "databricks_conn_id")
template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
ui_color = "#1CB1C2"
@@ -890,14 +811,27 @@ class DatabricksRunNowOperator(BaseOperator):
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs
- self.job_id = job_id
- self.job_name = job_name
- self.notebook_params = notebook_params
- self.python_params = python_params
- self.python_named_params = python_named_params
- self.jar_params = jar_params
- self.spark_submit_params = spark_submit_params
- self.idempotency_token = idempotency_token
+
+ if job_id is not None:
+ self.json["job_id"] = job_id
+ if job_name is not None:
+ self.json["job_name"] = job_name
+ if "job_id" in self.json and "job_name" in self.json:
+ raise AirflowException("Argument 'job_name' is not allowed with
argument 'job_id'")
+ if notebook_params is not None:
+ self.json["notebook_params"] = notebook_params
+ if python_params is not None:
+ self.json["python_params"] = python_params
+ if python_named_params is not None:
+ self.json["python_named_params"] = python_named_params
+ if jar_params is not None:
+ self.json["jar_params"] = jar_params
+ if spark_submit_params is not None:
+ self.json["spark_submit_params"] = spark_submit_params
+ if idempotency_token is not None:
+ self.json["idempotency_token"] = idempotency_token
+ if self.json:
+ self.json = normalise_json_content(self.json)
# This variable will be used in case our task gets killed.
self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
@@ -915,26 +849,7 @@ class DatabricksRunNowOperator(BaseOperator):
caller=caller,
)
- def _setup_and_validate_json(self):
- self.overridden_json_params = {
- "job_id": self.job_id,
- "job_name": self.job_name,
- "notebook_params": self.notebook_params,
- "python_params": self.python_params,
- "python_named_params": self.python_named_params,
- "jar_params": self.jar_params,
- "spark_submit_params": self.spark_submit_params,
- "idempotency_token": self.idempotency_token,
- }
- _handle_overridden_json_params(self)
-
- if "job_id" in self.json and "job_name" in self.json:
- raise AirflowException("Argument 'job_name' is not allowed with
argument 'job_id'")
-
- normalise_json_content(self)
-
def execute(self, context: Context):
- self._setup_and_validate_json()
hook = self._hook
if "job_name" in self.json:
job_id = hook.find_job_id_by_name(self.json["job_name"])
diff --git a/airflow/providers/databricks/utils/databricks.py
b/airflow/providers/databricks/utils/databricks.py
index ec99bce178..88d622c3bc 100644
--- a/airflow/providers/databricks/utils/databricks.py
+++ b/airflow/providers/databricks/utils/databricks.py
@@ -21,7 +21,7 @@ from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import RunState
-def _normalise_json_content(content, json_path: str = "json") -> str | bool |
list | dict:
+def normalise_json_content(content, json_path: str = "json") -> str | bool |
list | dict:
"""
Normalize content or all values of content if it is a dict to a string.
@@ -33,7 +33,7 @@ def _normalise_json_content(content, json_path: str = "json")
-> str | bool | li
The only one exception is when we have boolean values, they can not be
converted
to string type because databricks does not understand 'True' or 'False'
values.
"""
- normalise = _normalise_json_content
+ normalise = normalise_json_content
if isinstance(content, (str, bool)):
return content
elif isinstance(content, (int, float)):
diff --git a/tests/providers/databricks/operators/test_databricks.py
b/tests/providers/databricks/operators/test_databricks.py
index a733766904..7ff2295eda 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -23,10 +23,8 @@ from unittest.mock import MagicMock
import pytest
-from airflow.decorators import task
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import DAG
-from airflow.operators.python import PythonOperator
from airflow.providers.databricks.hooks.databricks import RunState
from airflow.providers.databricks.operators.databricks import (
DatabricksCreateJobsOperator,
@@ -38,7 +36,6 @@ from airflow.providers.databricks.operators.databricks import
(
)
from airflow.providers.databricks.triggers.databricks import
DatabricksExecutionTrigger
from airflow.providers.databricks.utils import databricks as utils
-from airflow.utils import timezone
pytestmark = pytest.mark.db_test
@@ -66,11 +63,7 @@ RUN_NAME = "run-name"
RUN_ID = 1
RUN_PAGE_URL = "run-page-url"
JOB_ID = "42"
-TEMPLATED_JOB_ID = "job-id-{{ ds }}"
-RENDERED_TEMPLATED_JOB_ID = f"job-id-{DATE}"
JOB_NAME = "job-name"
-TEMPLATED_JOB_NAME = "job-name-{{ ds }}"
-RENDERED_TEMPLATED_JOB_NAME = f"job-name-{DATE}"
JOB_DESCRIPTION = "job-description"
NOTEBOOK_PARAMS = {"dry-run": "true", "oldest-time-to-consider":
"1457570074236"}
JAR_PARAMS = ["param1", "param2"]
@@ -255,9 +248,9 @@ def make_run_with_state_mock(
class TestDatabricksCreateJobsOperator:
- def test_validate_json_with_named_parameters(self):
+ def test_init_with_named_parameters(self):
"""
- Test the _setup_and_validate_json function with the named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksCreateJobsOperator(
task_id=TASK_ID,
@@ -273,9 +266,7 @@ class TestDatabricksCreateJobsOperator:
git_source=GIT_SOURCE,
access_control_list=ACCESS_CONTROL_LIST,
)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"name": JOB_NAME,
"tags": TAGS,
@@ -293,9 +284,9 @@ class TestDatabricksCreateJobsOperator:
assert expected == op.json
- def test_validate_json_with_json(self):
+ def test_init_with_json(self):
"""
- Test the _setup_and_validate_json function with json data.
+ Test the initializer with json data.
"""
json = {
"name": JOB_NAME,
@@ -311,9 +302,8 @@ class TestDatabricksCreateJobsOperator:
"access_control_list": ACCESS_CONTROL_LIST,
}
op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
- op._setup_and_validate_json()
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"name": JOB_NAME,
"tags": TAGS,
@@ -331,9 +321,9 @@ class TestDatabricksCreateJobsOperator:
assert expected == op.json
- def test_validate_json_with_merging(self):
+ def test_init_with_merging(self):
"""
- Test the _setup_and_validate_json function when json and other named
parameters are both
+ Test the initializer when json and other named parameters are both
provided. The named parameters should override top level keys in the
json dict.
"""
@@ -377,9 +367,8 @@ class TestDatabricksCreateJobsOperator:
git_source=override_git_source,
access_control_list=override_access_control_list,
)
- op._setup_and_validate_json()
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"name": override_name,
"tags": override_tags,
@@ -397,220 +386,24 @@ class TestDatabricksCreateJobsOperator:
assert expected == op.json
- def test_validate_json_with_templating(self):
+ def test_init_with_templating(self):
json = {"name": "test-{{ ds }}"}
dag = DAG("test", start_date=datetime.now())
op = DatabricksCreateJobsOperator(dag=dag, task_id=TASK_ID, json=json)
op.render_template_fields(context={"ds": DATE})
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content({"name": f"test-{DATE}"})
+ expected = utils.normalise_json_content({"name": f"test-{DATE}"})
assert expected == op.json
- def test_validate_json_with_bad_type(self):
- json = {"test": datetime.now(), "name": "test"}
+ def test_init_with_bad_type(self):
+ json = {"test": datetime.now()}
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
r"Type \<(type|class) \'datetime.datetime\'\> used "
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
- DatabricksCreateJobsOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
-
- def test_validate_json_with_no_name(self):
- json = {}
- exception_message = "Missing required parameter: name"
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksCreateJobsOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
- json = "{{ ti.xcom_pull(task_ids='push_json') }}"
- with dag_maker("test_templated", render_template_as_native_obj=True):
- push_json = PythonOperator(
- task_id="push_json",
- python_callable=lambda: {
- "name": JOB_NAME,
- "description": JOB_DESCRIPTION,
- "tags": TAGS,
- "tasks": TASKS,
- "job_clusters": JOB_CLUSTERS,
- "email_notifications": EMAIL_NOTIFICATIONS,
- "webhook_notifications": WEBHOOK_NOTIFICATIONS,
- "notification_settings": NOTIFICATION_SETTINGS,
- "timeout_seconds": TIMEOUT_SECONDS,
- "schedule": SCHEDULE,
- "max_concurrent_runs": MAX_CONCURRENT_RUNS,
- "git_source": GIT_SOURCE,
- "access_control_list": ACCESS_CONTROL_LIST,
- },
- )
- op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
- push_json >> op
-
- db_mock = db_mock_class.return_value
- db_mock.create_job.return_value = JOB_ID
-
- db_mock.find_job_id_by_name.return_value = None
-
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "name": JOB_NAME,
- "description": JOB_DESCRIPTION,
- "tags": TAGS,
- "tasks": TASKS,
- "job_clusters": JOB_CLUSTERS,
- "email_notifications": EMAIL_NOTIFICATIONS,
- "webhook_notifications": WEBHOOK_NOTIFICATIONS,
- "notification_settings": NOTIFICATION_SETTINGS,
- "timeout_seconds": TIMEOUT_SECONDS,
- "schedule": SCHEDULE,
- "max_concurrent_runs": MAX_CONCURRENT_RUNS,
- "git_source": GIT_SOURCE,
- "access_control_list": ACCESS_CONTROL_LIST,
- }
- )
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksCreateJobsOperator",
- )
-
- db_mock.create_job.assert_called_once_with(expected)
- assert JOB_ID == tis[TASK_ID].xcom_pull()
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_templated_named_param(self, db_mock_class,
dag_maker):
- json = "{{ ti.xcom_pull(task_ids='push_json') }}"
- with dag_maker("test_templated", render_template_as_native_obj=True):
- push_json = PythonOperator(
- task_id="push_json",
- python_callable=lambda: {
- "description": JOB_DESCRIPTION,
- "tags": TAGS,
- "tasks": TASKS,
- "job_clusters": JOB_CLUSTERS,
- "email_notifications": EMAIL_NOTIFICATIONS,
- "webhook_notifications": WEBHOOK_NOTIFICATIONS,
- "notification_settings": NOTIFICATION_SETTINGS,
- "timeout_seconds": TIMEOUT_SECONDS,
- "schedule": SCHEDULE,
- "max_concurrent_runs": MAX_CONCURRENT_RUNS,
- "git_source": GIT_SOURCE,
- "access_control_list": ACCESS_CONTROL_LIST,
- },
- )
- op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json,
name=TEMPLATED_JOB_NAME)
- push_json >> op
-
- db_mock = db_mock_class.return_value
- db_mock.create_job.return_value = JOB_ID
-
- db_mock.find_job_id_by_name.return_value = None
-
- dagrun =
dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d"))
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "name": RENDERED_TEMPLATED_JOB_NAME,
- "description": JOB_DESCRIPTION,
- "tags": TAGS,
- "tasks": TASKS,
- "job_clusters": JOB_CLUSTERS,
- "email_notifications": EMAIL_NOTIFICATIONS,
- "webhook_notifications": WEBHOOK_NOTIFICATIONS,
- "notification_settings": NOTIFICATION_SETTINGS,
- "timeout_seconds": TIMEOUT_SECONDS,
- "schedule": SCHEDULE,
- "max_concurrent_runs": MAX_CONCURRENT_RUNS,
- "git_source": GIT_SOURCE,
- "access_control_list": ACCESS_CONTROL_LIST,
- }
- )
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksCreateJobsOperator",
- )
-
- db_mock.create_job.assert_called_once_with(expected)
- assert JOB_ID == tis[TASK_ID].xcom_pull()
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
- with dag_maker("test_xcomarg", render_template_as_native_obj=True):
-
- @task
- def push_json() -> dict:
- return {
- "name": JOB_NAME,
- "description": JOB_DESCRIPTION,
- "tags": TAGS,
- "tasks": TASKS,
- "job_clusters": JOB_CLUSTERS,
- "email_notifications": EMAIL_NOTIFICATIONS,
- "webhook_notifications": WEBHOOK_NOTIFICATIONS,
- "notification_settings": NOTIFICATION_SETTINGS,
- "timeout_seconds": TIMEOUT_SECONDS,
- "schedule": SCHEDULE,
- "max_concurrent_runs": MAX_CONCURRENT_RUNS,
- "git_source": GIT_SOURCE,
- "access_control_list": ACCESS_CONTROL_LIST,
- }
-
- json = push_json()
- op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
-
- db_mock = db_mock_class.return_value
- db_mock.create_job.return_value = JOB_ID
-
- db_mock.find_job_id_by_name.return_value = None
-
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "name": JOB_NAME,
- "description": JOB_DESCRIPTION,
- "tags": TAGS,
- "tasks": TASKS,
- "job_clusters": JOB_CLUSTERS,
- "email_notifications": EMAIL_NOTIFICATIONS,
- "webhook_notifications": WEBHOOK_NOTIFICATIONS,
- "notification_settings": NOTIFICATION_SETTINGS,
- "timeout_seconds": TIMEOUT_SECONDS,
- "schedule": SCHEDULE,
- "max_concurrent_runs": MAX_CONCURRENT_RUNS,
- "git_source": GIT_SOURCE,
- "access_control_list": ACCESS_CONTROL_LIST,
- }
- )
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksCreateJobsOperator",
- )
-
- db_mock.create_job.assert_called_once_with(expected)
- assert JOB_ID == tis[TASK_ID].xcom_pull()
+ DatabricksCreateJobsOperator(task_id=TASK_ID, json=json)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_create(self, db_mock_class):
@@ -640,7 +433,7 @@ class TestDatabricksCreateJobsOperator:
return_result = op.execute({})
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"name": JOB_NAME,
"description": JOB_DESCRIPTION,
@@ -694,7 +487,7 @@ class TestDatabricksCreateJobsOperator:
return_result = op.execute({})
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"name": JOB_NAME,
"description": JOB_DESCRIPTION,
@@ -746,7 +539,7 @@ class TestDatabricksCreateJobsOperator:
op.execute({})
- expected = utils._normalise_json_content({"access_control_list":
ACCESS_CONTROL_LIST})
+ expected = utils.normalise_json_content({"access_control_list":
ACCESS_CONTROL_LIST})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
@@ -793,76 +586,66 @@ class TestDatabricksCreateJobsOperator:
class TestDatabricksSubmitRunOperator:
- def test_validate_json_with_notebook_task_named_parameters(self):
+ def test_init_with_notebook_task_named_parameters(self):
"""
- Test the _setup_and_validate_json function with named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
notebook_task=NOTEBOOK_TASK
)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_spark_python_task_named_parameters(self):
+ def test_init_with_spark_python_task_named_parameters(self):
"""
- Test the _setup_and_validate_json function with the named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
spark_python_task=SPARK_PYTHON_TASK
)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "spark_python_task":
SPARK_PYTHON_TASK, "run_name": TASK_ID}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_pipeline_name_task_named_parameters(self):
+ def test_init_with_pipeline_name_task_named_parameters(self):
"""
- Test the _setup_and_validate_json function with the named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksSubmitRunOperator(task_id=TASK_ID,
pipeline_task=PIPELINE_NAME_TASK)
- op._setup_and_validate_json()
+ expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_NAME_TASK, "run_name": TASK_ID})
- expected = utils._normalise_json_content({"pipeline_task":
PIPELINE_NAME_TASK, "run_name": TASK_ID})
+ assert expected == utils.normalise_json_content(op.json)
- assert expected == utils._normalise_json_content(op.json)
-
- def test_validate_json_with_pipeline_id_task_named_parameters(self):
+ def test_init_with_pipeline_id_task_named_parameters(self):
"""
- Test the _setup_and_validate_json function with the named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksSubmitRunOperator(task_id=TASK_ID,
pipeline_task=PIPELINE_ID_TASK)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
+ expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_spark_submit_task_named_parameters(self):
+ def test_init_with_spark_submit_task_named_parameters(self):
"""
- Test the _setup_and_validate_json function with the named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER,
spark_submit_task=SPARK_SUBMIT_TASK
)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "spark_submit_task":
SPARK_SUBMIT_TASK, "run_name": TASK_ID}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_dbt_task_named_parameters(self):
+ def test_init_with_dbt_task_named_parameters(self):
"""
- Test the _setup_and_validate_json function with the named parameters.
+ Test the initializer with the named parameters.
"""
git_source = {
"git_url": "https://github.com/dbt-labs/jaffle_shop",
@@ -872,17 +655,15 @@ class TestDatabricksSubmitRunOperator:
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK,
git_source=git_source
)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source":
git_source, "run_name": TASK_ID}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_dbt_task_mixed_parameters(self):
+ def test_init_with_dbt_task_mixed_parameters(self):
"""
- Test the _setup_and_validate_json function with mixed parameters.
+ Test the initializer with mixed parameters.
"""
git_source = {
"git_url": "https://github.com/dbt-labs/jaffle_shop",
@@ -893,85 +674,73 @@ class TestDatabricksSubmitRunOperator:
op = DatabricksSubmitRunOperator(
task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK,
json=json
)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "dbt_task": DBT_TASK, "git_source":
git_source, "run_name": TASK_ID}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_dbt_task_without_git_source_raises_error(self):
+ def test_init_with_dbt_task_without_git_source_raises_error(self):
"""
- Test the _setup_and_validate_json function without the necessary
git_source for dbt_task raises error.
+ Test the initializer without the necessary git_source for dbt_task
raises error.
"""
exception_message = "git_source is required for dbt_task"
with pytest.raises(AirflowException, match=exception_message):
- DatabricksSubmitRunOperator(
- task_id=TASK_ID, new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK
- )._setup_and_validate_json()
+ DatabricksSubmitRunOperator(task_id=TASK_ID,
new_cluster=NEW_CLUSTER, dbt_task=DBT_TASK)
- def
test_validate_json_with_dbt_task_json_without_git_source_raises_error(self):
+ def test_init_with_dbt_task_json_without_git_source_raises_error(self):
"""
- Test the _setup_and_validate_json function without the necessary
git_source for dbt_task raises error.
+ Test the initializer without the necessary git_source for dbt_task
raises error.
"""
json = {"dbt_task": DBT_TASK, "new_cluster": NEW_CLUSTER}
exception_message = "git_source is required for dbt_task"
with pytest.raises(AirflowException, match=exception_message):
- DatabricksSubmitRunOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
+ DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- def test_validate_json_with_json(self):
+ def test_init_with_json(self):
"""
- Test the _setup_and_validate_json function with json data.
+ Test the initializer with json data.
"""
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_tasks(self):
+ def test_init_with_tasks(self):
tasks = [{"task_key": 1, "new_cluster": NEW_CLUSTER, "notebook_task":
NOTEBOOK_TASK}]
op = DatabricksSubmitRunOperator(task_id=TASK_ID, tasks=tasks)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content({"run_name": TASK_ID,
"tasks": tasks})
- assert expected == utils._normalise_json_content(op.json)
+ expected = utils.normalise_json_content({"run_name": TASK_ID, "tasks":
tasks})
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_specified_run_name(self):
+ def test_init_with_specified_run_name(self):
"""
- Test the _setup_and_validate_json function with a specified run_name.
+ Test the initializer with a specified run_name.
"""
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": RUN_NAME}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": RUN_NAME}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_pipeline_task(self):
+ def test_pipeline_task(self):
"""
- Test the _setup_and_validate_json function with a pipeline task.
+ Test the initializer with a pipeline task.
"""
pipeline_task = {"pipeline_id": "test-dlt"}
json = {"new_cluster": NEW_CLUSTER, "run_name": RUN_NAME,
"pipeline_task": pipeline_task}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "pipeline_task": pipeline_task,
"run_name": RUN_NAME}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_merging(self):
+ def test_init_with_merging(self):
"""
- Test the _setup_and_validate_json function when json and other named
parameters are both
+ Test the initializer when json and other named parameters are both
provided. The named parameters should override top level keys in the
json dict.
"""
@@ -981,38 +750,34 @@ class TestDatabricksSubmitRunOperator:
"notebook_task": NOTEBOOK_TASK,
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json,
new_cluster=override_new_cluster)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"new_cluster": override_new_cluster,
"notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID,
}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
@pytest.mark.db_test
- def test_validate_json_with_templating(self):
+ def test_init_with_templating(self):
json = {
"new_cluster": NEW_CLUSTER,
"notebook_task": TEMPLATED_NOTEBOOK_TASK,
}
dag = DAG("test", start_date=datetime.now())
op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
- op._setup_and_validate_json()
-
op.render_template_fields(context={"ds": DATE})
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"new_cluster": NEW_CLUSTER,
"notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK,
"run_name": TASK_ID,
}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_git_source(self):
+ def test_init_with_git_source(self):
json = {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": RUN_NAME}
git_source = {
"git_url": "https://github.com/apache/airflow",
@@ -1020,9 +785,7 @@ class TestDatabricksSubmitRunOperator:
"git_branch": "main",
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID,
git_source=git_source, json=json)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"new_cluster": NEW_CLUSTER,
"notebook_task": NOTEBOOK_TASK,
@@ -1030,139 +793,18 @@ class TestDatabricksSubmitRunOperator:
"git_source": git_source,
}
)
- assert expected == utils._normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op.json)
- def test_validate_json_with_bad_type(self):
+ def test_init_with_bad_type(self):
json = {"test": datetime.now()}
+ op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
r"Type \<(type|class) \'datetime.datetime\'\> used "
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
- DatabricksSubmitRunOperator(task_id=TASK_ID,
json=json)._setup_and_validate_json()
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
- json = "{{ ti.xcom_pull(task_ids='push_json') }}"
- with dag_maker("test_templated", render_template_as_native_obj=True):
- push_json = PythonOperator(
- task_id="push_json",
- python_callable=lambda: {
- "new_cluster": NEW_CLUSTER,
- "notebook_task": NOTEBOOK_TASK,
- },
- )
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- push_json >> op
-
- db_mock = db_mock_class.return_value
- db_mock.submit_run.return_value = RUN_ID
- db_mock.get_run_page_url.return_value = RUN_PAGE_URL
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
-
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
- )
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksSubmitRunOperator",
- )
-
- db_mock.submit_run.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run.assert_called_once_with(RUN_ID)
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_templated_named_params(self, db_mock_class,
dag_maker):
- json = "{{ ti.xcom_pull(task_ids='push_json') }}"
- with dag_maker("test_templated", render_template_as_native_obj=True):
- push_json = PythonOperator(
- task_id="push_json",
- python_callable=lambda: {
- "new_cluster": NEW_CLUSTER,
- },
- )
- op = DatabricksSubmitRunOperator(
- task_id=TASK_ID, json=json,
notebook_task=TEMPLATED_NOTEBOOK_TASK
- )
- push_json >> op
-
- db_mock = db_mock_class.return_value
- db_mock.submit_run.return_value = RUN_ID
- db_mock.get_run_page_url.return_value = RUN_PAGE_URL
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
-
- dagrun =
dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d"))
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "new_cluster": NEW_CLUSTER,
- "notebook_task": RENDERED_TEMPLATED_NOTEBOOK_TASK,
- "run_name": TASK_ID,
- }
- )
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksSubmitRunOperator",
- )
-
- db_mock.submit_run.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run.assert_called_once_with(RUN_ID)
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
- with dag_maker("test_xcomarg", render_template_as_native_obj=True):
-
- @task
- def push_json() -> dict:
- return {
- "new_cluster": NEW_CLUSTER,
- "notebook_task": NOTEBOOK_TASK,
- }
-
- json = push_json()
-
- op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
- db_mock = db_mock_class.return_value
- db_mock.submit_run.return_value = RUN_ID
- db_mock.get_run_page_url.return_value = RUN_PAGE_URL
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
-
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
- )
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksSubmitRunOperator",
- )
-
- db_mock.submit_run.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run.assert_called_once_with(RUN_ID)
+ utils.normalise_json_content(op.json)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_success(self, db_mock_class):
@@ -1180,7 +822,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1210,7 +852,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils._normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
+ expected = utils.normalise_json_content({"pipeline_task":
PIPELINE_ID_TASK, "run_name": TASK_ID})
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
@@ -1242,7 +884,7 @@ class TestDatabricksSubmitRunOperator:
with pytest.raises(AirflowException):
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"new_cluster": NEW_CLUSTER,
"notebook_task": NOTEBOOK_TASK,
@@ -1290,7 +932,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1319,7 +961,7 @@ class TestDatabricksSubmitRunOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1353,7 +995,7 @@ class TestDatabricksSubmitRunOperator:
assert isinstance(exc.value.trigger, DatabricksExecutionTrigger)
assert exc.value.method_name == "execute_complete"
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1435,7 +1077,7 @@ class TestDatabricksSubmitRunOperator:
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1465,7 +1107,7 @@ class TestDatabricksSubmitRunOperator:
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK,
"run_name": TASK_ID}
)
db_mock_class.assert_called_once_with(
@@ -1483,20 +1125,18 @@ class TestDatabricksSubmitRunOperator:
class TestDatabricksRunNowOperator:
- def test_validate_json_with_named_parameters(self):
+ def test_init_with_named_parameters(self):
"""
- Test the _setup_and_validate_json function with named parameters.
+ Test the initializer with the named parameters.
"""
op = DatabricksRunNowOperator(job_id=JOB_ID, task_id=TASK_ID)
- op._setup_and_validate_json()
-
- expected = utils._normalise_json_content({"job_id": 42})
+ expected = utils.normalise_json_content({"job_id": 42})
assert expected == op.json
- def test_validate_json_with_json(self):
+ def test_init_with_json(self):
"""
- Test the _setup_and_validate_json function with json data.
+ Test the initializer with json data.
"""
json = {
"notebook_params": NOTEBOOK_PARAMS,
@@ -1507,9 +1147,8 @@ class TestDatabricksRunNowOperator:
"repair_run": False,
}
op = DatabricksRunNowOperator(task_id=TASK_ID, json=json)
- op._setup_and_validate_json()
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"jar_params": JAR_PARAMS,
@@ -1522,9 +1161,9 @@ class TestDatabricksRunNowOperator:
assert expected == op.json
- def test_validate_json_with_merging(self):
+ def test_init_with_merging(self):
"""
- Test the _setup_and_validate_json function when json and other named
parameters are both
+ Test the initializer when json and other named parameters are both
provided. The named parameters should override top level keys in the
json dict.
"""
@@ -1541,9 +1180,8 @@ class TestDatabricksRunNowOperator:
jar_params=override_jar_params,
spark_submit_params=SPARK_SUBMIT_PARAMS,
)
- op._setup_and_validate_json()
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": override_notebook_params,
"jar_params": override_jar_params,
@@ -1556,14 +1194,13 @@ class TestDatabricksRunNowOperator:
assert expected == op.json
@pytest.mark.db_test
- def test_validate_json_with_templating(self):
+ def test_init_with_templating(self):
json = {"notebook_params": NOTEBOOK_PARAMS, "jar_params":
TEMPLATED_JAR_PARAMS}
dag = DAG("test", start_date=datetime.now())
op = DatabricksRunNowOperator(dag=dag, task_id=TASK_ID, job_id=JOB_ID,
json=json)
op.render_template_fields(context={"ds": DATE})
- op._setup_and_validate_json()
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"jar_params": RENDERED_TEMPLATED_JAR_PARAMS,
@@ -1572,7 +1209,7 @@ class TestDatabricksRunNowOperator:
)
assert expected == op.json
- def test_validate_json_with_bad_type(self):
+ def test_init_with_bad_type(self):
json = {"test": datetime.now()}
# Looks a bit weird since we have to escape regex reserved symbols.
exception_message = (
@@ -1580,162 +1217,7 @@ class TestDatabricksRunNowOperator:
r"for parameter json\[test\] is not a number or a string"
)
with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
json=json)._setup_and_validate_json()
-
- def test_validate_json_exception_with_job_name_and_job_id(self):
- exception_message = "Argument 'job_name' is not allowed with argument
'job_id'"
-
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(
- task_id=TASK_ID, job_id=JOB_ID, job_name=JOB_NAME
- )._setup_and_validate_json()
-
- run = {"job_id": JOB_ID, "job_name": JOB_NAME}
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID,
json=run)._setup_and_validate_json()
-
- run = {"job_id": JOB_ID}
- with pytest.raises(AirflowException, match=exception_message):
- DatabricksRunNowOperator(task_id=TASK_ID, json=run,
job_name=JOB_NAME)._setup_and_validate_json()
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_templated_json(self, db_mock_class, dag_maker):
- json = "{{ ti.xcom_pull(task_ids='push_json') }}"
- with dag_maker("test_templated", render_template_as_native_obj=True):
- push_json = PythonOperator(
- task_id="push_json",
- python_callable=lambda: {
- "notebook_params": NOTEBOOK_PARAMS,
- "notebook_task": NOTEBOOK_TASK,
- "jar_params": JAR_PARAMS,
- "job_id": JOB_ID,
- },
- )
- op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
json=json)
- push_json >> op
-
- db_mock = db_mock_class.return_value
- db_mock.run_now.return_value = RUN_ID
- db_mock.get_run_page_url.return_value = RUN_PAGE_URL
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
-
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "notebook_params": NOTEBOOK_PARAMS,
- "notebook_task": NOTEBOOK_TASK,
- "jar_params": JAR_PARAMS,
- "job_id": JOB_ID,
- }
- )
-
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksRunNowOperator",
- )
- db_mock.run_now.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run.assert_called_once_with(RUN_ID)
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_templated_named_params(self, db_mock_class,
dag_maker):
- json = "{{ ti.xcom_pull(task_ids='push_json') }}"
- with dag_maker("test_templated", render_template_as_native_obj=True):
- push_json = PythonOperator(
- task_id="push_json",
- python_callable=lambda: {
- "notebook_params": NOTEBOOK_PARAMS,
- "notebook_task": NOTEBOOK_TASK,
- },
- )
- op = DatabricksRunNowOperator(
- task_id=TASK_ID, job_id=TEMPLATED_JOB_ID,
jar_params=TEMPLATED_JAR_PARAMS, json=json
- )
- push_json >> op
-
- db_mock = db_mock_class.return_value
- db_mock.run_now.return_value = RUN_ID
- db_mock.get_run_page_url.return_value = RUN_PAGE_URL
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
-
- dagrun =
dag_maker.create_dagrun(execution_date=datetime.strptime(DATE, "%Y-%m-%d"))
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "notebook_params": NOTEBOOK_PARAMS,
- "notebook_task": NOTEBOOK_TASK,
- "jar_params": RENDERED_TEMPLATED_JAR_PARAMS,
- "job_id": RENDERED_TEMPLATED_JOB_ID,
- }
- )
-
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksRunNowOperator",
- )
- db_mock.run_now.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run.assert_called_once_with(RUN_ID)
-
-
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
- def test_validate_json_with_xcomarg_json(self, db_mock_class, dag_maker):
- with dag_maker("test_xcomarg", render_template_as_native_obj=True):
-
- @task
- def push_json() -> dict:
- return {
- "notebook_params": NOTEBOOK_PARAMS,
- "notebook_task": NOTEBOOK_TASK,
- "jar_params": JAR_PARAMS,
- "job_id": JOB_ID,
- }
-
- json = push_json()
-
- op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
json=json)
-
- db_mock = db_mock_class.return_value
- db_mock.run_now.return_value = RUN_ID
- db_mock.get_run_page_url.return_value = RUN_PAGE_URL
- db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
-
- dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
- tis = {ti.task_id: ti for ti in dagrun.task_instances}
- tis["push_json"].run()
- tis[TASK_ID].run()
-
- expected = utils._normalise_json_content(
- {
- "notebook_params": NOTEBOOK_PARAMS,
- "notebook_task": NOTEBOOK_TASK,
- "jar_params": JAR_PARAMS,
- "job_id": JOB_ID,
- }
- )
-
- db_mock_class.assert_called_once_with(
- DEFAULT_CONN_ID,
- retry_limit=op.databricks_retry_limit,
- retry_delay=op.databricks_retry_delay,
- retry_args=None,
- caller="DatabricksRunNowOperator",
- )
- db_mock.run_now.assert_called_once_with(expected)
- db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
- db_mock.get_run.assert_called_once_with(RUN_ID)
+ DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=json)
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_success(self, db_mock_class):
@@ -1750,7 +1232,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1785,7 +1267,7 @@ class TestDatabricksRunNowOperator:
with pytest.raises(AirflowException):
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1841,7 +1323,7 @@ class TestDatabricksRunNowOperator:
with pytest.raises(AirflowException, match="Exception: Something went
wrong"):
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1909,7 +1391,7 @@ class TestDatabricksRunNowOperator:
):
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1953,7 +1435,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -1984,7 +1466,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -2004,6 +1486,20 @@ class TestDatabricksRunNowOperator:
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()
+ def test_init_exception_with_job_name_and_job_id(self):
+ exception_message = "Argument 'job_name' is not allowed with argument
'job_id'"
+
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID,
job_name=JOB_NAME)
+
+ run = {"job_id": JOB_ID, "job_name": JOB_NAME}
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(task_id=TASK_ID, json=run)
+
+ run = {"job_id": JOB_ID}
+ with pytest.raises(AirflowException, match=exception_message):
+ DatabricksRunNowOperator(task_id=TASK_ID, json=run,
job_name=JOB_NAME)
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_with_job_name(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task":
NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
@@ -2015,7 +1511,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -2063,7 +1559,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -2097,7 +1593,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -2134,7 +1630,7 @@ class TestDatabricksRunNowOperator:
assert isinstance(exc.value.trigger, DatabricksExecutionTrigger)
assert exc.value.method_name == "execute_complete"
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -2244,7 +1740,7 @@ class TestDatabricksRunNowOperator:
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
@@ -2278,7 +1774,7 @@ class TestDatabricksRunNowOperator:
op.execute(None)
- expected = utils._normalise_json_content(
+ expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
diff --git a/tests/providers/databricks/utils/test_databricks.py
b/tests/providers/databricks/utils/test_databricks.py
index 4b57573253..8c6ce8ce4b 100644
--- a/tests/providers/databricks/utils/test_databricks.py
+++ b/tests/providers/databricks/utils/test_databricks.py
@@ -21,7 +21,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.databricks.hooks.databricks import RunState
-from airflow.providers.databricks.utils.databricks import
_normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import
normalise_json_content, validate_trigger_event
RUN_ID = 1
RUN_PAGE_URL = "run-page-url"
@@ -46,7 +46,7 @@ class TestDatabricksOperatorSharedFunctions:
"test_list": ["1", "1.0", "a", "b"],
"test_tuple": ["1", "1.0", "a", "b"],
}
- assert _normalise_json_content(test_json) == expected
+ assert normalise_json_content(test_json) == expected
def test_validate_trigger_event_success(self):
event = {