This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5a0494f83e Refactor: Simplify code in models (#33181)
5a0494f83e is described below
commit 5a0494f83e8ad0e5cbf0d3dcad3022a3ea89d789
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Mon Aug 7 21:54:15 2023 +0000
Refactor: Simplify code in models (#33181)
---
airflow/models/base.py | 2 +-
airflow/models/baseoperator.py | 2 +-
airflow/models/dag.py | 31 ++++++++++++++-----------------
airflow/models/dagbag.py | 2 +-
airflow/models/expandinput.py | 2 +-
airflow/models/taskmixin.py | 2 +-
6 files changed, 19 insertions(+), 22 deletions(-)
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 5f6b7e9893..934b9b1b74 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -69,7 +69,7 @@ def get_id_collation_args():
# We cannot use session/dialect as at this point we are trying to
determine the right connection
# parameters, so we use the connection
conn = conf.get("database", "sql_alchemy_conn", fallback="")
- if conn.startswith("mysql") or conn.startswith("mariadb"):
+ if conn.startswith(("mysql", "mariadb")):
return {"collation": "utf8mb3_bin"}
return {}
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 7e861e20a6..45462bf726 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -414,7 +414,7 @@ class BaseOperatorMeta(abc.ABCMeta):
if arg not in kwargs and arg in default_args:
kwargs[arg] = default_args[arg]
- missing_args = non_optional_args - set(kwargs)
+ missing_args = non_optional_args.difference(kwargs)
if len(missing_args) == 1:
raise AirflowException(f"missing keyword argument
{missing_args.pop()!r}")
elif missing_args:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 10f4c595d7..1cb9220e58 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -740,7 +740,7 @@ class DAG(LoggingMixin):
for c in self._comps:
# task_ids returns a list and lists can't be hashed
if c == "task_ids":
- val = tuple(self.task_dict.keys())
+ val = tuple(self.task_dict)
else:
val = getattr(self, c, None)
try:
@@ -1256,7 +1256,7 @@ class DAG(LoggingMixin):
@property
def task_ids(self) -> list[str]:
- return list(self.task_dict.keys())
+ return list(self.task_dict)
@property
def teardowns(self) -> list[Operator]:
@@ -2897,7 +2897,7 @@ class DAG(LoggingMixin):
log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}
- dag_ids = set(dag_by_ids.keys())
+ dag_ids = set(dag_by_ids)
query = (
select(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
@@ -3235,7 +3235,7 @@ class DAG(LoggingMixin):
"auto_register",
"fail_stop",
}
- cls.__serialized_fields =
frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
+ cls.__serialized_fields = frozenset(vars(DAG(dag_id="test"))) -
exclusion_list
return cls.__serialized_fields
def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) ->
EdgeInfoType:
@@ -3594,21 +3594,18 @@ class DagModel(Base):
.having(func.count() ==
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
)
}
- dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
+ dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
if dataset_triggered_dag_ids:
- exclusion_list = {
- x
- for x in (
- session.scalars(
- select(DagModel.dag_id)
- .join(DagRun.dag_model)
- .where(DagRun.state.in_((DagRunState.QUEUED,
DagRunState.RUNNING)))
- .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
- .group_by(DagModel.dag_id)
- .having(func.count() >=
func.max(DagModel.max_active_runs))
- )
+ exclusion_list = set(
+ session.scalars(
+ select(DagModel.dag_id)
+ .join(DagRun.dag_model)
+ .where(DagRun.state.in_((DagRunState.QUEUED,
DagRunState.RUNNING)))
+ .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
+ .group_by(DagModel.dag_id)
+ .having(func.count() >= func.max(DagModel.max_active_runs))
)
- }
+ )
if exclusion_list:
dataset_triggered_dag_ids -= exclusion_list
dataset_triggered_dag_info = {
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index e53c2ce3bd..a8f5b4d6fc 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -169,7 +169,7 @@ class DagBag(LoggingMixin):
:return: a list of DAG IDs in this bag
"""
- return list(self.dags.keys())
+ return list(self.dags)
@provide_session
def get_dag(self, dag_id, session: Session = None):
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 8a9a3d8740..36fb5f4165 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -168,7 +168,7 @@ class DictOfListsExpandInput(NamedTuple):
def _find_index_for_this_field(index: int) -> int:
# Need to use the original user input to retain argument order.
- for mapped_key in reversed(list(self.value)):
+ for mapped_key in reversed(self.value):
mapped_length = all_lengths[mapped_key]
if mapped_length < 1:
raise RuntimeError(f"cannot expand field mapped to length
{mapped_length!r}")
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index f52749c7ff..8c19749104 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
class DependencyMixin:
- """Mixing implementing common dependency setting methods methods like >>
and <<."""
+ """Mixing implementing common dependency setting methods like >> and <<."""
@property
def roots(self) -> Sequence[DependencyMixin]: