This is an automated email from the ASF dual-hosted git repository.

udim pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 23fe94b3af2 Implement Exception Sampling in the Python SDK (#27280)
23fe94b3af2 is described below

commit 23fe94b3af28d1a7d55f994dbddc9669a6567cb1
Author: Sam Rohde <[email protected]>
AuthorDate: Wed Jun 28 16:48:13 2023 -0700

    Implement Exception Sampling in the Python SDK (#27280)
    
    * Python exception sampling implementation
    
    * add to cython def
    
    * add more cython defs
    
    * address comments
    
    * fix circular imports
    
    * linter
    
    * fix tests
    
    * remove print
    
    * Add traceback to exception
    
    * fix tests
    
    ---------
    
    Co-authored-by: Sam Rohde <[email protected]>
---
 sdks/python/apache_beam/runners/common.pxd         |   2 +
 sdks/python/apache_beam/runners/common.py          |  35 ++-
 .../apache_beam/runners/worker/bundle_processor.py |  20 +-
 .../runners/worker/bundle_processor_test.py        | 142 +++++++--
 .../apache_beam/runners/worker/data_sampler.py     | 139 +++++++--
 .../runners/worker/data_sampler_test.py            | 330 +++++++++++++--------
 .../apache_beam/runners/worker/operations.pxd      |   4 +
 .../apache_beam/runners/worker/operations.py       |  48 +--
 .../apache_beam/runners/worker/sdk_worker.py       |  10 +-
 .../apache_beam/runners/worker/sdk_worker_test.py  |  17 +-
 10 files changed, 535 insertions(+), 212 deletions(-)

diff --git a/sdks/python/apache_beam/runners/common.pxd 
b/sdks/python/apache_beam/runners/common.pxd
index d1745970d26..9fb44af6377 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -121,6 +121,8 @@ cdef class DoFnRunner:
   cdef list side_inputs
   cdef DoFnInvoker do_fn_invoker
   cdef public object bundle_finalizer_param
+  cdef str transform_id
+  cdef object execution_context
   cpdef process(self, WindowedValue windowed_value)
 
 
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index 75eb85f2110..e91199787e5 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -66,6 +66,7 @@ from apache_beam.utils.windowed_value import WindowedBatch
 from apache_beam.utils.windowed_value import WindowedValue
 
 if TYPE_CHECKING:
+  from apache_beam.runners.worker.bundle_processor import ExecutionContext
   from apache_beam.transforms import sideinputs
   from apache_beam.transforms.core import TimerSpec
   from apache_beam.io.iobase import RestrictionProgress
@@ -1338,7 +1339,8 @@ class DoFnRunner:
                state=None,
                scoped_metrics_container=None,
                operation_name=None,
-               user_state_context=None  # type: 
Optional[userstate.UserStateContext]
+               transform_id=None,
+               user_state_context=None,  # type: 
Optional[userstate.UserStateContext]
               ):
     """Initializes a DoFnRunner.
 
@@ -1354,6 +1356,7 @@ class DoFnRunner:
       state: handle for accessing DoFn state
       scoped_metrics_container: DEPRECATED
       operation_name: The system name assigned by the runner for this 
operation.
+      transform_id: The PTransform Id in the pipeline proto for this DoFn.
       user_state_context: The UserStateContext instance for the current
                           Stateful DoFn.
     """
@@ -1361,8 +1364,10 @@ class DoFnRunner:
     side_inputs = list(side_inputs)
 
     self.step_name = step_name
+    self.transform_id = transform_id
     self.context = DoFnContext(step_name, state=state)
     self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
+    self.execution_context = None  # type: Optional[ExecutionContext]
 
     do_fn_signature = DoFnSignature(fn)
 
@@ -1417,9 +1422,25 @@ class DoFnRunner:
     try:
       return self.do_fn_invoker.invoke_process(windowed_value)
     except BaseException as exn:
-      self._reraise_augmented(exn)
+      self._reraise_augmented(exn, windowed_value)
       return []
 
+  def _maybe_sample_exception(
+      self, exn: BaseException, windowed_value: WindowedValue) -> None:
+
+    if self.execution_context is None:
+      return
+
+    output_sampler = self.execution_context.output_sampler
+    if output_sampler is None:
+      return
+
+    output_sampler.sample_exception(
+        windowed_value,
+        exn,
+        self.transform_id,
+        self.execution_context.instruction_id)
+
   def process_batch(self, windowed_batch):
     # type: (WindowedBatch) -> None
     try:
@@ -1487,7 +1508,7 @@ class DoFnRunner:
     # type: () -> None
     self.bundle_finalizer_param.finalize_bundle()
 
-  def _reraise_augmented(self, exn):
+  def _reraise_augmented(self, exn, windowed_value=None):
     if getattr(exn, '_tagged_with_step', False) or not self.step_name:
       raise exn
     step_annotation = " [while running '%s']" % self.step_name
@@ -1504,8 +1525,12 @@ class DoFnRunner:
           traceback.format_exception_only(type(exn), exn)[-1].strip() +
           step_annotation)
       new_exn._tagged_with_step = True
-    _, _, tb = sys.exc_info()
-    raise new_exn.with_traceback(tb)
+    exc_info = sys.exc_info()
+    _, _, tb = exc_info
+
+    new_exn = new_exn.with_traceback(tb)
+    self._maybe_sample_exception(exc_info, windowed_value)
+    raise new_exn
 
 
 class OutputHandler(object):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 3adb26552d0..b3ca92aa9a1 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -27,6 +27,8 @@ import json
 import logging
 import random
 import threading
+from dataclasses import dataclass
+from dataclasses import field
 from typing import TYPE_CHECKING
 from typing import Any
 from typing import Callable
@@ -986,7 +988,7 @@ class BundleProcessor(object):
         expected_input_ops.append(op)
 
     try:
-      execution_context = ExecutionContext()
+      execution_context = ExecutionContext(instruction_id=instruction_id)
       self.current_instruction_id = instruction_id
       self.state_sampler.start()
       # Start all operations.
