This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 17ef888a783 Add StringSet metrics to Python SDK (#31969)
17ef888a783 is described below

commit 17ef888a7830f304fe617a904418baf5923aae6e
Author: Yi Hu <[email protected]>
AuthorDate: Tue Jul 30 10:01:45 2024 -0400

    Add StringSet metrics to Python SDK (#31969)
    
    * Add StringSet metrics to Python SDK
    
    * Address comments
    
    * Use string_set everywhere
    
    * fix leftover SET_STRING_TYPE -> STRING_SET_TYPE
---
 .../java/org/apache/beam/sdk/metrics/Metrics.java  | 10 +--
 sdks/python/apache_beam/metrics/cells.pxd          |  6 ++
 sdks/python/apache_beam/metrics/cells.py           | 75 ++++++++++++++++++++++
 sdks/python/apache_beam/metrics/cells_test.py      | 24 +++++++
 sdks/python/apache_beam/metrics/execution.py       | 20 +++++-
 sdks/python/apache_beam/metrics/execution_test.py  |  9 +++
 sdks/python/apache_beam/metrics/metric.py          | 31 ++++++++-
 sdks/python/apache_beam/metrics/metricbase.py      | 16 ++++-
 .../python/apache_beam/metrics/monitoring_infos.py | 53 +++++++++++++--
 .../apache_beam/metrics/monitoring_infos_test.py   | 25 ++++++++
 .../runners/dataflow/dataflow_metrics.py           | 15 ++++-
 .../apache_beam/runners/direct/direct_metrics.py   | 15 ++++-
 .../runners/direct/direct_runner_test.py           |  9 +++
 .../runners/portability/fn_api_runner/fn_runner.py | 11 +++-
 .../portability/fn_api_runner/fn_runner_test.py    |  9 ++-
 .../runners/portability/portable_metrics.py        | 13 ++--
 .../runners/portability/portable_runner.py         |  5 +-
 17 files changed, 315 insertions(+), 31 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java
index dc80a66c055..a963015e98a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/metrics/Metrics.java
@@ -93,18 +93,12 @@ public class Metrics {
     return new DelegatingGauge(MetricName.named(namespace, name));
   }
 
-  /**
-   * Create a metric that can have its new value set, and is aggregated by 
taking the last reported
-   * value.
-   */
+  /** Create a metric that accumulates and reports set of unique string 
values. */
   public static StringSet stringSet(String namespace, String name) {
     return new DelegatingStringSet(MetricName.named(namespace, name));
   }
 
-  /**
-   * Create a metric that can have its new value set, and is aggregated by 
taking the last reported
-   * value.
-   */
+  /** Create a metric that accumulates and reports set of unique string 
values. */
   public static StringSet stringSet(Class<?> namespace, String name) {
     return new DelegatingStringSet(MetricName.named(namespace, name));
   }
diff --git a/sdks/python/apache_beam/metrics/cells.pxd 
b/sdks/python/apache_beam/metrics/cells.pxd
index 0eaa890c02a..a8f4003d898 100644
--- a/sdks/python/apache_beam/metrics/cells.pxd
+++ b/sdks/python/apache_beam/metrics/cells.pxd
@@ -44,6 +44,12 @@ cdef class GaugeCell(MetricCell):
   cdef readonly object data
 
 
+cdef class StringSetCell(MetricCell):
+  cdef readonly set data
+
+  cdef inline bint _update(self, value) except -1
+
+
 cdef class DistributionData(object):
   cdef readonly libc.stdint.int64_t sum
   cdef readonly libc.stdint.int64_t count
diff --git a/sdks/python/apache_beam/metrics/cells.py 
b/sdks/python/apache_beam/metrics/cells.py
index 53b6fc84959..d836d4cee58 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -268,6 +268,62 @@ class GaugeCell(MetricCell):
         ptransform=transform_id)
 
 
+class StringSetCell(MetricCell):
+  """For internal use only; no backwards-compatibility guarantees.
+
+  Tracks the current value for a StringSet metric.
+
+  Each cell tracks the state of a metric independently per context per bundle.
+  Therefore, each metric has a different cell in each bundle, that is later
+  aggregated.
+
+  This class is thread safe.
+  """
+  def __init__(self, *args):
+    super().__init__(*args)
+    self.data = StringSetAggregator.identity_element()
+
+  def add(self, value):
+    self.update(value)
+
+  def update(self, value):
+    # type: (str) -> None
+    if cython.compiled:
+      # We will hold the GIL throughout the entire _update.
+      self._update(value)
+    else:
+      with self._lock:
+        self._update(value)
+
+  def _update(self, value):
+    self.data.add(value)
+
+  def get_cumulative(self):
+    # type: () -> set
+    with self._lock:
+      return set(self.data)
+
+  def combine(self, other):
+    # type: (StringSetCell) -> StringSetCell
+    combined = StringSetAggregator().combine(self.data, other.data)
+    result = StringSetCell()
+    result.data = combined
+    return result
+
+  def to_runner_api_monitoring_info_impl(self, name, transform_id):
+    from apache_beam.metrics import monitoring_infos
+
+    return monitoring_infos.user_set_string(
+        name.namespace,
+        name.name,
+        self.get_cumulative(),
+        ptransform=transform_id)
+
+  def reset(self):
+    # type: () -> None
+    self.data = StringSetAggregator.identity_element()
+
+
 class DistributionResult(object):
   """The result of a Distribution metric."""
   def __init__(self, data):
@@ -553,3 +609,22 @@ class GaugeAggregator(MetricAggregator):
   def result(self, x):
     # type: (GaugeData) -> GaugeResult
     return GaugeResult(x.get_cumulative())
+
+
+class StringSetAggregator(MetricAggregator):
+  @staticmethod
+  def identity_element():
+    # type: () -> set
+    return set()
+
+  def combine(self, x, y):
+    # type: (set, set) -> set
+    if len(x) == 0:
+      return y
+    elif len(y) == 0:
+      return x
+    else:
+      return set.union(x, y)
+
+  def result(self, x):
+    return x
diff --git a/sdks/python/apache_beam/metrics/cells_test.py 
b/sdks/python/apache_beam/metrics/cells_test.py
index 3d4d81c3d12..052ff051bf9 100644
--- a/sdks/python/apache_beam/metrics/cells_test.py
+++ b/sdks/python/apache_beam/metrics/cells_test.py
@@ -25,6 +25,7 @@ from apache_beam.metrics.cells import DistributionCell
 from apache_beam.metrics.cells import DistributionData
 from apache_beam.metrics.cells import GaugeCell
 from apache_beam.metrics.cells import GaugeData
+from apache_beam.metrics.cells import StringSetCell
 from apache_beam.metrics.metricbase import MetricName
 
 
@@ -169,5 +170,28 @@ class TestGaugeCell(unittest.TestCase):
     self.assertGreater(mi.start_time.seconds, 0)
 
 
+class TestStringSetCell(unittest.TestCase):
+  def test_not_leak_mutable_set(self):
+    c = StringSetCell()
+    c.add('test')
+    c.add('another')
+    s = c.get_cumulative()
+    self.assertEqual(s, set(('test', 'another')))
+    s.add('yet another')
+    self.assertEqual(c.get_cumulative(), set(('test', 'another')))
+
+  def test_combine_appropriately(self):
+    s1 = StringSetCell()
+    s1.add('1')
+    s1.add('2')
+
+    s2 = StringSetCell()
+    s2.add('1')
+    s2.add('3')
+
+    result = s2.combine(s1)
+    self.assertEqual(result.data, set(('1', '2', '3')))
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/metrics/execution.py 
b/sdks/python/apache_beam/metrics/execution.py
index 4202f7996c7..74890b822bc 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -48,6 +48,7 @@ from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.cells import CounterCell
 from apache_beam.metrics.cells import DistributionCell
 from apache_beam.metrics.cells import GaugeCell
+from apache_beam.metrics.cells import StringSetCell
 from apache_beam.runners.worker import statesampler
 from apache_beam.runners.worker.statesampler import get_current_tracker
 
@@ -259,6 +260,12 @@ class MetricsContainer(object):
         GaugeCell,
         self.get_metric_cell(_TypedMetricName(GaugeCell, metric_name)))
 
+  def get_string_set(self, metric_name):
+    # type: (MetricName) -> StringSetCell
+    return cast(
+        StringSetCell,
+        self.get_metric_cell(_TypedMetricName(StringSetCell, metric_name)))
+
   def get_metric_cell(self, typed_metric_name):
     # type: (_TypedMetricName) -> MetricCell
     cell = self.metrics.get(typed_metric_name, None)
@@ -292,7 +299,13 @@ class MetricsContainer(object):
         v in self.metrics.items() if k.cell_type == GaugeCell
     }
 
-    return MetricUpdates(counters, distributions, gauges)
+    string_sets = {
+        MetricKey(self.step_name, k.metric_name): v.get_cumulative()
+        for k,
+        v in self.metrics.items() if k.cell_type == StringSetCell
+    }
+
+    return MetricUpdates(counters, distributions, gauges, string_sets)
 
   def to_runner_api(self):
     return [
@@ -344,7 +357,8 @@ class MetricUpdates(object):
       self,
       counters=None,  # type: Optional[Dict[MetricKey, int]]
       distributions=None,  # type: Optional[Dict[MetricKey, DistributionData]]
-      gauges=None  # type: Optional[Dict[MetricKey, GaugeData]]
+      gauges=None,  # type: Optional[Dict[MetricKey, GaugeData]]
+      string_sets=None,  # type: Optional[Dict[MetricKey, set]]
   ):
     # type: (...) -> None
 
