lukecwik commented on code in PR #22924:
URL: https://github.com/apache/beam/pull/22924#discussion_r962022318
##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -20,245 +20,180 @@
# mypy: disallow-untyped-defs
import collections
+import gc
import logging
import threading
-from typing import TYPE_CHECKING
from typing import Any
-from typing import Callable
-from typing import Generic
-from typing import Hashable
from typing import List
from typing import Optional
-from typing import Set
from typing import Tuple
-from typing import TypeVar
-from apache_beam.metrics import monitoring_infos
-
-if TYPE_CHECKING:
- from apache_beam.portability.api import metrics_pb2
+import objsize
_LOGGER = logging.getLogger(__name__)
-CallableT = TypeVar('CallableT', bound='Callable')
-KT = TypeVar('KT')
-VT = TypeVar('VT')
+class WeightedValue(object):
+ """Value type that stores corresponding weight.
-class Metrics(object):
- """Metrics container for state cache metrics."""
+ :arg value The value to be stored.
+ :arg weight The associated weight of the value. If unspecified, the objects
+ size will be used.
+ """
+ def __init__(self, value, weight):
+ # type: (Any, int) -> None
+ self._value = value
+ if weight <= 0:
+ raise ValueError(
+ 'Expected weight to be > 0 for %s but received %d' % (value, weight))
+ self._weight = weight
+
+ def weight(self):
+ # type: () -> int
+ return self._weight
- # A set of all registered metrics
- ALL_METRICS = set() # type: Set[Hashable]
- PREFIX = "beam:metric:statecache:"
+ def value(self):
+ # type: () -> Any
+ return self._value
- def __init__(self):
- # type: () -> None
- self._context = threading.local()
- def initialize(self):
+class CacheAware(object):
+ def __init__(self):
# type: () -> None
+ pass
- """Needs to be called once per thread to initialize the local metrics
cache.
- """
- if hasattr(self._context, 'metrics'):
- return # Already initialized
- self._context.metrics = collections.defaultdict(int)
-
- def count(self, name):
- # type: (str) -> None
- self._context.metrics[name] += 1
-
- def hit_miss(self, total_name, hit_miss_name):
- # type: (str, str) -> None
- self._context.metrics[total_name] += 1
- self._context.metrics[hit_miss_name] += 1
+ def get_referents_for_cache(self):
+ # type: () -> List[Any]
- def get_monitoring_infos(self, cache_size, cache_capacity):
- # type: (int, int) -> List[metrics_pb2.MonitoringInfo]
+ """Returns the list of objects accounted during cache measurement."""
+ raise NotImplementedError()
- """Returns the metrics scoped to the current bundle."""
- metrics = self._context.metrics
- if len(metrics) == 0:
- # No metrics collected, do not report
- return []
- # Add all missing metrics which were not reported
- for key in Metrics.ALL_METRICS:
- if key not in metrics:
- metrics[key] = 0
- # Gauges which reflect the state since last queried
- gauges = [
- monitoring_infos.int64_gauge(self.PREFIX + name, val) for name,
- val in metrics.items()
- ]
- gauges.append(
- monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size))
- gauges.append(
- monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity))
- # Counters for the summary across all metrics
- counters = [
- monitoring_infos.int64_counter(self.PREFIX + name + '_total', val)
- for name,
- val in metrics.items()
- ]
- # Reinitialize metrics for this thread/bundle
- metrics.clear()
- return gauges + counters
- @staticmethod
- def counter_hit_miss(total_name, hit_name, miss_name):
- # type: (str, str, str) -> Callable[[CallableT], CallableT]
+def get_referents_for_cache(*objs):
+ # type: (List[Any]) -> List[Any]
- """Decorator for counting function calls and whether
- the return value equals None (=miss) or not (=hit)."""
- Metrics.ALL_METRICS.update([total_name, hit_name, miss_name])
+ """Returns the list of objects accounted during cache measurement.
- def decorator(function):
- # type: (CallableT) -> CallableT
- def reporter(self, *args, **kwargs):
- # type: (StateCache, Any, Any) -> Any
- value = function(self, *args, **kwargs)
- if value is None:
- self._metrics.hit_miss(total_name, miss_name)
- else:
- self._metrics.hit_miss(total_name, hit_name)
- return value
-
- return reporter # type: ignore[return-value]
-
- return decorator
-
- @staticmethod
- def counter(metric_name):
- # type: (str) -> Callable[[CallableT], CallableT]
-
- """Decorator for counting function calls."""
- Metrics.ALL_METRICS.add(metric_name)
-
- def decorator(function):
- # type: (CallableT) -> CallableT
- def reporter(self, *args, **kwargs):
- # type: (StateCache, Any, Any) -> Any
- self._metrics.count(metric_name)
- return function(self, *args, **kwargs)
-
- return reporter # type: ignore[return-value]
-
- return decorator
+ Users can inherit CacheAware to override which referrents should be
+ used when measuring the deep size of the object. The default is to
+ use gc.get_referents(*objs).
+ """
+ # print(objs)
+ rval = []
+ for obj in objs:
+ if isinstance(obj, CacheAware):
+ rval.extend(obj.get_referents_for_cache())
+ else:
+ rval.extend(gc.get_referents(obj))
+ return rval
class StateCache(object):
- """ Cache for Beam state access, scoped by state key and cache_token.
- Assumes a bag state implementation.
+ """Cache for Beam state access, scoped by state key and cache_token.
+ Assumes a bag state implementation.
- For a given state_key, caches a (cache_token, value) tuple and allows to
+ For a given state_key and cache_token, caches a value and allows to
a) read from the cache (get),
if the currently stored cache_token matches the provided
- a) write to the cache (put),
+ b) write to the cache (put),
storing the new value alongside with a cache token
- c) append to the currently cache item (extend),
- if the currently stored cache_token matches the provided
c) empty a cached element (clear),
if the currently stored cache_token matches the provided
- d) evict a cached element (evict)
+ d) invalidate a cached element (invalidate)
+ e) invalidate all cached elements (invalidate_all)
The operations on the cache are thread-safe for use by multiple workers.
- :arg max_entries The maximum number of entries to store in the cache.
- TODO Memory-based caching: https://github.com/apache/beam/issues/19857
+ :arg max_weight The maximum weight of entries to store in the cache in bytes.
"""
- def __init__(self, max_entries):
+ def __init__(self, max_weight):
# type: (int) -> None
- _LOGGER.info('Creating state cache with size %s', max_entries)
- self._missing = None
- self._cache = self.LRUCache[Tuple[bytes, Optional[bytes]],
- Any](max_entries, self._missing)
+ _LOGGER.info('Creating state cache with size %s', max_weight)
+ self._max_weight = max_weight
+ self._current_weight = 0
+ self._cache = collections.OrderedDict(
+ ) # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]],
WeightedValue]
+ self._hit_count = 0
+ self._miss_count = 0
+ self._evict_count = 0
self._lock = threading.RLock()
- self._metrics = Metrics()
- @Metrics.counter_hit_miss("get", "hit", "miss")
def get(self, state_key, cache_token):
# type: (bytes, Optional[bytes]) -> Any
assert cache_token and self.is_cache_enabled()
+ key = (state_key, cache_token)
with self._lock:
- return self._cache.get((state_key, cache_token))
+ value = self._cache.get(key, None)
+ if value is None:
+ self._miss_count += 1
+ return None
+ self._cache.move_to_end(key)
+ self._hit_count += 1
+ return value.value()
- @Metrics.counter("put")
def put(self, state_key, cache_token, value):
# type: (bytes, Optional[bytes], Any) -> None
assert cache_token and self.is_cache_enabled()
+ if not isinstance(value, WeightedValue):
+ weight = objsize.get_deep_size(
+ value, get_referents_func=get_referents_for_cache)
+ if weight <= 0:
+ _LOGGER.warning(
Review Comment:
Incorrect usage of WeigtedValue is easy to fix and is a programming bug
while users may not even know they are using get_deep_size which and fixing a
possibly deep and complex object type issue seems like we should log and
continue instead of causing the pipeline to get stuck.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]