@@ -1181,10 +1183,18 @@ class BundleProcessor(object):
       op.teardown()
 
 
-class ExecutionContext(object):
-  def __init__(self):
-    self.delayed_applications = [
-    ]  # type: List[Tuple[operations.DoOperation, common.SplitResultResidual]]
+@dataclass
+class ExecutionContext:
+  # Any splits to be processed later.
+  delayed_applications: List[Tuple[operations.DoOperation,
+                                   common.SplitResultResidual]] = field(
+                                       default_factory=list)
+
+  # The exception sampler for the currently executing PTransform.
+  output_sampler: Optional[data_sampler.OutputSampler] = None
+
+  # The current instruction being executed.
+  instruction_id: Optional[str] = None
 
 
 class BeamTransformFactory(object):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py 
b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
index db9e35a0baf..8b81c9f17ac 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
@@ -18,15 +18,16 @@
 """Unit tests for bundle processing."""
 # pytype: skip-file
 
-import time
 import unittest
 from typing import Dict
 from typing import List
 
+import apache_beam as beam
 from apache_beam.coders.coders import FastPrimitivesCoder
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.runners import common
+from apache_beam.runners.worker import bundle_processor
 from apache_beam.runners.worker import operations
 from apache_beam.runners.worker.bundle_processor import BeamTransformFactory
 from apache_beam.runners.worker.bundle_processor import BundleProcessor
@@ -240,6 +241,25 @@ def create_test_op(factory, transform_id, transform_proto, 
payload, consumers):
       payload)
 
 
[email protected]_urn('beam:internal:testexn:v1', bytes)
+def create_exception_dofn(
+    factory, transform_id, transform_proto, payload, consumers):
+  """Returns a test DoFn that raises the given exception."""
+  class RaiseException(beam.DoFn):
+    def __init__(self, msg):
+      self.msg = msg.decode()
+
+    def process(self, _):
+      raise RuntimeError(self.msg)
+
+  return bundle_processor._create_simple_pardo_operation(
+      factory,
+      transform_id,
+      transform_proto,
+      consumers,
+      RaiseException(payload))
+
+
 class DataSamplingTest(unittest.TestCase):
   def test_disabled_by_default(self):
     """Test that not providing the sampler does not enable Data Sampling.
@@ -252,32 +272,12 @@ class DataSamplingTest(unittest.TestCase):
     _ = BundleProcessor(descriptor, None, None)
     self.assertEqual(len(descriptor.transforms), 0)
 
-  def wait_for_samples(self, data_sampler: DataSampler,
-                       pcollection_id: str) -> Dict[str, List[bytes]]:
-    """Waits for samples from the given PCollection to exist."""
-    now = time.time()
-    end = now + 30
-
-    samples = {}
-    while now < end:
-      time.sleep(0.1)
-      now = time.time()
-      samples.update(data_sampler.samples([pcollection_id]))
-
-      if samples:
-        return samples
-
-    self.assertLess(
-        now, end, 'Timed out waiting for samples for 
{}'.format(pcollection_id))
-    return {}
-
   def test_can_sample(self):
     """Test that elements are sampled.
 
     This is a small integration test with the BundleProcessor and the