@@ -354,7 +368,9 @@ class MetricUpdates(object):
       counters: Dictionary of MetricKey:MetricUpdate updates.
       distributions: Dictionary of MetricKey:MetricUpdate objects.
       gauges: Dictionary of MetricKey:MetricUpdate objects.
+      string_sets: Dictionary of MetricKey:MetricUpdate objects.
     """
     self.counters = counters or {}
     self.distributions = distributions or {}
     self.gauges = gauges or {}
+    self.string_sets = string_sets or {}
diff --git a/sdks/python/apache_beam/metrics/execution_test.py 
b/sdks/python/apache_beam/metrics/execution_test.py
index a888376e709..b157aeb20e9 100644
--- a/sdks/python/apache_beam/metrics/execution_test.py
+++ b/sdks/python/apache_beam/metrics/execution_test.py
@@ -17,6 +17,7 @@
 
 # pytype: skip-file
 
+import functools
 import unittest
 
 from apache_beam.metrics.execution import MetricKey
@@ -88,10 +89,12 @@ class TestMetricsContainer(unittest.TestCase):
       distribution = mc.get_distribution(
           MetricName('namespace', 'name{}'.format(i)))
       gauge = mc.get_gauge(MetricName('namespace', 'name{}'.format(i)))
+      str_set = mc.get_string_set(MetricName('namespace', 'name{}'.format(i)))
 
       counter.inc(i)
       distribution.update(i)
       gauge.set(i)
+      str_set.add(str(i % 7))
       all_values.append(i)
 
     # Retrieve ALL updates.
@@ -99,6 +102,7 @@ class TestMetricsContainer(unittest.TestCase):
     self.assertEqual(len(cumulative.counters), 10)
     self.assertEqual(len(cumulative.distributions), 10)
     self.assertEqual(len(cumulative.gauges), 10)
+    self.assertEqual(len(cumulative.string_sets), 10)
 
     self.assertEqual(
         set(all_values), {v
@@ -106,6 +110,11 @@ class TestMetricsContainer(unittest.TestCase):
     self.assertEqual(
         set(all_values), {v.value
                           for _, v in cumulative.gauges.items()})
+    self.assertEqual({str(i % 7)
+                      for i in all_values},
+                     functools.reduce(
+                         set.union,
+                         (v for _, v in cumulative.string_sets.items())))
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/metrics/metric.py 
b/sdks/python/apache_beam/metrics/metric.py
index 3722af6dc17..77cafb8bd64 100644
--- a/sdks/python/apache_beam/metrics/metric.py
+++ b/sdks/python/apache_beam/metrics/metric.py
@@ -44,6 +44,7 @@ from apache_beam.metrics.metricbase import Counter
 from apache_beam.metrics.metricbase import Distribution
 from apache_beam.metrics.metricbase import Gauge
 from apache_beam.metrics.metricbase import MetricName
+from apache_beam.metrics.metricbase import StringSet
 
 if TYPE_CHECKING:
   from apache_beam.metrics.execution import MetricKey
@@ -115,6 +116,23 @@ class Metrics(object):
     namespace = Metrics.get_namespace(namespace)
     return Metrics.DelegatingGauge(MetricName(namespace, name))
 
+  @staticmethod
+  def string_set(
+      namespace: Union[Type, str], name: str) -> 'Metrics.DelegatingStringSet':
+    """Obtains or creates a String set metric.
+
+    String set metrics are restricted to string values.
+
+    Args:
+      namespace: A class or string that gives the namespace to a metric
+      name: A string that gives a unique name to a metric
+
+    Returns:
+      A StringSet object.
+    """
+    namespace = Metrics.get_namespace(namespace)
+    return Metrics.DelegatingStringSet(MetricName(namespace, name))
+
   class DelegatingCounter(Counter):
     """Metrics Counter that Delegates functionality to MetricsEnvironment."""
     def __init__(
@@ -138,11 +156,18 @@ class Metrics(object):
       super().__init__(metric_name)
       self.set = MetricUpdater(cells.GaugeCell, metric_name)  # type: 
ignore[assignment]
 
+  class DelegatingStringSet(StringSet):
+    """Metrics StringSet that Delegates functionality to MetricsEnvironment."""
+    def __init__(self, metric_name: MetricName) -> None:
+      super().__init__(metric_name)
+      self.add = MetricUpdater(cells.StringSetCell, metric_name)  # type: 
ignore[assignment]
+
 
 class MetricResults(object):
   COUNTERS = "counters"
   DISTRIBUTIONS = "distributions"
   GAUGES = "gauges"
+  STRINGSETS = "string_sets"
 
   @staticmethod
   def _matches_name(filter: 'MetricsFilter', metric_key: 'MetricKey') -> bool:
@@ -207,11 +232,13 @@ class MetricResults(object):
         {
           "counters": [MetricResult(counter_key, committed, attempted), ...],
           "distributions": [MetricResult(dist_key, committed, attempted), ...],
-          "gauges": []  // Empty list if nothing matched the filter.
+          "gauges": [],  // Empty list if nothing matched the filter.
+          "string_sets": [] [MetricResult(string_set_key, committed, 
attempted),
+                            ...]
         }
 
     The committed / attempted values are DistributionResult / GaugeResult / int
-    objects.
+    / set objects.
     """
     raise NotImplementedError
 
