lukecwik commented on code in PR #22924:
URL: https://github.com/apache/beam/pull/22924#discussion_r962020219
##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -20,245 +20,180 @@
# mypy: disallow-untyped-defs
import collections
+import gc
import logging
import threading
-from typing import TYPE_CHECKING
from typing import Any
-from typing import Callable
-from typing import Generic
-from typing import Hashable
from typing import List
from typing import Optional
-from typing import Set
from typing import Tuple
-from typing import TypeVar
-from apache_beam.metrics import monitoring_infos
-
-if TYPE_CHECKING:
- from apache_beam.portability.api import metrics_pb2
+import objsize
_LOGGER = logging.getLogger(__name__)
-CallableT = TypeVar('CallableT', bound='Callable')
-KT = TypeVar('KT')
-VT = TypeVar('VT')
+class WeightedValue(object):
+ """Value type that stores corresponding weight.
-class Metrics(object):
- """Metrics container for state cache metrics."""
+ :arg value The value to be stored.
+ :arg weight The associated weight of the value. If unspecified, the objects
+ size will be used.
+ """
+ def __init__(self, value, weight):
+ # type: (Any, int) -> None
+ self._value = value
+ if weight <= 0:
+ raise ValueError(
+ 'Expected weight to be > 0 for %s but received %d' % (value, weight))
+ self._weight = weight
+
+ def weight(self):
+ # type: () -> int
+ return self._weight
- # A set of all registered metrics
- ALL_METRICS = set() # type: Set[Hashable]
- PREFIX = "beam:metric:statecache:"
+ def value(self):
+ # type: () -> Any
+ return self._value
- def __init__(self):
- # type: () -> None
- self._context = threading.local()
- def initialize(self):
+class CacheAware(object):
+ def __init__(self):
# type: () -> None
+ pass
- """Needs to be called once per thread to initialize the local metrics
cache.
- """
- if hasattr(self._context, 'metrics'):
- return # Already initialized
- self._context.metrics = collections.defaultdict(int)
-
- def count(self, name):
- # type: (str) -> None
- self._context.metrics[name] += 1
-
- def hit_miss(self, total_name, hit_miss_name):
- # type: (str, str) -> None
- self._context.metrics[total_name] += 1
- self._context.metrics[hit_miss_name] += 1
+ def get_referents_for_cache(self):
+ # type: () -> List[Any]
- def get_monitoring_infos(self, cache_size, cache_capacity):
- # type: (int, int) -> List[metrics_pb2.MonitoringInfo]
+ """Returns the list of objects accounted during cache measurement."""
+ raise NotImplementedError()
- """Returns the metrics scoped to the current bundle."""
- metrics = self._context.metrics
- if len(metrics) == 0:
- # No metrics collected, do not report
- return []
- # Add all missing metrics which were not reported
- for key in Metrics.ALL_METRICS:
- if key not in metrics:
- metrics[key] = 0
- # Gauges which reflect the state since last queried
- gauges = [
- monitoring_infos.int64_gauge(self.PREFIX + name, val) for name,
- val in metrics.items()
- ]
- gauges.append(
- monitoring_infos.int64_gauge(self.PREFIX + 'size', cache_size))
- gauges.append(
- monitoring_infos.int64_gauge(self.PREFIX + 'capacity', cache_capacity))
- # Counters for the summary across all metrics
- counters = [
- monitoring_infos.int64_counter(self.PREFIX + name + '_total', val)
- for name,
- val in metrics.items()
- ]
- # Reinitialize metrics for this thread/bundle
- metrics.clear()
- return gauges + counters
- @staticmethod
- def counter_hit_miss(total_name, hit_name, miss_name):
- # type: (str, str, str) -> Callable[[CallableT], CallableT]
+def get_referents_for_cache(*objs):
+ # type: (List[Any]) -> List[Any]
- """Decorator for counting function calls and whether
- the return value equals None (=miss) or not (=hit)."""
- Metrics.ALL_METRICS.update([total_name, hit_name, miss_name])
+ """Returns the list of objects accounted during cache measurement.
- def decorator(function):
- # type: (CallableT) -> CallableT
- def reporter(self, *args, **kwargs):
- # type: (StateCache, Any, Any) -> Any
- value = function(self, *args, **kwargs)
- if value is None:
- self._metrics.hit_miss(total_name, miss_name)
- else:
- self._metrics.hit_miss(total_name, hit_name)
- return value
-
- return reporter # type: ignore[return-value]
-
- return decorator
-
- @staticmethod
- def counter(metric_name):
- # type: (str) -> Callable[[CallableT], CallableT]
-
- """Decorator for counting function calls."""
- Metrics.ALL_METRICS.add(metric_name)
-
- def decorator(function):
- # type: (CallableT) -> CallableT
- def reporter(self, *args, **kwargs):
- # type: (StateCache, Any, Any) -> Any
- self._metrics.count(metric_name)
- return function(self, *args, **kwargs)
-
- return reporter # type: ignore[return-value]
-
- return decorator
+ Users can inherit CacheAware to override which referrents should be
+ used when measuring the deep size of the object. The default is to
+ use gc.get_referents(*objs).
+ """
+ # print(objs)
Review Comment:
Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]