This is an automated email from the ASF dual-hosted git repository.

bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 814a5ded8c4 Add record_metrics argument to utils.BatchElements (#23701)
814a5ded8c4 is described below

commit 814a5ded8c493d55edeaf350c808c131289165e8
Author: fab-jul <[email protected]>
AuthorDate: Mon Nov 21 21:32:05 2022 +0100

    Add record_metrics argument to utils.BatchElements (#23701)
    
    * Add record_metrics argument to BatchTransform
    
    * Also add test
    
    * Apply suggestions from code review
    
    Co-authored-by: Brian Hulette <[email protected]>
    
    * Add variable.
    
    * Maybe make formatter happy
    
    * Update test to check for metrics.
    
    * Make linter happy
    
    * Make linter happy.
    
    * Update sdks/python/apache_beam/transforms/util.py
    
    * Update util_test.py
    
    Co-authored-by: Brian Hulette <[email protected]>
---
 sdks/python/apache_beam/transforms/util.py      | 28 +++++++++++++-------
 sdks/python/apache_beam/transforms/util_test.py | 35 ++++++++++++++++++++-----
 2 files changed, 47 insertions(+), 16 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index dca5628118d..d91ce112471 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -317,7 +317,8 @@ class _BatchSizeEstimator(object):
       target_batch_duration_secs=1,
       variance=0.25,
       clock=time.time,
-      ignore_first_n_seen_per_batch_size=0):
+      ignore_first_n_seen_per_batch_size=0,
+      record_metrics=True):
     if min_batch_size > max_batch_size:
       raise ValueError(
           "Minimum (%s) must not be greater than maximum (%s)" %
@@ -350,11 +351,15 @@ class _BatchSizeEstimator(object):
         ignore_first_n_seen_per_batch_size)
     self._batch_size_num_seen = {}
     self._replay_last_batch_size = None
+    self._record_metrics = record_metrics
 
-    self._size_distribution = Metrics.distribution(
-        'BatchElements', 'batch_size')
-    self._time_distribution = Metrics.distribution(
-        'BatchElements', 'msec_per_batch')
+    if record_metrics:
+      self._size_distribution = Metrics.distribution(
+          'BatchElements', 'batch_size')
+      self._time_distribution = Metrics.distribution(
+          'BatchElements', 'msec_per_batch')
+    else:
+      self._size_distribution = self._time_distribution = None
     # Beam distributions only accept integer values, so we use this to
     # accumulate under-reported values until they add up to whole milliseconds.
     # (Milliseconds are chosen because that's conventionally used elsewhere in
@@ -375,8 +380,9 @@ class _BatchSizeEstimator(object):
     yield
     elapsed = self._clock() - start
     elapsed_msec = 1e3 * elapsed + self._remainder_msecs
-    self._size_distribution.update(batch_size)
-    self._time_distribution.update(int(elapsed_msec))
+    if self._record_metrics:
+      self._size_distribution.update(batch_size)
+      self._time_distribution.update(int(elapsed_msec))
     self._remainder_msecs = elapsed_msec - int(elapsed_msec)
     # If we ignore the next timing, replay the batch size to get accurate
     # timing.
@@ -642,6 +648,8 @@ class BatchElements(PTransform):
         linear interpolation
     clock: (optional) an alternative to time.time for measuring the cost of
         donwstream operations (mostly for testing)
+    record_metrics: (optional) whether or not to record beam metrics on
+        distributions of the batch size. Defaults to True.
   """
   def __init__(
       self,
@@ -652,14 +660,16 @@ class BatchElements(PTransform):
       *,
       element_size_fn=lambda x: 1,
       variance=0.25,
-      clock=time.time):
+      clock=time.time,
+      record_metrics=True):
     self._batch_size_estimator = _BatchSizeEstimator(
         min_batch_size=min_batch_size,
         max_batch_size=max_batch_size,
         target_batch_overhead=target_batch_overhead,
         target_batch_duration_secs=target_batch_duration_secs,
         variance=variance,
-        clock=clock)
+        clock=clock,
+        record_metrics=record_metrics)
     self._element_size_fn = element_size_fn
 
   def expand(self, pcoll):
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index 180b7ae5f8e..fb799cfb7a5 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -36,6 +36,7 @@ from apache_beam import GroupByKey
 from apache_beam import Map
 from apache_beam import WindowInto
 from apache_beam.coders import coders
+from apache_beam.metrics import MetricsFilter
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.portability import common_urns
@@ -189,13 +190,33 @@ class FakeClock(object):
 class BatchElementsTest(unittest.TestCase):
   def test_constant_batch(self):
     # Assumes a single bundle...
-    with TestPipeline() as p:
-      res = (
-          p
-          | beam.Create(range(35))
-          | util.BatchElements(min_batch_size=10, max_batch_size=10)
-          | beam.Map(len))
-      assert_that(res, equal_to([10, 10, 10, 5]))
+    p = TestPipeline()
+    output = (
+        p
+        | beam.Create(range(35))
+        | util.BatchElements(min_batch_size=10, max_batch_size=10)
+        | beam.Map(len))
+    assert_that(output, equal_to([10, 10, 10, 5]))
+    res = p.run()
+    res.wait_until_finish()
+    metrics = res.metrics()
+    results = metrics.query(MetricsFilter().with_name("batch_size"))
+    self.assertEqual(len(results["distributions"]), 1)
+
+  def test_constant_batch_no_metrics(self):
+    p = TestPipeline()
+    output = (
+        p
+        | beam.Create(range(35))
+        | util.BatchElements(
+            min_batch_size=10, max_batch_size=10, record_metrics=False)
+        | beam.Map(len))
+    assert_that(output, equal_to([10, 10, 10, 5]))
+    res = p.run()
+    res.wait_until_finish()
+    metrics = res.metrics()
+    results = metrics.query(MetricsFilter().with_name("batch_size"))
+    self.assertEqual(len(results["distributions"]), 0)
 
   def test_grows_to_max_batch(self):
     # Assumes a single bundle...

Reply via email to