diff --git a/sdks/python/apache_beam/metrics/metricbase.py 
b/sdks/python/apache_beam/metrics/metricbase.py
index 53da01f3955..7819dbb093a 100644
--- a/sdks/python/apache_beam/metrics/metricbase.py
+++ b/sdks/python/apache_beam/metrics/metricbase.py
@@ -38,7 +38,13 @@ from typing import Dict
 from typing import Optional
 
 __all__ = [
-    'Metric', 'Counter', 'Distribution', 'Gauge', 'Histogram', 'MetricName'
+    'Metric',
+    'Counter',
+    'Distribution',
+    'Gauge',
+    'StringSet',
+    'Histogram',
+    'MetricName'
 ]
 
 
@@ -138,6 +144,14 @@ class Gauge(Metric):
     raise NotImplementedError
 
 
+class StringSet(Metric):
+  """StringSet Metric interface.
+
+  Reports set of unique string values during pipeline execution.."""
+  def add(self, value):
+    raise NotImplementedError
+
+
 class Histogram(Metric):
   """Histogram Metric interface.
 
diff --git a/sdks/python/apache_beam/metrics/monitoring_infos.py 
b/sdks/python/apache_beam/metrics/monitoring_infos.py
index 7bc7cced280..72640c8f92a 100644
--- a/sdks/python/apache_beam/metrics/monitoring_infos.py
+++ b/sdks/python/apache_beam/metrics/monitoring_infos.py
@@ -50,8 +50,13 @@ USER_COUNTER_URN = 
common_urns.monitoring_info_specs.USER_SUM_INT64.spec.urn
 USER_DISTRIBUTION_URN = (
     common_urns.monitoring_info_specs.USER_DISTRIBUTION_INT64.spec.urn)
 USER_GAUGE_URN = common_urns.monitoring_info_specs.USER_LATEST_INT64.spec.urn
-USER_METRIC_URNS = set(
-    [USER_COUNTER_URN, USER_DISTRIBUTION_URN, USER_GAUGE_URN])
+USER_STRING_SET_URN = 
common_urns.monitoring_info_specs.USER_SET_STRING.spec.urn
+USER_METRIC_URNS = set([
+    USER_COUNTER_URN,
+    USER_DISTRIBUTION_URN,
+    USER_GAUGE_URN,
+    USER_STRING_SET_URN
+])
 WORK_REMAINING_URN = common_urns.monitoring_info_specs.WORK_REMAINING.spec.urn
 WORK_COMPLETED_URN = common_urns.monitoring_info_specs.WORK_COMPLETED.spec.urn
 DATA_CHANNEL_READ_INDEX = (
@@ -67,10 +72,12 @@ DISTRIBUTION_INT64_TYPE = (
     common_urns.monitoring_info_types.DISTRIBUTION_INT64_TYPE.urn)
 LATEST_INT64_TYPE = common_urns.monitoring_info_types.LATEST_INT64_TYPE.urn
 PROGRESS_TYPE = common_urns.monitoring_info_types.PROGRESS_TYPE.urn
+STRING_SET_TYPE = common_urns.monitoring_info_types.SET_STRING_TYPE.urn
 
 COUNTER_TYPES = set([SUM_INT64_TYPE])
 DISTRIBUTION_TYPES = set([DISTRIBUTION_INT64_TYPE])
 GAUGE_TYPES = set([LATEST_INT64_TYPE])
+STRING_SET_TYPES = set([STRING_SET_TYPE])
 
 # TODO(migryz) extract values from beam_fn_api.proto::MonitoringInfoLabels
 PCOLLECTION_LABEL = (
@@ -149,6 +156,14 @@ def extract_distribution(monitoring_info_proto):
       coders.VarIntCoder(), monitoring_info_proto.payload)
 
 
+def extract_string_set_value(monitoring_info_proto):
+  if not is_string_set(monitoring_info_proto):
+    raise ValueError('Unsupported type %s' % monitoring_info_proto.type)
+
+  coder = coders.IterableCoder(coders.StrUtf8Coder())
+  return set(coder.decode(monitoring_info_proto.payload))
+
+
 def create_labels(ptransform=None, namespace=None, name=None, 
pcollection=None):
   """Create the label dictionary based on the provided values.
 
