jscheffl commented on code in PR #36029:
URL: https://github.com/apache/airflow/pull/36029#discussion_r1520500377


##########
airflow/models/baseoperator.py:
##########
@@ -575,6 +579,11 @@ class derived from this one results in the creation of a 
task object,
         significantly speeding up the task creation process as for very large
         DAGs. Options can be set as string or using the constants defined in
         the static class ``airflow.utils.WeightRule``
+    :param priority_weight_strategy: weighting method used for the effective 
total priority weight

Review Comment:
   I do not fully understand why we need to rename the field. Even if we are 
using a custom implementation, is the old name not matching anymore? Especially 
as the same parameters apply.
   In this way we force all customers to change the DAG definitions as we 
migrate the parameter name, in 95% of cases the parameter and function does not 
change.



##########
airflow/migrations/versions/0137_2_9_0_add_priority_weight_strategy_to_task.py:
##########
@@ -0,0 +1,47 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""add priority_weight_strategy to task_instance
+
+Revision ID: 624ecf3b6a5e
+Revises: 1fd565369930
+Create Date: 2023-10-29 02:01:34.774596
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "624ecf3b6a5e"
+down_revision = "ab34f260b71c"
+branch_labels = None
+depends_on = None
+airflow_version = "2.9.0"
+
+
+def upgrade():
+    """Apply add priority_weight_strategy to task_instance"""
+    with op.batch_alter_table("task_instance") as batch_op:
+        batch_op.add_column(sa.Column("_priority_weight_strategy", sa.JSON()))

Review Comment:
   I understand we need to store some context information about the selcted 
priority weight strategy - but do we really need to add this to the DB? 
TaskInstance is the most largest table in the DB scheme and potentially 
contains millions of rows. Do we really want to store the same values in mostly 
millions of cases? Or can we leave it `NULL` and store only a value if we have 
this special rule and data needs stored?
   
   I like the approach of this PR in general but fear this will create a lot of 
overhead in DB. Especially as it is a JSON field.



##########
airflow/serialization/serialized_objects.py:
##########
@@ -228,6 +257,40 @@ def decode_timetable(var: dict[str, Any]) -> Timetable:
     return timetable_class.deserialize(var[Encoding.VAR])
 
 
+def _encode_priority_weight_strategy(var: PriorityWeightStrategy) -> dict[str, 
Any]:
+    """
+    Encode a priority weight strategy instance.
+
+    This delegates most of the serialization work to the type, so the behavior
+    can be completely controlled by a custom subclass.
+    """
+    priority_weight_strategy_class = type(var)
+    importable_string = qualname(priority_weight_strategy_class)
+    if _get_registered_priority_weight_strategy(importable_string) is None:
+        raise _PriorityWeightStrategyNotRegistered(importable_string)
+    return {Encoding.TYPE: importable_string, Encoding.VAR: var.serialize()}

Review Comment:
   Do we assume we need to store a state for the strategy? Would it not enough 
just to use the task context information and the python code? Would we assume 
real "data" needs to be serialized for a custom strategy that need to be 
stored/retrieved? I feel a lot of overhead for very special use cases where 
strategy state information needs to be persisted.
   Do you have a use case in mind where per task (other than the context) data 
needs to be persisted in a custom strategy? I feel this is a niche use case - 
but maybe I don't have a use case in mind.



