This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 458ac63f7cb Supporting Variable.set in execution time with task SDK
(#49005)
458ac63f7cb is described below
commit 458ac63f7cb5b671483f0e521f1cb0ad57508675
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Apr 10 12:56:00 2025 +0530
Supporting Variable.set in execution time with task SDK (#49005)
* Supporting Variable.set in execution time with task SDK
* adding tests for variable set
* fixing static checks and review comments from ash
---
.../src/airflow/dag_processing/processor.py | 5 ++-
airflow-core/src/airflow/models/variable.py | 24 +++++++++++
.../tests/unit/dag_processing/test_processor.py | 30 ++++++++++++++
task-sdk/src/airflow/sdk/definitions/variable.py | 10 +++++
task-sdk/src/airflow/sdk/execution_time/context.py | 46 ++++++++++++++++++++++
.../tests/task_sdk/definitions/test_variables.py | 36 ++++++++++++++++-
6 files changed, 149 insertions(+), 2 deletions(-)
diff --git a/airflow-core/src/airflow/dag_processing/processor.py
b/airflow-core/src/airflow/dag_processing/processor.py
index b6e07242dce..365d771a081 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -38,6 +38,7 @@ from airflow.sdk.execution_time.comms import (
ErrorResponse,
GetConnection,
GetVariable,
+ PutVariable,
VariableResult,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
@@ -53,7 +54,7 @@ if TYPE_CHECKING:
from airflow.typing_compat import Self
ToManager = Annotated[
- Union["DagFileParsingResult", GetConnection, GetVariable],
+ Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable],
Field(discriminator="type"),
]
@@ -290,6 +291,8 @@ class DagFileProcessorProcess(WatchedSubprocess):
dump_opts = {"exclude_unset": True}
else:
resp = var
+ elif isinstance(msg, PutVariable):
+ self.client.variables.set(msg.key, msg.value, msg.description)
else:
log.error("Unhandled request", msg=msg)
return
diff --git a/airflow-core/src/airflow/models/variable.py
b/airflow-core/src/airflow/models/variable.py
index e0b91686285..9c6af8ab727 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -201,6 +201,30 @@ class Variable(Base, LoggingMixin):
"""
# check if the secret exists in the custom secrets' backend.
Variable.check_for_write_conflict(key=key)
+
+ # TODO: This is not the best way of having compat, but it's "better
than erroring" for now. This still
+ # means SQLA etc is loaded, but we can't avoid that unless/until we
add import shims as a big
+ # back-compat layer
+
+ # If this is set it means are in some kind of execution context (Task,
Dag Parse or Triggerer perhaps)
+ # and should use the Task SDK API server path
+ if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"),
"SUPERVISOR_COMMS"):
+ warnings.warn(
+ "Using Variable.set from `airflow.models` is deprecated.
Please use `from airflow.sdk import"
+ "Variable` instead",
+ DeprecationWarning,
+ stacklevel=1,
+ )
+ from airflow.sdk import Variable as TaskSDKVariable
+
+ TaskSDKVariable.set(
+ key=key,
+ value=value,
+ description=description,
+ serialize_json=serialize_json,
+ )
+ return
+
if serialize_json:
stored_value = json.dumps(value, indent=2)
else:
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py
b/airflow-core/tests/unit/dag_processing/test_processor.py
index ca9670f81ae..4135c617f39 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -181,6 +181,36 @@ class TestDagFileProcessor:
if result.import_errors:
assert "VARIABLE_NOT_FOUND" in
next(iter(result.import_errors.values()))
+ def test_top_level_variable_set(self, tmp_path: pathlib.Path):
+ from airflow.models.variable import Variable as VariableORM
+
+ logger_filehandle = MagicMock()
+
+ def dag_in_a_fn():
+ from airflow.sdk import DAG, Variable
+
+ Variable.set(key="mykey", value="myvalue")
+ with DAG(f"test_{Variable.get('mykey')}"):
+ ...
+
+ path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
+ proc = DagFileProcessorProcess.start(
+ id=1, path=path, bundle_path=tmp_path, callbacks=[],
logger_filehandle=logger_filehandle
+ )
+
+ while not proc.is_ready:
+ proc._service_subprocess(0.1)
+
+ with create_session() as session:
+ result = proc.parsing_result
+ assert result is not None
+ assert result.import_errors == {}
+ assert result.serialized_dags[0].dag_id == "test_myvalue"
+
+ all_vars = session.query(VariableORM).all()
+ assert len(all_vars) == 1
+ assert all_vars[0].key == "mykey"
+
def test_top_level_connection_access(self, tmp_path: pathlib.Path,
monkeypatch: pytest.MonkeyPatch):
logger_filehandle = MagicMock()
diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py
b/task-sdk/src/airflow/sdk/definitions/variable.py
index 87b0ee29fab..9e80b8a5667 100644
--- a/task-sdk/src/airflow/sdk/definitions/variable.py
+++ b/task-sdk/src/airflow/sdk/definitions/variable.py
@@ -55,3 +55,13 @@ class Variable:
if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is
not NOTSET:
return default
raise
+
+ @classmethod
+ def set(cls, key: str, value: Any, description: str | None = None,
serialize_json: bool = False) -> None:
+ from airflow.sdk.exceptions import AirflowRuntimeError
+ from airflow.sdk.execution_time.context import _set_variable
+
+ try:
+ return _set_variable(key, value, description,
serialize_json=serialize_json)
+ except AirflowRuntimeError as e:
+ log.exception(e)
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 2ff7dbbda08..fee91f3efaf 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -211,6 +211,52 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
return variable.value
+def _set_variable(key: str, value: Any, description: str | None = None,
serialize_json: bool = False) -> None:
+ # TODO: This should probably be moved to a separate module like
`airflow.sdk.execution_time.comms`
+ # or `airflow.sdk.execution_time.variable`
+ # A reason to not move it to `airflow.sdk.execution_time.comms` is that
it
+ # will make that module depend on Task SDK, which is not ideal because
we intend to
+ # keep Task SDK as a separate package than execution time mods.
+ import json
+
+ from airflow.sdk.execution_time.comms import PutVariable
+ from airflow.sdk.execution_time.supervisor import
ensure_secrets_backend_loaded
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ # check for write conflicts on the worker
+ for secrets_backend in ensure_secrets_backend_loaded():
+ try:
+ var_val = secrets_backend.get_variable(key=key)
+ if var_val is not None:
+ _backend_name = type(secrets_backend).__name__
+ log.warning(
+ "The variable %s is defined in the %s secrets backend,
which takes "
+ "precedence over reading from the database. The value in
the database will be "
+ "updated, but to read it you have to delete the
conflicting variable "
+ "from %s",
+ key,
+ _backend_name,
+ _backend_name,
+ )
+ except Exception:
+ log.exception(
+ "Unable to retrieve variable from secrets backend (%s).
Checking subsequent secrets backend.",
+ type(secrets_backend).__name__,
+ )
+
+ try:
+ if serialize_json:
+ value = json.dumps(value, indent=2)
+ except Exception as e:
+ log.exception(e)
+
+ # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS,
lock was
+ # primarily added for triggers but it doesn't make sense to have it in
some places
+ # and not in the rest. A lot of this will be simplified by
https://github.com/apache/airflow/issues/46426
+ with SUPERVISOR_COMMS.lock:
+ SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key,
value=value, description=description))
+
+
class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py
b/task-sdk/tests/task_sdk/definitions/test_variables.py
index 6560bdee903..242c5af407b 100644
--- a/task-sdk/tests/task_sdk/definitions/test_variables.py
+++ b/task-sdk/tests/task_sdk/definitions/test_variables.py
@@ -17,13 +17,14 @@
# under the License.
from __future__ import annotations
+import json
from unittest import mock
import pytest
from airflow.configuration import initialize_secrets_backends
from airflow.sdk import Variable
-from airflow.sdk.execution_time.comms import VariableResult
+from airflow.sdk.execution_time.comms import PutVariable, VariableResult
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
from tests_common.test_utils.config import conf_vars
@@ -55,6 +56,39 @@ class TestVariables:
assert var is not None
assert var == expected_value
+ @pytest.mark.parametrize(
+ "key, value, description, serialize_json",
+ [
+ pytest.param(
+ "key",
+ "value",
+ "description",
+ False,
+ id="simple-value",
+ ),
+ pytest.param(
+ "key2",
+ {"hi": "there", "hello": 42, "flag": True},
+ "description2",
+ True,
+ id="serialize-json-value",
+ ),
+ ],
+ )
+ def test_var_set(self, key, value, description, serialize_json,
mock_supervisor_comms):
+ Variable.set(key=key, value=value, description=description,
serialize_json=serialize_json)
+
+ expected_value = value
+ if serialize_json:
+ expected_value = json.dumps(value, indent=2)
+
+ mock_supervisor_comms.send_request.assert_called_once_with(
+ log=mock.ANY,
+ msg=PutVariable(
+ key=key, value=expected_value, description=description,
serialize_json=serialize_json
+ ),
+ )
+
class TestVariableFromSecrets:
def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):