This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a0494f83e Refactor: Simplify code in models (#33181)
5a0494f83e is described below

commit 5a0494f83e8ad0e5cbf0d3dcad3022a3ea89d789
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Mon Aug 7 21:54:15 2023 +0000

    Refactor: Simplify code in models (#33181)
---
 airflow/models/base.py         |  2 +-
 airflow/models/baseoperator.py |  2 +-
 airflow/models/dag.py          | 31 ++++++++++++++-----------------
 airflow/models/dagbag.py       |  2 +-
 airflow/models/expandinput.py  |  2 +-
 airflow/models/taskmixin.py    |  2 +-
 6 files changed, 19 insertions(+), 22 deletions(-)

diff --git a/airflow/models/base.py b/airflow/models/base.py
index 5f6b7e9893..934b9b1b74 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -69,7 +69,7 @@ def get_id_collation_args():
         # We cannot use session/dialect as at this point we are trying to 
determine the right connection
         # parameters, so we use the connection
         conn = conf.get("database", "sql_alchemy_conn", fallback="")
-        if conn.startswith("mysql") or conn.startswith("mariadb"):
+        if conn.startswith(("mysql", "mariadb")):
             return {"collation": "utf8mb3_bin"}
         return {}
 
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 7e861e20a6..45462bf726 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -414,7 +414,7 @@ class BaseOperatorMeta(abc.ABCMeta):
                 if arg not in kwargs and arg in default_args:
                     kwargs[arg] = default_args[arg]
 
-            missing_args = non_optional_args - set(kwargs)
+            missing_args = non_optional_args.difference(kwargs)
             if len(missing_args) == 1:
                 raise AirflowException(f"missing keyword argument 
{missing_args.pop()!r}")
             elif missing_args:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 10f4c595d7..1cb9220e58 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -740,7 +740,7 @@ class DAG(LoggingMixin):
         for c in self._comps:
             # task_ids returns a list and lists can't be hashed
             if c == "task_ids":
-                val = tuple(self.task_dict.keys())
+                val = tuple(self.task_dict)
             else:
                 val = getattr(self, c, None)
             try:
@@ -1256,7 +1256,7 @@ class DAG(LoggingMixin):
 
     @property
     def task_ids(self) -> list[str]:
-        return list(self.task_dict.keys())
+        return list(self.task_dict)
 
     @property
     def teardowns(self) -> list[Operator]:
@@ -2897,7 +2897,7 @@ class DAG(LoggingMixin):
         log.info("Sync %s DAGs", len(dags))
         dag_by_ids = {dag.dag_id: dag for dag in dags}
 
-        dag_ids = set(dag_by_ids.keys())
+        dag_ids = set(dag_by_ids)
         query = (
             select(DagModel)
             .options(joinedload(DagModel.tags, innerjoin=False))
@@ -3235,7 +3235,7 @@ class DAG(LoggingMixin):
                 "auto_register",
                 "fail_stop",
             }
-            cls.__serialized_fields = 
frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
+            cls.__serialized_fields = frozenset(vars(DAG(dag_id="test"))) - 
exclusion_list
         return cls.__serialized_fields
 
     def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> 
EdgeInfoType:
@@ -3594,21 +3594,18 @@ class DagModel(Base):
                 .having(func.count() == 
func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
             )
         }
-        dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
+        dataset_triggered_dag_ids = set(dataset_triggered_dag_info)
         if dataset_triggered_dag_ids:
-            exclusion_list = {
-                x
-                for x in (
-                    session.scalars(
-                        select(DagModel.dag_id)
-                        .join(DagRun.dag_model)
-                        .where(DagRun.state.in_((DagRunState.QUEUED, 
DagRunState.RUNNING)))
-                        .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
-                        .group_by(DagModel.dag_id)
-                        .having(func.count() >= 
func.max(DagModel.max_active_runs))
-                    )
+            exclusion_list = set(
+                session.scalars(
+                    select(DagModel.dag_id)
+                    .join(DagRun.dag_model)
+                    .where(DagRun.state.in_((DagRunState.QUEUED, 
DagRunState.RUNNING)))
+                    .where(DagModel.dag_id.in_(dataset_triggered_dag_ids))
+                    .group_by(DagModel.dag_id)
+                    .having(func.count() >= func.max(DagModel.max_active_runs))
                 )
-            }
+            )
             if exclusion_list:
                 dataset_triggered_dag_ids -= exclusion_list
                 dataset_triggered_dag_info = {
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index e53c2ce3bd..a8f5b4d6fc 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -169,7 +169,7 @@ class DagBag(LoggingMixin):
 
         :return: a list of DAG IDs in this bag
         """
-        return list(self.dags.keys())
+        return list(self.dags)
 
     @provide_session
     def get_dag(self, dag_id, session: Session = None):
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 8a9a3d8740..36fb5f4165 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -168,7 +168,7 @@ class DictOfListsExpandInput(NamedTuple):
 
         def _find_index_for_this_field(index: int) -> int:
             # Need to use the original user input to retain argument order.
-            for mapped_key in reversed(list(self.value)):
+            for mapped_key in reversed(self.value):
                 mapped_length = all_lengths[mapped_key]
                 if mapped_length < 1:
                     raise RuntimeError(f"cannot expand field mapped to length 
{mapped_length!r}")
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index f52749c7ff..8c19749104 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
 
 
 class DependencyMixin:
-    """Mixing implementing common dependency setting methods methods like >> 
and <<."""
+    """Mixing implementing common dependency setting methods like >> and <<."""
 
     @property
     def roots(self) -> Sequence[DependencyMixin]:

Reply via email to