This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f8104325b7 Activate RUF019 that checks for unnecessary key check
(#38950)
f8104325b7 is described below
commit f8104325b7a66d4e98ff3a6c3555f90c796071c6
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Apr 15 09:53:34 2024 +0200
Activate RUF019 that checks for unnecessary key check (#38950)
---
airflow/example_dags/example_params_trigger_ui.py | 2 +-
airflow/models/dag.py | 10 +++-------
airflow/providers/fab/auth_manager/fab_auth_manager.py | 4 ++--
airflow/providers/google/cloud/hooks/bigquery.py | 6 ++----
airflow/providers/snowflake/hooks/snowflake_sql_api.py | 11 ++++++-----
pyproject.toml | 5 +++--
tests/providers/google/cloud/hooks/test_bigquery.py | 8 ++++++--
tests/providers/google/cloud/triggers/test_bigquery.py | 4 +++-
8 files changed, 26 insertions(+), 24 deletions(-)
diff --git a/airflow/example_dags/example_params_trigger_ui.py
b/airflow/example_dags/example_params_trigger_ui.py
index 2a1e6c9b34..47465ad39d 100644
--- a/airflow/example_dags/example_params_trigger_ui.py
+++ b/airflow/example_dags/example_params_trigger_ui.py
@@ -73,7 +73,7 @@ with DAG(
dag_run: DagRun = ti.dag_run
selected_languages = []
for lang in ["english", "german", "french"]:
- if lang in dag_run.conf and dag_run.conf[lang]:
+ if dag_run.conf.get(lang):
selected_languages.append(f"generate_{lang}_greeting")
return selected_languages
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9dc0083196..4e94432aef 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -582,8 +582,7 @@ class DAG(LoggingMixin):
if start_date and start_date.tzinfo:
tzinfo = None if start_date.tzinfo else settings.TIMEZONE
tz = pendulum.instance(start_date, tz=tzinfo).timezone
- elif "start_date" in self.default_args and
self.default_args["start_date"]:
- date = self.default_args["start_date"]
+ elif date := self.default_args.get("start_date"):
if not isinstance(date, datetime):
date = timezone.parse(date)
self.default_args["start_date"] = date
@@ -594,11 +593,8 @@ class DAG(LoggingMixin):
self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE
# Apply the timezone we settled on to end_date if it wasn't supplied
- if "end_date" in self.default_args and self.default_args["end_date"]:
- if isinstance(self.default_args["end_date"], str):
- self.default_args["end_date"] = timezone.parse(
- self.default_args["end_date"], timezone=self.timezone
- )
+ if isinstance(_end_date := self.default_args.get("end_date"), str):
+ self.default_args["end_date"] = timezone.parse(_end_date,
timezone=self.timezone)
self.start_date = timezone.convert_to_utc(start_date)
self.end_date = timezone.convert_to_utc(end_date)
diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py
b/airflow/providers/fab/auth_manager/fab_auth_manager.py
index 87e80d3dcd..d01b3526bf 100644
--- a/airflow/providers/fab/auth_manager/fab_auth_manager.py
+++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py
@@ -350,8 +350,8 @@ class FabAuthManager(BaseAuthManager):
"""Return the login page url."""
if not self.security_manager.auth_view:
raise AirflowException("`auth_view` not defined in the security
manager.")
- if "next_url" in kwargs and kwargs["next_url"]:
- return
url_for(f"{self.security_manager.auth_view.endpoint}.login",
next=kwargs["next_url"])
+ if next_url := kwargs.get("next_url"):
+ return
url_for(f"{self.security_manager.auth_view.endpoint}.login", next=next_url)
else:
return url_for(f"{self.security_manager.auth_view.endpoint}.login")
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index a39025931d..0594ce4351 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -2849,11 +2849,10 @@ class BigQueryCursor(BigQueryBaseCursor):
return None
query_results = self._get_query_result()
- if "rows" in query_results and query_results["rows"]:
+ if rows := query_results.get("rows"):
self.page_token = query_results.get("pageToken")
fields = query_results["schema"]["fields"]
col_types = [field["type"] for field in fields]
- rows = query_results["rows"]
for dict_row in rows:
typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs
in enumerate(dict_row["f"])]
@@ -3396,8 +3395,7 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
:param as_dict: if True returns the result as a list of dictionaries,
otherwise as list of lists.
"""
buffer: list[Any] = []
- if "rows" in query_results and query_results["rows"]:
- rows = query_results["rows"]
+ if rows := query_results.get("rows"):
fields = query_results["schema"]["fields"]
fields_names = [field["name"] for field in fields]
col_types = [field["type"] for field in fields]
diff --git a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
index 018182302e..6eec055eb5 100644
--- a/airflow/providers/snowflake/hooks/snowflake_sql_api.py
+++ b/airflow/providers/snowflake/hooks/snowflake_sql_api.py
@@ -283,11 +283,12 @@ class SnowflakeSqlApiHook(SnowflakeHook):
elif status_code == 422:
return {"status": "error", "message": resp["message"]}
elif status_code == 200:
- statement_handles = []
- if "statementHandles" in resp and resp["statementHandles"]:
- statement_handles = resp["statementHandles"]
- elif "statementHandle" in resp and resp["statementHandle"]:
- statement_handles.append(resp["statementHandle"])
+ if resp_statement_handles := resp.get("statementHandles"):
+ statement_handles = resp_statement_handles
+ elif resp_statement_handle := resp.get("statementHandle"):
+ statement_handles = [resp_statement_handle]
+ else:
+ statement_handles = []
return {
"status": "success",
"message": resp["message"],
diff --git a/pyproject.toml b/pyproject.toml
index 0cf12cbe0c..1b505517d2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -271,6 +271,9 @@ extend-select = [
"PT", # flake8-pytest-style rules
"TID25", # flake8-tidy-imports rules
# Per rule enables
+ "RUF006", # Checks for asyncio dangling task
+ "RUF015", # Checks for unnecessary iterable allocation for first element
+ "RUF019", # Checks for unnecessary key check
"RUF100", # Unused noqa (auto-fixable)
# We ignore more pydocstyle than we enable, so be more selective at what
we enable
"D101",
@@ -292,8 +295,6 @@ extend-select = [
"B019", # Use of functools.lru_cache or functools.cache on methods can
lead to memory leaks
"B028", # No explicit stacklevel keyword argument found
"TRY002", # Prohibit use of `raise Exception`, use specific exceptions
instead.
- "RUF006", # Checks for asyncio dangling task
- "RUF015", # Checks for unnecessary iterable allocation for first element
]
ignore = [
"D203",
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index b02222b350..9118116287 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -1794,7 +1794,9 @@ class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
- def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor._get_query_result")
+ def test_hook_uses_legacy_sql_by_default(self, mock_get_query_result,
mock_insert, _):
+ mock_get_query_result.return_value = {}
self.hook.get_first("query")
_, kwargs = mock_insert.call_args
assert kwargs["configuration"]["query"]["useLegacySql"] is True
@@ -1805,9 +1807,11 @@ class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor._get_query_result")
def test_legacy_sql_override_propagates_properly(
- self, mock_insert, mock_build, mock_get_creds_and_proj_id
+ self, mock_get_query_result, mock_insert, mock_build,
mock_get_creds_and_proj_id
):
+ mock_get_query_result.return_value = {}
bq_hook = BigQueryHook(use_legacy_sql=False)
bq_hook.get_first("query")
_, kwargs = mock_insert.call_args
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 8c4318fde4..9eec245f83 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -499,12 +499,14 @@ class TestBigQueryIntervalCheckTrigger:
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_records")
async def test_interval_check_trigger_success(
- self, mock_get_job_output, mock_job_status, interval_check_trigger
+ self, mock_get_records, mock_get_job_output, mock_job_status,
interval_check_trigger
):
"""
Tests the BigQueryIntervalCheckTrigger only fires once the query
execution reaches a successful state.
"""
+ mock_get_records.return_value = {}
mock_job_status.return_value = {"status": "success", "message": "Job
completed"}
mock_get_job_output.return_value = ["0"]