@@ -243,8 +258,8 @@ def int64_user_gauge(namespace, name, metric, 
ptransform=None):
   """Return the gauge monitoring info for the URN, metric and labels.
 
   Args:
-    namespace: User-defined namespace of counter.
-    name: Name of counter.
+    namespace: User-defined namespace of gauge metric.
+    name: Name of gauge metric.
     metric: The GaugeData containing the metrics.
     ptransform: The ptransform id used as a label.
   """
@@ -286,6 +301,24 @@ def int64_gauge(urn, metric, ptransform=None):
   return create_monitoring_info(urn, LATEST_INT64_TYPE, payload, labels)
 
 
+def user_set_string(namespace, name, metric, ptransform=None):
+  """Return the string set monitoring info for the URN, metric and labels.
+
+  Args:
+    namespace: User-defined namespace of StringSet.
+    name: Name of StringSet.
+    metric: The set representing the metrics.
+    ptransform: The ptransform id used as a label.
+  """
+  labels = create_labels(ptransform=ptransform, namespace=namespace, name=name)
+  if isinstance(metric, set):
+    metric = list(metric)
+  if isinstance(metric, list):
+    metric = coders.IterableCoder(coders.StrUtf8Coder()).encode(metric)
+  return create_monitoring_info(
+      USER_STRING_SET_URN, STRING_SET_TYPE, metric, labels)
+
+
 def create_monitoring_info(urn, type_urn, payload, labels=None):
   # type: (...) -> metrics_pb2.MonitoringInfo
 
@@ -322,15 +355,21 @@ def is_distribution(monitoring_info_proto):
   return monitoring_info_proto.type in DISTRIBUTION_TYPES
 
 
+def is_string_set(monitoring_info_proto):
+  """Returns true if the monitoring info is a StringSet metric."""
+  return monitoring_info_proto.type in STRING_SET_TYPES
+
+
 def is_user_monitoring_info(monitoring_info_proto):
   """Returns true if the monitoring info is a user metric."""
   return monitoring_info_proto.urn in USER_METRIC_URNS
 
 
 def extract_metric_result_map_value(monitoring_info_proto):
-  # type: (...) -> Union[None, int, DistributionResult, GaugeResult]
+  # type: (...) -> Union[None, int, DistributionResult, GaugeResult, set]
 