-    DataSampler. It ensures that the BundleProcessor correctly makes
-    DataSamplingOperations and samples are taken from in-flight elements. These
-    elements are then finally queried.
+    DataSampler. It ensures that samples are taken from in-flight elements.
+    These elements are then finally queried.
     """
     data_sampler = DataSampler(sample_every_sec=0.1)
     descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
@@ -306,8 +306,100 @@ class DataSamplingTest(unittest.TestCase):
           descriptor, None, None, data_sampler=data_sampler)
       processor.process_bundle('instruction_id')
 
-      samples = self.wait_for_samples(data_sampler, PCOLLECTION_ID)
-      self.assertEqual(samples, {PCOLLECTION_ID: [b'\rhello, world!']})
+      samples = data_sampler.wait_for_samples([PCOLLECTION_ID])
+      expected = beam_fn_api_pb2.SampleDataResponse(
+          element_samples={
+              PCOLLECTION_ID: beam_fn_api_pb2.SampleDataResponse.ElementList(
+                  elements=[
+                      beam_fn_api_pb2.SampledElement(
+                          element=b'\rhello, world!')
+                  ])
+          })
+      self.assertEqual(samples, expected)
+    finally:
+      data_sampler.stop()
+
+  def test_can_sample_exceptions(self):
+    """Test that exceptions are sampled."""
+    data_sampler = DataSampler(sample_every_sec=0.1)
+    descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
+
+    # Boiler plate for the DoFn.
+    WINDOWING_ID = 'window'
+    WINDOW_CODER_ID = 'cw'
+    window = descriptor.windowing_strategies[WINDOWING_ID]
+    window.window_fn.urn = common_urns.global_windows.urn
+    window.window_coder_id = WINDOW_CODER_ID
+    window.trigger.default.SetInParent()
+    window_coder = descriptor.coders[WINDOW_CODER_ID]
+    window_coder.spec.urn = common_urns.StandardCoders.Enum.GLOBAL_WINDOW.urn
+
+    # Input collection to the exception raising DoFn.
+    INPUT_PCOLLECTION_ID = 'pc-in'
+    INPUT_CODER_ID = 'c-in'
+    descriptor.pcollections[
+        INPUT_PCOLLECTION_ID].unique_name = INPUT_PCOLLECTION_ID
+    descriptor.pcollections[INPUT_PCOLLECTION_ID].coder_id = INPUT_CODER_ID
+    descriptor.pcollections[
+        INPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
+    descriptor.coders[
+        INPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
+
+    # Output collection to the exception raising DoFn. Because the transform
+    # "failed" to process the input element, we do NOT expect to see a sample 
in
+    # this PCollection.
+    OUTPUT_PCOLLECTION_ID = 'pc-out'
+    OUTPUT_CODER_ID = 'c-out'
+    descriptor.pcollections[
+        OUTPUT_PCOLLECTION_ID].unique_name = OUTPUT_PCOLLECTION_ID
+    descriptor.pcollections[OUTPUT_PCOLLECTION_ID].coder_id = OUTPUT_CODER_ID
+    descriptor.pcollections[
+        OUTPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
+    descriptor.coders[
+        OUTPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
+
+    # Add a simple transform to inject an element into the data sampler. This
+    # doesn't use the FnApi, so this uses a simple operation to forward its
+    # payload to consumers.
+    TEST_OP_TRANSFORM_ID = 'test_op'
+    test_transform = descriptor.transforms[TEST_OP_TRANSFORM_ID]
+    test_transform.outputs['None'] = INPUT_PCOLLECTION_ID
+    test_transform.spec.urn = 'beam:internal:testop:v1'
+    test_transform.spec.payload = b'hello, world!'
+
+    # Add the DoFn to create an exception to sample from.
+    TEST_EXCEPTION_TRANSFORM_ID = 'test_transform'
+    test_transform = descriptor.transforms[TEST_EXCEPTION_TRANSFORM_ID]
+    test_transform.inputs['0'] = INPUT_PCOLLECTION_ID
+    test_transform.outputs['None'] = OUTPUT_PCOLLECTION_ID
+    test_transform.spec.urn = 'beam:internal:testexn:v1'
+    test_transform.spec.payload = b'expected exception'
+
+    try:
+      # Create and process a fake bundle. The instruction id doesn't matter
+      # here.
+      processor = BundleProcessor(
+          descriptor, None, None, data_sampler=data_sampler)
+
+      with self.assertRaisesRegex(RuntimeError, 'expected exception'):
+        processor.process_bundle('instruction_id')
+
+      # NOTE: The expected sample comes from the input PCollection. This is 
very
+      # important because there can be coder issues if the sample is put in the
+      # wrong PCollection.
+      samples = data_sampler.wait_for_samples([INPUT_PCOLLECTION_ID])
+      self.assertEqual(len(samples.element_samples), 1)
+
+      element = samples.element_samples[INPUT_PCOLLECTION_ID].elements[0]
+      self.assertEqual(element.element, b'\rhello, world!')
+      self.assertTrue(element.HasField('exception'))
+
+      exception = element.exception
+      self.assertEqual(exception.instruction_id, 'instruction_id')
+      self.assertEqual(exception.transform_id, TEST_EXCEPTION_TRANSFORM_ID)
+      self.assertRegex(
+          exception.error, 'Traceback(\n|.)*RuntimeError: expected exception')
+
     finally:
       data_sampler.stop()
 
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler.py 
b/sdks/python/apache_beam/runners/worker/data_sampler.py
index 9c74188a699..a5992b9ceba 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler.py
@@ -24,14 +24,17 @@ from __future__ import annotations
 import collections
 import logging
 import threading
+import time
+import traceback
+from dataclasses import dataclass
 from threading import Timer
 from typing import Any
-from typing import DefaultDict
 from typing import Deque
 from typing import Dict
 from typing import Iterable
 from typing import List
 from typing import Optional
+from typing import Tuple
 from typing import Union
 
 from apache_beam.coders.coder_impl import CoderImpl
@@ -50,19 +53,32 @@ class SampleTimer:
     self._timer = Timer(self._timeout_secs, self.sample)
     self._sampler = sampler
 
-  def reset(self):
+  def reset(self) -> None:
     self._timer.cancel()
     self._timer = Timer(self._timeout_secs, self.sample)
     self._timer.start()
 
-  def stop(self):
+  def stop(self) -> None:
     self._timer.cancel()
 
-  def sample(self):
+  def sample(self) -> None:
     self._sampler.sample()
     self.reset()
 
 
+@dataclass
+class ExceptionMetadata:
+  # The repr-ified Exception.
+  msg: str
+
+  # The transform where the exception occured.
+  transform_id: str
+
+  # The instruction when the exception occured.
+  instruction_id: str
+
+
+@dataclass
 class ElementSampler:
   """Record class to hold sampled elements.
 
