This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f46382110d7 Fix incorrect object size calculation in StateCache 
(#23000) (#23886)
f46382110d7 is described below

commit f46382110d78a6a8f03491ea2165a24dc1d5739c
Author: Minbo Bae <49642083+baemi...@users.noreply.github.com>
AuthorDate: Mon Oct 31 19:48:36 2022 -0700

    Fix incorrect object size calculation in StateCache (#23000) (#23886)
---
 .../apache_beam/runners/worker/statecache.py       |  3 +-
 .../apache_beam/runners/worker/statecache_test.py  | 51 ++++++++++++++++++++++
 2 files changed, 53 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/runners/worker/statecache.py 
b/sdks/python/apache_beam/runners/worker/statecache.py
index e3f37fec114..dde4243057d 100644
--- a/sdks/python/apache_beam/runners/worker/statecache.py
+++ b/sdks/python/apache_beam/runners/worker/statecache.py
@@ -176,7 +176,7 @@ def get_deep_size(*objs):
 
   """Calculates the deep size of all the arguments in bytes."""
   return objsize.get_deep_size(
-      objs,
+      *objs,
       get_size_func=_size_func,
       get_referents_func=_get_referents_func,
       filter_func=_filter_func)
@@ -274,6 +274,7 @@ class StateCache(object):
     self._miss_count += 1
     loading_value = _LoadingValue()
     self._cache[key] = loading_value
+    self._current_weight += loading_value.weight()
 
     # Ensure that we unlock the lock while loading to allow for parallel gets
     self._lock.release()
diff --git a/sdks/python/apache_beam/runners/worker/statecache_test.py 
b/sdks/python/apache_beam/runners/worker/statecache_test.py
index 6850cb21284..a5d1ff2e01e 100644
--- a/sdks/python/apache_beam/runners/worker/statecache_test.py
+++ b/sdks/python/apache_beam/runners/worker/statecache_test.py
@@ -20,11 +20,13 @@
 
 import logging
 import re
+import sys
 import threading
 import time
 import unittest
 import weakref
 
+import objsize
 from hamcrest import assert_that
 from hamcrest import contains_string
 
@@ -32,6 +34,7 @@ from apache_beam.runners.worker.statecache import CacheAware
 from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.runners.worker.statecache import WeightedValue
 from apache_beam.runners.worker.statecache import _LoadingValue
+from apache_beam.runners.worker.statecache import get_deep_size
 
 
 class StateCacheTest(unittest.TestCase):
@@ -356,6 +359,54 @@ class StateCacheTest(unittest.TestCase):
             'used/max 1/5 MB, hit 100.00%, lookups 0, '
             'avg load time 0 ns, loads 0, evictions 0'))
 
+  def test_get_deep_size_builtin_objects(self):
+    """
+    `statecache.get_deep_copy` should work same with objsize unless the `objs`
+    has `CacheAware` or a filtered object. They should return the same size for
+    built-in objects.
+    """
+    primitive_test_objects = [
+        1,                    # int
+        2.0,                  # float
+        1+1j,                 # complex
+        True,                 # bool
+        'hello,world',        # str
+        b'\00\01\02',         # bytes
+    ]
+
+    collection_test_objects = [
+        [3, 4, 5],            # list
+        (6, 7),               # tuple
+        {'a', 'b', 'c'},      # set
+        {'k': 8, 'l': 9},     # dict
+    ]
+
+    for obj in primitive_test_objects:
+      self.assertEqual(
+          get_deep_size(obj),
+          objsize.get_deep_size(obj),
+          f'different size for obj: `{obj}`, type: {type(obj)}')
+      self.assertEqual(
+          get_deep_size(obj),
+          sys.getsizeof(obj),
+          f'different size for obj: `{obj}`, type: {type(obj)}')
+
+    for obj in collection_test_objects:
+      self.assertEqual(
+          get_deep_size(obj),
+          objsize.get_deep_size(obj),
+          f'different size for obj: `{obj}`, type: {type(obj)}')
+
+  def test_current_weight_between_get_and_put(self):
+    value = 1234567
+    get_cache = StateCache(100)
+    get_cache.get("key", lambda k: value)
+
+    put_cache = StateCache(100)
+    put_cache.put("key", value)
+
+    self.assertEqual(get_cache._current_weight, put_cache._current_weight)
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to