uranusjr commented on code in PR #66848: URL: https://github.com/apache/airflow/pull/66848#discussion_r3370856422
########## airflow-core/src/airflow/partition_mappers/wait_policy.py: ########## @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any + + +class WaitPolicy: + """ + An object the scheduler asks whether a partitioned Dag run should fire. + + Concrete policies are ``WaitForAll`` and ``MinimumCount``. Each implements + ``is_satisfied(matched, expected)`` and ``is_unreachable(expected)``; the + scheduler calls these methods directly in the hot path on every tick. + """ + + def is_satisfied(self, matched: int, expected: int) -> bool: + raise NotImplementedError + + def is_unreachable(self, expected: int) -> bool: + raise NotImplementedError + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> WaitPolicy: + raise NotImplementedError + + +class WaitForAll(WaitPolicy): + """ + Fires only when every expected upstream key has arrived. + + ``matched == expected`` is the satisfaction condition, including the + vacuously-true case where both are zero (empty window). + ``is_unreachable`` always returns ``False`` — even an empty window + satisfies vacuously. + """ + + def is_satisfied(self, matched: int, expected: int) -> bool: + return matched == expected + + def is_unreachable(self, expected: int) -> bool: + return False + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> WaitForAll: + return cls() + + def __repr__(self) -> str: + return "WaitForAll()" + + def __eq__(self, other: object) -> bool: + return isinstance(other, WaitForAll) + + def __hash__(self) -> int: + return hash(type(self)) Review Comment: This is where `dataclasses` or `attrs` come in handy. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