@@ -71,12 +87,12 @@ class ElementSampler:
   """
 
   # Is true iff the `el` has been set with a sample.
-  has_element: bool
+  has_element: bool = False
 
   # The sampled element. Note that `None` is a valid element and cannot be uesd
   # as a sentintel to check if there is a sample. Use the `has_element` flag to
   # check for this case.
-  el: Any
+  el: Any = None
 
 
 class OutputSampler:
@@ -96,6 +112,8 @@ class OutputSampler:
     self._sample_timer = SampleTimer(sample_every_sec, self)
     self.element_sampler = ElementSampler()
     self.element_sampler.has_element = False
+    self._exceptions: Deque[Tuple[Any, ExceptionMetadata]] = collections.deque(
+        maxlen=max_samples)
 
     # For testing, it's easier to disable the Timer and manually sample.
     if sample_every_sec > 0:
@@ -115,26 +133,64 @@ class OutputSampler:
       el = el.value
     return el
 
-  def flush(self, clear: bool = True) -> List[bytes]:
+  def flush(self, clear: bool = True) -> List[beam_fn_api_pb2.SampledElement]:
     """Returns all samples and optionally clears buffer if clear is True."""
     with self._samples_lock:
+      # TODO(rohdesamuel): There can duplicates between the exceptions and
+      # samples. This happens when the OutputSampler samples during an
+      # exception. The fix is to create a OutputSampler per process bundle.
+      # Until then use a set to keep track of the elements.
+      seen = set(id(el) for el, _ in self._exceptions)
       if isinstance(self._coder_impl, WindowedValueCoderImpl):
-        samples = [s for s in self._samples]
+        exceptions = [s for s in self._exceptions]
+        samples = [s for s in self._samples if id(s) not in seen]
       else:
-        samples = [self.remove_windowed_value(s) for s in self._samples]
+        exceptions = [
+            (self.remove_windowed_value(a), b) for a, b in self._exceptions
+        ]
+        samples = [
+            self.remove_windowed_value(s) for s in self._samples
+            if id(s) not in seen
+        ]
 
       # Encode in the nested context b/c this ensures that the SDK can decode
       # the bytes with the ToStringFn.
       if clear:
         self._samples.clear()
-      return [self._coder_impl.encode_nested(s) for s in samples]
+        self._exceptions.clear()
+
+      ret = [
+          beam_fn_api_pb2.SampledElement(
+              element=self._coder_impl.encode_nested(s),
+          ) for s in samples
+      ]
+
+      ret.extend(
+          beam_fn_api_pb2.SampledElement(
+              element=self._coder_impl.encode_nested(s),
+              exception=beam_fn_api_pb2.SampledElement.Exception(
+                  instruction_id=exn.instruction_id,
+                  transform_id=exn.transform_id,
+                  error=exn.msg)) for s,
+          exn in exceptions)
+
+      return ret
 
   def sample(self) -> None:
     """Samples the given element to an internal buffer."""
     with self._samples_lock:
       if self.element_sampler.has_element:
-        self.element_sampler.has_element = False
         self._samples.append(self.element_sampler.el)
+        self.element_sampler.has_element = False
+
+  def sample_exception(
+      self, el: Any, exc_info: Any, transform_id: str,
+      instruction_id: str) -> None:
+    """Adds the given exception to the samples."""
+    with self._samples_lock:
+      err_string = ''.join(traceback.format_exception(*exc_info))
+      self._exceptions.append(
+          (el, ExceptionMetadata(err_string, transform_id, instruction_id)))
 
 
 class DataSampler:
@@ -160,7 +216,7 @@ class DataSampler:
     self._samplers_lock: threading.Lock = threading.Lock()
     self._max_samples = max_samples
     self._sample_every_sec = sample_every_sec
-    self._element_samplers: Dict[str, List[ElementSampler]] = {}
+    self._samplers_by_output: Dict[str, List[OutputSampler]] = {}
     self._clock = clock
 
   def stop(self) -> None:
@@ -171,24 +227,26 @@ class DataSampler:
       for sampler in self._samplers.values():
         sampler.stop()
 
-  def sampler_for_output(
-      self, transform_id: str, output_index: int) -> ElementSampler:
-    """Returns the ElementSampler for the given output."""
+  def sampler_for_output(self, transform_id: str,
+                         output_index: int) -> Optional[OutputSampler]:
+    """Returns the OutputSampler for the given output."""
     try:
-      return self._element_samplers[transform_id][output_index]
+      with self._samplers_lock:
+        outputs = self._samplers_by_output[transform_id]
+        return outputs[output_index]
     except KeyError:
       _LOGGER.warning(
           f'Out-of-bounds access for transform "{transform_id}" ' +
-          'and output "{output_index}" ElementSampler. This may ' +
+          'and output "{output_index}" OutputSampler. This may ' +
           'indicate that the transform was improperly ' +
           'initialized with the DataSampler.')
-      return ElementSampler()
+      return None
 
   def initialize_samplers(
       self,
       transform_id: str,
       descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
-      coder_factory) -> List[ElementSampler]:
+      coder_factory) -> List[OutputSampler]:
     """Creates the OutputSamplers for the given PTransform.
 
     This initializes the samplers only once per PCollection Id. Note that an
@@ -198,6 +256,9 @@ class DataSampler:
     """
     transform_proto = descriptor.transforms[transform_id]
     with self._samplers_lock:
+      if transform_id in self._samplers_by_output:
+        return self._samplers_by_output[transform_id]
+
       # Initialize the samplers.
       for pcoll_id in transform_proto.outputs.values():
         # Only initialize new PCollections.
@@ -215,28 +276,22 @@ class DataSampler:
       # Operations look up the ElementSampler for an output based on the index
       # of the tag in the PTransform's outputs. The following code intializes
       # the array with ElementSamplers in the correct indices.
-      if transform_id in self._element_samplers:
-        return self._element_samplers[transform_id]
-
       outputs = transform_proto.outputs
-      samplers = [
-          self._samplers[pcoll_id].element_sampler
-          for pcoll_id in outputs.values()
-      ]
-      self._element_samplers[transform_id] = samplers
+      samplers = [self._samplers[pcoll_id] for pcoll_id in outputs.values()]
+      self._samplers_by_output[transform_id] = samplers
 
       return samplers
 
   def samples(
       self,
       pcollection_ids: Optional[Iterable[str]] = None
-  ) -> Dict[str, List[bytes]]:
+  ) -> beam_fn_api_pb2.SampleDataResponse:
     """Returns samples filtered PCollection ids.
 
     All samples from the given PCollections are returned. Empty lists are
     wildcards.
     """
-    ret: DefaultDict[str, List[bytes]] = collections.defaultdict(lambda: [])
+    ret = beam_fn_api_pb2.SampleDataResponse()
 
     with self._samplers_lock:
       samplers = self._samplers.copy()
@@ -247,6 +302,28 @@ class DataSampler:
 
       samples = samplers[pcoll_id].flush()
       if samples:
-        ret[pcoll_id].extend(samples)
+        ret.element_samples[pcoll_id].elements.extend(samples)
+
+    return ret
+
+  def wait_for_samples(
+      self, pcollection_ids: List[str]) -> beam_fn_api_pb2.SampleDataResponse:
+    """Waits for samples to exist for the given PCollections (only testing)."""
+    now = time.time()
+    end = now + 30
+
+    samples = beam_fn_api_pb2.SampleDataResponse()
+    while now < end:
+      time.sleep(0.1)
+      now = time.time()
+      samples.MergeFrom(self.samples(pcollection_ids))
+
+      if not samples:
+        continue
+
+      has_all = all(
+          pcoll_id in samples.element_samples for pcoll_id in pcollection_ids)
+      if has_all:
+        break
 
-    return dict(ret)
+    return samples
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler_test.py 
b/sdks/python/apache_beam/runners/worker/data_sampler_test.py
index 346251ee216..b6793612183 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler_test.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler_test.py
@@ -17,7 +17,9 @@
 
 # pytype: skip-file
 
+import sys
 import time
+import traceback
 import unittest
 from typing import Any
 from typing import Dict
@@ -37,14 +39,6 @@ MAIN_PCOLLECTION_ID = 'pcoll'
 PRIMITIVES_CODER = FastPrimitivesCoder()
 
 
-class FakeClock:
-  def __init__(self):
-    self.clock = 0
-
-  def time(self):
-    return self.clock
-
-
 class DataSamplerTest(unittest.TestCase):
   def make_test_descriptor(
       self,
@@ -68,32 +62,6 @@ class DataSamplerTest(unittest.TestCase):
   def tearDown(self):
     self.data_sampler.stop()
 
-  def wait_for_samples(
-      self, data_sampler: DataSampler,
-      pcollection_ids: List[str]) -> Dict[str, List[bytes]]:
-    """Waits for samples to exist for the given PCollections."""
-    now = time.time()
-    end = now + 30
-
-    samples = {}
-    while now < end:
-      time.sleep(0.1)
-      now = time.time()
-      samples.update(data_sampler.samples(pcollection_ids))
-
-      if not samples:
-        continue
-
-      has_all = all(pcoll_id in samples for pcoll_id in pcollection_ids)
-      if has_all:
-        return samples
-
-    self.assertLess(
-        now,
-        end,
-        'Timed out waiting for samples for {}'.format(pcollection_ids))
-    return {}
-
   def primitives_coder_factory(self, _):
     return PRIMITIVES_CODER
 
@@ -105,7 +73,7 @@ class DataSamplerTest(unittest.TestCase):
       transform_id: str = MAIN_TRANSFORM_ID):
     """Generates a sample for the given transform's output."""
     element_sampler = self.data_sampler.sampler_for_output(
-        transform_id, output_index)
+        transform_id, output_index).element_sampler
     element_sampler.el = element
     element_sampler.has_element = True
 