-  """Returns the relevant GaugeResult, DistributionResult or int value.
+  """Returns the relevant GaugeResult, DistributionResult or int value for
+  counter metric, set for StringSet metric.
 
   These are the proper format for use in the MetricResult.query() result.
   """
@@ -344,6 +383,8 @@ def extract_metric_result_map_value(monitoring_info_proto):
   if is_gauge(monitoring_info_proto):
     (timestamp, value) = extract_gauge_value(monitoring_info_proto)
     return GaugeResult(GaugeData(value, timestamp))
+  if is_string_set(monitoring_info_proto):
+    return extract_string_set_value(monitoring_info_proto)
   return None
 
 
diff --git a/sdks/python/apache_beam/metrics/monitoring_infos_test.py 
b/sdks/python/apache_beam/metrics/monitoring_infos_test.py
index d19e8bc10df..022943f417c 100644
--- a/sdks/python/apache_beam/metrics/monitoring_infos_test.py
+++ b/sdks/python/apache_beam/metrics/monitoring_infos_test.py
@@ -21,6 +21,7 @@ import unittest
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.cells import CounterCell
 from apache_beam.metrics.cells import GaugeCell
+from apache_beam.metrics.cells import StringSetCell
 
 
 class MonitoringInfosTest(unittest.TestCase):
@@ -64,6 +65,17 @@ class MonitoringInfosTest(unittest.TestCase):
     self.assertEqual(namespace, "counternamespace")
     self.assertEqual(name, "countername")
 
+  def test_parse_namespace_and_name_for_user_string_set_metric(self):
+    urn = monitoring_infos.USER_STRING_SET_URN
+    labels = {}
+    labels[monitoring_infos.NAMESPACE_LABEL] = "stringsetnamespace"
+    labels[monitoring_infos.NAME_LABEL] = "stringsetname"
+    input = monitoring_infos.create_monitoring_info(
+        urn, "typeurn", None, labels)
+    namespace, name = monitoring_infos.parse_namespace_and_name(input)
+    self.assertEqual(namespace, "stringsetnamespace")
+    self.assertEqual(name, "stringsetname")
+
   def test_int64_user_gauge(self):
     metric = GaugeCell().get_cumulative()
     result = monitoring_infos.int64_user_gauge(
@@ -105,6 +117,19 @@ class MonitoringInfosTest(unittest.TestCase):
     self.assertEqual(0, counter_value)
     self.assertEqual(result.labels, expected_labels)
 
+  def test_user_set_string(self):
+    expected_labels = {}
+    expected_labels[monitoring_infos.NAMESPACE_LABEL] = "stringsetnamespace"
+    expected_labels[monitoring_infos.NAME_LABEL] = "stringsetname"
+
+    metric = StringSetCell().get_cumulative()
+    result = monitoring_infos.user_set_string(
+        'stringsetnamespace', 'stringsetname', metric)
+    string_set_value = monitoring_infos.extract_string_set_value(result)
+
+    self.assertEqual(set(), string_set_value)
+    self.assertEqual(result.labels, expected_labels)
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py 
b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py
index 7e6a11c4abf..78c3b64595b 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_metrics.py
@@ -90,6 +90,10 @@ class DataflowMetrics(MetricResults):
   def _is_distribution(metric_result):
     return isinstance(metric_result.attempted, DistributionResult)
 
+  @staticmethod
+  def _is_string_set(metric_result):
+    return isinstance(metric_result.attempted, set)
+
   def _translate_step_name(self, internal_name):
     """Translate between internal step names (e.g. "s1") and user step 