##########
airflow/task/priority_strategy.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Priority weight strategies for task scheduling."""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+from airflow.exceptions import AirflowException
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+
+
+class PriorityWeightStrategy(ABC):
+    """Priority weight strategy interface."""
+
+    @abstractmethod
+    def get_weight(self, ti: TaskInstance):
+        """Get the priority weight of a task."""
+        ...
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
+        """Deserialize a priority weight strategy from data.
+
+        This is called when a serialized DAG is deserialized. ``data`` will be 
whatever
+        was returned by ``serialize`` during DAG serialization. The default
+        implementation constructs the priority weight strategy without any 
arguments.
+        """
+        return cls(**data)  # type: ignore[call-arg]
+
+    def serialize(self) -> dict[str, Any]:
+        """Serialize the priority weight strategy for JSON encoding.
+
+        This is called during DAG serialization to store priority weight 
strategy information
+        in the database. This should return a JSON-serializable dict that will 
be fed into
+        ``deserialize`` when the DAG is deserialized. The default 
implementation returns
+        an empty dict.
+        """
+        return {}

Review Comment:
   If there is a real use case that state data needs to be persisted (see other 
comment s- do we really need this?) can we optimize in a way that if no data 
needs to be persisted that `None`is returned? Would at least leave the DB 
column empty in 95% of cases I feel.



##########
airflow/task/priority_strategy.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Priority weight strategies for task scheduling."""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+from airflow.exceptions import AirflowException
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+
+
+class PriorityWeightStrategy(ABC):
+    """Priority weight strategy interface."""
+
+    @abstractmethod
+    def get_weight(self, ti: TaskInstance):
+        """Get the priority weight of a task."""
+        ...
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
+        """Deserialize a priority weight strategy from data.
+
+        This is called when a serialized DAG is deserialized. ``data`` will be 
whatever
+        was returned by ``serialize`` during DAG serialization. The default
+        implementation constructs the priority weight strategy without any 
arguments.
+        """
+        return cls(**data)  # type: ignore[call-arg]
+
+    def serialize(self) -> dict[str, Any]:
+        """Serialize the priority weight strategy for JSON encoding.
+
+        This is called during DAG serialization to store priority weight 
strategy information
+        in the database. This should return a JSON-serializable dict that will 
be fed into
+        ``deserialize`` when the DAG is deserialized. The default 
implementation returns
+        an empty dict.
+        """
+        return {}
+
+    def __eq__(self, other: object) -> bool:
+        """Equality comparison."""
+        if not isinstance(other, type(self)):
+            return False
+        return self.serialize() == other.serialize()
+
+
+class AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
+    """Priority weight strategy that uses the task's priority weight 
directly."""
+
+    def get_weight(self, ti: TaskInstance):

Review Comment:
   ```suggestion
       def get_weight(self, ti: TaskInstance) -> int:
   ```



##########
airflow/task/priority_strategy.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Priority weight strategies for task scheduling."""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+from airflow.exceptions import AirflowException
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+
+
+class PriorityWeightStrategy(ABC):
+    """Priority weight strategy interface."""
+
+    @abstractmethod
+    def get_weight(self, ti: TaskInstance):
+        """Get the priority weight of a task."""
+        ...
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
+        """Deserialize a priority weight strategy from data.
+
+        This is called when a serialized DAG is deserialized. ``data`` will be 
whatever
+        was returned by ``serialize`` during DAG serialization. The default
+        implementation constructs the priority weight strategy without any 
arguments.
+        """
+        return cls(**data)  # type: ignore[call-arg]
+
+    def serialize(self) -> dict[str, Any]:
+        """Serialize the priority weight strategy for JSON encoding.
+
+        This is called during DAG serialization to store priority weight 
strategy information
+        in the database. This should return a JSON-serializable dict that will 
be fed into
+        ``deserialize`` when the DAG is deserialized. The default 
implementation returns
+        an empty dict.
+        """
+        return {}
+
+    def __eq__(self, other: object) -> bool:
+        """Equality comparison."""
+        if not isinstance(other, type(self)):
+            return False
+        return self.serialize() == other.serialize()
+
+
+class AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
+    """Priority weight strategy that uses the task's priority weight 
directly."""
+
+    def get_weight(self, ti: TaskInstance):
+        return ti.task.priority_weight
+
+
+class DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
+    """Priority weight strategy that uses the sum of the priority weights of 
all downstream tasks."""
+
+    def get_weight(self, ti: TaskInstance):

Review Comment:
   ```suggestion
       def get_weight(self, ti: TaskInstance) -> int:
   ```



##########
tests/serialization/test_dag_serialization.py:
##########
@@ -487,32 +488,32 @@ def sorted_serialized_dag(dag_dict: dict):
         expected = json.loads(json.dumps(sorted_serialized_dag(expected)))
         return actual, expected
 
-    def test_deserialization_across_process(self):
-        """A serialized DAG can be deserialized in another process."""
-
-        # Since we need to parse the dags twice here (once in the subprocess,
-        # and once here to get a DAG to compare to) we don't want to load all
-        # dags.
-        queue = multiprocessing.Queue()
-        proc = multiprocessing.Process(target=serialize_subprocess, 
args=(queue, "airflow/example_dags"))
-        proc.daemon = True
-        proc.start()
-
-        stringified_dags = {}
-        while True:
-            v = queue.get()
-            if v is None:
-                break
-            dag = SerializedDAG.from_json(v)
-            assert isinstance(dag, DAG)
-            stringified_dags[dag.dag_id] = dag
-
-        dags = collect_dags("airflow/example_dags")
-        assert set(stringified_dags.keys()) == set(dags.keys())
-
-        # Verify deserialized DAGs.
-        for dag_id in stringified_dags:
-            self.validate_deserialized_dag(stringified_dags[dag_id], 
dags[dag_id])
+    # def test_deserialization_across_process(self):
+    #     """A serialized DAG can be deserialized in another process."""
+    #
+    #     # Since we need to parse the dags twice here (once in the subprocess,
+    #     # and once here to get a DAG to compare to) we don't want to load all
+    #     # dags.
+    #     queue = multiprocessing.Queue()
+    #     proc = multiprocessing.Process(target=serialize_subprocess, 
args=(queue, "airflow/example_dags"))
+    #     proc.daemon = True
+    #     proc.start()
+    #
+    #     stringified_dags = {}
+    #     while True:
+    #         v = queue.get()
+    #         if v is None:
+    #             break
+    #         dag = SerializedDAG.from_json(v)
+    #         assert isinstance(dag, DAG)
+    #         stringified_dags[dag.dag_id] = dag
+    #
+    #     dags = collect_dags("airflow/example_dags")
+    #     assert set(stringified_dags.keys()) == set(dags.keys())
+    #
+    #     # Verify deserialized DAGs.
+    #     for dag_id in stringified_dags:
+    #         self.validate_deserialized_dag(stringified_dags[dag_id], 
dags[dag_id])

Review Comment:
   Commented code - does this need fixed or shall it be deleted?



##########
airflow/task/priority_strategy.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Priority weight strategies for task scheduling."""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+from airflow.exceptions import AirflowException
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+
+
+class PriorityWeightStrategy(ABC):
+    """Priority weight strategy interface."""
+
+    @abstractmethod
+    def get_weight(self, ti: TaskInstance):
+        """Get the priority weight of a task."""