@@ -117,10 +85,15 @@ class DataSamplerTest(unittest.TestCase):
 
     self.gen_sample(self.data_sampler, 'a', output_index=0)
 
-    expected_sample = {
-        MAIN_PCOLLECTION_ID: [PRIMITIVES_CODER.encode_nested('a')]
-    }
-    samples = self.wait_for_samples(self.data_sampler, [MAIN_PCOLLECTION_ID])
+    expected_sample = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            MAIN_PCOLLECTION_ID: 
beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('a'))
+                ])
+        })
+    samples = self.data_sampler.wait_for_samples([MAIN_PCOLLECTION_ID])
     self.assertEqual(samples, expected_sample)
 
   def test_not_initialized(self):
@@ -154,18 +127,21 @@ class DataSamplerTest(unittest.TestCase):
     # samplers.
     index = outputs['o0']
     self.assertEqual(
-        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
-        samplers[index])
+        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
+                                             index).element_sampler,
+        samplers[index].element_sampler)
 
     index = outputs['o1']
     self.assertEqual(
-        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
-        samplers[index])
+        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
+                                             index).element_sampler,
+        samplers[index].element_sampler)
 
     index = outputs['o2']
     self.assertEqual(
-        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
-        samplers[index])
+        self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
+                                             index).element_sampler,
+        samplers[index].element_sampler)
 
   def test_multiple_outputs(self):
     """Tests that multiple PCollections have their own sampler."""
@@ -180,12 +156,25 @@ class DataSamplerTest(unittest.TestCase):
     self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
     self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
 
-    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1', 'o2'])
-    expected_samples = {
-        'o0': [PRIMITIVES_CODER.encode_nested('a')],
-        'o1': [PRIMITIVES_CODER.encode_nested('b')],
-        'o2': [PRIMITIVES_CODER.encode_nested('c')],
-    }
+    samples = self.data_sampler.wait_for_samples(['o0', 'o1', 'o2'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('a'))
+                ]),
+            'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('b'))
+                ]),
+            'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('c'))
+                ]),
+        })
     self.assertEqual(samples, expected_samples)
 
   def test_multiple_transforms(self):
@@ -218,11 +207,20 @@ class DataSamplerTest(unittest.TestCase):
         'd',
         output_index=t1_outputs['o1'],
         transform_id='t1')
-    expected_samples = {
-        'o0': [PRIMITIVES_CODER.encode_nested('a')],
-        'o1': [PRIMITIVES_CODER.encode_nested('d')],
-    }
-    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('a'))
+                ]),
+            'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('d'))
+                ]),
+        })
+    samples = self.data_sampler.wait_for_samples(['o0', 'o1'])
     self.assertEqual(samples, expected_samples)
 
     self.gen_sample(
@@ -235,11 +233,20 @@ class DataSamplerTest(unittest.TestCase):
         'c',
         output_index=t1_outputs['o0'],
         transform_id='t1')
-    expected_samples = {
-        'o0': [PRIMITIVES_CODER.encode_nested('c')],
-        'o1': [PRIMITIVES_CODER.encode_nested('b')],
-    }
-    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('c'))
+                ]),
+            'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('b'))
+                ]),
+        })
+    samples = self.data_sampler.wait_for_samples(['o0', 'o1'])
     self.assertEqual(samples, expected_samples)
 
   def test_sample_filters_single_pcollection_ids(self):
@@ -255,22 +262,37 @@ class DataSamplerTest(unittest.TestCase):
     self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
     self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
 
-    samples = self.wait_for_samples(self.data_sampler, ['o0'])
-    expected_samples = {
-        'o0': [PRIMITIVES_CODER.encode_nested('a')],
-    }
+    samples = self.data_sampler.wait_for_samples(['o0'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('a'))
+                ]),
+        })
     self.assertEqual(samples, expected_samples)
 
