Vamsi-klu commented on code in PR #68519:
URL: https://github.com/apache/airflow/pull/68519#discussion_r3424462673
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -281,6 +284,50 @@ def _inject_airflow_params_into_task(task: dict, params:
dict) -> None:
task_def[field] = dict(params)
+def _coerce_json_to_dict(json: Any) -> dict[str, Any]:
+ if json is None:
+ return {}
+ if isinstance(json, Mapping):
+ return dict(json)
+ if isinstance(json, str):
+ return _parse_json_string_to_dict(json)
+ raise DatabricksOperatorPayloadError(
+ f"Databricks json payload must resolve to a mapping, not
{type(json).__name__}."
+ )
+
+
+def _parse_json_string_to_dict(json: str) -> dict[str, Any]:
+ if not json:
+ return {}
+ try:
+ parsed_json = json_utils.loads(json)
+ except json_utils.JSONDecodeError:
+ try:
+ parsed_json = ast.literal_eval(json)
Review Comment:
Kept and documented. It is required, not incidental. Classic Jinja rendering
of a dict pulled from XCom (json="{{ ti.xcom_pull(...) }}") produces a
single-quoted Python repr like "{'job_id': 42}", which json.loads rejects and
ast.literal_eval parses correctly; the PR's own
test_exec_with_rendered_python_literal_json_and_templated_named_parameters
depends on it. I documented it in the _parse_json_string_to_dict docstring, in
the :param json: of all three operators (recommending the cleaner native
XComArg/producer.output path, which resolves to a real dict and skips the
string parser), and in the changelog note. Anything that still is not a mapping
raises DatabricksOperatorPayloadError.
---
Drafted-by: Claude Code (Opus 4.8); reviewed by @Vamsi-klu before posting
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -430,14 +498,16 @@ def _hook(self):
)
def execute(self, context: Context) -> int:
- if "name" not in self.json:
+ json = cast("dict[str, Any]",
normalise_json_content(self._get_merged_json()))
+ if "name" not in json:
raise AirflowException("Missing required parameter: name")
- job_id = self._hook.find_job_id_by_name(self.json["name"])
- if not self.json.get("parameters") and self.params:
- self.json["parameters"] = [{"name": k, "default": v} for k, v in
dict(self.params).items()]
+ job_id = self._hook.find_job_id_by_name(json["name"])
+ if not json.get("parameters") and self.params:
+ json["parameters"] = [{"name": k, "default": v} for k, v in
dict(self.params).items()]
+ self.json = json
Review Comment:
Fixed. Removed the self.json = json write entirely. It had no downstream
reader (create_job/reset_job take the local json and the method then returns),
so the payload is derived on demand via _get_merged_json() and the json
template field is never overwritten.
test_execute_does_not_mutate_json_template_field renders, executes with a
mocked hook, and asserts op.json is unchanged (and that the params to
parameters injection still reaches the submitted payload).
---
Drafted-by: Claude Code (Opus 4.8); reviewed by @Vamsi-klu before posting
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -674,28 +779,31 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)
def execute(self, context: Context):
+ json = self._get_merged_json()
+ self._validate_merged_json(json)
if (
- "pipeline_task" in self.json
- and self.json["pipeline_task"].get("pipeline_id") is None
- and self.json["pipeline_task"].get("pipeline_name")
+ isinstance(json.get("pipeline_task"), Mapping)
+ and json["pipeline_task"].get("pipeline_id") is None
+ and json["pipeline_task"].get("pipeline_name")
):
# If pipeline_id is not provided, we need to fetch it from the
pipeline_name
- pipeline_name = self.json["pipeline_task"]["pipeline_name"]
- self.json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
- del self.json["pipeline_task"]["pipeline_name"]
+ pipeline_name = json["pipeline_task"]["pipeline_name"]
+ json["pipeline_task"] = dict(json["pipeline_task"])
+ json["pipeline_task"]["pipeline_id"] =
self._hook.find_pipeline_id_by_name(pipeline_name)
+ del json["pipeline_task"]["pipeline_name"]
if self.params:
params_dump = dict(self.params)
- tasks = self.json.get("tasks")
+ tasks = json.get("tasks")
if isinstance(tasks, list):
for task in tasks:
if isinstance(task, dict):
_inject_airflow_params_into_task(task, params_dump)
else:
- _inject_airflow_params_into_task(self.json, params_dump)
+ _inject_airflow_params_into_task(json, params_dump)
- json_normalised = normalise_json_content(self.json)
- self.run_id = self._hook.submit_run(json_normalised)
+ self.json = normalise_json_content(json)
Review Comment:
Fixed. execute() now works on copy.deepcopy(self._get_merged_json()) and
submits a local normalised value, so self.json and the named template fields
are never written. The deep copy also fixes a subtler aliasing bug:
_get_merged_json() only shallow-copies, so _inject_airflow_params_into_task was
mutating the nested tasks/notebook_task dicts in place whenever self.params is
set, corrupting the named template fields even apart from the self.json write.
test_execute_does_not_mutate_template_fields populates params with a named
notebook_task and asserts the field is untouched while the params still reach
the submitted payload.
---
Drafted-by: Claude Code (Opus 4.8); reviewed by @Vamsi-klu before posting
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -992,26 +1121,32 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)
def execute(self, context: Context):
+ json = self._get_merged_json()
+ self._validate_merged_json(json)
hook = self._hook
- if "job_name" in self.json:
- job_id = hook.find_job_id_by_name(self.json["job_name"])
+ if "job_name" in json:
+ job_id = hook.find_job_id_by_name(json["job_name"])
if job_id is None:
- raise AirflowException(f"Job ID for job name
{self.json['job_name']} can not be found")
- self.json["job_id"] = job_id
- del self.json["job_name"]
+ raise DatabricksOperatorPayloadError(
+ f"Job ID for job name {json['job_name']} can not be found"
+ )
+ json["job_id"] = job_id
+ del json["job_name"]
if self.cancel_previous_runs:
- if (job_id := self.json.get("job_id")) is None:
+ if (job_id := json.get("job_id")) is None:
raise ValueError(
"cancel_previous_runs=True requires either job_id or
job_name to be provided."
)
hook.cancel_all_runs(job_id)
- if not self.json.get("job_parameters") and self.params:
- self.json["job_parameters"] = dict(self.params)
+ json = cast("dict[str, Any]", normalise_json_content(json))
Review Comment:
Fixed. normalise_json_content(json) now runs right after
_validate_merged_json, before hook = self._hook and any
find_job_id_by_name/cancel_all_runs call, so an invalid payload type fails fast
with no remote side-effects. For parity I added the same side-effect-free
validation pass to DatabricksSubmitRunOperator.execute() before
find_pipeline_id_by_name. test_exec_invalid_payload_fails_before_api_call
(RunNow) and test_exec_invalid_payload_fails_before_pipeline_lookup (SubmitRun)
assert the relevant hook methods are never called on a bad-type payload, and
test_init_with_bad_type now asserts the hook is not even constructed.
---
Drafted-by: Claude Code (Opus 4.8); reviewed by @Vamsi-klu before posting
##########
providers/databricks/src/airflow/providers/databricks/operators/databricks.py:
##########
@@ -992,26 +1121,32 @@ def _get_hook(self, caller: str) -> DatabricksHook:
)
def execute(self, context: Context):
+ json = self._get_merged_json()
+ self._validate_merged_json(json)
hook = self._hook
- if "job_name" in self.json:
- job_id = hook.find_job_id_by_name(self.json["job_name"])
+ if "job_name" in json:
+ job_id = hook.find_job_id_by_name(json["job_name"])
if job_id is None:
- raise AirflowException(f"Job ID for job name
{self.json['job_name']} can not be found")
- self.json["job_id"] = job_id
- del self.json["job_name"]
+ raise DatabricksOperatorPayloadError(
+ f"Job ID for job name {json['job_name']} can not be found"
+ )
+ json["job_id"] = job_id
+ del json["job_name"]
if self.cancel_previous_runs:
- if (job_id := self.json.get("job_id")) is None:
+ if (job_id := json.get("job_id")) is None:
raise ValueError(
"cancel_previous_runs=True requires either job_id or
job_name to be provided."
)
hook.cancel_all_runs(job_id)
- if not self.json.get("job_parameters") and self.params:
- self.json["job_parameters"] = dict(self.params)
+ json = cast("dict[str, Any]", normalise_json_content(json))
+ if not json.get("job_parameters") and self.params:
+ json["job_parameters"] = dict(self.params)
- self.run_id = hook.run_now(self.json)
+ self.json = json
Review Comment:
Fixed, and this addresses the deferrable caveat directly. execute() stores
the merged payload on a transient, non-template self._merged_json (not
self.json), and execute_complete() no longer reads self.json. It reconstructs
from the re-rendered template fields plus named params via _get_merged_json().
That is strictly more correct than the old read: on a deferral resume the
worker is a fresh process where any self.json write from execute() is already
gone, and _get_merged_json() now also recovers a job_parameters supplied via
the named job_parameters= argument (the old code only saw it inside json=).
Guarded by test_execute_complete_repair_includes_named_job_parameters (named
arg, fails on the old code and passes now) and
test_sync_repair_reads_job_parameters_from_merged_json (the synchronous repair
branch, previously uncovered). The existing
test_execute_complete_repair_includes_job_parameters still passes.
---
Drafted-by: Claude Code (Opus 4.8); reviewed by @Vamsi-klu before posting
##########
providers/databricks/tests/unit/databricks/operators/test_databricks.py:
##########
@@ -867,7 +893,37 @@ def test_init_with_templating(self):
"run_name": TASK_ID,
}
)
- assert expected == utils.normalise_json_content(op.json)
+ assert expected == utils.normalise_json_content(op._get_merged_json())
+
+
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+ def test_exec_with_xcom_arg_json_and_templated_named_parameters(self,
db_mock_class):
+ with DAG("test", schedule=None, start_date=datetime.now()):
Review Comment:
Fixed. One correction: DEFAULT_DATE did not actually exist in this file yet
(only DATE = "2017-04-20", a string used for ds), so I added DEFAULT_DATE =
datetime(2024, 1, 1) near the other constants and used it for the new test's
start_date. The other pre-existing start_date=datetime.now() occurrences in
this file were not introduced by this PR; happy to clean those up in a
follow-up to keep this diff focused.
---
Drafted-by: Claude Code (Opus 4.8); reviewed by @Vamsi-klu before posting
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]