rohitsinha54 commented on code in PR #33196:
URL: https://github.com/apache/beam/pull/33196#discussion_r1877471925


##########
sdks/python/apache_beam/metrics/cells.py:
##########
@@ -630,3 +662,216 @@ def singleton(value: str) -> "StringSetData":
   @staticmethod
   def identity_element() -> "StringSetData":
     return StringSetData()
+
+
+class _BoundedTrieNode(object):
+  def __init__(self):
+    # invariant: size = len(self.flattened()) = min(1, sum(size of children))
+    self._size = 1
+    self._children: Optional[dict[str, '_BoundedTrieNode']] = {}
+    self._truncated = False
+
+  def to_proto(self) -> metrics_pb2.BoundedTrieNode:
+    return metrics_pb2.BoundedTrieNode(
+        truncated=self._truncated,
+        children={
+            name: child.to_proto()
+            for name, child in self._children.items()
+        } if self._children else None)
+
+  @staticmethod
+  def from_proto(proto: metrics_pb2.BoundedTrieNode) -> '_BoundedTrieNode':
+    node = _BoundedTrieNode()
+    if proto.truncated:
+      node._truncated = True
+      node._children = None
+    else:
+      node._children = {
+          name: _BoundedTrieNode.from_proto(child)
+          for name,
+          child in proto.children.items()
+      }
+      node._size = min(1, sum(child._size for child in 
node._children.values()))
+    return node
+
+  def size(self):
+    return self._size
+
+  def add(self, segments) -> int:
+    if self._truncated or not segments:
+      return 0
+    head, *tail = segments
+    was_empty = not self._children
+    child = self._children.get(head, None)  # type: ignore[union-attr]
+    if child is None:
+      child = self._children[head] = _BoundedTrieNode()  # type: ignore[index]
+      delta = 0 if was_empty else 1
+    else:
+      delta = 0
+    if tail:
+      delta += child.add(tail)
+    self._size += delta
+    return delta
+
+  def add_all(self, segments_iter):
+    return sum(self.add(segments) for segments in segments_iter)
+
+  def trim(self) -> int:
+    if not self._children:
+      return 0
+    max_child = max(self._children.values(), key=lambda child: child._size)
+    if max_child._size == 1:
+      delta = 1 - self._size
+      self._truncated = True
+      self._children = None
+    else:
+      delta = max_child.trim()
+    self._size += delta
+    return delta
+
+  def merge(self, other: '_BoundedTrieNode') -> int:
+    if self._truncated:
+      delta = 0
+    elif other._truncated:
+      delta = 1 - self._size
+      self._truncated = True
+      self._children = None
+    elif not other._children:
+      delta = 0
+    elif not self._children:
+      self._children = other._children
+      delta = self._size - other._size
+    else:
+      delta = 0
+      other_child: '_BoundedTrieNode'
+      self_child: Optional['_BoundedTrieNode']
+      for prefix, other_child in other._children.items():
+        self_child = self._children.get(prefix, None)
+        if self_child is None:
+          self._children[prefix] = other_child
+          delta += other_child._size
+        else:
+          delta += self_child.merge(other_child)
+    self._size += delta
+    return delta
+
+  def flattened(self):
+    if self._truncated:
+      yield (True, )
+    elif not self._children:
+      yield (False, )
+    else:
+      for prefix, child in sorted(self._children.items()):
+        for flattened in child.flattened():
+          yield (prefix, ) + flattened
+
+  def __hash__(self):
+    return self._truncated or hash(sorted(self._children.items()))
+
+  def __eq__(self, other):
+    if isinstance(other, _BoundedTrieNode):
+      return (
+          self._truncated == other._truncated and
+          self._children == other._children)
+    else:
+      return False
+
+  def __repr__(self):
+    return repr(set(''.join(str(s) for s in t) for t in self.flattened()))
+
+
+class BoundedTrieData(object):
+  _DEFAULT_BOUND = 100
+
+  def __init__(self, *, root=None, singleton=None, bound=_DEFAULT_BOUND):
+    assert singleton is None or root is None
+    self._singleton = singleton
+    self._root = root
+    self._bound = bound
+
+  def to_proto(self) -> metrics_pb2.BoundedTrie:
+    return metrics_pb2.BoundedTrie(
+        bound=self._bound,
+        singleton=self._singlton if self._singleton else None,
+        root=self._root.to_proto() if self._root else None)
+
+  @staticmethod
+  def from_proto(proto: metrics_pb2.BoundedTrie) -> 'BoundedTrieData':
+    return BoundedTrieData(
+        bound=proto.bound,
+        singleton=tuple(proto.singleton) if proto.singleton else None,
+        root=_BoundedTrieNode.from_proto(proto.root) if proto.root else None)
+
+  def as_trie(self):
+    if self._root is not None:
+      return self._root
+    else:
+      root = _BoundedTrieNode()
+      if self._singleton is not None:
+        root.add(self._singleton)
+      return root
+
+  def __eq__(self, other: object) -> bool:
+    if isinstance(other, BoundedTrieData):
+      return self.as_trie() == other.as_trie()
+    else:
+      return False
+
+  def __hash__(self) -> int:
+    return hash(self.as_trie())
+
+  def __repr__(self) -> str:
+    return 'BoundedTrieData({})'.format(self.as_trie())
+
+  def get_cumulative(self) -> "BoundedTrieData":
+    return copy.deepcopy(self)
+
+  def get_result(self) -> set[tuple]:
+    if self._root is None:
+      if self._singleton is None:
+        return set()
+      else:
+        return set([self._singleton + (False, )])
+    else:
+      return set(self._root.flattened())
+
+  def add(self, segments):
+    if self._root is None and self._singleton is None:
+      self._singleton = segments
+    elif self._singleton is not None and self._singleton == segments:

Review Comment:
   Nice!



##########
sdks/python/apache_beam/metrics/cells.py:
##########
@@ -314,6 +315,35 @@ def to_runner_api_monitoring_info_impl(self, name, 
transform_id):
         ptransform=transform_id)
 
 
+class BoundedTrieCell(AbstractMetricCell):
+  """For internal use only; no backwards-compatibility guarantees.
+
+  Tracks the current value for a StringSet metric.

Review Comment:
   typo StringSet -> BoundedTrie



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to