-    samples = self.wait_for_samples(self.data_sampler, ['o1'])
-    expected_samples = {
-        'o1': [PRIMITIVES_CODER.encode_nested('b')],
-    }
+    samples = self.data_sampler.wait_for_samples(['o1'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('b'))
+                ]),
+        })
     self.assertEqual(samples, expected_samples)
 
-    samples = self.wait_for_samples(self.data_sampler, ['o2'])
-    expected_samples = {
-        'o2': [PRIMITIVES_CODER.encode_nested('c')],
-    }
+    samples = self.data_sampler.wait_for_samples(['o2'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('c'))
+                ]),
+        })
     self.assertEqual(samples, expected_samples)
 
   def test_sample_filters_multiple_pcollection_ids(self):
@@ -286,24 +308,45 @@ class DataSamplerTest(unittest.TestCase):
     self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
     self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
 
-    samples = self.wait_for_samples(self.data_sampler, ['o0', 'o2'])
-    expected_samples = {
-        'o0': [PRIMITIVES_CODER.encode_nested('a')],
-        'o2': [PRIMITIVES_CODER.encode_nested('c')],
-    }
+    samples = self.data_sampler.wait_for_samples(['o0', 'o2'])
+    expected_samples = beam_fn_api_pb2.SampleDataResponse(
+        element_samples={
+            'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('a'))
+                ]),
+            'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                elements=[
+                    beam_fn_api_pb2.SampledElement(
+                        element=PRIMITIVES_CODER.encode_nested('c'))
+                ]),
+        })
     self.assertEqual(samples, expected_samples)
 
+  def test_can_sample_exceptions(self):
+    """Tests that exceptions sampled can be queried by the DataSampler."""
+    descriptor = self.make_test_descriptor()
+    self.data_sampler.initialize_samplers(
+        MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
+    sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
+
+    exc_info = None
+    try:
+      raise Exception('test')
+    except Exception:
+      exc_info = sys.exc_info()
+
+    sampler.sample_exception('a', exc_info, MAIN_TRANSFORM_ID, 'instid')
+
+    samples = self.data_sampler.wait_for_samples([MAIN_PCOLLECTION_ID])
+    self.assertGreater(len(samples.element_samples), 0)
 
-class OutputSamplerTest(unittest.TestCase):
-  def setUp(self):
-    self.fake_clock = FakeClock()
 
+class OutputSamplerTest(unittest.TestCase):
   def tearDown(self):
     self.sampler.stop()
 
-  def control_time(self, new_time):
-    self.fake_clock.clock = new_time
-
   def wait_for_samples(self, output_sampler: OutputSampler, expected_num: int):
     """Waits for the expected number of samples for the given sampler."""
     now = time.time()
@@ -322,31 +365,6 @@ class OutputSamplerTest(unittest.TestCase):
 
     self.assertLess(now, end, 'Timed out waiting for samples')
 
-  def ensure_sample(
-      self, output_sampler: OutputSampler, sample: Any, expected_num: int):
-    """Generates a sample and waits for it to be available."""
-
-    element_sampler = output_sampler.element_sampler
-
-    now = time.time()
-    end = now + 30
-
-    while now < end:
-      element_sampler.el = sample
-      element_sampler.has_element = True
-      time.sleep(0.1)
-      now = time.time()
-      samples = output_sampler.flush(clear=False)
-
-      if not samples:
-        continue
-
-      if len(samples) == expected_num:
-        return samples
-
-    self.assertLess(
-        now, end, 'Timed out waiting for sample "{sample}" to be generated.')
-
   def test_can_sample(self):
     """Tests that the underlying timer can sample."""
     self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0.05)
@@ -354,8 +372,13 @@ class OutputSamplerTest(unittest.TestCase):
     element_sampler.el = 'a'
     element_sampler.has_element = True
 
-    samples = self.wait_for_samples(self.sampler, expected_num=1)
-    self.assertEqual(samples, [PRIMITIVES_CODER.encode_nested('a')])
+    self.wait_for_samples(self.sampler, expected_num=1)
+    self.assertEqual(
+        self.sampler.flush(),
+        [
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested('a'))
+        ])
 
   def test_acts_like_circular_buffer(self):
     """Tests that the buffer overwrites old samples."""
@@ -370,19 +393,28 @@ class OutputSamplerTest(unittest.TestCase):
 
     self.assertEqual(
         self.sampler.flush(),
-        [PRIMITIVES_CODER.encode_nested(i) for i in (8, 9)])
+        [
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested(i)) for i in (8, 9)
+        ])
 
   def test_samples_multiple_times(self):
-    """Tests that the buffer overwrites old samples."""
+    """Tests that the underlying timer repeats."""
     self.sampler = OutputSampler(
         PRIMITIVES_CODER, max_samples=10, sample_every_sec=0.05)
 
     # Always samples the first ten.
     for i in range(10):
-      self.ensure_sample(self.sampler, i, i + 1)
+      self.sampler.element_sampler.el = i
+      self.sampler.element_sampler.has_element = True
+      self.wait_for_samples(self.sampler, i + 1)
+
     self.assertEqual(
         self.sampler.flush(),
-        [PRIMITIVES_CODER.encode_nested(i) for i in range(10)])
+        [
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested(i)) for i in range(10)
+        ])
 
   def test_can_sample_windowed_value(self):
     """Tests that values with WindowedValueCoders are sampled wholesale."""
@@ -395,7 +427,9 @@ class OutputSamplerTest(unittest.TestCase):
     element_sampler.has_element = True
     self.sampler.sample()
 