Review Comment:
   Two suggestions:
   1) Return type missing? What is returned?
   2) Is it a relative or absolute weight?
   ```suggestion
       def get_weight(self, ti: TaskInstance) -> int:
           """Get the absolute priority weight of a task."""
   ```



##########
airflow/task/priority_strategy.py:
##########
@@ -0,0 +1,144 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Priority weight strategies for task scheduling."""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any
+
+from airflow.exceptions import AirflowException
+
+if TYPE_CHECKING:
+    from airflow.models.taskinstance import TaskInstance
+
+
+class PriorityWeightStrategy(ABC):
+    """Priority weight strategy interface."""
+
+    @abstractmethod
+    def get_weight(self, ti: TaskInstance):
+        """Get the priority weight of a task."""
+        ...
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> PriorityWeightStrategy:
+        """Deserialize a priority weight strategy from data.
+
+        This is called when a serialized DAG is deserialized. ``data`` will be 
whatever
+        was returned by ``serialize`` during DAG serialization. The default
+        implementation constructs the priority weight strategy without any 
arguments.
+        """
+        return cls(**data)  # type: ignore[call-arg]
+
+    def serialize(self) -> dict[str, Any]:
+        """Serialize the priority weight strategy for JSON encoding.
+
+        This is called during DAG serialization to store priority weight 
strategy information
+        in the database. This should return a JSON-serializable dict that will 
be fed into
+        ``deserialize`` when the DAG is deserialized. The default 
implementation returns
+        an empty dict.
+        """
+        return {}
+
+    def __eq__(self, other: object) -> bool:
+        """Equality comparison."""
+        if not isinstance(other, type(self)):
+            return False
+        return self.serialize() == other.serialize()
+
+
+class AbsolutePriorityWeightStrategy(PriorityWeightStrategy):
+    """Priority weight strategy that uses the task's priority weight 
directly."""
+
+    def get_weight(self, ti: TaskInstance):
+        return ti.task.priority_weight
+
+
+class DownstreamPriorityWeightStrategy(PriorityWeightStrategy):
+    """Priority weight strategy that uses the sum of the priority weights of 
all downstream tasks."""
+
+    def get_weight(self, ti: TaskInstance):
+        dag = ti.task.get_dag()
+        if dag is None:
+            return ti.task.priority_weight
+        return ti.task.priority_weight + sum(
+            dag.task_dict[task_id].priority_weight
+            for task_id in ti.task.get_flat_relative_ids(upstream=False)
+        )
+
+
+class UpstreamPriorityWeightStrategy(PriorityWeightStrategy):
+    """Priority weight strategy that uses the sum of the priority weights of 
all upstream tasks."""
+
+    def get_weight(self, ti: TaskInstance):

Review Comment:
   ```suggestion
       def get_weight(self, ti: TaskInstance) -> int:
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to