names."""
     if not self._job_graph:
@@ -233,6 +237,8 @@ class DataflowMetrics(MetricResults):
                 lambda x: x.key == 'sum').value.double_value)
       return DistributionResult(
           DistributionData(dist_sum, dist_count, dist_min, dist_max))
+      #TODO(https://github.com/apache/beam/issues/31788) support StringSet 
after
+      #  re-generate apiclient
     else:
       return None
 
@@ -277,8 +283,13 @@ class DataflowMetrics(MetricResults):
             elm for elm in metric_results if self.matches(filter, elm.key) and
             DataflowMetrics._is_distribution(elm)
         ],
-        self.GAUGES: []
-    }  # TODO(pabloem): Add Gauge support for dataflow.
+        # TODO(pabloem): Add Gauge support for dataflow.
+        self.GAUGES: [],
+        self.STRINGSETS: [
+            elm for elm in metric_results if self.matches(filter, elm.key) and
+            DataflowMetrics._is_string_set(elm)
+        ]
+    }
 
 
 def main(argv):
diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py 
b/sdks/python/apache_beam/runners/direct/direct_metrics.py
index e4fd4405311..f715ce3bf52 100644
--- a/sdks/python/apache_beam/runners/direct/direct_metrics.py
+++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py
@@ -28,6 +28,7 @@ from collections import defaultdict
 from apache_beam.metrics.cells import CounterAggregator
 from apache_beam.metrics.cells import DistributionAggregator
 from apache_beam.metrics.cells import GaugeAggregator
+from apache_beam.metrics.cells import StringSetAggregator
 from apache_beam.metrics.execution import MetricKey
 from apache_beam.metrics.execution import MetricResult
 from apache_beam.metrics.metric import MetricResults
@@ -39,6 +40,7 @@ class DirectMetrics(MetricResults):
     self._distributions = defaultdict(
         lambda: DirectMetric(DistributionAggregator()))
     self._gauges = defaultdict(lambda: DirectMetric(GaugeAggregator()))
+    self._string_sets = defaultdict(lambda: 
DirectMetric(StringSetAggregator()))
 
   def _apply_operation(self, bundle, updates, op):
     for k, v in updates.counters.items():
@@ -50,6 +52,9 @@ class DirectMetrics(MetricResults):
     for k, v in updates.gauges.items():
       op(self._gauges[k], bundle, v)
 
+    for k, v in updates.string_sets.items():
+      op(self._string_sets[k], bundle, v)
+
   def commit_logical(self, bundle, updates):
     op = lambda obj, bundle, update: obj.commit_logical(bundle, update)
     self._apply_operation(bundle, updates, op)
@@ -84,11 +89,19 @@ class DirectMetrics(MetricResults):
             v.extract_latest_attempted()) for k,
         v in self._gauges.items() if self.matches(filter, k)
     ]
+    string_sets = [
+        MetricResult(
+            MetricKey(k.step, k.metric),
+            v.extract_committed(),
+            v.extract_latest_attempted()) for k,
+        v in self._string_sets.items() if self.matches(filter, k)
+    ]
 
     return {
         self.COUNTERS: counters,
         self.DISTRIBUTIONS: distributions,
-        self.GAUGES: gauges
+        self.GAUGES: gauges,
+        self.STRINGSETS: string_sets
     }
 
 
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py 
b/sdks/python/apache_beam/runners/direct/direct_runner_test.py
index 58cec732d3f..d8f1ea097b8 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py
@@ -76,6 +76,8 @@ class DirectPipelineResultTest(unittest.TestCase):
         count.inc()
         distro = Metrics.distribution(self.__class__, 'element_dist')
         distro.update(element)
+        str_set = Metrics.string_set(self.__class__, 'element_str_set')
+        str_set.add(str(element % 4))
         return [element]
 
     p = Pipeline(DirectRunner())
@@ -115,6 +117,13 @@ class DirectPipelineResultTest(unittest.TestCase):
     hc.assert_that(gauge_result.committed.value, hc.equal_to(5))
     hc.assert_that(gauge_result.attempted.value, hc.equal_to(5))
 
+    str_set_result = metrics['string_sets'][0]
+    hc.assert_that(
+        str_set_result.key,
+        hc.equal_to(MetricKey('Do', MetricName(namespace, 'element_str_set'))))
+    hc.assert_that(len(str_set_result.committed), hc.equal_to(4))
+    hc.assert_that(len(str_set_result.attempted), hc.equal_to(4))
+
   def test_create_runner(self):
     self.assertTrue(isinstance(create_runner('DirectRunner'), DirectRunner))
     self.assertTrue(
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 8b313d624a5..1ed21942d28 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -1535,15 +1535,17 @@ class FnApiMetrics(metric.MetricResults):
     self._counters = {}
     self._distributions = {}
     self._gauges = {}
+    self._string_sets = {}
     self._user_metrics_only = user_metrics_only
     self._monitoring_infos = step_monitoring_infos
 
     for smi in step_monitoring_infos.values():
-      counters, distributions, gauges = \
+      counters, distributions, gauges, string_sets = \
           portable_metrics.from_monitoring_infos(smi, user_metrics_only)
       self._counters.update(counters)
       self._distributions.update(distributions)
       self._gauges.update(gauges)
+      self._string_sets.update(string_sets)
 
   def query(self, filter=None):
     counters = [
@@ -1558,11 +1560,16 @@ class FnApiMetrics(metric.MetricResults):
         MetricResult(k, v, v) for k,
         v in self._gauges.items() if self.matches(filter, k)
     ]
+    string_sets = [
+        MetricResult(k, v, v) for k,
+        v in self._string_sets.items() if self.matches(filter, k)
+    ]
 
     return {
         self.COUNTERS: counters,
         self.DISTRIBUTIONS: distributions,
-        self.GAUGES: gauges
+        self.GAUGES: gauges,
+        self.STRINGSETS: string_sets
     }
 
   def monitoring_infos(self) -> List[metrics_pb2.MonitoringInfo]:
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
index 97b10b83e05..4a737feaf28 100644
--- 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
+++ 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
@@ -1212,13 +1212,16 @@ class FnApiRunnerTest(unittest.TestCase):
     counter = beam.metrics.Metrics.counter('ns', 'counter')
     distribution = beam.metrics.Metrics.distribution('ns', 'distribution')
     gauge = beam.metrics.Metrics.gauge('ns', 'gauge')
+    string_set = beam.metrics.Metrics.string_set('ns', 'string_set')
 
-    pcoll = p | beam.Create(['a', 'zzz'])
+    elements = ['a', 'zzz']
+    pcoll = p | beam.Create(elements)
     # pylint: disable=expression-not-assigned
     pcoll | 'count1' >> beam.FlatMap(lambda x: counter.inc())
     pcoll | 'count2' >> beam.FlatMap(lambda x: counter.inc(len(x)))
     pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x)))
     pcoll | 'gauge' >> beam.FlatMap(lambda x: gauge.set(3))
+    pcoll | 'string_set' >> beam.FlatMap(lambda x: string_set.add(x))
 
     res = p.run()
     res.wait_until_finish()
@@ -1238,6 +1241,10 @@ class FnApiRunnerTest(unittest.TestCase):
                                   .with_name('gauge'))['gauges']
       self.assertEqual(gaug.committed.value, 3)
 
+    str_set, = res.metrics().query(beam.metrics.MetricsFilter()
+                                  .with_name('string_set'))['string_sets']
+    self.assertEqual(str_set.committed, set(elements))
+
   def test_callbacks_with_exception(self):
     elements_list = ['1', '2']
 
diff --git a/sdks/python/apache_beam/runners/portability/portable_metrics.py 
b/sdks/python/apache_beam/runners/portability/portable_metrics.py
index d7d330dd7e7..5bc3e053918 100644
--- a/sdks/python/apache_beam/runners/portability/portable_metrics.py
+++ b/sdks/python/apache_beam/runners/portability/portable_metrics.py
@@ -27,18 +27,21 @@ _LOGGER = logging.getLogger(__name__)
 
 
 def from_monitoring_infos(monitoring_info_list, user_metrics_only=False):
-  """Groups MonitoringInfo objects into counters, distributions and gauges.
+  """Groups MonitoringInfo objects into counters, distributions, gauges and
+  string sets
 
   Args:
     monitoring_info_list: An iterable of MonitoringInfo objects.
     user_metrics_only: If true, includes user metrics only.
   Returns:
-    A tuple containing three dictionaries: counters, distributions and gauges,
-    respectively. Each dictionary contains (MetricKey, metric result) pairs.
+    A tuple containing three dictionaries: counters, distributions, gauges and
+    string set, respectively. Each dictionary contains (MetricKey, metric
+    result) pairs.
   """
   counters = {}
   distributions = {}
   gauges = {}