-    self.assertEqual(self.sampler.flush(), [coder.encode_nested(value)])
+    self.assertEqual(
+        self.sampler.flush(),
+        [beam_fn_api_pb2.SampledElement(element=coder.encode_nested(value))])
 
   def test_can_sample_non_windowed_value(self):
     """Tests that windowed values with WindowedValueCoders sample only the
@@ -414,7 +448,69 @@ class OutputSamplerTest(unittest.TestCase):
     self.sampler.sample()
 
     self.assertEqual(
-        self.sampler.flush(), [PRIMITIVES_CODER.encode_nested('Hello, 
World!')])
+        self.sampler.flush(),
+        [
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested('Hello, World!'))
+        ])
+
+  def test_can_sample_exceptions(self):
+    """Tests that exceptions are sampled."""
+    val = WindowedValue('Hello, World!', 0, [GlobalWindow()])
+    exc_info = None
+    try:
+      raise Exception('test')
+    except Exception:
+      exc_info = sys.exc_info()
+      err_string = ''.join(traceback.format_exception(*exc_info))
+
+    self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
+    self.sampler.sample_exception(
+        el=val, exc_info=exc_info, transform_id='tid', instruction_id='instid')
+
+    self.assertEqual(
+        self.sampler.flush(),
+        [
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested('Hello, World!'),
+                exception=beam_fn_api_pb2.SampledElement.Exception(
+                    instruction_id='instid',
+                    transform_id='tid',
+                    error=err_string))
+        ])
+
+  def test_can_sample_multiple_exceptions(self):
+    """Tests that multiple exceptions in the same PCollection are sampled."""
+    exc_info = None
+    try:
+      raise Exception('test')
+    except Exception:
+      exc_info = sys.exc_info()
+      err_string = ''.join(traceback.format_exception(*exc_info))
+
+    self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
+    self.sampler.sample_exception(
+        el='a', exc_info=exc_info, transform_id='tid', instruction_id='instid')
+
+    self.sampler.sample_exception(
+        el='b', exc_info=exc_info, transform_id='tid', instruction_id='instid')
+
+    self.assertEqual(
+        self.sampler.flush(),
+        [
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested('a'),
+                exception=beam_fn_api_pb2.SampledElement.Exception(
+                    instruction_id='instid',
+                    transform_id='tid',
+                    error=err_string)),
+            beam_fn_api_pb2.SampledElement(
+                element=PRIMITIVES_CODER.encode_nested('b'),
+                exception=beam_fn_api_pb2.SampledElement.Exception(
+                    instruction_id='instid',
+                    transform_id='tid',
+                    error=err_string)),
+        ])
 
 
 if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd 
b/sdks/python/apache_beam/runners/worker/operations.pxd
index 725c5d2346a..bc8197ef84f 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -34,7 +34,9 @@ cdef class ConsumerSet(Receiver):
   cdef public step_name
   cdef public output_index
   cdef public coder
+  cdef public object output_sampler
   cdef public object element_sampler
+  cdef public object execution_context
 
   cpdef update_counters_start(self, WindowedValue windowed_value)
   cpdef update_counters_finish(self)
@@ -82,6 +84,8 @@ cdef class Operation(object):
   cdef readonly object scoped_process_state
   cdef readonly object scoped_finish_state
 
+  cdef readonly object data_sampler
+
   cpdef start(self)
   cpdef process(self, WindowedValue windowed_value)
   cpdef finish(self)
diff --git a/sdks/python/apache_beam/runners/worker/operations.py 
b/sdks/python/apache_beam/runners/worker/operations.py
index ca5e6552973..7eb14d29593 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -53,7 +53,7 @@ from apache_beam.runners.common import Receiver
 from apache_beam.runners.worker import opcounters
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import sideinputs
-from apache_beam.runners.worker.data_sampler import ElementSampler
+from apache_beam.runners.worker.data_sampler import DataSampler
 from apache_beam.transforms import sideinputs as apache_sideinputs
 from apache_beam.transforms import combiners
 from apache_beam.transforms import core
@@ -70,7 +70,6 @@ if TYPE_CHECKING:
   from apache_beam.runners.sdf_utils import SplitResultPrimary
   from apache_beam.runners.sdf_utils import SplitResultResidual
   from apache_beam.runners.worker.bundle_processor import ExecutionContext
-  from apache_beam.runners.worker.data_sampler import DataSampler
   from apache_beam.runners.worker.data_sampler import OutputSampler
   from apache_beam.runners.worker.statesampler import StateSampler
   from apache_beam.transforms.userstate import TimerSpec
@@ -126,7 +125,7 @@ class ConsumerSet(Receiver):
              coder,
              producer_type_hints,
              producer_batch_converter, # type: Optional[BatchConverter]
-             element_sampler=None,  # type: Optional[ElementSampler]
+             output_sampler=None,  # type: Optional[OutputSampler]
              ):
     # type: (...) -> ConsumerSet
     if len(consumers) == 1:
@@ -144,7 +143,7 @@ class ConsumerSet(Receiver):
             consumer,
             coder,
             producer_type_hints,
-            element_sampler)
+            output_sampler)
 
     return GeneralPurposeConsumerSet(
         counter_factory,
@@ -154,7 +153,7 @@ class ConsumerSet(Receiver):
         producer_type_hints,
         consumers,
         producer_batch_converter,
-        element_sampler)
+        output_sampler)
 
   def __init__(self,
                counter_factory,
@@ -164,7 +163,7 @@ class ConsumerSet(Receiver):
                coder,
                producer_type_hints,
                producer_batch_converter,
-               element_sampler
+               output_sampler
                ):
     self.opcounter = opcounters.OperationCounters(
         counter_factory,
@@ -178,7 +177,10 @@ class ConsumerSet(Receiver):
     self.output_index = output_index
     self.coder = coder
     self.consumers = consumers
-    self.element_sampler = element_sampler
+    self.output_sampler = output_sampler
+    self.element_sampler = (
+        output_sampler.element_sampler if output_sampler else None)
+    self.execution_context = None  # type: Optional[ExecutionContext]
 
   def try_split(self, fraction_of_remainder):
     # type: (...) -> Optional[Any]
@@ -211,9 +213,11 @@ class ConsumerSet(Receiver):
     # between here and the DataSampler as an additional operation. The tradeoff
     # is that some samples might be dropped, but it is better than the
     # alternative which is double sampling the same element.
-    if self.element_sampler is not None:
-      self.element_sampler.el = windowed_value
-      self.element_sampler.has_element = True
+    if self.element_sampler is not None and self.execution_context is not None:
+      self.execution_context.output_sampler = self.output_sampler
+      if not self.element_sampler.has_element:
+        self.element_sampler.el = windowed_value
+        self.element_sampler.has_element = True
 
   def update_counters_finish(self):
     # type: () -> None
@@ -242,7 +246,7 @@ class SingletonElementConsumerSet(ConsumerSet):
                consumer,  # type: Operation
                coder,
                producer_type_hints,
-               element_sampler
+               output_sampler
                ):
     super().__init__(
         counter_factory,
@@ -251,7 +255,7 @@ class SingletonElementConsumerSet(ConsumerSet):
         coder,
         producer_type_hints,
         None,
-        element_sampler)
+        output_sampler)
     self.consumer = consumer
 
   def receive(self, windowed_value):
@@ -289,7 +293,7 @@ class GeneralPurposeConsumerSet(ConsumerSet):
                producer_type_hints,
                consumers,  # type: List[Operation]
                producer_batch_converter,
-               element_sampler):
+               output_sampler):
     super().__init__(
         counter_factory,
         step_name,
@@ -298,7 +302,7 @@ class GeneralPurposeConsumerSet(ConsumerSet):
         coder,
         producer_type_hints,
         producer_batch_converter,
-        element_sampler)
+        output_sampler)
 
     self.producer_batch_converter = producer_batch_converter
 
@@ -453,6 +457,7 @@ class Operation(object):
     # on the operation.
     self.setup_done = False
     self.step_name = None  # type: Optional[str]
+    self.data_sampler: Optional[DataSampler] = None
 
   def setup(self, data_sampler=None):
     # type: (Optional[DataSampler]) -> None
@@ -461,6 +466,7 @@ class Operation(object):
 
     This must be called before any other methods of the operation."""
     with self.scoped_start_state:
