This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 74ec32ff950 Enable ruff B008 (function-call-in-default-argument) and
fix violations (#66979)
74ec32ff950 is described below
commit 74ec32ff95092316123ddfc2d9ca31ae59f22d18
Author: Shahar Epstein <[email protected]>
AuthorDate: Sat May 16 20:54:23 2026 +0300
Enable ruff B008 (function-call-in-default-argument) and fix violations
(#66979)
Co-authored-by: Jens Scheffler <[email protected]>
---
.../src/airflow/api_fastapi/common/parameters.py | 18 ++--
.../unit/fab/auth_manager/api_fastapi/conftest.py | 4 +-
.../providers/google/cloud/hooks/vertex_ai/ray.py | 4 +-
.../google/cloud/operators/vertex_ai/ray.py | 4 +-
.../providers/google/cloud/transfers/s3_to_gcs.py | 2 +-
.../unit/google/cloud/hooks/vertex_ai/test_ray.py | 17 ++++
.../google/cloud/operators/vertex_ai/test_ray.py | 99 ++++++++++++++++++++++
.../tests/system/openlineage/operator.py | 4 +-
pyproject.toml | 13 +++
.../observability/metrics/datadog_logger.py | 8 +-
.../observability/metrics/otel_logger.py | 4 +-
.../observability/metrics/statsd_logger.py | 8 +-
task-sdk/src/airflow/sdk/bases/sensor.py | 4 +-
.../airflow/sdk/definitions/operator_resources.py | 16 ++--
task-sdk/tests/task_sdk/bases/test_sensor.py | 14 +++
.../definitions/test_operator_resources.py | 38 +++++++++
16 files changed, 222 insertions(+), 35 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py
b/airflow-core/src/airflow/api_fastapi/common/parameters.py
index a93ec040e1c..dfd0cd473cb 100644
--- a/airflow-core/src/airflow/api_fastapi/common/parameters.py
+++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py
@@ -77,6 +77,8 @@ if TYPE_CHECKING:
T = TypeVar("T")
+_FALLBACK_PAGE_LIMIT: int = conf.getint("api", "fallback_page_limit")
+
class BaseParam(OrmClause[T], ABC):
"""Base class for path or query parameters with ORM transformation."""
@@ -106,7 +108,7 @@ class LimitFilter(BaseParam[NonNegativeInt]):
return select.limit(self.value)
@classmethod
- def depends(cls, limit: NonNegativeInt = conf.getint("api",
"fallback_page_limit")) -> LimitFilter:
+ def depends(cls, limit: NonNegativeInt = _FALLBACK_PAGE_LIMIT) ->
LimitFilter:
return cls().set_value(min(limit, conf.getint("api",
"maximum_page_limit")))
@@ -607,13 +609,13 @@ class SortParam(BaseParam[list[str]]):
else:
default_list = list(default)
- def inner(
- order_by: list[str] = Query(
- default=default_list,
- description=f"Attributes to order by, multi criteria sort is
supported. Prefix with `-` for descending order. "
- f"Supported attributes: `{', '.join(all_attrs) if all_attrs
else self.get_primary_key_string()}`",
- ),
- ) -> SortParam:
+ _order_by_query = Query(
+ default=default_list,
+ description=f"Attributes to order by, multi criteria sort is
supported. Prefix with `-` for descending order. "
+ f"Supported attributes: `{', '.join(all_attrs) if all_attrs else
self.get_primary_key_string()}`",
+ )
+
+ def inner(order_by: list[str] = _order_by_query) -> SortParam:
return self.set_value(order_by)
return inner
diff --git a/providers/fab/tests/unit/fab/auth_manager/api_fastapi/conftest.py
b/providers/fab/tests/unit/fab/auth_manager/api_fastapi/conftest.py
index fb884b8ede7..af2615f7581 100644
--- a/providers/fab/tests/unit/fab/auth_manager/api_fastapi/conftest.py
+++ b/providers/fab/tests/unit/fab/auth_manager/api_fastapi/conftest.py
@@ -63,7 +63,9 @@ def override_deps(test_client):
@pytest.fixture
def as_user(override_deps):
@contextmanager
- def _as(u=types.SimpleNamespace(id=1, username="tester")):
+ def _as(u=None):
+ if u is None:
+ u = types.SimpleNamespace(id=1, username="tester")
with override_deps({get_user_dep: lambda: u}):
yield u
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
index de722ed5959..9d854804145 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/vertex_ai/ray.py
@@ -66,7 +66,7 @@ class RayHook(GoogleBaseHook):
self,
project_id: str,
location: str,
- head_node_type: resources.Resources = resources.Resources(),
+ head_node_type: resources.Resources | None = None,
python_version: str = "3.10",
ray_version: str = "2.33",
network: str | None = None,
@@ -115,7 +115,7 @@ class RayHook(GoogleBaseHook):
"""
aiplatform.init(project=project_id, location=location,
credentials=self.get_credentials())
cluster_path = vertex_ray.create_ray_cluster(
- head_node_type=head_node_type,
+ head_node_type=head_node_type or resources.Resources(),
python_version=python_version,
ray_version=ray_version,
network=network,
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
index 4d6723977af..0a90637f103 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/ray.py
@@ -140,7 +140,7 @@ class CreateRayClusterOperator(RayBaseOperator):
self,
python_version: str,
ray_version: Literal["2.9.3", "2.33", "2.42"],
- head_node_type: resources.Resources = resources.Resources(),
+ head_node_type: resources.Resources | None = None,
network: str | None = None,
service_account: str | None = None,
cluster_name: str | None = None,
@@ -155,7 +155,7 @@ class CreateRayClusterOperator(RayBaseOperator):
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
- self.head_node_type = head_node_type
+ self.head_node_type = head_node_type or resources.Resources()
self.python_version = python_version
self.ray_version = ray_version
self.network = network
diff --git
a/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py
b/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py
index 861769d5768..b7df46631dd 100644
--- a/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py
+++ b/providers/google/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py
@@ -165,7 +165,7 @@ class S3ToGCSOperator(S3ListOperator):
replace=False,
gzip=False,
google_impersonation_chain: str | Sequence[str] | None = None,
- deferrable=conf.getboolean("operators", "default_deferrable",
fallback=False),
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
poll_interval: int = 10,
return_gcs_uris: bool = False,
**kwargs,
diff --git
a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
index d3d071c4b3f..26205939a85 100644
--- a/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
+++ b/providers/google/tests/unit/google/cloud/hooks/vertex_ai/test_ray.py
@@ -94,6 +94,23 @@ class TestRayWithDefaultProjectIdHook:
labels=None,
)
+ @mock.patch(RAY_STRING.format("vertex_ray.create_ray_cluster"))
+ @mock.patch(RAY_STRING.format("aiplatform.init"))
+ def test_create_ray_cluster_default_head_node_type(
+ self, mock_aiplatform_init, mock_create_ray_cluster
+ ) -> None:
+ self.hook.create_ray_cluster(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ head_node_type=None,
+ python_version=TEST_PYTHON_VERSION,
+ ray_version=TEST_RAY_VERSION,
+ cluster_name=TEST_CLUSTER_NAME,
+ )
+ mock_aiplatform_init.assert_called_once()
+ call_kwargs = mock_create_ray_cluster.call_args.kwargs
+ assert isinstance(call_kwargs["head_node_type"], Resources)
+
@mock.patch(RAY_STRING.format("vertex_ray.delete_ray_cluster"))
@mock.patch(RAY_STRING.format("aiplatform.init"))
@mock.patch(RAY_STRING.format("PersistentResourceServiceClient.persistent_resource_path"))
diff --git
a/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_ray.py
b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_ray.py
new file mode 100644
index 00000000000..889c81849e0
--- /dev/null
+++ b/providers/google/tests/unit/google/cloud/operators/vertex_ai/test_ray.py
@@ -0,0 +1,99 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+pytest.importorskip("google.cloud.aiplatform.vertex_ray.util.resources")
+from google.cloud.aiplatform.vertex_ray.util.resources import Resources
+
+from airflow.providers.google.cloud.operators.vertex_ai.ray import
CreateRayClusterOperator
+
+TEST_GCP_CONN_ID = "test-gcp-conn-id"
+TEST_LOCATION = "us-central1"
+TEST_PROJECT_ID = "test-project-id"
+TEST_PYTHON_VERSION = "3.10"
+TEST_RAY_VERSION = "2.33"
+TEST_CLUSTER_NAME = "test-cluster-name"
+
+VERTEX_AI_RAY_OP_PATH =
"airflow.providers.google.cloud.operators.vertex_ai.ray.{}"
+
+
+class TestCreateRayClusterOperator:
+ @mock.patch(VERTEX_AI_RAY_OP_PATH.format("RayHook"))
+ def test_create_ray_cluster_with_explicit_head_node_type(self,
mock_hook_cls):
+ explicit_head = Resources()
+ op = CreateRayClusterOperator(
+ task_id="test-task",
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ python_version=TEST_PYTHON_VERSION,
+ ray_version=TEST_RAY_VERSION,
+ head_node_type=explicit_head,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ )
+ assert op.head_node_type is explicit_head
+
+ @mock.patch(VERTEX_AI_RAY_OP_PATH.format("RayHook"))
+ def
test_create_ray_cluster_default_head_node_type_is_fresh_resources(self,
mock_hook_cls):
+ op1 = CreateRayClusterOperator(
+ task_id="test-task-1",
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ python_version=TEST_PYTHON_VERSION,
+ ray_version=TEST_RAY_VERSION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ )
+ op2 = CreateRayClusterOperator(
+ task_id="test-task-2",
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ python_version=TEST_PYTHON_VERSION,
+ ray_version=TEST_RAY_VERSION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ )
+ assert isinstance(op1.head_node_type, Resources)
+ assert isinstance(op2.head_node_type, Resources)
+ assert op1.head_node_type is not op2.head_node_type
+
+ @mock.patch(VERTEX_AI_RAY_OP_PATH.format("VertexAIRayClusterLink"))
+ @mock.patch(VERTEX_AI_RAY_OP_PATH.format("RayHook"))
+ def test_execute_without_head_node_type_passes_default_resources(self,
mock_hook_cls, mock_link):
+ mock_hook = mock_hook_cls.return_value
+ mock_hook.create_ray_cluster.return_value = (
+
f"projects/{TEST_PROJECT_ID}/locations/{TEST_LOCATION}/persistentResources/{TEST_CLUSTER_NAME}"
+ )
+ mock_hook.extract_cluster_id.return_value = TEST_CLUSTER_NAME
+
+ op = CreateRayClusterOperator(
+ task_id="test-task",
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ python_version=TEST_PYTHON_VERSION,
+ ray_version=TEST_RAY_VERSION,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ )
+
+ ti_mock = mock.MagicMock()
+ context = {"ti": ti_mock, "task": mock.MagicMock()}
+ op.execute(context=context)
+
+ call_kwargs = mock_hook.create_ray_cluster.call_args.kwargs
+ assert isinstance(call_kwargs["head_node_type"], Resources)
diff --git a/providers/openlineage/tests/system/openlineage/operator.py
b/providers/openlineage/tests/system/openlineage/operator.py
index 0bc8aa7c003..51ec3fd67ec 100644
--- a/providers/openlineage/tests/system/openlineage/operator.py
+++ b/providers/openlineage/tests/system/openlineage/operator.py
@@ -197,7 +197,7 @@ class OpenLineageTestOperator(BaseOperator):
self,
event_templates: dict[str, dict] | None = None,
file_path: str | None = None,
- env: Environment = setup_jinja(),
+ env: Environment | None = None,
allow_duplicate_events_regex: str | None = None,
clear_variables: bool = True,
**kwargs,
@@ -205,7 +205,7 @@ class OpenLineageTestOperator(BaseOperator):
super().__init__(**kwargs)
self.event_templates = event_templates
self.file_path = file_path
- self.env = env
+ self.env = env or setup_jinja()
self.allow_duplicate_events_regex = allow_duplicate_events_regex
self.clear_variables = clear_variables
if self.event_templates and self.file_path:
diff --git a/pyproject.toml b/pyproject.toml
index 12436ce5bf5..8d2c55b9e97 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -655,6 +655,7 @@ extend-select = [
"B004", # Checks for use of hasattr(x, "__call__") and replaces it with
callable(x)
"B006", # Checks for uses of mutable objects as function argument defaults.
"B007", # Checks for unused variables in the loop
+ "B008", # Do not perform function call in argument defaults (use
extend-immutable-calls for FastAPI DI)
"B012", # Checks for `break`, `continue`, and `return` statements in
`finally` blocks
"B017", # Checks for pytest.raises context managers that catch Exception
or BaseException.
"B019", # Use of functools.lru_cache or functools.cache on methods can
lead to memory leaks
@@ -704,6 +705,18 @@ unfixable = [
"PT022",
]
+[tool.ruff.lint.flake8-bugbear]
+# FastAPI dependency injection uses function calls in argument defaults
intentionally.
+# SHA256 is a stateless algorithm descriptor (cryptography library).
+extend-immutable-calls = [
+ "fastapi.Body",
+ "fastapi.Depends",
+ "fastapi.Query",
+ "fastapi.Path",
+ "fastapi.Security",
+ "cryptography.hazmat.primitives.hashes.SHA256",
+]
+
[tool.ruff.format]
docstring-code-format = true
diff --git
a/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
b/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
index e6daf53b793..129354c3a43 100644
---
a/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
+++
b/shared/observability/src/airflow_shared/observability/metrics/datadog_logger.py
@@ -45,16 +45,16 @@ class SafeDogStatsdLogger:
def __init__(
self,
dogstatsd_client: DogStatsd,
- metrics_validator: ListValidator = PatternAllowListValidator(),
+ metrics_validator: ListValidator | None = None,
metrics_tags: bool = False,
- metric_tags_validator: ListValidator = PatternAllowListValidator(),
+ metric_tags_validator: ListValidator | None = None,
stat_name_handler: Callable[[str], str] | None = None,
statsd_influxdb_enabled: bool = False,
) -> None:
self.dogstatsd = dogstatsd_client
- self.metrics_validator = metrics_validator
+ self.metrics_validator = metrics_validator or
PatternAllowListValidator()
self.metrics_tags = metrics_tags
- self.metric_tags_validator = metric_tags_validator
+ self.metric_tags_validator = metric_tags_validator or
PatternAllowListValidator()
self.stat_name_handler = stat_name_handler
self.statsd_influxdb_enabled = statsd_influxdb_enabled
diff --git
a/shared/observability/src/airflow_shared/observability/metrics/otel_logger.py
b/shared/observability/src/airflow_shared/observability/metrics/otel_logger.py
index 5aaa77741f0..8d25b23372a 100644
---
a/shared/observability/src/airflow_shared/observability/metrics/otel_logger.py
+++
b/shared/observability/src/airflow_shared/observability/metrics/otel_logger.py
@@ -175,13 +175,13 @@ class SafeOtelLogger:
self,
otel_provider,
prefix: str = DEFAULT_METRIC_NAME_PREFIX,
- metrics_validator: ListValidator = PatternAllowListValidator(),
+ metrics_validator: ListValidator | None = None,
stat_name_handler: Callable[[str], str] | None = None,
statsd_influxdb_enabled: bool = False,
):
self.otel: Callable = otel_provider
self.prefix: str = prefix
- self.metrics_validator = metrics_validator
+ self.metrics_validator = metrics_validator or
PatternAllowListValidator()
self.meter = otel_provider.get_meter(__name__)
self.metrics_map = MetricsMap(self.meter)
self.stat_name_handler = stat_name_handler
diff --git
a/shared/observability/src/airflow_shared/observability/metrics/statsd_logger.py
b/shared/observability/src/airflow_shared/observability/metrics/statsd_logger.py
index 7e8d29f3a26..3500a04dc7b 100644
---
a/shared/observability/src/airflow_shared/observability/metrics/statsd_logger.py
+++
b/shared/observability/src/airflow_shared/observability/metrics/statsd_logger.py
@@ -67,16 +67,16 @@ class SafeStatsdLogger:
def __init__(
self,
statsd_client: StatsClient,
- metrics_validator: ListValidator = PatternAllowListValidator(),
+ metrics_validator: ListValidator | None = None,
influxdb_tags_enabled: bool = False,
- metric_tags_validator: ListValidator = PatternAllowListValidator(),
+ metric_tags_validator: ListValidator | None = None,
stat_name_handler: Callable[[str], str] | None = None,
statsd_influxdb_enabled: bool = False,
) -> None:
self.statsd = statsd_client
- self.metrics_validator = metrics_validator
+ self.metrics_validator = metrics_validator or
PatternAllowListValidator()
self.influxdb_tags_enabled = influxdb_tags_enabled
- self.metric_tags_validator = metric_tags_validator
+ self.metric_tags_validator = metric_tags_validator or
PatternAllowListValidator()
self.stat_name_handler = stat_name_handler
self.statsd_influxdb_enabled = statsd_influxdb_enabled
diff --git a/task-sdk/src/airflow/sdk/bases/sensor.py
b/task-sdk/src/airflow/sdk/bases/sensor.py
index 3a877dd98af..3f1f0842e89 100644
--- a/task-sdk/src/airflow/sdk/bases/sensor.py
+++ b/task-sdk/src/airflow/sdk/bases/sensor.py
@@ -116,7 +116,7 @@ class BaseSensorOperator(BaseOperator):
self,
*,
poke_interval: timedelta | float = 60,
- timeout: timedelta | float = conf.getfloat("sensors",
"default_timeout"),
+ timeout: timedelta | float | None = None,
soft_fail: bool = False,
mode: str = "poke",
exponential_backoff: bool = False,
@@ -128,6 +128,8 @@ class BaseSensorOperator(BaseOperator):
super().__init__(**kwargs)
self.poke_interval =
self._coerce_poke_interval(poke_interval).total_seconds()
self.soft_fail = soft_fail
+ if timeout is None:
+ timeout = conf.getfloat("sensors", "default_timeout")
self.timeout: int | float =
self._coerce_timeout(timeout).total_seconds()
self.mode = mode
self.exponential_backoff = exponential_backoff
diff --git a/task-sdk/src/airflow/sdk/definitions/operator_resources.py
b/task-sdk/src/airflow/sdk/definitions/operator_resources.py
index d6cbf10039d..d8a0d02a1b4 100644
--- a/task-sdk/src/airflow/sdk/definitions/operator_resources.py
+++ b/task-sdk/src/airflow/sdk/definitions/operator_resources.py
@@ -125,15 +125,15 @@ class Resources:
def __init__(
self,
- cpus=conf.getint("operators", "default_cpus"),
- ram=conf.getint("operators", "default_ram"),
- disk=conf.getint("operators", "default_disk"),
- gpus=conf.getint("operators", "default_gpus"),
+ cpus=None,
+ ram=None,
+ disk=None,
+ gpus=None,
):
- self.cpus = CpuResource(cpus)
- self.ram = RamResource(ram)
- self.disk = DiskResource(disk)
- self.gpus = GpuResource(gpus)
+ self.cpus = CpuResource(cpus if cpus is not None else
conf.getint("operators", "default_cpus"))
+ self.ram = RamResource(ram if ram is not None else
conf.getint("operators", "default_ram"))
+ self.disk = DiskResource(disk if disk is not None else
conf.getint("operators", "default_disk"))
+ self.gpus = GpuResource(gpus if gpus is not None else
conf.getint("operators", "default_gpus"))
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
diff --git a/task-sdk/tests/task_sdk/bases/test_sensor.py
b/task-sdk/tests/task_sdk/bases/test_sensor.py
index 8a6fbcc15ac..5e7f588a155 100644
--- a/task-sdk/tests/task_sdk/bases/test_sensor.py
+++ b/task-sdk/tests/task_sdk/bases/test_sensor.py
@@ -41,6 +41,8 @@ from airflow.sdk.exceptions import (
from airflow.sdk.execution_time.comms import RescheduleTask,
TaskRescheduleStartDate
from airflow.sdk.timezone import datetime
+from tests_common.test_utils.config import conf_vars
+
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
@@ -358,6 +360,18 @@ class TestBaseSensor:
task_id="test_sensor_task_3", return_value=None, poke_interval=10,
timeout=positive_timeout
)
+ def test_sensor_timeout_default_read_from_conf_at_instantiation(self):
+ """When ``timeout`` is not supplied, it should be read from
``sensors.default_timeout``
+ at instantiation time (not at module import time).
+ """
+ with conf_vars({("sensors", "default_timeout"): "12345"}):
+ sensor = DummySensor(task_id="test_sensor_default_timeout",
return_value=None, poke_interval=10)
+ assert sensor.timeout == 12345
+
+ with conf_vars({("sensors", "default_timeout"): "67"}):
+ sensor = DummySensor(task_id="test_sensor_default_timeout_2",
return_value=None, poke_interval=10)
+ assert sensor.timeout == 67
+
def test_sensor_with_exponential_backoff_off(self):
sensor = DummySensor(
task_id=SENSOR_OP, return_value=None, poke_interval=5, timeout=60,
exponential_backoff=False
diff --git a/task-sdk/tests/task_sdk/definitions/test_operator_resources.py
b/task-sdk/tests/task_sdk/definitions/test_operator_resources.py
index 9e0875cf076..389ab8701f6 100644
--- a/task-sdk/tests/task_sdk/definitions/test_operator_resources.py
+++ b/task-sdk/tests/task_sdk/definitions/test_operator_resources.py
@@ -19,6 +19,8 @@ from __future__ import annotations
from airflow.sdk.definitions.operator_resources import Resources
+from tests_common.test_utils.config import conf_vars
+
class TestResources:
def test_resource_eq(self):
@@ -41,3 +43,39 @@ class TestResources:
"disk": {"name": "Disk", "qty": 1024, "units_str": "MB"},
"gpus": {"name": "GPU", "qty": 1, "units_str": "gpu(s)"},
}
+
+ def test_defaults_read_from_conf_at_instantiation(self):
+ """When fields are omitted, ``Resources`` should read defaults from
the ``operators``
+ section at instantiation time (not at module import time).
+ """
+ with conf_vars(
+ {
+ ("operators", "default_cpus"): "7",
+ ("operators", "default_ram"): "5120",
+ ("operators", "default_disk"): "8192",
+ ("operators", "default_gpus"): "3",
+ }
+ ):
+ r = Resources()
+ assert r.cpus.qty == 7
+ assert r.ram.qty == 5120
+ assert r.disk.qty == 8192
+ assert r.gpus.qty == 3
+
+ def test_falsy_zero_values_are_preserved(self):
+ """Explicit ``0`` for a resource must not be replaced by the config
default —
+ only ``None`` (the sentinel for "not supplied") should fall back to
config.
+ """
+ with conf_vars(
+ {
+ ("operators", "default_cpus"): "4",
+ ("operators", "default_ram"): "2048",
+ ("operators", "default_disk"): "1024",
+ ("operators", "default_gpus"): "2",
+ }
+ ):
+ r = Resources(cpus=0, ram=0, disk=0, gpus=0)
+ assert r.cpus.qty == 0
+ assert r.ram.qty == 0
+ assert r.disk.qty == 0
+ assert r.gpus.qty == 0