ryanthompson591 commented on code in PR #23046:
URL: https://github.com/apache/beam/pull/23046#discussion_r982763780


##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -111,28 +141,104 @@ def __init__(self, max_weight):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], 
WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:
       value = self._cache.get(key, None)
-      if value is None:
+      if value is None or isinstance(value, _LoadingValue):
         self._miss_count += 1
         return None
+
+      self._cache.move_to_end(key)
+      self._hit_count += 1
+    return value.value()
+
+  def get(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> Any
+    assert self.is_cache_enabled() and callable(loading_fn)
+
+    self._lock.acquire()
+    value = self._cache.get(key, None)
+    if value is None:
+      self._miss_count += 1
+      loading_value = _LoadingValue()
+      self._cache[key] = loading_value
+
+      # Ensure that we unlock the lock while loading to allow for parallel gets
+      self._lock.release()
+
+      start_time_ns = time.time_ns()
+      loading_value.load(key, loading_fn)
+      elapsed_time_ns = time.time_ns() - start_time_ns
+
+      try:
+        value = loading_value.value()
+      except Exception as err:
+        # If loading failed then delete the value from the cache allowing for
+        # the next lookup to possibly succeed.
+        with self._lock:
+          self._load_count += 1
+          self._load_time_ns += elapsed_time_ns
+          # Don't remove values that have already been replaced with a 
different
+          # value by a put/invalidate that occurred concurrently with the load.
+          # The put/invalidate will have been responsible for updating the
+          # cache weight appropriately already.
+          old_value = self._cache.get(key, None)
+          if old_value is not loading_value:
+            raise err
+          self._current_weight -= loading_value.weight()
+          del self._cache[key]
+        raise err
+
+      # Replace the value in the cache with a weighted value now that the
+      # loading has completed successfully.
+      weight = objsize.get_deep_size(
+          value, get_referents_func=get_referents_for_cache)
+      if weight <= 0:
+        _LOGGER.warning(
+            'Expected object size to be >= 0 for %s but received %d.',
+            value,
+            weight)
+        weight = 8
+      value = WeightedValue(value, weight)
+      with self._lock:
+        self._load_count += 1
+        self._load_time_ns += elapsed_time_ns
+        # Don't replace values that have already been replaced with a different
+        # value by a put/invalidate that occurred concurrently with the load.
+        # The put/invalidate will have been responsible for updating the
+        # cache weight appropriately already.
+        old_value = self._cache.get(key, None)
+        if old_value is not loading_value:
+          return value.value()

Review Comment:
   what could make this happen?



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -111,28 +141,104 @@ def __init__(self, max_weight):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], 
WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:
       value = self._cache.get(key, None)
-      if value is None:
+      if value is None or isinstance(value, _LoadingValue):
         self._miss_count += 1
         return None
+
+      self._cache.move_to_end(key)
+      self._hit_count += 1
+    return value.value()
+
+  def get(self, key, loading_fn):
+    # type: (Any, Callable[[Any], Any]) -> Any
+    assert self.is_cache_enabled() and callable(loading_fn)
+
+    self._lock.acquire()
+    value = self._cache.get(key, None)
+    if value is None:
+      self._miss_count += 1
+      loading_value = _LoadingValue()
+      self._cache[key] = loading_value
+
+      # Ensure that we unlock the lock while loading to allow for parallel gets
+      self._lock.release()
+
+      start_time_ns = time.time_ns()
+      loading_value.load(key, loading_fn)
+      elapsed_time_ns = time.time_ns() - start_time_ns
+
+      try:
+        value = loading_value.value()
+      except Exception as err:
+        # If loading failed then delete the value from the cache allowing for
+        # the next lookup to possibly succeed.
+        with self._lock:
+          self._load_count += 1
+          self._load_time_ns += elapsed_time_ns
+          # Don't remove values that have already been replaced with a 
different
+          # value by a put/invalidate that occurred concurrently with the load.
+          # The put/invalidate will have been responsible for updating the
+          # cache weight appropriately already.
+          old_value = self._cache.get(key, None)
+          if old_value is not loading_value:
+            raise err
+          self._current_weight -= loading_value.weight()
+          del self._cache[key]
+        raise err
+
+      # Replace the value in the cache with a weighted value now that the
+      # loading has completed successfully.
+      weight = objsize.get_deep_size(
+          value, get_referents_func=get_referents_for_cache)
+      if weight <= 0:
+        _LOGGER.warning(
+            'Expected object size to be >= 0 for %s but received %d.',
+            value,
+            weight)
+        weight = 8
+      value = WeightedValue(value, weight)
+      with self._lock:
+        self._load_count += 1
+        self._load_time_ns += elapsed_time_ns
+        # Don't replace values that have already been replaced with a different
+        # value by a put/invalidate that occurred concurrently with the load.
+        # The put/invalidate will have been responsible for updating the
+        # cache weight appropriately already.
+        old_value = self._cache.get(key, None)
+        if old_value is not loading_value:
+          return value.value()
+
+        self._current_weight -= loading_value.weight()
+        self._cache[key] = value
+        self._current_weight += value.weight()
+        while self._current_weight > self._max_weight:
+          (_, weighted_value) = self._cache.popitem(last=False)
+          self._current_weight -= weighted_value.weight()
+          self._evict_count += 1
+
+    else:

Review Comment:
   This is such a long method that I was trying to see where this else matched. 
 Maybe add a comment like value is not None.
   
   Or potentially just put this at the top
   
   with self._lock:
     value = self._cache.get(key, None)
     if value is not None:
       self._cache.move_to_end(key)
       self._hit_count += 1
       return value.value()
     ## all the code for if the loading needs to happen.
   
   It's up to you.
   



##########
sdks/python/apache_beam/runners/worker/statecache.py:
##########
@@ -111,28 +141,104 @@ def __init__(self, max_weight):
     self._max_weight = max_weight
     self._current_weight = 0
     self._cache = collections.OrderedDict(
-    )  # type: collections.OrderedDict[Tuple[bytes, Optional[bytes]], 
WeightedValue]
+    )  # type: collections.OrderedDict[Any, WeightedValue]
     self._hit_count = 0
     self._miss_count = 0
     self._evict_count = 0
+    self._load_time_ns = 0
+    self._load_count = 0
     self._lock = threading.RLock()
 
-  def get(self, state_key, cache_token):
-    # type: (bytes, Optional[bytes]) -> Any
-    assert cache_token and self.is_cache_enabled()
-    key = (state_key, cache_token)
+  def peek(self, key):
+    # type: (Any) -> Any
+    assert self.is_cache_enabled()
     with self._lock:
       value = self._cache.get(key, None)
-      if value is None:
+      if value is None or isinstance(value, _LoadingValue):
         self._miss_count += 1
         return None
+
+      self._cache.move_to_end(key)

Review Comment:
   what I was asking about earlier was this line here.  I'm just validating 
that we do want peek to modify the cache by counting a peek as a hit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to