+      self.data_sampler = data_sampler
       self.debug_logging_enabled = logging.getLogger().isEnabledFor(
           logging.DEBUG)
       transform_id = self.name_context.transform_id
@@ -470,7 +476,7 @@ class Operation(object):
       #TODO(pabloem): Define better what step name is used here.
       if getattr(self.spec, 'output_coders', None):
 
-        def get_element_sampler(output_num):
+        def get_output_sampler(output_num):
           if data_sampler is None:
             return None
           return data_sampler.sampler_for_output(transform_id, output_num)
@@ -484,7 +490,7 @@ class Operation(object):
                 coder,
                 self._get_runtime_performance_hints(),
                 self.get_output_batch_converter(),
-                get_element_sampler(i)) for i,
+                get_output_sampler(i)) for i,
             coder in enumerate(self.spec.output_coders)
         ]
     self.setup_done = True
@@ -495,7 +501,13 @@ class Operation(object):
     """Start operation."""
     if not self.setup_done:
       # For legacy workers.
-      self.setup()
+      self.setup(self.data_sampler)
+
+    # The ExecutionContext is per instruction and so cannot be set at
+    # initialization time.
+    if self.data_sampler is not None:
+      for receiver in self.receivers:
+        receiver.execution_context = self.execution_context
 
   def get_batching_preference(self):
     # By default operations don't support batching, require Receiver to unbatch
@@ -908,6 +920,7 @@ class DoOperation(Operation):
           step_name=self.name_context.logging_name(),
           state=state,
           user_state_context=self.user_state_context,
+          transform_id=self.name_context.transform_id,
           operation_name=self.name_context.metrics_name())
       self.dofn_runner.setup()
 
@@ -915,6 +928,7 @@ class DoOperation(Operation):
     # type: () -> None
     with self.scoped_start_state:
       super(DoOperation, self).start()
+      self.dofn_runner.execution_context = self.execution_context
       self.dofn_runner.start()
 
   def get_batching_preference(self):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index f5b1456c251..bfd6544d802 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -380,18 +380,12 @@ class SdkHarness(object):
 
     def get_samples(request):
       # type: (beam_fn_api_pb2.InstructionRequest) -> 
beam_fn_api_pb2.InstructionResponse
-      samples: Dict[str, List[bytes]] = {}
+      samples = beam_fn_api_pb2.SampleDataResponse()
       if self.data_sampler is not None:
         samples = 
self.data_sampler.samples(request.sample_data.pcollection_ids)
 
-      sample_response = beam_fn_api_pb2.SampleDataResponse()
-      for pcoll_id in samples:
-        sample_response.element_samples[pcoll_id].elements.extend(
-            beam_fn_api_pb2.SampledElement(element=s)
-            for s in samples[pcoll_id])
-
       return beam_fn_api_pb2.InstructionResponse(
-          instruction_id=request.instruction_id, sample_data=sample_response)
+          instruction_id=request.instruction_id, sample_data=samples)
 
     self._execute(lambda: get_samples(request), request)
 
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index d12dc48ed29..ccfc27d0101 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -283,10 +283,19 @@ class SdkWorkerTest(unittest.TestCase):
 
     class FakeDataSampler:
       def samples(self, pcollection_ids):
-        return {
-            'pcoll_id_1': [coder.encode_nested('a')],
-            'pcoll_id_2': [coder.encode_nested('b')],
-        }
+        return beam_fn_api_pb2.SampleDataResponse(
+            element_samples={
+                'pcoll_id_1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                    elements=[
+                        beam_fn_api_pb2.SampledElement(
+                            element=coder.encode_nested('a'))
+                    ]),
+                'pcoll_id_2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+                    elements=[
+                        beam_fn_api_pb2.SampledElement(
+                            element=coder.encode_nested('b'))
+                    ])
+            })
 
       def stop(self):
         pass

Reply via email to