+  string_sets = {}
 
   for mi in monitoring_info_list:
     if (user_metrics_only and not 
monitoring_infos.is_user_monitoring_info(mi)):
@@ -57,8 +60,10 @@ def from_monitoring_infos(monitoring_info_list, 
user_metrics_only=False):
       distributions[key] = metric_result
     elif monitoring_infos.is_gauge(mi):
       gauges[key] = metric_result
+    elif monitoring_infos.is_string_set(mi):
+      string_sets[key] = metric_result
 
-  return counters, distributions, gauges
+  return counters, distributions, gauges, string_sets
 
 
 def _create_metric_key(monitoring_info):
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py 
b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 92f123697a9..ba48bbec6d3 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -437,7 +437,7 @@ class PortableMetrics(metric.MetricResults):
     ]
 
   def query(self, filter=None):
-    counters, distributions, gauges = [
+    counters, distributions, gauges, stringsets = [
         self._combine(x, y, filter)
         for x, y in zip(self.committed, self.attempted)
     ]
@@ -445,7 +445,8 @@ class PortableMetrics(metric.MetricResults):
     return {
         self.COUNTERS: counters,
         self.DISTRIBUTIONS: distributions,
-        self.GAUGES: gauges
+        self.GAUGES: gauges,
+        self.STRINGSETS: stringsets
     }
 
 

Reply via email to