This is an automated email from the ASF dual-hosted git repository. dianfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new a77d052 [FLINK-22912][python] Support state ttl in Python DataStream API a77d052 is described below commit a77d0523ffd8aac907e062ace259f9699a276388 Author: huangxingbo <hxbks...@gmail.com> AuthorDate: Mon Aug 2 15:01:10 2021 +0800 [FLINK-22912][python] Support state ttl in Python DataStream API This closes #16667. --- .../docs/dev/datastream/fault-tolerance/state.md | 48 +- .../docs/dev/datastream/fault-tolerance/state.md | 48 +- flink-python/pyflink/common/__init__.py | 3 +- flink-python/pyflink/common/time.py | 41 +- flink-python/pyflink/datastream/state.py | 488 ++++++++++++++++++++- .../pyflink/datastream/tests/test_data_stream.py | 31 +- .../fn_execution/datastream/runtime_context.py | 19 +- .../pyflink/fn_execution/flink_fn_execution_pb2.py | 476 +++++++++++++++++++- flink-python/pyflink/fn_execution/state_impl.py | 97 ++-- ...b2_synced.py => test_flink_fn_execution_pb2.py} | 38 +- .../pyflink/proto/flink-fn-execution.proto | 89 ++++ .../python/beam/SimpleStateRequestHandler.java | 26 +- .../flink/streaming/api/utils/ProtoUtils.java | 93 +++- .../flink/streaming/api/utils/ProtoUtilsTest.java | 122 ++++++ 14 files changed, 1548 insertions(+), 71 deletions(-) diff --git a/docs/content.zh/docs/dev/datastream/fault-tolerance/state.md b/docs/content.zh/docs/dev/datastream/fault-tolerance/state.md index 31e5b40..5b45613 100644 --- a/docs/content.zh/docs/dev/datastream/fault-tolerance/state.md +++ b/docs/content.zh/docs/dev/datastream/fault-tolerance/state.md @@ -342,6 +342,22 @@ val stateDescriptor = new ValueStateDescriptor[String]("text state", classOf[Str stateDescriptor.enableTimeToLive(ttlConfig) ``` {{< /tab >}} +{{< tab "Python" >}} +```python +from pyflink.common.time import Time +from pyflink.common.typeinfo import Types +from pyflink.datastream.state import ValueStateDescriptor, StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite) \ + .set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired) \ + .build() + +state_descriptor = ValueStateDescriptor("text state", Types.STRING()) +state_descriptor.enable_time_to_live(ttl_config) +``` +{{< /tab >}} {{< /tabs >}} TTL 配置有以下几个选项: @@ -404,7 +420,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL 当前在 PyFlink DataStream API 中还不支持。 +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .disable_cleanup_in_background() \ + .build() ``` {{< /tab >}} {{< /tabs >}} @@ -441,7 +463,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL 当前在 PyFlink DataStream API 中还不支持。 +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .cleanup_full_snapshot() \ + .build() ``` {{< /tab >}} {{< /tabs >}} @@ -479,7 +507,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL 当前在 PyFlink DataStream API 中还不支持。 +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .cleanup_incrementally(10, True) \ + .build() ``` {{< /tab >}} {{< /tabs >}} @@ -524,7 +558,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL 当前在 PyFlink DataStream API 中还不支持。 +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .cleanup_in_rocksdb_compact_filter(1000) \ + .build() ``` {{< /tab >}} {{< /tabs >}} diff --git a/docs/content/docs/dev/datastream/fault-tolerance/state.md b/docs/content/docs/dev/datastream/fault-tolerance/state.md index 24e35b2..0137a6e 100644 --- a/docs/content/docs/dev/datastream/fault-tolerance/state.md +++ b/docs/content/docs/dev/datastream/fault-tolerance/state.md @@ -369,6 +369,22 @@ val stateDescriptor = new ValueStateDescriptor[String]("text state", classOf[Str stateDescriptor.enableTimeToLive(ttlConfig) ``` {{< /tab >}} +{{< tab "Python" >}} +```python +from pyflink.common.time import Time +from pyflink.common.typeinfo import Types +from pyflink.datastream.state import ValueStateDescriptor, StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite) \ + .set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired) \ + .build() + +state_descriptor = ValueStateDescriptor("text state", Types.STRING()) +state_descriptor.enable_time_to_live(ttl_config) +``` +{{< /tab >}} {{< /tabs >}} The configuration has several options to consider: @@ -438,7 +454,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL is still not supported in PyFlink DataStream API. +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .disable_cleanup_in_background() \ + .build() ``` {{< /tab >}} {{< /tabs >}} @@ -478,7 +500,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL is still not supported in PyFlink DataStream API. +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .cleanup_full_snapshot() \ + .build() ``` {{< /tab >}} {{< /tabs >}} @@ -522,7 +550,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL is still not supported in PyFlink DataStream API. +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .cleanup_incrementally(10, True) \ + .build() ``` {{< /tab >}} {{< /tabs >}} @@ -574,7 +608,13 @@ val ttlConfig = StateTtlConfig {{< /tab >}} {{< tab "Python" >}} ```python -State TTL is still not supported in PyFlink DataStream API. +from pyflink.common.time import Time +from pyflink.datastream.state import StateTtlConfig + +ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .cleanup_in_rocksdb_compact_filter(1000) \ + .build() ``` {{< /tab >}} {{< /tabs >}} diff --git a/flink-python/pyflink/common/__init__.py b/flink-python/pyflink/common/__init__.py index d60c5e3..968316b 100644 --- a/flink-python/pyflink/common/__init__.py +++ b/flink-python/pyflink/common/__init__.py @@ -35,7 +35,7 @@ from pyflink.common.job_status import JobStatus from pyflink.common.restart_strategy import RestartStrategies, RestartStrategyConfiguration from pyflink.common.typeinfo import Types, TypeInformation from pyflink.common.types import Row, RowKind -from pyflink.common.time import Duration, Instant +from pyflink.common.time import Duration, Instant, Time from pyflink.common.watermark_strategy import WatermarkStrategy __all__ = [ @@ -59,4 +59,5 @@ __all__ = [ "Types", "TypeInformation", "Instant", + "Time" ] diff --git a/flink-python/pyflink/common/time.py b/flink-python/pyflink/common/time.py index 41e4f0f..938a046 100644 --- a/flink-python/pyflink/common/time.py +++ b/flink-python/pyflink/common/time.py @@ -17,7 +17,7 @@ ################################################################################ from pyflink.java_gateway import get_gateway -__all__ = ['Duration', 'Instant'] +__all__ = ['Duration', 'Instant', 'Time'] class Duration(object): @@ -83,3 +83,42 @@ class Instant(object): def __repr__(self): return 'Instant<{}, {}>'.format(self.seconds, self.nanos) + + +class Time(object): + """ + The definition of a time interval. + """ + + def __init__(self, milliseconds: int): + self._milliseconds = milliseconds + + def to_milliseconds(self) -> int: + return self._milliseconds + + @staticmethod + def milliseconds(milliseconds: int): + return Time(milliseconds) + + @staticmethod + def seconds(seconds: int): + return Time.milliseconds(seconds * 1000) + + @staticmethod + def minutes(minutes: int): + return Time.seconds(minutes * 60) + + @staticmethod + def hours(hours: int): + return Time.minutes(hours * 60) + + @staticmethod + def days(days: int): + return Time.hours(days * 24) + + def __eq__(self, other): + return (self.__class__ == other.__class__ and + self._milliseconds == other._milliseconds) + + def __str__(self): + return "{} ms".format(self._milliseconds) diff --git a/flink-python/pyflink/datastream/state.py b/flink-python/pyflink/datastream/state.py index a576d41..bd1c263 100644 --- a/flink-python/pyflink/datastream/state.py +++ b/flink-python/pyflink/datastream/state.py @@ -16,10 +16,12 @@ # limitations under the License. ################################################################################ from abc import ABC, abstractmethod +from enum import Enum -from typing import TypeVar, Generic, Iterable, List, Iterator, Dict, Tuple +from typing import TypeVar, Generic, Iterable, List, Iterator, Dict, Tuple, Optional from pyflink.common.typeinfo import TypeInformation, Types +from pyflink.common.time import Time __all__ = [ 'ValueStateDescriptor', @@ -31,7 +33,8 @@ __all__ = [ 'ReducingStateDescriptor', 'ReducingState', 'AggregatingStateDescriptor', - 'AggregatingState' + 'AggregatingState', + 'StateTtlConfig' ] T = TypeVar('T') @@ -293,6 +296,7 @@ class StateDescriptor(ABC): """ self.name = name self.type_info = type_info + self._ttl_config = None # type: Optional[StateTtlConfig] def get_name(self) -> str: """ @@ -302,6 +306,17 @@ class StateDescriptor(ABC): """ return self.name + def enable_time_to_live(self, ttl_config: 'StateTtlConfig'): + """ + Configures optional activation of state time-to-live (TTL). + + State user value will expire, become unavailable and be cleaned up in storage depending on + configured StateTtlConfig. + + :param ttl_config: Configuration of state TTL + """ + self._ttl_config = ttl_config + class ValueStateDescriptor(StateDescriptor): """ @@ -402,3 +417,472 @@ class AggregatingStateDescriptor(StateDescriptor): def get_agg_function(self): return self._agg_function + + +class StateTtlConfig(object): + class UpdateType(Enum): + """ + This option value configures when to update last access timestamp which prolongs state TTL. + """ + + Disabled = 0 + """ + TTL is disabled. State does not expire. + """ + + OnCreateAndWrite = 1 + """ + Last access timestamp is initialised when state is created and updated on every write + operation. + """ + + OnReadAndWrite = 2 + """ + The same as OnCreateAndWrite but also updated on read. + """ + + def _to_proto(self): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + return getattr(StateDescriptor.StateTTLConfig.UpdateType, self.name) + + @staticmethod + def _from_proto(proto): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + update_type_name = StateDescriptor.StateTTLConfig.UpdateType.Name(proto) + return StateTtlConfig.UpdateType[update_type_name] + + class StateVisibility(Enum): + """ + This option configures whether expired user value can be returned or not. + """ + + ReturnExpiredIfNotCleanedUp = 0 + """ + Return expired user value if it is not cleaned up yet. + """ + + NeverReturnExpired = 1 + """ + Never return expired user value. + """ + + def _to_proto(self): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + return getattr(StateDescriptor.StateTTLConfig.StateVisibility, self.name) + + @staticmethod + def _from_proto(proto): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + state_visibility_name = StateDescriptor.StateTTLConfig.StateVisibility.Name(proto) + return StateTtlConfig.StateVisibility[state_visibility_name] + + class TtlTimeCharacteristic(Enum): + """ + This option configures time scale to use for ttl. + """ + + ProcessingTime = 0 + """ + Processing time + """ + + def _to_proto(self): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + return getattr(StateDescriptor.StateTTLConfig.TtlTimeCharacteristic, self.name) + + @staticmethod + def _from_proto(proto): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + ttl_time_characteristic_name = \ + StateDescriptor.StateTTLConfig.TtlTimeCharacteristic.Name(proto) + return StateTtlConfig.TtlTimeCharacteristic[ttl_time_characteristic_name] + + def __init__(self, + update_type: UpdateType, + state_visibility: StateVisibility, + ttl_time_characteristic: TtlTimeCharacteristic, + ttl: Time, + cleanup_strategies: 'StateTtlConfig.CleanupStrategies'): + self._update_type = update_type + self._state_visibility = state_visibility + self._ttl_time_characteristic = ttl_time_characteristic + self._ttl = ttl + self._cleanup_strategies = cleanup_strategies + + @staticmethod + def new_builder(ttl: Time): + return StateTtlConfig.Builder(ttl) + + def get_update_type(self) -> 'StateTtlConfig.UpdateType': + return self._update_type + + def get_state_visibility(self) -> 'StateTtlConfig.StateVisibility': + return self._state_visibility + + def get_ttl(self) -> Time: + return self._ttl + + def get_ttl_time_characteristic(self) -> 'StateTtlConfig.TtlTimeCharacteristic': + return self._ttl_time_characteristic + + def is_enabled(self) -> bool: + return self._update_type.value != StateTtlConfig.UpdateType.Disabled.value + + def get_cleanup_strategies(self) -> 'StateTtlConfig.CleanupStrategies': + return self._cleanup_strategies + + def _to_proto(self): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + state_ttl_config = StateDescriptor.StateTTLConfig() + state_ttl_config.update_type = self._update_type._to_proto() + state_ttl_config.state_visibility = self._state_visibility._to_proto() + state_ttl_config.ttl_time_characteristic = self._ttl_time_characteristic._to_proto() + state_ttl_config.ttl = self._ttl.to_milliseconds() + state_ttl_config.cleanup_strategies.CopyFrom(self._cleanup_strategies._to_proto()) + return state_ttl_config + + @staticmethod + def _from_proto(proto): + update_type = StateTtlConfig.UpdateType._from_proto(proto.update_type) + state_visibility = StateTtlConfig.StateVisibility._from_proto(proto.state_visibility) + ttl_time_characteristic = \ + StateTtlConfig.TtlTimeCharacteristic._from_proto(proto.ttl_time_characteristic) + ttl = Time.milliseconds(proto.ttl) + cleanup_strategies = StateTtlConfig.CleanupStrategies._from_proto(proto.cleanup_strategies) + builder = StateTtlConfig.new_builder(ttl) \ + .set_update_type(update_type) \ + .set_state_visibility(state_visibility) \ + .set_ttl_time_characteristic(ttl_time_characteristic) + builder._strategies = cleanup_strategies._strategies + builder._is_cleanup_in_background = cleanup_strategies._is_cleanup_in_background + return builder.build() + + def __repr__(self): + return "StateTtlConfig<" \ + "update_type={}," \ + " state_visibility={}," \ + "ttl_time_characteristic ={}," \ + "ttl={}>".format(self._update_type, + self._state_visibility, + self._ttl_time_characteristic, + self._ttl) + + class Builder(object): + """ + Builder for the StateTtlConfig. + """ + + def __init__(self, ttl: Time): + self._ttl = ttl + self._update_type = StateTtlConfig.UpdateType.OnCreateAndWrite + self._state_visibility = StateTtlConfig.StateVisibility.NeverReturnExpired + self._ttl_time_characteristic = StateTtlConfig.TtlTimeCharacteristic.ProcessingTime + self._is_cleanup_in_background = True + self._strategies = {} # type: Dict + + def set_update_type(self, + update_type: 'StateTtlConfig.UpdateType') -> 'StateTtlConfig.Builder': + """ + Sets the ttl update type. + + :param update_type: The ttl update type configures when to update last access timestamp + which prolongs state TTL. + """ + self._update_type = update_type + return self + + def update_ttl_on_create_and_write(self) -> 'StateTtlConfig.Builder': + return self.set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite) + + def update_ttl_on_read_and_write(self) -> 'StateTtlConfig.Builder': + return self.set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) + + def set_state_visibility( + self, + state_visibility: 'StateTtlConfig.StateVisibility') -> 'StateTtlConfig.Builder': + """ + Sets the state visibility. + + :param state_visibility: The state visibility configures whether expired user value can + be returned or not. + """ + + self._state_visibility = state_visibility + return self + + def return_expired_if_not_cleaned_up(self) -> 'StateTtlConfig.Builder': + return self.set_state_visibility( + StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp) + + def never_return_expired(self) -> 'StateTtlConfig.Builder': + return self.set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired) + + def set_ttl_time_characteristic( + self, + ttl_time_characteristic: 'StateTtlConfig.TtlTimeCharacteristic') \ + -> 'StateTtlConfig.Builder': + """ + Sets the time characteristic. + + :param ttl_time_characteristic: The time characteristic configures time scale to use for + ttl. + """ + self._ttl_time_characteristic = ttl_time_characteristic + return self + + def use_processing_time(self) -> 'StateTtlConfig.Builder': + return self.set_ttl_time_characteristic( + StateTtlConfig.TtlTimeCharacteristic.ProcessingTime) + + def cleanup_full_snapshot(self) -> 'StateTtlConfig.Builder': + """ + Cleanup expired state in full snapshot on checkpoint. + """ + self._strategies[ + StateTtlConfig.CleanupStrategies.Strategies.FULL_STATE_SCAN_SNAPSHOT] = \ + StateTtlConfig.CleanupStrategies.EMPTY_STRATEGY + return self + + def cleanup_incrementally(self, + cleanup_size: int, + run_cleanup_for_every_record) -> 'StateTtlConfig.Builder': + """ + Cleanup expired state incrementally cleanup local state. + + Upon every state access this cleanup strategy checks a bunch of state keys for + expiration and cleans up expired ones. It keeps a lazy iterator through all keys with + relaxed consistency if backend supports it. This way all keys should be regularly + checked and cleaned eventually over time if any state is constantly being accessed. + + Additionally to the incremental cleanup upon state access, it can also run per every + record. Caution: if there are a lot of registered states using this option, they all + will be iterated for every record to check if there is something to cleanup. + + if no access happens to this state or no records are processed in case of + run_cleanup_for_every_record, expired state will persist. + + Time spent for the incremental cleanup increases record processing latency. + + Note: + + At the moment incremental cleanup is implemented only for Heap state backend. + Setting it for RocksDB will have no effect. + + Note: + + If heap state backend is used with synchronous snapshotting, the global iterator keeps a + copy of all keys while iterating because of its specific implementation which does not + support concurrent modifications. Enabling of this feature will increase memory + consumption then. Asynchronous snapshotting does not have this problem. + + :param cleanup_size: max number of keys pulled from queue for clean up upon state touch + for any key + :param run_cleanup_for_every_record: run incremental cleanup per each processed record + """ + self._strategies[StateTtlConfig.CleanupStrategies.Strategies.INCREMENTAL_CLEANUP] = \ + StateTtlConfig.CleanupStrategies.IncrementalCleanupStrategy( + cleanup_size, run_cleanup_for_every_record) + return self + + def cleanup_in_rocksdb_compact_filter( + self, + query_time_after_num_entries) -> 'StateTtlConfig.Builder': + """ + Cleanup expired state while Rocksdb compaction is running. + + RocksDB compaction filter will query current timestamp, used to check expiration, from + Flink every time after processing {@code queryTimeAfterNumEntries} number of state + entries. Updating the timestamp more often can improve cleanup speed but it decreases + compaction performance because it uses JNI call from native code. + + :param query_time_after_num_entries: number of state entries to process by compaction + filter before updating current timestamp + :return: + """ + self._strategies[ + StateTtlConfig.CleanupStrategies.Strategies.ROCKSDB_COMPACTION_FILTER] = \ + StateTtlConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy( + query_time_after_num_entries) + return self + + def disable_cleanup_in_background(self) -> 'StateTtlConfig.Builder': + """ + Disable default cleanup of expired state in background (enabled by default). + + If some specific cleanup is configured, e.g. :func:`cleanup_incrementally` or + :func:`cleanup_in_rocksdb_compact_filter`, this setting does not disable it. + """ + self._is_cleanup_in_background = False + return self + + def set_ttl(self, ttl: Time) -> 'StateTtlConfig.Builder': + """ + Sets the ttl time. + + :param ttl: The ttl time. + """ + self._ttl = ttl + return self + + def build(self) -> 'StateTtlConfig': + return StateTtlConfig( + self._update_type, + self._state_visibility, + self._ttl_time_characteristic, + self._ttl, + StateTtlConfig.CleanupStrategies(self._strategies, self._is_cleanup_in_background) + ) + + class CleanupStrategies(object): + """ + TTL cleanup strategies. + + This class configures when to cleanup expired state with TTL. By default, state is always + cleaned up on explicit read access if found expired. Currently cleanup of state full + snapshot can be additionally activated. + """ + + class Strategies(Enum): + """ + Fixed strategies ordinals in strategies config field. + """ + + FULL_STATE_SCAN_SNAPSHOT = 0 + INCREMENTAL_CLEANUP = 1 + ROCKSDB_COMPACTION_FILTER = 2 + + def _to_proto(self): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + return getattr( + StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies, self.name) + + @staticmethod + def _from_proto(proto): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + strategies_name = \ + StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies.Name(proto) + return StateTtlConfig.CleanupStrategies.Strategies[strategies_name] + + class CleanupStrategy(ABC): + """ + Base interface for cleanup strategies configurations. + """ + pass + + class EmptyCleanupStrategy(CleanupStrategy): + pass + + class IncrementalCleanupStrategy(CleanupStrategy): + """ + Configuration of cleanup strategy while taking the full snapshot. + """ + + def __init__(self, cleanup_size: int, run_cleanup_for_every_record: int): + self._cleanup_size = cleanup_size + self._run_cleanup_for_every_record = run_cleanup_for_every_record + + def get_cleanup_size(self) -> int: + return self._cleanup_size + + def run_cleanup_for_every_record(self) -> int: + return self._run_cleanup_for_every_record + + class RocksdbCompactFilterCleanupStrategy(CleanupStrategy): + """ + Configuration of cleanup strategy using custom compaction filter in RocksDB. + """ + + def __init__(self, query_time_after_num_entries: int): + self._query_time_after_num_entries = query_time_after_num_entries + + def get_query_time_after_num_entries(self) -> int: + return self._query_time_after_num_entries + + EMPTY_STRATEGY = EmptyCleanupStrategy() + + def __init__(self, + strategies: Dict[Strategies, CleanupStrategy], + is_cleanup_in_background: bool): + self._strategies = strategies + self._is_cleanup_in_background = is_cleanup_in_background + + def is_cleanup_in_background(self) -> bool: + return self._is_cleanup_in_background + + def in_full_snapshot(self) -> bool: + return (StateTtlConfig.CleanupStrategies.Strategies.FULL_STATE_SCAN_SNAPSHOT in + self._strategies) + + def get_incremental_cleanup_strategy(self) \ + -> 'StateTtlConfig.CleanupStrategies.IncrementalCleanupStrategy': + if self._is_cleanup_in_background: + default_strategy = \ + StateTtlConfig.CleanupStrategies.IncrementalCleanupStrategy(5, False) + else: + default_strategy = None + return self._strategies.get( # type: ignore + StateTtlConfig.CleanupStrategies.Strategies.INCREMENTAL_CLEANUP, + default_strategy) + + def get_rocksdb_compact_filter_cleanup_strategy(self) \ + -> 'StateTtlConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy': + if self._is_cleanup_in_background: + default_strategy = \ + StateTtlConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy(1000) + else: + default_strategy = None + return self._strategies.get( # type: ignore + StateTtlConfig.CleanupStrategies.Strategies.ROCKSDB_COMPACTION_FILTER, + default_strategy) + + def _to_proto(self): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + DescriptorCleanupStrategies = StateDescriptor.StateTTLConfig.CleanupStrategies + CleanupStrategies = StateTtlConfig.CleanupStrategies + + cleanup_strategies = StateDescriptor.StateTTLConfig.CleanupStrategies() + cleanup_strategies.is_cleanup_in_background = self._is_cleanup_in_background + for k, v in self._strategies.items(): + cleanup_strategy = cleanup_strategies.strategies.add() + cleanup_strategy.strategy = k._to_proto() + if isinstance(v, CleanupStrategies.EmptyCleanupStrategy): + empty_strategy = DescriptorCleanupStrategies.EmptyCleanupStrategy.EMPTY_STRATEGY + cleanup_strategy.empty_strategy = empty_strategy + elif isinstance(v, CleanupStrategies.IncrementalCleanupStrategy): + incremental_cleanup_strategy = \ + DescriptorCleanupStrategies.IncrementalCleanupStrategy() + incremental_cleanup_strategy.cleanup_size = v._cleanup_size + incremental_cleanup_strategy.run_cleanup_for_every_record = \ + v._run_cleanup_for_every_record + cleanup_strategy.incremental_cleanup_strategy.CopyFrom( + incremental_cleanup_strategy) + elif isinstance(v, CleanupStrategies.RocksdbCompactFilterCleanupStrategy): + rocksdb_compact_filter_cleanup_strategy = \ + DescriptorCleanupStrategies.RocksdbCompactFilterCleanupStrategy() + rocksdb_compact_filter_cleanup_strategy.query_time_after_num_entries = \ + v._query_time_after_num_entries + cleanup_strategy.rocksdb_compact_filter_cleanup_strategy.CopyFrom( + rocksdb_compact_filter_cleanup_strategy) + return cleanup_strategies + + @staticmethod + def _from_proto(proto): + CleanupStrategies = StateTtlConfig.CleanupStrategies + + strategies = {} + is_cleanup_in_background = proto.is_cleanup_in_background + for strategy_entry in proto.strategies: + strategy = CleanupStrategies.Strategies._from_proto(strategy_entry.strategy) + if strategy_entry.HasField('empty_strategy'): + strategies[strategy] = CleanupStrategies.EmptyCleanupStrategy + elif strategy_entry.HasField('incremental_cleanup_strategy'): + incremental_cleanup_strategy = strategy_entry.incremental_cleanup_strategy + strategies[strategy] = CleanupStrategies.IncrementalCleanupStrategy( + incremental_cleanup_strategy.cleanup_size, + incremental_cleanup_strategy.run_cleanup_for_every_record) + elif strategy_entry.HasField('rocksdb_compact_filter_cleanup_strategy'): + rocksdb_compact_filter_cleanup_strategy = \ + strategy_entry.rocksdb_compact_filter_cleanup_strategy + strategies[strategy] = CleanupStrategies.RocksdbCompactFilterCleanupStrategy( + rocksdb_compact_filter_cleanup_strategy.query_time_after_num_entries) + return CleanupStrategies(strategies, is_cleanup_in_background) diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 097428d..63f1f4a 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -22,7 +22,8 @@ import unittest import uuid from typing import Collection, Iterable -from pyflink.common import Row +from pyflink.common import Row, Configuration +from pyflink.common.time import Time from pyflink.common.serializer import TypeSerializer from pyflink.common.typeinfo import Types from pyflink.common.watermark_strategy import WatermarkStrategy, TimestampAssigner @@ -36,7 +37,7 @@ from pyflink.datastream.functions import (AggregateFunction, CoMapFunction, CoFl WindowFunction) from pyflink.datastream.state import (ValueStateDescriptor, ListStateDescriptor, MapStateDescriptor, ReducingStateDescriptor, ReducingState, AggregatingState, - AggregatingStateDescriptor) + AggregatingStateDescriptor, StateTtlConfig) from pyflink.datastream.window import (CountWindowSerializer, MergingWindowAssigner, TimeWindow, TimeWindowSerializer) from pyflink.java_gateway import get_gateway @@ -620,6 +621,12 @@ class DataStreamTests(object): list_state_descriptor = ListStateDescriptor('list_state', Types.INT()) self.list_state = runtime_context.get_list_state(list_state_descriptor) map_state_descriptor = MapStateDescriptor('map_state', Types.INT(), Types.STRING()) + state_ttl_config = StateTtlConfig \ + .new_builder(Time.seconds(1)) \ + .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \ + .disable_cleanup_in_background() \ + .build() + map_state_descriptor.enable_time_to_live(state_ttl_config) self.map_state = runtime_context.get_map_state(map_state_descriptor) def process_element(self, value, ctx): @@ -727,20 +734,32 @@ class DataStreamTests(object): self.aggregating_state = None # type: AggregatingState def open(self, runtime_context: RuntimeContext): - self.aggregating_state = runtime_context.get_aggregating_state( - AggregatingStateDescriptor( - 'aggregating_state', MyAggregateFunction(), Types.INT())) + descriptor = AggregatingStateDescriptor( + 'aggregating_state', MyAggregateFunction(), Types.INT()) + state_ttl_config = StateTtlConfig \ + .new_builder(Time.milliseconds(1)) \ + .set_update_type(StateTtlConfig.UpdateType.OnReadAndWrite) \ + .disable_cleanup_in_background() \ + .build() + descriptor.enable_time_to_live(state_ttl_config) + self.aggregating_state = runtime_context.get_aggregating_state(descriptor) def process_element(self, value, ctx): self.aggregating_state.add(value[0]) yield self.aggregating_state.get(), value[1] + config = Configuration( + j_configuration=get_j_env_configuration(self.env._j_stream_execution_environment)) + config.set_integer("python.fn-execution.bundle.size", 1) data_stream.key_by(lambda x: x[1], key_type=Types.STRING()) \ .process(MyProcessFunction(), output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \ .add_sink(self.test_sink) self.env.execute('test_aggregating_state') results = self.test_sink.get_results() - expected = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)'] + if isinstance(self, PyFlinkBatchTestCase): + expected = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', '(9,hi)', '(12,hello)'] + else: + expected = ['(1,hi)', '(2,hello)', '(3,hi)', '(4,hello)', '(5,hi)', '(6,hello)'] self.assert_equals_sorted(expected, results) def test_count_window(self): diff --git a/flink-python/pyflink/fn_execution/datastream/runtime_context.py b/flink-python/pyflink/fn_execution/datastream/runtime_context.py index 6ff7118..86969f2 100644 --- a/flink-python/pyflink/fn_execution/datastream/runtime_context.py +++ b/flink-python/pyflink/fn_execution/datastream/runtime_context.py @@ -105,7 +105,9 @@ class StreamingRuntimeContext(RuntimeContext): def get_state(self, state_descriptor: ValueStateDescriptor) -> ValueState: if self._keyed_state_backend: return self._keyed_state_backend.get_value_state( - state_descriptor.name, from_type_info(state_descriptor.type_info)) + state_descriptor.name, + from_type_info(state_descriptor.type_info), + state_descriptor._ttl_config) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") @@ -113,7 +115,9 @@ class StreamingRuntimeContext(RuntimeContext): if self._keyed_state_backend: array_coder = from_type_info(state_descriptor.type_info) # type: GenericArrayCoder return self._keyed_state_backend.get_list_state( - state_descriptor.name, array_coder._elem_coder) + state_descriptor.name, + array_coder._elem_coder, + state_descriptor._ttl_config) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") @@ -123,7 +127,10 @@ class StreamingRuntimeContext(RuntimeContext): key_coder = map_coder._key_coder value_coder = map_coder._value_coder return self._keyed_state_backend.get_map_state( - state_descriptor.name, key_coder, value_coder) + state_descriptor.name, + key_coder, + value_coder, + state_descriptor._ttl_config) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") @@ -132,7 +139,8 @@ class StreamingRuntimeContext(RuntimeContext): return self._keyed_state_backend.get_reducing_state( state_descriptor.get_name(), from_type_info(state_descriptor.type_info), - state_descriptor.get_reduce_function()) + state_descriptor.get_reduce_function(), + state_descriptor._ttl_config) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") @@ -142,7 +150,8 @@ class StreamingRuntimeContext(RuntimeContext): return self._keyed_state_backend.get_aggregating_state( state_descriptor.get_name(), from_type_info(state_descriptor.type_info), - state_descriptor.get_agg_function()) + state_descriptor.get_agg_function(), + state_descriptor._ttl_config) else: raise Exception("This state is only accessible by functions executed on a KeyedStream.") diff --git a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py index 894afaf..c087b64 100644 --- a/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py +++ b/flink-python/pyflink/fn_execution/flink_fn_execution_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( name='flink-fn-execution.proto', package='org.apache.flink.fn_execution.v1', syntax='proto3', - serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"\xa8\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12\x37\n\x06inputs\x18\x02 \x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...] + serialized_pb=_b('\n\x18\x66link-fn-execution.proto\x12 org.apache.flink.fn_execution.v1\"\x86\x01\n\x05Input\x12\x44\n\x03udf\x18\x01 \x01(\x0b\x32\x35.org.apache.flink.fn_execution.v1.UserDefinedFunctionH\x00\x12\x15\n\x0binputOffset\x18\x02 \x01(\x05H\x00\x12\x17\n\rinputConstant\x18\x03 \x01(\x0cH\x00\x42\x07\n\x05input\"\xa8\x01\n\x13UserDefinedFunction\x12\x0f\n\x07payload\x18\x01 \x01(\x0c\x12\x37\n\x06inputs\x18\x02 \x03(\x0b\x32\'.org.apache.flink.fn_execution.v1.Input\x12\x14 [...] ) @@ -385,6 +385,116 @@ _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE = _descriptor.EnumDescriptor( ) _sym_db.RegisterEnumDescriptor(_USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE) +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES = _descriptor.EnumDescriptor( + name='Strategies', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='FULL_STATE_SCAN_SNAPSHOT', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INCREMENTAL_CLEANUP', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='ROCKSDB_COMPACTION_FILTER', index=2, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=8416, + serialized_end=8514, +) +_sym_db.RegisterEnumDescriptor(_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES) + +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY = _descriptor.EnumDescriptor( + name='EmptyCleanupStrategy', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.EmptyCleanupStrategy', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='EMPTY_STRATEGY', index=0, number=0, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=8516, + serialized_end=8558, +) +_sym_db.RegisterEnumDescriptor(_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY) + +_STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE = _descriptor.EnumDescriptor( + name='UpdateType', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.UpdateType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='Disabled', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='OnCreateAndWrite', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='OnReadAndWrite', index=2, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=8560, + serialized_end=8628, +) +_sym_db.RegisterEnumDescriptor(_STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE) + +_STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY = _descriptor.EnumDescriptor( + name='StateVisibility', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.StateVisibility', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='ReturnExpiredIfNotCleanedUp', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='NeverReturnExpired', index=1, number=1, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=8630, + serialized_end=8704, +) +_sym_db.RegisterEnumDescriptor(_STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY) + +_STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC = _descriptor.EnumDescriptor( + name='TtlTimeCharacteristic', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.TtlTimeCharacteristic', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='ProcessingTime', index=0, number=0, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=8706, + serialized_end=8749, +) +_sym_db.RegisterEnumDescriptor(_STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC) + _CODERINFODESCRIPTOR_MODE = _descriptor.EnumDescriptor( name='Mode', full_name='org.apache.flink.fn_execution.v1.CoderInfoDescriptor.Mode', @@ -402,8 +512,8 @@ _CODERINFODESCRIPTOR_MODE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=7821, - serialized_end=7853, + serialized_start=9716, + serialized_end=9748, ) _sym_db.RegisterEnumDescriptor(_CODERINFODESCRIPTOR_MODE) @@ -1905,6 +2015,265 @@ _USERDEFINEDDATASTREAMFUNCTION = _descriptor.Descriptor( ) +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY = _descriptor.Descriptor( + name='IncrementalCleanupStrategy', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.IncrementalCleanupStrategy', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='cleanup_size', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.IncrementalCleanupStrategy.cleanup_size', index=0, + number=1, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='run_cleanup_for_every_record', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.IncrementalCleanupStrategy.run_cleanup_for_every_record', index=1, + number=2, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=7638, + serialized_end=7726, +) + +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY = _descriptor.Descriptor( + name='RocksdbCompactFilterCleanupStrategy', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='query_time_after_num_entries', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy.query_time_after_num_entries', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=7728, + serialized_end=7803, +) + +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY = _descriptor.Descriptor( + name='MapStrategiesEntry', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='strategy', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry.strategy', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='empty_strategy', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry.empty_strategy', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='incremental_cleanup_strategy', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry.incremental_cleanup_strategy', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='rocksdb_compact_filter_cleanup_strategy', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry.rocksdb_compact_filter_cleanup_strategy', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='CleanupStrategy', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry.CleanupStrategy', + index=0, containing_type=None, fields=[]), + ], + serialized_start=7806, + serialized_end=8414, +) + +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES = _descriptor.Descriptor( + name='CleanupStrategies', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='is_cleanup_in_background', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.is_cleanup_in_background', index=0, + number=1, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='strategies', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.strategies', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY, _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY, _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY, ], + enum_types=[ + _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES, + _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY, + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=7460, + serialized_end=8558, +) + +_STATEDESCRIPTOR_STATETTLCONFIG = _descriptor.Descriptor( + name='StateTTLConfig', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='update_type', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.update_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='state_visibility', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.state_visibility', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='ttl_time_characteristic', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.ttl_time_characteristic', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='ttl', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.ttl', index=3, + number=4, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='cleanup_strategies', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.cleanup_strategies', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES, ], + enum_types=[ + _STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE, + _STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY, + _STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC, + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=6989, + serialized_end=8749, +) + +_STATEDESCRIPTOR = _descriptor.Descriptor( + name='StateDescriptor', + full_name='org.apache.flink.fn_execution.v1.StateDescriptor', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='state_name', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.state_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='state_ttl_config', full_name='org.apache.flink.fn_execution.v1.StateDescriptor.state_ttl_config', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_STATEDESCRIPTOR_STATETTLCONFIG, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=6857, + serialized_end=8749, +) + + _CODERINFODESCRIPTOR_FLATTENROWTYPE = _descriptor.Descriptor( name='FlattenRowType', full_name='org.apache.flink.fn_execution.v1.CoderInfoDescriptor.FlattenRowType', @@ -1931,8 +2300,8 @@ _CODERINFODESCRIPTOR_FLATTENROWTYPE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=7450, - serialized_end=7524, + serialized_start=9345, + serialized_end=9419, ) _CODERINFODESCRIPTOR_ROWTYPE = _descriptor.Descriptor( @@ -1961,8 +2330,8 @@ _CODERINFODESCRIPTOR_ROWTYPE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=7526, - serialized_end=7593, + serialized_start=9421, + serialized_end=9488, ) _CODERINFODESCRIPTOR_ARROWTYPE = _descriptor.Descriptor( @@ -1991,8 +2360,8 @@ _CODERINFODESCRIPTOR_ARROWTYPE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=7595, - serialized_end=7664, + serialized_start=9490, + serialized_end=9559, ) _CODERINFODESCRIPTOR_OVERWINDOWARROWTYPE = _descriptor.Descriptor( @@ -2021,8 +2390,8 @@ _CODERINFODESCRIPTOR_OVERWINDOWARROWTYPE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=7666, - serialized_end=7745, + serialized_start=9561, + serialized_end=9640, ) _CODERINFODESCRIPTOR_RAWTYPE = _descriptor.Descriptor( @@ -2051,8 +2420,8 @@ _CODERINFODESCRIPTOR_RAWTYPE = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=7747, - serialized_end=7819, + serialized_start=9642, + serialized_end=9714, ) _CODERINFODESCRIPTOR = _descriptor.Descriptor( @@ -2127,8 +2496,8 @@ _CODERINFODESCRIPTOR = _descriptor.Descriptor( name='data_type', full_name='org.apache.flink.fn_execution.v1.CoderInfoDescriptor.data_type', index=0, containing_type=None, fields=[]), ], - serialized_start=6857, - serialized_end=7866, + serialized_start=8752, + serialized_end=9761, ) _INPUT.fields_by_name['udf'].message_type = _USERDEFINEDFUNCTION @@ -2269,6 +2638,35 @@ _USERDEFINEDDATASTREAMFUNCTION.fields_by_name['function_type'].enum_type = _USER _USERDEFINEDDATASTREAMFUNCTION.fields_by_name['runtime_context'].message_type = _USERDEFINEDDATASTREAMFUNCTION_RUNTIMECONTEXT _USERDEFINEDDATASTREAMFUNCTION.fields_by_name['key_type_info'].message_type = _TYPEINFO _USERDEFINEDDATASTREAMFUNCTION_FUNCTIONTYPE.containing_type = _USERDEFINEDDATASTREAMFUNCTION +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['strategy'].enum_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['empty_strategy'].enum_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['incremental_cleanup_strategy'].message_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['rocksdb_compact_filter_cleanup_strategy'].message_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.oneofs_by_name['CleanupStrategy'].fields.append( + _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['empty_strategy']) +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['empty_strategy'].containing_oneof = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.oneofs_by_name['CleanupStrategy'] +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.oneofs_by_name['CleanupStrategy'].fields.append( + _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['incremental_cleanup_strategy']) +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['incremental_cleanup_strategy'].containing_oneof = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.oneofs_by_name['CleanupStrategy'] +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.oneofs_by_name['CleanupStrategy'].fields.append( + _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['rocksdb_compact_filter_cleanup_strategy']) +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.fields_by_name['rocksdb_compact_filter_cleanup_strategy'].containing_oneof = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY.oneofs_by_name['CleanupStrategy'] +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES.fields_by_name['strategies'].message_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_STRATEGIES.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_EMPTYCLEANUPSTRATEGY.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG.fields_by_name['update_type'].enum_type = _STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE +_STATEDESCRIPTOR_STATETTLCONFIG.fields_by_name['state_visibility'].enum_type = _STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY +_STATEDESCRIPTOR_STATETTLCONFIG.fields_by_name['ttl_time_characteristic'].enum_type = _STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC +_STATEDESCRIPTOR_STATETTLCONFIG.fields_by_name['cleanup_strategies'].message_type = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES +_STATEDESCRIPTOR_STATETTLCONFIG.containing_type = _STATEDESCRIPTOR +_STATEDESCRIPTOR_STATETTLCONFIG_UPDATETYPE.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG +_STATEDESCRIPTOR_STATETTLCONFIG_STATEVISIBILITY.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG +_STATEDESCRIPTOR_STATETTLCONFIG_TTLTIMECHARACTERISTIC.containing_type = _STATEDESCRIPTOR_STATETTLCONFIG +_STATEDESCRIPTOR.fields_by_name['state_ttl_config'].message_type = _STATEDESCRIPTOR_STATETTLCONFIG _CODERINFODESCRIPTOR_FLATTENROWTYPE.fields_by_name['schema'].message_type = _SCHEMA _CODERINFODESCRIPTOR_FLATTENROWTYPE.containing_type = _CODERINFODESCRIPTOR _CODERINFODESCRIPTOR_ROWTYPE.fields_by_name['schema'].message_type = _SCHEMA @@ -2311,6 +2709,7 @@ DESCRIPTOR.message_types_by_name['UserDefinedAggregateFunctions'] = _USERDEFINED DESCRIPTOR.message_types_by_name['Schema'] = _SCHEMA DESCRIPTOR.message_types_by_name['TypeInfo'] = _TYPEINFO DESCRIPTOR.message_types_by_name['UserDefinedDataStreamFunction'] = _USERDEFINEDDATASTREAMFUNCTION +DESCRIPTOR.message_types_by_name['StateDescriptor'] = _STATEDESCRIPTOR DESCRIPTOR.message_types_by_name['CoderInfoDescriptor'] = _CODERINFODESCRIPTOR _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -2552,6 +2951,53 @@ _sym_db.RegisterMessage(UserDefinedDataStreamFunction) _sym_db.RegisterMessage(UserDefinedDataStreamFunction.JobParameter) _sym_db.RegisterMessage(UserDefinedDataStreamFunction.RuntimeContext) +StateDescriptor = _reflection.GeneratedProtocolMessageType('StateDescriptor', (_message.Message,), dict( + + StateTTLConfig = _reflection.GeneratedProtocolMessageType('StateTTLConfig', (_message.Message,), dict( + + CleanupStrategies = _reflection.GeneratedProtocolMessageType('CleanupStrategies', (_message.Message,), dict( + + IncrementalCleanupStrategy = _reflection.GeneratedProtocolMessageType('IncrementalCleanupStrategy', (_message.Message,), dict( + DESCRIPTOR = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_INCREMENTALCLEANUPSTRATEGY, + __module__ = 'flink_fn_execution_pb2' + # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.IncrementalCleanupStrategy) + )) + , + + RocksdbCompactFilterCleanupStrategy = _reflection.GeneratedProtocolMessageType('RocksdbCompactFilterCleanupStrategy', (_message.Message,), dict( + DESCRIPTOR = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_ROCKSDBCOMPACTFILTERCLEANUPSTRATEGY, + __module__ = 'flink_fn_execution_pb2' + # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy) + )) + , + + MapStrategiesEntry = _reflection.GeneratedProtocolMessageType('MapStrategiesEntry', (_message.Message,), dict( + DESCRIPTOR = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES_MAPSTRATEGIESENTRY, + __module__ = 'flink_fn_execution_pb2' + # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry) + )) + , + DESCRIPTOR = _STATEDESCRIPTOR_STATETTLCONFIG_CLEANUPSTRATEGIES, + __module__ = 'flink_fn_execution_pb2' + # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig.CleanupStrategies) + )) + , + DESCRIPTOR = _STATEDESCRIPTOR_STATETTLCONFIG, + __module__ = 'flink_fn_execution_pb2' + # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.StateDescriptor.StateTTLConfig) + )) + , + DESCRIPTOR = _STATEDESCRIPTOR, + __module__ = 'flink_fn_execution_pb2' + # @@protoc_insertion_point(class_scope:org.apache.flink.fn_execution.v1.StateDescriptor) + )) +_sym_db.RegisterMessage(StateDescriptor) +_sym_db.RegisterMessage(StateDescriptor.StateTTLConfig) +_sym_db.RegisterMessage(StateDescriptor.StateTTLConfig.CleanupStrategies) +_sym_db.RegisterMessage(StateDescriptor.StateTTLConfig.CleanupStrategies.IncrementalCleanupStrategy) +_sym_db.RegisterMessage(StateDescriptor.StateTTLConfig.CleanupStrategies.RocksdbCompactFilterCleanupStrategy) +_sym_db.RegisterMessage(StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry) + CoderInfoDescriptor = _reflection.GeneratedProtocolMessageType('CoderInfoDescriptor', (_message.Message,), dict( FlattenRowType = _reflection.GeneratedProtocolMessageType('FlattenRowType', (_message.Message,), dict( diff --git a/flink-python/pyflink/fn_execution/state_impl.py b/flink-python/pyflink/fn_execution/state_impl.py index 58580c1..a9c0fee 100644 --- a/flink-python/pyflink/fn_execution/state_impl.py +++ b/flink-python/pyflink/fn_execution/state_impl.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ +import base64 import collections from abc import ABC, abstractmethod from enum import Enum @@ -29,6 +30,7 @@ from typing import List, Tuple, Any, Dict, Collection from pyflink.datastream import ReduceFunction from pyflink.datastream.functions import AggregateFunction +from pyflink.datastream.state import StateTtlConfig from pyflink.fn_execution.beam.beam_coders import FlinkCoder from pyflink.fn_execution.coders import FieldCoder from pyflink.fn_execution.internal_state import InternalKvState, N, InternalValueState, \ @@ -96,6 +98,7 @@ class SynchronousKvRuntimeState(InternalKvState, ABC): self._remote_state_backend = remote_state_backend self._internal_state = None self.namespace = None + self._ttl_config = None def set_current_namespace(self, namespace: N) -> None: if namespace == self.namespace: @@ -106,6 +109,9 @@ class SynchronousKvRuntimeState(InternalKvState, ABC): self.namespace = namespace self._internal_state = None + def enable_time_to_live(self, ttl_config: StateTtlConfig): + self._ttl_config = ttl_config + @abstractmethod def get_internal_state(self): pass @@ -122,7 +128,7 @@ class SynchronousBagKvRuntimeState(SynchronousKvRuntimeState, ABC): def get_internal_state(self): if self._internal_state is None: self._internal_state = self._remote_state_backend._get_internal_bag_state( - self.name, self.namespace, self._value_coder) + self.name, self.namespace, self._value_coder, self._ttl_config) return self._internal_state @@ -158,7 +164,7 @@ class SynchronousMergingRuntimeState(SynchronousBagKvRuntimeState, InternalMergi name, value_coder, remote_state_backend) def merge_namespaces(self, target: N, sources: Collection[N]) -> None: - self._remote_state_backend.merge_namespaces(self, target, sources) + self._remote_state_backend.merge_namespaces(self, target, sources, self._ttl_config) class SynchronousListRuntimeState(SynchronousMergingRuntimeState, InternalListState): @@ -867,7 +873,11 @@ class SynchronousMapRuntimeState(SynchronousKvRuntimeState, InternalMapState): def get_internal_state(self): if self._internal_state is None: self._internal_state = self._remote_state_backend._get_internal_map_state( - self.name, self.namespace, self._map_key_coder, self._map_value_coder) + self.name, + self.namespace, + self._map_key_coder, + self._map_value_coder, + self._ttl_config) return self._internal_state def get(self, key): @@ -938,31 +948,47 @@ class RemoteKeyedStateBackend(object): side_input_id="clear_iterators", key=self._encoded_current_key)) - def get_list_state(self, name, element_coder): + def get_list_state(self, name, element_coder, ttl_config=None): return self._wrap_internal_bag_state( - name, element_coder, SynchronousListRuntimeState, SynchronousListRuntimeState) + name, + element_coder, + SynchronousListRuntimeState, + SynchronousListRuntimeState, + ttl_config) - def get_value_state(self, name, value_coder): + def get_value_state(self, name, value_coder, ttl_config=None): return self._wrap_internal_bag_state( - name, value_coder, SynchronousValueRuntimeState, SynchronousValueRuntimeState) + name, + value_coder, + SynchronousValueRuntimeState, + SynchronousValueRuntimeState, + ttl_config) - def get_map_state(self, name, map_key_coder, map_value_coder): + def get_map_state(self, name, map_key_coder, map_value_coder, ttl_config=None): if name in self._all_states: self.validate_map_state(name, map_key_coder, map_value_coder) return self._all_states[name] map_state = SynchronousMapRuntimeState(name, map_key_coder, map_value_coder, self) + if ttl_config is not None: + map_state.enable_time_to_live(ttl_config) self._all_states[name] = map_state return map_state - def get_reducing_state(self, name, coder, reduce_function): + def get_reducing_state(self, name, coder, reduce_function, ttl_config=None): return self._wrap_internal_bag_state( - name, coder, SynchronousReducingRuntimeState, - partial(SynchronousReducingRuntimeState, reduce_function=reduce_function)) + name, + coder, + SynchronousReducingRuntimeState, + partial(SynchronousReducingRuntimeState, reduce_function=reduce_function), + ttl_config) - def get_aggregating_state(self, name, coder, agg_function): + def get_aggregating_state(self, name, coder, agg_function, ttl_config=None): return self._wrap_internal_bag_state( - name, coder, SynchronousAggregatingRuntimeState, - partial(SynchronousAggregatingRuntimeState, agg_function=agg_function)) + name, + coder, + SynchronousAggregatingRuntimeState, + partial(SynchronousAggregatingRuntimeState, agg_function=agg_function), + ttl_config) def validate_state(self, name, coder, expected_type): if name in self._all_states: @@ -983,15 +1009,18 @@ class RemoteKeyedStateBackend(object): state._map_value_coder != map_value_coder: raise Exception("State name corrupted: %s" % name) - def _wrap_internal_bag_state(self, name, element_coder, wrapper_type, wrap_method): + def _wrap_internal_bag_state( + self, name, element_coder, wrapper_type, wrap_method, ttl_config): if name in self._all_states: self.validate_state(name, element_coder, wrapper_type) return self._all_states[name] wrapped_state = wrap_method(name, element_coder, self) + if ttl_config is not None: + wrapped_state.enable_time_to_live(ttl_config) self._all_states[name] = wrapped_state return wrapped_state - def _get_internal_bag_state(self, name, namespace, element_coder): + def _get_internal_bag_state(self, name, namespace, element_coder, ttl_config): encoded_namespace = self._encode_namespace(namespace) cached_state = self._internal_state_cache.get( (name, self._encoded_current_key, encoded_namespace)) @@ -1004,39 +1033,45 @@ class RemoteKeyedStateBackend(object): if isinstance(element_coder, FieldCoder): element_coder = FlinkCoder(element_coder) state_spec = userstate.BagStateSpec(name, element_coder) - internal_state = self._create_bag_state(state_spec, encoded_namespace) + internal_state = self._create_bag_state(state_spec, encoded_namespace, ttl_config) return internal_state - def _get_internal_map_state(self, name, namespace, map_key_coder, map_value_coder): + def _get_internal_map_state(self, name, namespace, map_key_coder, map_value_coder, ttl_config): encoded_namespace = self._encode_namespace(namespace) cached_state = self._internal_state_cache.get( (name, self._encoded_current_key, encoded_namespace)) if cached_state is not None: return cached_state internal_map_state = self._create_internal_map_state( - name, encoded_namespace, map_key_coder, map_value_coder) + name, encoded_namespace, map_key_coder, map_value_coder, ttl_config) return internal_map_state - def _create_bag_state(self, state_spec: userstate.StateSpec, encoded_namespace) \ + def _create_bag_state(self, state_spec: userstate.StateSpec, encoded_namespace, ttl_config) \ -> userstate.AccumulatingRuntimeState: if isinstance(state_spec, userstate.BagStateSpec): bag_state = SynchronousBagRuntimeState( self._state_handler, state_key=self.get_bag_state_key( - state_spec.name, self._encoded_current_key, encoded_namespace), + state_spec.name, self._encoded_current_key, encoded_namespace, ttl_config), value_coder=state_spec.coder) return bag_state else: raise NotImplementedError(state_spec) - def _create_internal_map_state(self, name, encoded_namespace, map_key_coder, map_value_coder): + def _create_internal_map_state( + self, name, encoded_namespace, map_key_coder, map_value_coder, ttl_config): # Currently the `beam_fn_api.proto` does not support MapState, so we use the # the `MultimapSideInput` message to mark the state as a MapState for now. + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + state_proto = StateDescriptor() + state_proto.state_name = name + if ttl_config is not None: + state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto()) state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="", window=encoded_namespace, - side_input_id=name, + side_input_id=base64.b64encode(state_proto.SerializeToString()), key=self._encoded_current_key)) return InternalSynchronousMapRuntimeState( self._map_state_handler, @@ -1087,7 +1122,7 @@ class RemoteKeyedStateBackend(object): self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key self._map_state_handler.clear(self._clear_iterator_mark) - def merge_namespaces(self, state: SynchronousMergingRuntimeState, target, sources): + def merge_namespaces(self, state: SynchronousMergingRuntimeState, target, sources, ttl_config): state.set_current_namespace(target) self.commit_internal_state(state.get_internal_state()) encoded_target_namespace = self._encode_namespace(target) @@ -1097,7 +1132,7 @@ class RemoteKeyedStateBackend(object): self.clear_state_cache(state, encoded_namespaces) state_key = self.get_bag_state_key( - state.name, self._encoded_current_key, encoded_target_namespace) + state.name, self._encoded_current_key, encoded_target_namespace, ttl_config) state_key.bag_user_state.transform_id = self.MERGE_NAMESAPCES_MARK encoded_namespaces_writer = BytesIO() @@ -1118,18 +1153,22 @@ class RemoteKeyedStateBackend(object): (name, self._encoded_current_key, encoded_namespace)) # currently all the SynchronousMergingRuntimeState is based on bag state state_key = self.get_bag_state_key( - name, self._encoded_current_key, encoded_namespace) + name, self._encoded_current_key, encoded_namespace, None) # clear the read cache, the read cache is shared between map state handler and bag # state handler. So we can use the map state handler instead. self._map_state_handler.clear_read_cache(state_key) - @staticmethod - def get_bag_state_key(name, encoded_key, encoded_namespace): + def get_bag_state_key(self, name, encoded_key, encoded_namespace, ttl_config): + from pyflink.fn_execution.flink_fn_execution_pb2 import StateDescriptor + state_proto = StateDescriptor() + state_proto.state_name = name + if ttl_config is not None: + state_proto.state_ttl_config.CopyFrom(ttl_config._to_proto()) return beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( transform_id="", window=encoded_namespace, - user_state_id=name, + user_state_id=base64.b64encode(state_proto.SerializeToString()), key=encoded_key)) @staticmethod diff --git a/flink-python/pyflink/fn_execution/tests/test_flink_fn_execution_pb2_synced.py b/flink-python/pyflink/fn_execution/tests/test_flink_fn_execution_pb2.py similarity index 51% rename from flink-python/pyflink/fn_execution/tests/test_flink_fn_execution_pb2_synced.py rename to flink-python/pyflink/fn_execution/tests/test_flink_fn_execution_pb2.py index 544e4f3..313747f 100644 --- a/flink-python/pyflink/fn_execution/tests/test_flink_fn_execution_pb2_synced.py +++ b/flink-python/pyflink/fn_execution/tests/test_flink_fn_execution_pb2.py @@ -22,7 +22,7 @@ from pyflink.gen_protos import generate_proto_files from pyflink.testing.test_case_utils import PyFlinkTestCase -class FlinkFnExecutionSyncTests(PyFlinkTestCase): +class FlinkFnExecutionTests(PyFlinkTestCase): """ Tests whether flink_fn_exeution_pb2.py is synced with flink-fn-execution.proto. """ @@ -41,3 +41,39 @@ class FlinkFnExecutionSyncTests(PyFlinkTestCase): % (self.flink_fn_execution_pb2_file_name, self.gen_protos_script, self.flink_fn_execution_proto_file_name)) + + def test_state_ttl_config_proto(self): + from pyflink.datastream.state import StateTtlConfig + from pyflink.common.time import Time + state_ttl_config = StateTtlConfig \ + .new_builder(Time.milliseconds(1000)) \ + .set_update_type(StateTtlConfig.UpdateType.OnCreateAndWrite) \ + .set_state_visibility(StateTtlConfig.StateVisibility.NeverReturnExpired) \ + .cleanup_full_snapshot() \ + .cleanup_incrementally(10, True) \ + .cleanup_in_rocksdb_compact_filter(1000) \ + .build() + state_ttl_config_proto = state_ttl_config._to_proto() + state_ttl_config = StateTtlConfig._from_proto(state_ttl_config_proto) + self.assertEqual(state_ttl_config.get_ttl(), Time.milliseconds(1000)) + self.assertEqual( + state_ttl_config.get_update_type(), StateTtlConfig.UpdateType.OnCreateAndWrite) + self.assertEqual( + state_ttl_config.get_state_visibility(), + StateTtlConfig.StateVisibility.NeverReturnExpired) + self.assertEqual( + state_ttl_config.get_ttl_time_characteristic(), + StateTtlConfig.TtlTimeCharacteristic.ProcessingTime) + + cleanup_strategies = state_ttl_config.get_cleanup_strategies() + self.assertTrue(cleanup_strategies.is_cleanup_in_background()) + self.assertTrue(cleanup_strategies.in_full_snapshot()) + + incremental_cleanup_strategy = cleanup_strategies.get_incremental_cleanup_strategy() + self.assertEqual(incremental_cleanup_strategy.get_cleanup_size(), 10) + self.assertTrue(incremental_cleanup_strategy.run_cleanup_for_every_record()) + + rocksdb_compact_filter_cleanup_strategy = \ + cleanup_strategies.get_rocksdb_compact_filter_cleanup_strategy() + self.assertEqual( + rocksdb_compact_filter_cleanup_strategy.get_query_time_after_num_entries(), 1000) diff --git a/flink-python/pyflink/proto/flink-fn-execution.proto b/flink-python/pyflink/proto/flink-fn-execution.proto index d4c449f..95ee2f5 100644 --- a/flink-python/pyflink/proto/flink-fn-execution.proto +++ b/flink-python/pyflink/proto/flink-fn-execution.proto @@ -373,6 +373,95 @@ message UserDefinedDataStreamFunction { bool profile_enabled = 6; } +// A representation of State +message StateDescriptor { + message StateTTLConfig { + // This option value configures when to update last access timestamp which prolongs state TTL. + enum UpdateType { + // TTL is disabled. State does not expire. + Disabled = 0; + + // Last access timestamp is initialised when state is created and updated on every write operation. + OnCreateAndWrite = 1; + + // The same as OnCreateAndWrite but also updated on read. + OnReadAndWrite = 2; + } + + // This option configures whether expired user value can be returned or not. + enum StateVisibility { + // Return expired user value if it is not cleaned up yet. + ReturnExpiredIfNotCleanedUp = 0; + + // Never return expired user value. + NeverReturnExpired = 1; + } + + // This option configures time scale to use for ttl. + enum TtlTimeCharacteristic { + // Processing time + ProcessingTime = 0; + } + + // TTL cleanup strategies. + message CleanupStrategies { + // Fixed strategies ordinals in strategies config field. + enum Strategies { + FULL_STATE_SCAN_SNAPSHOT = 0; + INCREMENTAL_CLEANUP = 1; + ROCKSDB_COMPACTION_FILTER = 2; + } + + enum EmptyCleanupStrategy { + EMPTY_STRATEGY = 0; + } + + // Configuration of cleanup strategy while taking the full snapshot. + message IncrementalCleanupStrategy { + // Max number of keys pulled from queue for clean up upon state touch for any key. + int32 cleanup_size = 1; + + // Whether to run incremental cleanup per each processed record. + bool run_cleanup_for_every_record = 2; + } + + // Configuration of cleanup strategy using custom compaction filter in RocksDB. + message RocksdbCompactFilterCleanupStrategy { + // Number of state entries to process by compaction filter before updating current timestamp. + int64 query_time_after_num_entries = 1; + } + + message MapStrategiesEntry { + Strategies strategy = 1; + + oneof CleanupStrategy { + EmptyCleanupStrategy empty_strategy = 2; + IncrementalCleanupStrategy incremental_cleanup_strategy = 3; + RocksdbCompactFilterCleanupStrategy rocksdb_compact_filter_cleanup_strategy = 4; + } + } + + bool is_cleanup_in_background = 1; + + repeated MapStrategiesEntry strategies = 2; + } + + UpdateType update_type = 1; + + StateVisibility state_visibility = 2; + + TtlTimeCharacteristic ttl_time_characteristic = 3; + + int64 ttl = 4; + + CleanupStrategies cleanup_strategies = 5; + } + + string state_name = 1; + + StateTTLConfig state_ttl_config = 2; +} + // ------------------------------------------------------------------------ // Common of Table API and DataStream API // ------------------------------------------------------------------------ diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/SimpleStateRequestHandler.java b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/SimpleStateRequestHandler.java index 95f571f..6f5e752 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/SimpleStateRequestHandler.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/SimpleStateRequestHandler.java @@ -25,6 +25,7 @@ import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.StateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.typeutils.runtime.RowSerializer; @@ -32,6 +33,7 @@ import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.fnexecution.v1.FlinkFnApi; import org.apache.flink.python.PythonOptions; import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.VoidNamespace; @@ -39,6 +41,7 @@ import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.internal.InternalMergingState; import org.apache.flink.streaming.api.utils.ByteArrayWrapper; import org.apache.flink.streaming.api.utils.ByteArrayWrapperSerializer; +import org.apache.flink.streaming.api.utils.ProtoUtils; import org.apache.flink.table.data.RowData; import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer; import org.apache.flink.table.runtime.typeutils.RowDataSerializer; @@ -48,6 +51,7 @@ import org.apache.beam.runners.fnexecution.state.StateRequestHandler; import org.apache.beam.vendor.grpc.v1p26p0.com.google.common.base.Charsets; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; +import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -294,13 +298,22 @@ public class SimpleStateRequestHandler implements StateRequestHandler { private ListState<byte[]> getListState(BeamFnApi.StateRequest request) throws Exception { BeamFnApi.StateKey.BagUserState bagUserState = request.getStateKey().getBagUserState(); - String stateName = PYTHON_STATE_PREFIX + bagUserState.getUserStateId(); + byte[] data = Base64.getDecoder().decode(bagUserState.getUserStateId()); + FlinkFnApi.StateDescriptor stateDescriptor = FlinkFnApi.StateDescriptor.parseFrom(data); + String stateName = PYTHON_STATE_PREFIX + stateDescriptor.getStateName(); ListStateDescriptor<byte[]> listStateDescriptor; StateDescriptor cachedStateDescriptor = stateDescriptorCache.get(stateName); if (cachedStateDescriptor instanceof ListStateDescriptor) { listStateDescriptor = (ListStateDescriptor<byte[]>) cachedStateDescriptor; } else if (cachedStateDescriptor == null) { listStateDescriptor = new ListStateDescriptor<>(stateName, valueSerializer); + if (stateDescriptor.hasStateTtlConfig()) { + FlinkFnApi.StateDescriptor.StateTTLConfig stateTtlConfigProto = + stateDescriptor.getStateTtlConfig(); + StateTtlConfig stateTtlConfig = + ProtoUtils.parseStateTtlConfigFromProto(stateTtlConfigProto); + listStateDescriptor.enableTimeToLive(stateTtlConfig); + } stateDescriptorCache.put(stateName, listStateDescriptor); } else { throw new RuntimeException( @@ -587,7 +600,9 @@ public class SimpleStateRequestHandler implements StateRequestHandler { throws Exception { BeamFnApi.StateKey.MultimapSideInput mapUserState = request.getStateKey().getMultimapSideInput(); - String stateName = PYTHON_STATE_PREFIX + mapUserState.getSideInputId(); + byte[] data = Base64.getDecoder().decode(mapUserState.getSideInputId()); + FlinkFnApi.StateDescriptor stateDescriptor = FlinkFnApi.StateDescriptor.parseFrom(data); + String stateName = PYTHON_STATE_PREFIX + stateDescriptor.getStateName(); StateDescriptor cachedStateDescriptor = stateDescriptorCache.get(stateName); MapStateDescriptor<ByteArrayWrapper, byte[]> mapStateDescriptor; if (cachedStateDescriptor instanceof MapStateDescriptor) { @@ -597,6 +612,13 @@ public class SimpleStateRequestHandler implements StateRequestHandler { mapStateDescriptor = new MapStateDescriptor<>( stateName, ByteArrayWrapperSerializer.INSTANCE, valueSerializer); + if (stateDescriptor.hasStateTtlConfig()) { + FlinkFnApi.StateDescriptor.StateTTLConfig stateTtlConfigProto = + stateDescriptor.getStateTtlConfig(); + StateTtlConfig stateTtlConfig = + ProtoUtils.parseStateTtlConfigFromProto(stateTtlConfigProto); + mapStateDescriptor.enableTimeToLive(stateTtlConfig); + } stateDescriptorCache.put(stateName, mapStateDescriptor); } else { throw new RuntimeException( diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/utils/ProtoUtils.java b/flink-python/src/main/java/org/apache/flink/streaming/api/utils/ProtoUtils.java index 9c60500..070d6a9 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/utils/ProtoUtils.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/utils/ProtoUtils.java @@ -20,6 +20,8 @@ package org.apache.flink.streaming.api.utils; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.time.Time; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.fnexecution.v1.FlinkFnApi; import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo; @@ -41,7 +43,7 @@ import java.util.stream.Collectors; import static org.apache.flink.python.Constants.FLINK_CODER_URN; import static org.apache.flink.table.runtime.typeutils.PythonTypeUtils.toProtoType; -/** Utilities used to construct protobuf objects. */ +/** Utilities used to construct protobuf objects or construct objects from protobuf objects. */ @Internal public enum ProtoUtils { ; @@ -348,4 +350,93 @@ public enum ProtoUtils { builder.setSeparatedWithEndMessage(separatedWithEndMessage); return builder.build(); } + + public static StateTtlConfig parseStateTtlConfigFromProto( + FlinkFnApi.StateDescriptor.StateTTLConfig stateTTLConfigProto) { + StateTtlConfig.Builder builder = + StateTtlConfig.newBuilder(Time.milliseconds(stateTTLConfigProto.getTtl())) + .setUpdateType( + parseUpdateTypeFromProto(stateTTLConfigProto.getUpdateType())) + .setStateVisibility( + parseStateVisibilityFromProto( + stateTTLConfigProto.getStateVisibility())) + .setTtlTimeCharacteristic( + parseTtlTimeCharacteristicFromProto( + stateTTLConfigProto.getTtlTimeCharacteristic())); + + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies cleanupStrategiesProto = + stateTTLConfigProto.getCleanupStrategies(); + + if (!cleanupStrategiesProto.getIsCleanupInBackground()) { + builder.disableCleanupInBackground(); + } + + for (FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.MapStrategiesEntry + mapStrategiesEntry : cleanupStrategiesProto.getStrategiesList()) { + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies strategyProto = + mapStrategiesEntry.getStrategy(); + if (strategyProto + == FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies + .FULL_STATE_SCAN_SNAPSHOT) { + builder.cleanupFullSnapshot(); + } else if (strategyProto + == FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies + .INCREMENTAL_CLEANUP) { + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies + .IncrementalCleanupStrategy + incrementalCleanupStrategyProto = + mapStrategiesEntry.getIncrementalCleanupStrategy(); + builder.cleanupIncrementally( + incrementalCleanupStrategyProto.getCleanupSize(), + incrementalCleanupStrategyProto.getRunCleanupForEveryRecord()); + } else if (strategyProto + == FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.Strategies + .ROCKSDB_COMPACTION_FILTER) { + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies + .RocksdbCompactFilterCleanupStrategy + rocksdbCompactFilterCleanupStrategyProto = + mapStrategiesEntry.getRocksdbCompactFilterCleanupStrategy(); + builder.cleanupInRocksdbCompactFilter( + rocksdbCompactFilterCleanupStrategyProto.getQueryTimeAfterNumEntries()); + } + } + + return builder.build(); + } + + private static StateTtlConfig.UpdateType parseUpdateTypeFromProto( + FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType updateType) { + if (updateType == FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType.Disabled) { + return StateTtlConfig.UpdateType.Disabled; + } else if (updateType + == FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType.OnCreateAndWrite) { + return StateTtlConfig.UpdateType.OnCreateAndWrite; + } else if (updateType + == FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType.OnReadAndWrite) { + return StateTtlConfig.UpdateType.OnReadAndWrite; + } + throw new RuntimeException("Unknown UpdateType " + updateType); + } + + private static StateTtlConfig.StateVisibility parseStateVisibilityFromProto( + FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility stateVisibility) { + if (stateVisibility + == FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility + .ReturnExpiredIfNotCleanedUp) { + return StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp; + } else if (stateVisibility + == FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility.NeverReturnExpired) { + return StateTtlConfig.StateVisibility.NeverReturnExpired; + } + throw new RuntimeException("Unknown StateVisibility " + stateVisibility); + } + + private static StateTtlConfig.TtlTimeCharacteristic parseTtlTimeCharacteristicFromProto( + FlinkFnApi.StateDescriptor.StateTTLConfig.TtlTimeCharacteristic ttlTimeCharacteristic) { + if (ttlTimeCharacteristic + == FlinkFnApi.StateDescriptor.StateTTLConfig.TtlTimeCharacteristic.ProcessingTime) { + return StateTtlConfig.TtlTimeCharacteristic.ProcessingTime; + } + throw new RuntimeException("Unknown TtlTimeCharacteristic " + ttlTimeCharacteristic); + } } diff --git a/flink-python/src/test/java/org/apache/flink/streaming/api/utils/ProtoUtilsTest.java b/flink-python/src/test/java/org/apache/flink/streaming/api/utils/ProtoUtilsTest.java new file mode 100644 index 0000000..84a2d99 --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/streaming/api/utils/ProtoUtilsTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.utils; + +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.fnexecution.v1.FlinkFnApi; + +import org.junit.Test; + +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Test class for testing utilities used to construct protobuf objects or construct objects from + * protobuf objects. + */ +public class ProtoUtilsTest { + @Test + public void testParseStateTtlConfigFromProto() { + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies cleanupStrategiesProto = + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies.newBuilder() + .setIsCleanupInBackground(true) + .addStrategies( + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies + .MapStrategiesEntry.newBuilder() + .setStrategy( + FlinkFnApi.StateDescriptor.StateTTLConfig + .CleanupStrategies.Strategies + .FULL_STATE_SCAN_SNAPSHOT) + .setEmptyStrategy( + FlinkFnApi.StateDescriptor.StateTTLConfig + .CleanupStrategies.EmptyCleanupStrategy + .EMPTY_STRATEGY)) + .addStrategies( + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies + .MapStrategiesEntry.newBuilder() + .setStrategy( + FlinkFnApi.StateDescriptor.StateTTLConfig + .CleanupStrategies.Strategies + .INCREMENTAL_CLEANUP) + .setIncrementalCleanupStrategy( + FlinkFnApi.StateDescriptor.StateTTLConfig + .CleanupStrategies + .IncrementalCleanupStrategy.newBuilder() + .setCleanupSize(10) + .setRunCleanupForEveryRecord(true) + .build())) + .addStrategies( + FlinkFnApi.StateDescriptor.StateTTLConfig.CleanupStrategies + .MapStrategiesEntry.newBuilder() + .setStrategy( + FlinkFnApi.StateDescriptor.StateTTLConfig + .CleanupStrategies.Strategies + .ROCKSDB_COMPACTION_FILTER) + .setRocksdbCompactFilterCleanupStrategy( + FlinkFnApi.StateDescriptor.StateTTLConfig + .CleanupStrategies + .RocksdbCompactFilterCleanupStrategy + .newBuilder() + .setQueryTimeAfterNumEntries(1000) + .build())) + .build(); + FlinkFnApi.StateDescriptor.StateTTLConfig stateTTLConfigProto = + FlinkFnApi.StateDescriptor.StateTTLConfig.newBuilder() + .setTtl(Time.of(1000, TimeUnit.MILLISECONDS).toMilliseconds()) + .setUpdateType( + FlinkFnApi.StateDescriptor.StateTTLConfig.UpdateType + .OnCreateAndWrite) + .setStateVisibility( + FlinkFnApi.StateDescriptor.StateTTLConfig.StateVisibility + .NeverReturnExpired) + .setCleanupStrategies(cleanupStrategiesProto) + .build(); + + StateTtlConfig stateTTLConfig = + ProtoUtils.parseStateTtlConfigFromProto(stateTTLConfigProto); + + assertEquals(stateTTLConfig.getUpdateType(), StateTtlConfig.UpdateType.OnCreateAndWrite); + assertEquals( + stateTTLConfig.getStateVisibility(), + StateTtlConfig.StateVisibility.NeverReturnExpired); + assertEquals(stateTTLConfig.getTtl(), Time.milliseconds(1000)); + assertEquals( + stateTTLConfig.getTtlTimeCharacteristic(), + StateTtlConfig.TtlTimeCharacteristic.ProcessingTime); + + StateTtlConfig.CleanupStrategies cleanupStrategies = stateTTLConfig.getCleanupStrategies(); + assertTrue(cleanupStrategies.isCleanupInBackground()); + assertTrue(cleanupStrategies.inFullSnapshot()); + + StateTtlConfig.IncrementalCleanupStrategy incrementalCleanupStrategy = + cleanupStrategies.getIncrementalCleanupStrategy(); + assertNotNull(incrementalCleanupStrategy); + assertEquals(incrementalCleanupStrategy.getCleanupSize(), 10); + assertTrue(incrementalCleanupStrategy.runCleanupForEveryRecord()); + + StateTtlConfig.RocksdbCompactFilterCleanupStrategy rocksdbCompactFilterCleanupStrategy = + cleanupStrategies.getRocksdbCompactFilterCleanupStrategy(); + assertNotNull(rocksdbCompactFilterCleanupStrategy); + assertEquals(rocksdbCompactFilterCleanupStrategy.getQueryTimeAfterNumEntries(), 1000); + } +}