This is an automated email from the ASF dual-hosted git repository.
udim pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 23fe94b3af2 Implement Exception Sampling in the Python SDK (#27280)
23fe94b3af2 is described below
commit 23fe94b3af28d1a7d55f994dbddc9669a6567cb1
Author: Sam Rohde <[email protected]>
AuthorDate: Wed Jun 28 16:48:13 2023 -0700
Implement Exception Sampling in the Python SDK (#27280)
* Python exception sampling implementation
* add to cython def
* add more cython defs
* address comments
* fix circular imports
* linter
* fix tests
* remove print
* Add traceback to exception
* fix tests
---------
Co-authored-by: Sam Rohde <[email protected]>
---
sdks/python/apache_beam/runners/common.pxd | 2 +
sdks/python/apache_beam/runners/common.py | 35 ++-
.../apache_beam/runners/worker/bundle_processor.py | 20 +-
.../runners/worker/bundle_processor_test.py | 142 +++++++--
.../apache_beam/runners/worker/data_sampler.py | 139 +++++++--
.../runners/worker/data_sampler_test.py | 330 +++++++++++++--------
.../apache_beam/runners/worker/operations.pxd | 4 +
.../apache_beam/runners/worker/operations.py | 48 +--
.../apache_beam/runners/worker/sdk_worker.py | 10 +-
.../apache_beam/runners/worker/sdk_worker_test.py | 17 +-
10 files changed, 535 insertions(+), 212 deletions(-)
diff --git a/sdks/python/apache_beam/runners/common.pxd
b/sdks/python/apache_beam/runners/common.pxd
index d1745970d26..9fb44af6377 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -121,6 +121,8 @@ cdef class DoFnRunner:
cdef list side_inputs
cdef DoFnInvoker do_fn_invoker
cdef public object bundle_finalizer_param
+ cdef str transform_id
+ cdef object execution_context
cpdef process(self, WindowedValue windowed_value)
diff --git a/sdks/python/apache_beam/runners/common.py
b/sdks/python/apache_beam/runners/common.py
index 75eb85f2110..e91199787e5 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -66,6 +66,7 @@ from apache_beam.utils.windowed_value import WindowedBatch
from apache_beam.utils.windowed_value import WindowedValue
if TYPE_CHECKING:
+ from apache_beam.runners.worker.bundle_processor import ExecutionContext
from apache_beam.transforms import sideinputs
from apache_beam.transforms.core import TimerSpec
from apache_beam.io.iobase import RestrictionProgress
@@ -1338,7 +1339,8 @@ class DoFnRunner:
state=None,
scoped_metrics_container=None,
operation_name=None,
- user_state_context=None # type:
Optional[userstate.UserStateContext]
+ transform_id=None,
+ user_state_context=None, # type:
Optional[userstate.UserStateContext]
):
"""Initializes a DoFnRunner.
@@ -1354,6 +1356,7 @@ class DoFnRunner:
state: handle for accessing DoFn state
scoped_metrics_container: DEPRECATED
operation_name: The system name assigned by the runner for this
operation.
+ transform_id: The PTransform Id in the pipeline proto for this DoFn.
user_state_context: The UserStateContext instance for the current
Stateful DoFn.
"""
@@ -1361,8 +1364,10 @@ class DoFnRunner:
side_inputs = list(side_inputs)
self.step_name = step_name
+ self.transform_id = transform_id
self.context = DoFnContext(step_name, state=state)
self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
+ self.execution_context = None # type: Optional[ExecutionContext]
do_fn_signature = DoFnSignature(fn)
@@ -1417,9 +1422,25 @@ class DoFnRunner:
try:
return self.do_fn_invoker.invoke_process(windowed_value)
except BaseException as exn:
- self._reraise_augmented(exn)
+ self._reraise_augmented(exn, windowed_value)
return []
+ def _maybe_sample_exception(
+ self, exn: BaseException, windowed_value: WindowedValue) -> None:
+
+ if self.execution_context is None:
+ return
+
+ output_sampler = self.execution_context.output_sampler
+ if output_sampler is None:
+ return
+
+ output_sampler.sample_exception(
+ windowed_value,
+ exn,
+ self.transform_id,
+ self.execution_context.instruction_id)
+
def process_batch(self, windowed_batch):
# type: (WindowedBatch) -> None
try:
@@ -1487,7 +1508,7 @@ class DoFnRunner:
# type: () -> None
self.bundle_finalizer_param.finalize_bundle()
- def _reraise_augmented(self, exn):
+ def _reraise_augmented(self, exn, windowed_value=None):
if getattr(exn, '_tagged_with_step', False) or not self.step_name:
raise exn
step_annotation = " [while running '%s']" % self.step_name
@@ -1504,8 +1525,12 @@ class DoFnRunner:
traceback.format_exception_only(type(exn), exn)[-1].strip() +
step_annotation)
new_exn._tagged_with_step = True
- _, _, tb = sys.exc_info()
- raise new_exn.with_traceback(tb)
+ exc_info = sys.exc_info()
+ _, _, tb = exc_info
+
+ new_exn = new_exn.with_traceback(tb)
+ self._maybe_sample_exception(exc_info, windowed_value)
+ raise new_exn
class OutputHandler(object):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 3adb26552d0..b3ca92aa9a1 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -27,6 +27,8 @@ import json
import logging
import random
import threading
+from dataclasses import dataclass
+from dataclasses import field
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
@@ -986,7 +988,7 @@ class BundleProcessor(object):
expected_input_ops.append(op)
try:
- execution_context = ExecutionContext()
+ execution_context = ExecutionContext(instruction_id=instruction_id)
self.current_instruction_id = instruction_id
self.state_sampler.start()
# Start all operations.
@@ -1181,10 +1183,18 @@ class BundleProcessor(object):
op.teardown()
-class ExecutionContext(object):
- def __init__(self):
- self.delayed_applications = [
- ] # type: List[Tuple[operations.DoOperation, common.SplitResultResidual]]
+@dataclass
+class ExecutionContext:
+ # Any splits to be processed later.
+ delayed_applications: List[Tuple[operations.DoOperation,
+ common.SplitResultResidual]] = field(
+ default_factory=list)
+
+ # The exception sampler for the currently executing PTransform.
+ output_sampler: Optional[data_sampler.OutputSampler] = None
+
+ # The current instruction being executed.
+ instruction_id: Optional[str] = None
class BeamTransformFactory(object):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
index db9e35a0baf..8b81c9f17ac 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
@@ -18,15 +18,16 @@
"""Unit tests for bundle processing."""
# pytype: skip-file
-import time
import unittest
from typing import Dict
from typing import List
+import apache_beam as beam
from apache_beam.coders.coders import FastPrimitivesCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.runners import common
+from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import operations
from apache_beam.runners.worker.bundle_processor import BeamTransformFactory
from apache_beam.runners.worker.bundle_processor import BundleProcessor
@@ -240,6 +241,25 @@ def create_test_op(factory, transform_id, transform_proto,
payload, consumers):
payload)
[email protected]_urn('beam:internal:testexn:v1', bytes)
+def create_exception_dofn(
+ factory, transform_id, transform_proto, payload, consumers):
+ """Returns a test DoFn that raises the given exception."""
+ class RaiseException(beam.DoFn):
+ def __init__(self, msg):
+ self.msg = msg.decode()
+
+ def process(self, _):
+ raise RuntimeError(self.msg)
+
+ return bundle_processor._create_simple_pardo_operation(
+ factory,
+ transform_id,
+ transform_proto,
+ consumers,
+ RaiseException(payload))
+
+
class DataSamplingTest(unittest.TestCase):
def test_disabled_by_default(self):
"""Test that not providing the sampler does not enable Data Sampling.
@@ -252,32 +272,12 @@ class DataSamplingTest(unittest.TestCase):
_ = BundleProcessor(descriptor, None, None)
self.assertEqual(len(descriptor.transforms), 0)
- def wait_for_samples(self, data_sampler: DataSampler,
- pcollection_id: str) -> Dict[str, List[bytes]]:
- """Waits for samples from the given PCollection to exist."""
- now = time.time()
- end = now + 30
-
- samples = {}
- while now < end:
- time.sleep(0.1)
- now = time.time()
- samples.update(data_sampler.samples([pcollection_id]))
-
- if samples:
- return samples
-
- self.assertLess(
- now, end, 'Timed out waiting for samples for
{}'.format(pcollection_id))
- return {}
-
def test_can_sample(self):
"""Test that elements are sampled.
This is a small integration test with the BundleProcessor and the
- DataSampler. It ensures that the BundleProcessor correctly makes
- DataSamplingOperations and samples are taken from in-flight elements. These
- elements are then finally queried.
+ DataSampler. It ensures that samples are taken from in-flight elements.
+ These elements are then finally queried.
"""
data_sampler = DataSampler(sample_every_sec=0.1)
descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
@@ -306,8 +306,100 @@ class DataSamplingTest(unittest.TestCase):
descriptor, None, None, data_sampler=data_sampler)
processor.process_bundle('instruction_id')
- samples = self.wait_for_samples(data_sampler, PCOLLECTION_ID)
- self.assertEqual(samples, {PCOLLECTION_ID: [b'\rhello, world!']})
+ samples = data_sampler.wait_for_samples([PCOLLECTION_ID])
+ expected = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ PCOLLECTION_ID: beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=b'\rhello, world!')
+ ])
+ })
+ self.assertEqual(samples, expected)
+ finally:
+ data_sampler.stop()
+
+ def test_can_sample_exceptions(self):
+ """Test that exceptions are sampled."""
+ data_sampler = DataSampler(sample_every_sec=0.1)
+ descriptor = beam_fn_api_pb2.ProcessBundleDescriptor()
+
+ # Boiler plate for the DoFn.
+ WINDOWING_ID = 'window'
+ WINDOW_CODER_ID = 'cw'
+ window = descriptor.windowing_strategies[WINDOWING_ID]
+ window.window_fn.urn = common_urns.global_windows.urn
+ window.window_coder_id = WINDOW_CODER_ID
+ window.trigger.default.SetInParent()
+ window_coder = descriptor.coders[WINDOW_CODER_ID]
+ window_coder.spec.urn = common_urns.StandardCoders.Enum.GLOBAL_WINDOW.urn
+
+ # Input collection to the exception raising DoFn.
+ INPUT_PCOLLECTION_ID = 'pc-in'
+ INPUT_CODER_ID = 'c-in'
+ descriptor.pcollections[
+ INPUT_PCOLLECTION_ID].unique_name = INPUT_PCOLLECTION_ID
+ descriptor.pcollections[INPUT_PCOLLECTION_ID].coder_id = INPUT_CODER_ID
+ descriptor.pcollections[
+ INPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
+ descriptor.coders[
+ INPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
+
+ # Output collection to the exception raising DoFn. Because the transform
+ # "failed" to process the input element, we do NOT expect to see a sample
in
+ # this PCollection.
+ OUTPUT_PCOLLECTION_ID = 'pc-out'
+ OUTPUT_CODER_ID = 'c-out'
+ descriptor.pcollections[
+ OUTPUT_PCOLLECTION_ID].unique_name = OUTPUT_PCOLLECTION_ID
+ descriptor.pcollections[OUTPUT_PCOLLECTION_ID].coder_id = OUTPUT_CODER_ID
+ descriptor.pcollections[
+ OUTPUT_PCOLLECTION_ID].windowing_strategy_id = WINDOWING_ID
+ descriptor.coders[
+ OUTPUT_CODER_ID].spec.urn = common_urns.StandardCoders.Enum.BYTES.urn
+
+ # Add a simple transform to inject an element into the data sampler. This
+ # doesn't use the FnApi, so this uses a simple operation to forward its
+ # payload to consumers.
+ TEST_OP_TRANSFORM_ID = 'test_op'
+ test_transform = descriptor.transforms[TEST_OP_TRANSFORM_ID]
+ test_transform.outputs['None'] = INPUT_PCOLLECTION_ID
+ test_transform.spec.urn = 'beam:internal:testop:v1'
+ test_transform.spec.payload = b'hello, world!'
+
+ # Add the DoFn to create an exception to sample from.
+ TEST_EXCEPTION_TRANSFORM_ID = 'test_transform'
+ test_transform = descriptor.transforms[TEST_EXCEPTION_TRANSFORM_ID]
+ test_transform.inputs['0'] = INPUT_PCOLLECTION_ID
+ test_transform.outputs['None'] = OUTPUT_PCOLLECTION_ID
+ test_transform.spec.urn = 'beam:internal:testexn:v1'
+ test_transform.spec.payload = b'expected exception'
+
+ try:
+ # Create and process a fake bundle. The instruction id doesn't matter
+ # here.
+ processor = BundleProcessor(
+ descriptor, None, None, data_sampler=data_sampler)
+
+ with self.assertRaisesRegex(RuntimeError, 'expected exception'):
+ processor.process_bundle('instruction_id')
+
+ # NOTE: The expected sample comes from the input PCollection. This is
very
+ # important because there can be coder issues if the sample is put in the
+ # wrong PCollection.
+ samples = data_sampler.wait_for_samples([INPUT_PCOLLECTION_ID])
+ self.assertEqual(len(samples.element_samples), 1)
+
+ element = samples.element_samples[INPUT_PCOLLECTION_ID].elements[0]
+ self.assertEqual(element.element, b'\rhello, world!')
+ self.assertTrue(element.HasField('exception'))
+
+ exception = element.exception
+ self.assertEqual(exception.instruction_id, 'instruction_id')
+ self.assertEqual(exception.transform_id, TEST_EXCEPTION_TRANSFORM_ID)
+ self.assertRegex(
+ exception.error, 'Traceback(\n|.)*RuntimeError: expected exception')
+
finally:
data_sampler.stop()
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler.py
b/sdks/python/apache_beam/runners/worker/data_sampler.py
index 9c74188a699..a5992b9ceba 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler.py
@@ -24,14 +24,17 @@ from __future__ import annotations
import collections
import logging
import threading
+import time
+import traceback
+from dataclasses import dataclass
from threading import Timer
from typing import Any
-from typing import DefaultDict
from typing import Deque
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
+from typing import Tuple
from typing import Union
from apache_beam.coders.coder_impl import CoderImpl
@@ -50,19 +53,32 @@ class SampleTimer:
self._timer = Timer(self._timeout_secs, self.sample)
self._sampler = sampler
- def reset(self):
+ def reset(self) -> None:
self._timer.cancel()
self._timer = Timer(self._timeout_secs, self.sample)
self._timer.start()
- def stop(self):
+ def stop(self) -> None:
self._timer.cancel()
- def sample(self):
+ def sample(self) -> None:
self._sampler.sample()
self.reset()
+@dataclass
+class ExceptionMetadata:
+ # The repr-ified Exception.
+ msg: str
+
+ # The transform where the exception occured.
+ transform_id: str
+
+ # The instruction when the exception occured.
+ instruction_id: str
+
+
+@dataclass
class ElementSampler:
"""Record class to hold sampled elements.
@@ -71,12 +87,12 @@ class ElementSampler:
"""
# Is true iff the `el` has been set with a sample.
- has_element: bool
+ has_element: bool = False
# The sampled element. Note that `None` is a valid element and cannot be uesd
# as a sentintel to check if there is a sample. Use the `has_element` flag to
# check for this case.
- el: Any
+ el: Any = None
class OutputSampler:
@@ -96,6 +112,8 @@ class OutputSampler:
self._sample_timer = SampleTimer(sample_every_sec, self)
self.element_sampler = ElementSampler()
self.element_sampler.has_element = False
+ self._exceptions: Deque[Tuple[Any, ExceptionMetadata]] = collections.deque(
+ maxlen=max_samples)
# For testing, it's easier to disable the Timer and manually sample.
if sample_every_sec > 0:
@@ -115,26 +133,64 @@ class OutputSampler:
el = el.value
return el
- def flush(self, clear: bool = True) -> List[bytes]:
+ def flush(self, clear: bool = True) -> List[beam_fn_api_pb2.SampledElement]:
"""Returns all samples and optionally clears buffer if clear is True."""
with self._samples_lock:
+ # TODO(rohdesamuel): There can duplicates between the exceptions and
+ # samples. This happens when the OutputSampler samples during an
+ # exception. The fix is to create a OutputSampler per process bundle.
+ # Until then use a set to keep track of the elements.
+ seen = set(id(el) for el, _ in self._exceptions)
if isinstance(self._coder_impl, WindowedValueCoderImpl):
- samples = [s for s in self._samples]
+ exceptions = [s for s in self._exceptions]
+ samples = [s for s in self._samples if id(s) not in seen]
else:
- samples = [self.remove_windowed_value(s) for s in self._samples]
+ exceptions = [
+ (self.remove_windowed_value(a), b) for a, b in self._exceptions
+ ]
+ samples = [
+ self.remove_windowed_value(s) for s in self._samples
+ if id(s) not in seen
+ ]
# Encode in the nested context b/c this ensures that the SDK can decode
# the bytes with the ToStringFn.
if clear:
self._samples.clear()
- return [self._coder_impl.encode_nested(s) for s in samples]
+ self._exceptions.clear()
+
+ ret = [
+ beam_fn_api_pb2.SampledElement(
+ element=self._coder_impl.encode_nested(s),
+ ) for s in samples
+ ]
+
+ ret.extend(
+ beam_fn_api_pb2.SampledElement(
+ element=self._coder_impl.encode_nested(s),
+ exception=beam_fn_api_pb2.SampledElement.Exception(
+ instruction_id=exn.instruction_id,
+ transform_id=exn.transform_id,
+ error=exn.msg)) for s,
+ exn in exceptions)
+
+ return ret
def sample(self) -> None:
"""Samples the given element to an internal buffer."""
with self._samples_lock:
if self.element_sampler.has_element:
- self.element_sampler.has_element = False
self._samples.append(self.element_sampler.el)
+ self.element_sampler.has_element = False
+
+ def sample_exception(
+ self, el: Any, exc_info: Any, transform_id: str,
+ instruction_id: str) -> None:
+ """Adds the given exception to the samples."""
+ with self._samples_lock:
+ err_string = ''.join(traceback.format_exception(*exc_info))
+ self._exceptions.append(
+ (el, ExceptionMetadata(err_string, transform_id, instruction_id)))
class DataSampler:
@@ -160,7 +216,7 @@ class DataSampler:
self._samplers_lock: threading.Lock = threading.Lock()
self._max_samples = max_samples
self._sample_every_sec = sample_every_sec
- self._element_samplers: Dict[str, List[ElementSampler]] = {}
+ self._samplers_by_output: Dict[str, List[OutputSampler]] = {}
self._clock = clock
def stop(self) -> None:
@@ -171,24 +227,26 @@ class DataSampler:
for sampler in self._samplers.values():
sampler.stop()
- def sampler_for_output(
- self, transform_id: str, output_index: int) -> ElementSampler:
- """Returns the ElementSampler for the given output."""
+ def sampler_for_output(self, transform_id: str,
+ output_index: int) -> Optional[OutputSampler]:
+ """Returns the OutputSampler for the given output."""
try:
- return self._element_samplers[transform_id][output_index]
+ with self._samplers_lock:
+ outputs = self._samplers_by_output[transform_id]
+ return outputs[output_index]
except KeyError:
_LOGGER.warning(
f'Out-of-bounds access for transform "{transform_id}" ' +
- 'and output "{output_index}" ElementSampler. This may ' +
+ 'and output "{output_index}" OutputSampler. This may ' +
'indicate that the transform was improperly ' +
'initialized with the DataSampler.')
- return ElementSampler()
+ return None
def initialize_samplers(
self,
transform_id: str,
descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
- coder_factory) -> List[ElementSampler]:
+ coder_factory) -> List[OutputSampler]:
"""Creates the OutputSamplers for the given PTransform.
This initializes the samplers only once per PCollection Id. Note that an
@@ -198,6 +256,9 @@ class DataSampler:
"""
transform_proto = descriptor.transforms[transform_id]
with self._samplers_lock:
+ if transform_id in self._samplers_by_output:
+ return self._samplers_by_output[transform_id]
+
# Initialize the samplers.
for pcoll_id in transform_proto.outputs.values():
# Only initialize new PCollections.
@@ -215,28 +276,22 @@ class DataSampler:
# Operations look up the ElementSampler for an output based on the index
# of the tag in the PTransform's outputs. The following code intializes
# the array with ElementSamplers in the correct indices.
- if transform_id in self._element_samplers:
- return self._element_samplers[transform_id]
-
outputs = transform_proto.outputs
- samplers = [
- self._samplers[pcoll_id].element_sampler
- for pcoll_id in outputs.values()
- ]
- self._element_samplers[transform_id] = samplers
+ samplers = [self._samplers[pcoll_id] for pcoll_id in outputs.values()]
+ self._samplers_by_output[transform_id] = samplers
return samplers
def samples(
self,
pcollection_ids: Optional[Iterable[str]] = None
- ) -> Dict[str, List[bytes]]:
+ ) -> beam_fn_api_pb2.SampleDataResponse:
"""Returns samples filtered PCollection ids.
All samples from the given PCollections are returned. Empty lists are
wildcards.
"""
- ret: DefaultDict[str, List[bytes]] = collections.defaultdict(lambda: [])
+ ret = beam_fn_api_pb2.SampleDataResponse()
with self._samplers_lock:
samplers = self._samplers.copy()
@@ -247,6 +302,28 @@ class DataSampler:
samples = samplers[pcoll_id].flush()
if samples:
- ret[pcoll_id].extend(samples)
+ ret.element_samples[pcoll_id].elements.extend(samples)
+
+ return ret
+
+ def wait_for_samples(
+ self, pcollection_ids: List[str]) -> beam_fn_api_pb2.SampleDataResponse:
+ """Waits for samples to exist for the given PCollections (only testing)."""
+ now = time.time()
+ end = now + 30
+
+ samples = beam_fn_api_pb2.SampleDataResponse()
+ while now < end:
+ time.sleep(0.1)
+ now = time.time()
+ samples.MergeFrom(self.samples(pcollection_ids))
+
+ if not samples:
+ continue
+
+ has_all = all(
+ pcoll_id in samples.element_samples for pcoll_id in pcollection_ids)
+ if has_all:
+ break
- return dict(ret)
+ return samples
diff --git a/sdks/python/apache_beam/runners/worker/data_sampler_test.py
b/sdks/python/apache_beam/runners/worker/data_sampler_test.py
index 346251ee216..b6793612183 100644
--- a/sdks/python/apache_beam/runners/worker/data_sampler_test.py
+++ b/sdks/python/apache_beam/runners/worker/data_sampler_test.py
@@ -17,7 +17,9 @@
# pytype: skip-file
+import sys
import time
+import traceback
import unittest
from typing import Any
from typing import Dict
@@ -37,14 +39,6 @@ MAIN_PCOLLECTION_ID = 'pcoll'
PRIMITIVES_CODER = FastPrimitivesCoder()
-class FakeClock:
- def __init__(self):
- self.clock = 0
-
- def time(self):
- return self.clock
-
-
class DataSamplerTest(unittest.TestCase):
def make_test_descriptor(
self,
@@ -68,32 +62,6 @@ class DataSamplerTest(unittest.TestCase):
def tearDown(self):
self.data_sampler.stop()
- def wait_for_samples(
- self, data_sampler: DataSampler,
- pcollection_ids: List[str]) -> Dict[str, List[bytes]]:
- """Waits for samples to exist for the given PCollections."""
- now = time.time()
- end = now + 30
-
- samples = {}
- while now < end:
- time.sleep(0.1)
- now = time.time()
- samples.update(data_sampler.samples(pcollection_ids))
-
- if not samples:
- continue
-
- has_all = all(pcoll_id in samples for pcoll_id in pcollection_ids)
- if has_all:
- return samples
-
- self.assertLess(
- now,
- end,
- 'Timed out waiting for samples for {}'.format(pcollection_ids))
- return {}
-
def primitives_coder_factory(self, _):
return PRIMITIVES_CODER
@@ -105,7 +73,7 @@ class DataSamplerTest(unittest.TestCase):
transform_id: str = MAIN_TRANSFORM_ID):
"""Generates a sample for the given transform's output."""
element_sampler = self.data_sampler.sampler_for_output(
- transform_id, output_index)
+ transform_id, output_index).element_sampler
element_sampler.el = element
element_sampler.has_element = True
@@ -117,10 +85,15 @@ class DataSamplerTest(unittest.TestCase):
self.gen_sample(self.data_sampler, 'a', output_index=0)
- expected_sample = {
- MAIN_PCOLLECTION_ID: [PRIMITIVES_CODER.encode_nested('a')]
- }
- samples = self.wait_for_samples(self.data_sampler, [MAIN_PCOLLECTION_ID])
+ expected_sample = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ MAIN_PCOLLECTION_ID:
beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'))
+ ])
+ })
+ samples = self.data_sampler.wait_for_samples([MAIN_PCOLLECTION_ID])
self.assertEqual(samples, expected_sample)
def test_not_initialized(self):
@@ -154,18 +127,21 @@ class DataSamplerTest(unittest.TestCase):
# samplers.
index = outputs['o0']
self.assertEqual(
- self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
- samplers[index])
+ self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
+ index).element_sampler,
+ samplers[index].element_sampler)
index = outputs['o1']
self.assertEqual(
- self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
- samplers[index])
+ self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
+ index).element_sampler,
+ samplers[index].element_sampler)
index = outputs['o2']
self.assertEqual(
- self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, index),
- samplers[index])
+ self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID,
+ index).element_sampler,
+ samplers[index].element_sampler)
def test_multiple_outputs(self):
"""Tests that multiple PCollections have their own sampler."""
@@ -180,12 +156,25 @@ class DataSamplerTest(unittest.TestCase):
self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
- samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1', 'o2'])
- expected_samples = {
- 'o0': [PRIMITIVES_CODER.encode_nested('a')],
- 'o1': [PRIMITIVES_CODER.encode_nested('b')],
- 'o2': [PRIMITIVES_CODER.encode_nested('c')],
- }
+ samples = self.data_sampler.wait_for_samples(['o0', 'o1', 'o2'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'))
+ ]),
+ 'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('b'))
+ ]),
+ 'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('c'))
+ ]),
+ })
self.assertEqual(samples, expected_samples)
def test_multiple_transforms(self):
@@ -218,11 +207,20 @@ class DataSamplerTest(unittest.TestCase):
'd',
output_index=t1_outputs['o1'],
transform_id='t1')
- expected_samples = {
- 'o0': [PRIMITIVES_CODER.encode_nested('a')],
- 'o1': [PRIMITIVES_CODER.encode_nested('d')],
- }
- samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'))
+ ]),
+ 'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('d'))
+ ]),
+ })
+ samples = self.data_sampler.wait_for_samples(['o0', 'o1'])
self.assertEqual(samples, expected_samples)
self.gen_sample(
@@ -235,11 +233,20 @@ class DataSamplerTest(unittest.TestCase):
'c',
output_index=t1_outputs['o0'],
transform_id='t1')
- expected_samples = {
- 'o0': [PRIMITIVES_CODER.encode_nested('c')],
- 'o1': [PRIMITIVES_CODER.encode_nested('b')],
- }
- samples = self.wait_for_samples(self.data_sampler, ['o0', 'o1'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('c'))
+ ]),
+ 'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('b'))
+ ]),
+ })
+ samples = self.data_sampler.wait_for_samples(['o0', 'o1'])
self.assertEqual(samples, expected_samples)
def test_sample_filters_single_pcollection_ids(self):
@@ -255,22 +262,37 @@ class DataSamplerTest(unittest.TestCase):
self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
- samples = self.wait_for_samples(self.data_sampler, ['o0'])
- expected_samples = {
- 'o0': [PRIMITIVES_CODER.encode_nested('a')],
- }
+ samples = self.data_sampler.wait_for_samples(['o0'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'))
+ ]),
+ })
self.assertEqual(samples, expected_samples)
- samples = self.wait_for_samples(self.data_sampler, ['o1'])
- expected_samples = {
- 'o1': [PRIMITIVES_CODER.encode_nested('b')],
- }
+ samples = self.data_sampler.wait_for_samples(['o1'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('b'))
+ ]),
+ })
self.assertEqual(samples, expected_samples)
- samples = self.wait_for_samples(self.data_sampler, ['o2'])
- expected_samples = {
- 'o2': [PRIMITIVES_CODER.encode_nested('c')],
- }
+ samples = self.data_sampler.wait_for_samples(['o2'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('c'))
+ ]),
+ })
self.assertEqual(samples, expected_samples)
def test_sample_filters_multiple_pcollection_ids(self):
@@ -286,24 +308,45 @@ class DataSamplerTest(unittest.TestCase):
self.gen_sample(self.data_sampler, 'b', output_index=outputs['o1'])
self.gen_sample(self.data_sampler, 'c', output_index=outputs['o2'])
- samples = self.wait_for_samples(self.data_sampler, ['o0', 'o2'])
- expected_samples = {
- 'o0': [PRIMITIVES_CODER.encode_nested('a')],
- 'o2': [PRIMITIVES_CODER.encode_nested('c')],
- }
+ samples = self.data_sampler.wait_for_samples(['o0', 'o2'])
+ expected_samples = beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'o0': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'))
+ ]),
+ 'o2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('c'))
+ ]),
+ })
self.assertEqual(samples, expected_samples)
+ def test_can_sample_exceptions(self):
+ """Tests that exceptions sampled can be queried by the DataSampler."""
+ descriptor = self.make_test_descriptor()
+ self.data_sampler.initialize_samplers(
+ MAIN_TRANSFORM_ID, descriptor, self.primitives_coder_factory)
+ sampler = self.data_sampler.sampler_for_output(MAIN_TRANSFORM_ID, 0)
+
+ exc_info = None
+ try:
+ raise Exception('test')
+ except Exception:
+ exc_info = sys.exc_info()
+
+ sampler.sample_exception('a', exc_info, MAIN_TRANSFORM_ID, 'instid')
+
+ samples = self.data_sampler.wait_for_samples([MAIN_PCOLLECTION_ID])
+ self.assertGreater(len(samples.element_samples), 0)
-class OutputSamplerTest(unittest.TestCase):
- def setUp(self):
- self.fake_clock = FakeClock()
+class OutputSamplerTest(unittest.TestCase):
def tearDown(self):
self.sampler.stop()
- def control_time(self, new_time):
- self.fake_clock.clock = new_time
-
def wait_for_samples(self, output_sampler: OutputSampler, expected_num: int):
"""Waits for the expected number of samples for the given sampler."""
now = time.time()
@@ -322,31 +365,6 @@ class OutputSamplerTest(unittest.TestCase):
self.assertLess(now, end, 'Timed out waiting for samples')
- def ensure_sample(
- self, output_sampler: OutputSampler, sample: Any, expected_num: int):
- """Generates a sample and waits for it to be available."""
-
- element_sampler = output_sampler.element_sampler
-
- now = time.time()
- end = now + 30
-
- while now < end:
- element_sampler.el = sample
- element_sampler.has_element = True
- time.sleep(0.1)
- now = time.time()
- samples = output_sampler.flush(clear=False)
-
- if not samples:
- continue
-
- if len(samples) == expected_num:
- return samples
-
- self.assertLess(
- now, end, 'Timed out waiting for sample "{sample}" to be generated.')
-
def test_can_sample(self):
"""Tests that the underlying timer can sample."""
self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0.05)
@@ -354,8 +372,13 @@ class OutputSamplerTest(unittest.TestCase):
element_sampler.el = 'a'
element_sampler.has_element = True
- samples = self.wait_for_samples(self.sampler, expected_num=1)
- self.assertEqual(samples, [PRIMITIVES_CODER.encode_nested('a')])
+ self.wait_for_samples(self.sampler, expected_num=1)
+ self.assertEqual(
+ self.sampler.flush(),
+ [
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'))
+ ])
def test_acts_like_circular_buffer(self):
"""Tests that the buffer overwrites old samples."""
@@ -370,19 +393,28 @@ class OutputSamplerTest(unittest.TestCase):
self.assertEqual(
self.sampler.flush(),
- [PRIMITIVES_CODER.encode_nested(i) for i in (8, 9)])
+ [
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested(i)) for i in (8, 9)
+ ])
def test_samples_multiple_times(self):
- """Tests that the buffer overwrites old samples."""
+ """Tests that the underlying timer repeats."""
self.sampler = OutputSampler(
PRIMITIVES_CODER, max_samples=10, sample_every_sec=0.05)
# Always samples the first ten.
for i in range(10):
- self.ensure_sample(self.sampler, i, i + 1)
+ self.sampler.element_sampler.el = i
+ self.sampler.element_sampler.has_element = True
+ self.wait_for_samples(self.sampler, i + 1)
+
self.assertEqual(
self.sampler.flush(),
- [PRIMITIVES_CODER.encode_nested(i) for i in range(10)])
+ [
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested(i)) for i in range(10)
+ ])
def test_can_sample_windowed_value(self):
"""Tests that values with WindowedValueCoders are sampled wholesale."""
@@ -395,7 +427,9 @@ class OutputSamplerTest(unittest.TestCase):
element_sampler.has_element = True
self.sampler.sample()
- self.assertEqual(self.sampler.flush(), [coder.encode_nested(value)])
+ self.assertEqual(
+ self.sampler.flush(),
+ [beam_fn_api_pb2.SampledElement(element=coder.encode_nested(value))])
def test_can_sample_non_windowed_value(self):
"""Tests that windowed values with WindowedValueCoders sample only the
@@ -414,7 +448,69 @@ class OutputSamplerTest(unittest.TestCase):
self.sampler.sample()
self.assertEqual(
- self.sampler.flush(), [PRIMITIVES_CODER.encode_nested('Hello,
World!')])
+ self.sampler.flush(),
+ [
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('Hello, World!'))
+ ])
+
+ def test_can_sample_exceptions(self):
+ """Tests that exceptions are sampled."""
+ val = WindowedValue('Hello, World!', 0, [GlobalWindow()])
+ exc_info = None
+ try:
+ raise Exception('test')
+ except Exception:
+ exc_info = sys.exc_info()
+ err_string = ''.join(traceback.format_exception(*exc_info))
+
+ self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
+ self.sampler.sample_exception(
+ el=val, exc_info=exc_info, transform_id='tid', instruction_id='instid')
+
+ self.assertEqual(
+ self.sampler.flush(),
+ [
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('Hello, World!'),
+ exception=beam_fn_api_pb2.SampledElement.Exception(
+ instruction_id='instid',
+ transform_id='tid',
+ error=err_string))
+ ])
+
+ def test_can_sample_multiple_exceptions(self):
+ """Tests that multiple exceptions in the same PCollection are sampled."""
+ exc_info = None
+ try:
+ raise Exception('test')
+ except Exception:
+ exc_info = sys.exc_info()
+ err_string = ''.join(traceback.format_exception(*exc_info))
+
+ self.sampler = OutputSampler(PRIMITIVES_CODER, sample_every_sec=0)
+ self.sampler.sample_exception(
+ el='a', exc_info=exc_info, transform_id='tid', instruction_id='instid')
+
+ self.sampler.sample_exception(
+ el='b', exc_info=exc_info, transform_id='tid', instruction_id='instid')
+
+ self.assertEqual(
+ self.sampler.flush(),
+ [
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('a'),
+ exception=beam_fn_api_pb2.SampledElement.Exception(
+ instruction_id='instid',
+ transform_id='tid',
+ error=err_string)),
+ beam_fn_api_pb2.SampledElement(
+ element=PRIMITIVES_CODER.encode_nested('b'),
+ exception=beam_fn_api_pb2.SampledElement.Exception(
+ instruction_id='instid',
+ transform_id='tid',
+ error=err_string)),
+ ])
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd
b/sdks/python/apache_beam/runners/worker/operations.pxd
index 725c5d2346a..bc8197ef84f 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -34,7 +34,9 @@ cdef class ConsumerSet(Receiver):
cdef public step_name
cdef public output_index
cdef public coder
+ cdef public object output_sampler
cdef public object element_sampler
+ cdef public object execution_context
cpdef update_counters_start(self, WindowedValue windowed_value)
cpdef update_counters_finish(self)
@@ -82,6 +84,8 @@ cdef class Operation(object):
cdef readonly object scoped_process_state
cdef readonly object scoped_finish_state
+ cdef readonly object data_sampler
+
cpdef start(self)
cpdef process(self, WindowedValue windowed_value)
cpdef finish(self)
diff --git a/sdks/python/apache_beam/runners/worker/operations.py
b/sdks/python/apache_beam/runners/worker/operations.py
index ca5e6552973..7eb14d29593 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -53,7 +53,7 @@ from apache_beam.runners.common import Receiver
from apache_beam.runners.worker import opcounters
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import sideinputs
-from apache_beam.runners.worker.data_sampler import ElementSampler
+from apache_beam.runners.worker.data_sampler import DataSampler
from apache_beam.transforms import sideinputs as apache_sideinputs
from apache_beam.transforms import combiners
from apache_beam.transforms import core
@@ -70,7 +70,6 @@ if TYPE_CHECKING:
from apache_beam.runners.sdf_utils import SplitResultPrimary
from apache_beam.runners.sdf_utils import SplitResultResidual
from apache_beam.runners.worker.bundle_processor import ExecutionContext
- from apache_beam.runners.worker.data_sampler import DataSampler
from apache_beam.runners.worker.data_sampler import OutputSampler
from apache_beam.runners.worker.statesampler import StateSampler
from apache_beam.transforms.userstate import TimerSpec
@@ -126,7 +125,7 @@ class ConsumerSet(Receiver):
coder,
producer_type_hints,
producer_batch_converter, # type: Optional[BatchConverter]
- element_sampler=None, # type: Optional[ElementSampler]
+ output_sampler=None, # type: Optional[OutputSampler]
):
# type: (...) -> ConsumerSet
if len(consumers) == 1:
@@ -144,7 +143,7 @@ class ConsumerSet(Receiver):
consumer,
coder,
producer_type_hints,
- element_sampler)
+ output_sampler)
return GeneralPurposeConsumerSet(
counter_factory,
@@ -154,7 +153,7 @@ class ConsumerSet(Receiver):
producer_type_hints,
consumers,
producer_batch_converter,
- element_sampler)
+ output_sampler)
def __init__(self,
counter_factory,
@@ -164,7 +163,7 @@ class ConsumerSet(Receiver):
coder,
producer_type_hints,
producer_batch_converter,
- element_sampler
+ output_sampler
):
self.opcounter = opcounters.OperationCounters(
counter_factory,
@@ -178,7 +177,10 @@ class ConsumerSet(Receiver):
self.output_index = output_index
self.coder = coder
self.consumers = consumers
- self.element_sampler = element_sampler
+ self.output_sampler = output_sampler
+ self.element_sampler = (
+ output_sampler.element_sampler if output_sampler else None)
+ self.execution_context = None # type: Optional[ExecutionContext]
def try_split(self, fraction_of_remainder):
# type: (...) -> Optional[Any]
@@ -211,9 +213,11 @@ class ConsumerSet(Receiver):
# between here and the DataSampler as an additional operation. The tradeoff
# is that some samples might be dropped, but it is better than the
# alternative which is double sampling the same element.
- if self.element_sampler is not None:
- self.element_sampler.el = windowed_value
- self.element_sampler.has_element = True
+ if self.element_sampler is not None and self.execution_context is not None:
+ self.execution_context.output_sampler = self.output_sampler
+ if not self.element_sampler.has_element:
+ self.element_sampler.el = windowed_value
+ self.element_sampler.has_element = True
def update_counters_finish(self):
# type: () -> None
@@ -242,7 +246,7 @@ class SingletonElementConsumerSet(ConsumerSet):
consumer, # type: Operation
coder,
producer_type_hints,
- element_sampler
+ output_sampler
):
super().__init__(
counter_factory,
@@ -251,7 +255,7 @@ class SingletonElementConsumerSet(ConsumerSet):
coder,
producer_type_hints,
None,
- element_sampler)
+ output_sampler)
self.consumer = consumer
def receive(self, windowed_value):
@@ -289,7 +293,7 @@ class GeneralPurposeConsumerSet(ConsumerSet):
producer_type_hints,
consumers, # type: List[Operation]
producer_batch_converter,
- element_sampler):
+ output_sampler):
super().__init__(
counter_factory,
step_name,
@@ -298,7 +302,7 @@ class GeneralPurposeConsumerSet(ConsumerSet):
coder,
producer_type_hints,
producer_batch_converter,
- element_sampler)
+ output_sampler)
self.producer_batch_converter = producer_batch_converter
@@ -453,6 +457,7 @@ class Operation(object):
# on the operation.
self.setup_done = False
self.step_name = None # type: Optional[str]
+ self.data_sampler: Optional[DataSampler] = None
def setup(self, data_sampler=None):
# type: (Optional[DataSampler]) -> None
@@ -461,6 +466,7 @@ class Operation(object):
This must be called before any other methods of the operation."""
with self.scoped_start_state:
+ self.data_sampler = data_sampler
self.debug_logging_enabled = logging.getLogger().isEnabledFor(
logging.DEBUG)
transform_id = self.name_context.transform_id
@@ -470,7 +476,7 @@ class Operation(object):
#TODO(pabloem): Define better what step name is used here.
if getattr(self.spec, 'output_coders', None):
- def get_element_sampler(output_num):
+ def get_output_sampler(output_num):
if data_sampler is None:
return None
return data_sampler.sampler_for_output(transform_id, output_num)
@@ -484,7 +490,7 @@ class Operation(object):
coder,
self._get_runtime_performance_hints(),
self.get_output_batch_converter(),
- get_element_sampler(i)) for i,
+ get_output_sampler(i)) for i,
coder in enumerate(self.spec.output_coders)
]
self.setup_done = True
@@ -495,7 +501,13 @@ class Operation(object):
"""Start operation."""
if not self.setup_done:
# For legacy workers.
- self.setup()
+ self.setup(self.data_sampler)
+
+ # The ExecutionContext is per instruction and so cannot be set at
+ # initialization time.
+ if self.data_sampler is not None:
+ for receiver in self.receivers:
+ receiver.execution_context = self.execution_context
def get_batching_preference(self):
# By default operations don't support batching, require Receiver to unbatch
@@ -908,6 +920,7 @@ class DoOperation(Operation):
step_name=self.name_context.logging_name(),
state=state,
user_state_context=self.user_state_context,
+ transform_id=self.name_context.transform_id,
operation_name=self.name_context.metrics_name())
self.dofn_runner.setup()
@@ -915,6 +928,7 @@ class DoOperation(Operation):
# type: () -> None
with self.scoped_start_state:
super(DoOperation, self).start()
+ self.dofn_runner.execution_context = self.execution_context
self.dofn_runner.start()
def get_batching_preference(self):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index f5b1456c251..bfd6544d802 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -380,18 +380,12 @@ class SdkHarness(object):
def get_samples(request):
# type: (beam_fn_api_pb2.InstructionRequest) ->
beam_fn_api_pb2.InstructionResponse
- samples: Dict[str, List[bytes]] = {}
+ samples = beam_fn_api_pb2.SampleDataResponse()
if self.data_sampler is not None:
samples =
self.data_sampler.samples(request.sample_data.pcollection_ids)
- sample_response = beam_fn_api_pb2.SampleDataResponse()
- for pcoll_id in samples:
- sample_response.element_samples[pcoll_id].elements.extend(
- beam_fn_api_pb2.SampledElement(element=s)
- for s in samples[pcoll_id])
-
return beam_fn_api_pb2.InstructionResponse(
- instruction_id=request.instruction_id, sample_data=sample_response)
+ instruction_id=request.instruction_id, sample_data=samples)
self._execute(lambda: get_samples(request), request)
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
index d12dc48ed29..ccfc27d0101 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -283,10 +283,19 @@ class SdkWorkerTest(unittest.TestCase):
class FakeDataSampler:
def samples(self, pcollection_ids):
- return {
- 'pcoll_id_1': [coder.encode_nested('a')],
- 'pcoll_id_2': [coder.encode_nested('b')],
- }
+ return beam_fn_api_pb2.SampleDataResponse(
+ element_samples={
+ 'pcoll_id_1': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=coder.encode_nested('a'))
+ ]),
+ 'pcoll_id_2': beam_fn_api_pb2.SampleDataResponse.ElementList(
+ elements=[
+ beam_fn_api_pb2.SampledElement(
+ element=coder.encode_nested('b'))
+ ])
+ })
def stop(self):
pass