Repository: beam
Updated Branches:
  refs/heads/master b2138b0d7 -> 2aa3b5c69


Add test to fix partial writouts after a bundle retry


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ce6a18c1
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ce6a18c1
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ce6a18c1

Branch: refs/heads/master
Commit: ce6a18c19f913ff8afe0f6cc8bb68f3406aaa9ec
Parents: b2138b0
Author: Maria Garcia Herrero <[email protected]>
Authored: Sun Sep 10 23:28:21 2017 -0700
Committer: [email protected] <[email protected]>
Committed: Tue Sep 26 12:53:00 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/pipeline_test.py        | 52 ++++++++++++++++++++
 .../runners/direct/evaluation_context.py        |  6 ++-
 2 files changed, 56 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/ce6a18c1/sdks/python/apache_beam/pipeline_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline_test.py 
b/sdks/python/apache_beam/pipeline_test.py
index 0917c78..9bbb0d7 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -20,6 +20,7 @@
 import logging
 import platform
 import unittest
+from collections import defaultdict
 
 import apache_beam as beam
 from apache_beam.io import Read
@@ -31,6 +32,9 @@ from apache_beam.pipeline import PTransformOverride
 from apache_beam.pvalue import AsSingleton
 from apache_beam.runners import DirectRunner
 from apache_beam.runners.dataflow.native_io.iobase import NativeSource
+from apache_beam.runners.direct.evaluation_context import _ExecutionContext
+from apache_beam.runners.direct.transform_evaluator import 
_GroupByKeyOnlyEvaluator
+from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -528,6 +532,54 @@ class DirectRunnerRetryTests(unittest.TestCase):
       p.run().wait_until_finish()
     assert count_b == count_c == 4
 
+  def test_no_partial_writeouts(self):
+
+    class TestTransformEvaluator(_TransformEvaluator):
+
+      def __init__(self):
+        self._execution_context = _ExecutionContext(None, {})
+
+      def start_bundle(self):
+        self.step_context = self._execution_context.get_step_context()
+
+      def process_element(self, element):
+        k, v = element
+        state = self.step_context.get_keyed_state(k)
+        state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
+
+    # Create instance and add key/value, key/value2
+    evaluator = TestTransformEvaluator()
+    evaluator.start_bundle()
+    self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key'))
+    self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key'))
+
+    evaluator.process_element(['key', 'value'])
+    self.assertEqual(
+        evaluator.step_context.existing_keyed_state['key'].state,
+        defaultdict(lambda: defaultdict(list)))
+    self.assertEqual(
+        evaluator.step_context.partial_keyed_state['key'].state,
+        {None: {'elements':['value']}})
+
+    evaluator.process_element(['key', 'value2'])
+    self.assertEqual(
+        evaluator.step_context.existing_keyed_state['key'].state,
+        defaultdict(lambda: defaultdict(list)))
+    self.assertEqual(
+        evaluator.step_context.partial_keyed_state['key'].state,
+        {None: {'elements':['value', 'value2']}})
+
+    # Simulate an exception (redo key/value)
+    evaluator._execution_context.reset()
+    evaluator.start_bundle()
+    evaluator.process_element(['key', 'value'])
+    self.assertEqual(
+        evaluator.step_context.existing_keyed_state['key'].state,
+        defaultdict(lambda: defaultdict(list)))
+    self.assertEqual(
+        evaluator.step_context.partial_keyed_state['key'].state,
+        {None: {'elements':['value']}})
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.DEBUG)

http://git-wip-us.apache.org/repos/asf/beam/blob/ce6a18c1/sdks/python/apache_beam/runners/direct/evaluation_context.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/evaluation_context.py 
b/sdks/python/apache_beam/runners/direct/evaluation_context.py
index 2e8b33b..abb2dc4 100644
--- a/sdks/python/apache_beam/runners/direct/evaluation_context.py
+++ b/sdks/python/apache_beam/runners/direct/evaluation_context.py
@@ -44,6 +44,9 @@ class _ExecutionContext(object):
       self._step_context = DirectStepContext(self.keyed_states)
     return self._step_context
 
+  def reset(self):
+    self._step_context = None
+
 
 class _SideInputView(object):
 
@@ -335,6 +338,5 @@ class DirectStepContext(object):
     if not self.existing_keyed_state.get(key):
       self.existing_keyed_state[key] = DirectUnmergedState()
     if not self.partial_keyed_state.get(key):
-      self.partial_keyed_state[key] = (
-          self.existing_keyed_state[key].copy())
+      self.partial_keyed_state[key] = self.existing_keyed_state[key].copy()
     return self.partial_keyed_state[key]

Reply via email to