This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 847efa5  [BEAM-9487] Add trigger safety check to GroupByKey
     new 99aa83d  Merge pull request #14857 from [BEAM-9487] Add trigger safety 
check to GroupByKey
847efa5 is described below

commit 847efa58e21d0b275c5b72e6b17908464a94223a
Author: zhoufek <[email protected]>
AuthorDate: Thu May 20 12:44:00 2021 -0400

    [BEAM-9487] Add trigger safety check to GroupByKey
---
 .../examples/complete/game/leader_board_it_test.py |   1 +
 .../examples/complete/game/leader_board_test.py    |   4 +-
 .../apache_beam/examples/snippets/snippets_test.py |   8 +-
 .../apache_beam/io/gcp/bigquery_file_loads_test.py |   5 +-
 sdks/python/apache_beam/io/gcp/bigquery_test.py    |   3 +-
 .../python/apache_beam/options/pipeline_options.py |   9 ++
 sdks/python/apache_beam/pipeline.py                |   5 +
 .../python/apache_beam/testing/test_stream_test.py |   2 +
 .../transforms/combinefn_lifecycle_test.py         |   6 +-
 sdks/python/apache_beam/transforms/core.py         |  17 ++-
 .../apache_beam/transforms/ptransform_test.py      |  30 +++++
 sdks/python/apache_beam/transforms/trigger.py      | 112 ++++++++++++++++++
 sdks/python/apache_beam/transforms/trigger_test.py | 128 ++++++++++++++++++++-
 13 files changed, 320 insertions(+), 10 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py 
b/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py
index afbaa18..8f5f91c 100644
--- a/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py
+++ b/sdks/python/apache_beam/examples/complete/game/leader_board_it_test.py
@@ -130,6 +130,7 @@ class LeaderBoardIT(unittest.TestCase):
         self.project, teams_query, self.DEFAULT_EXPECTED_CHECKSUM)
 
     extra_opts = {
+        'allow_unsafe_triggers': True,
         'subscription': self.input_sub.name,
         'dataset': self.dataset_ref.dataset_id,
         'topic': self.input_topic.name,
diff --git 
a/sdks/python/apache_beam/examples/complete/game/leader_board_test.py 
b/sdks/python/apache_beam/examples/complete/game/leader_board_test.py
index 167ce4a..1c1cd65 100644
--- a/sdks/python/apache_beam/examples/complete/game/leader_board_test.py
+++ b/sdks/python/apache_beam/examples/complete/game/leader_board_test.py
@@ -24,6 +24,7 @@ import unittest
 
 import apache_beam as beam
 from apache_beam.examples.complete.game import leader_board
+from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -59,7 +60,8 @@ class LeaderBoardTest(unittest.TestCase):
                     ('team3', 13)]))
 
   def test_leader_board_users(self):
-    with TestPipeline() as p:
+    test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
+    with TestPipeline(options=test_options) as p:
       result = (
           self.create_data(p)
           | leader_board.CalculateUserScores(allowed_lateness=120))
diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py 
b/sdks/python/apache_beam/examples/snippets/snippets_test.py
index fddb24c..5bc63ae 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets_test.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py
@@ -1057,8 +1057,8 @@ class SnippetsTest(unittest.TestCase):
       assert_that(counts, equal_to([('a', 4), ('b', 2), ('a', 1)]))
 
   def test_model_setting_trigger(self):
-    pipeline_options = PipelineOptions()
-    pipeline_options.view_as(StandardOptions).streaming = True
+    pipeline_options = PipelineOptions(
+        flags=['--streaming', '--allow_unsafe_triggers'])
 
     with TestPipeline(options=pipeline_options) as p:
       test_stream = (
@@ -1112,8 +1112,8 @@ class SnippetsTest(unittest.TestCase):
       assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 2), ('c', 2)]))
 
   def test_model_other_composite_triggers(self):
-    pipeline_options = PipelineOptions()
-    pipeline_options.view_as(StandardOptions).streaming = True
+    pipeline_options = PipelineOptions(
+        flags=['--streaming', '--allow_unsafe_triggers'])
 
     with TestPipeline(options=pipeline_options) as p:
       test_stream = (
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py 
b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
index ff1b50e..9eb59b5 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
@@ -41,6 +41,7 @@ from apache_beam.io.gcp import bigquery_tools
 from apache_beam.io.gcp.internal.clients import bigquery as bigquery_api
 from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
 from apache_beam.io.gcp.tests.bigquery_matcher import 
BigqueryFullResultStreamingMatcher
+from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.runners.dataflow.test_dataflow_runner import 
TestDataflowRunner
 from apache_beam.runners.runner import PipelineState
@@ -648,8 +649,10 @@ class TestBigQueryFileLoads(_TestCaseWithTempDirCleanUp):
         with_auto_sharding=with_auto_sharding)
 
     # Need to test this with the DirectRunner to avoid serializing mocks
+    test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
+    test_options.view_as(StandardOptions).streaming = is_streaming
     with TestPipeline(runner='BundleBasedDirectRunner',
-                      options=StandardOptions(streaming=is_streaming)) as p:
+                      options=test_options) as p:
       if is_streaming:
         _SIZE = len(_ELEMENTS)
         fisrt_batch = [
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py 
b/sdks/python/apache_beam/io/gcp/bigquery_test.py
index c3178da..41bbfe2 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -1322,7 +1322,8 @@ class PubSubBigQueryIT(unittest.TestCase):
     args = self.test_pipeline.get_full_options_as_args(
         on_success_matcher=hc.all_of(*matchers),
         wait_until_finish_duration=self.WAIT_UNTIL_FINISH_DURATION,
-        streaming=True)
+        streaming=True,
+        allow_unsafe_triggers=True)
 
     def add_schema_info(element):
       yield {'number': element}
diff --git a/sdks/python/apache_beam/options/pipeline_options.py 
b/sdks/python/apache_beam/options/pipeline_options.py
index 073014c..335cca8 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -532,6 +532,15 @@ class TypeOptions(PipelineOptions):
         'operations such as GropuByKey.  This is unsafe, as runners may group '
         'keys based on their encoded bytes, but is available for backwards '
         'compatibility. See BEAM-11719.')
+    parser.add_argument(
+        '--allow_unsafe_triggers',
+        default=False,
+        action='store_true',
+        help='Allow the use of unsafe triggers. Unsafe triggers have the '
+        'potential to cause data loss due to finishing and/or never having '
+        'their condition met. Some operations, such as GroupByKey, disallow '
+        'this. This exists for cases where such loss is acceptable and for '
+        'backwards compatibility. See BEAM-9487.')
 
   def validate(self, unused_validator):
     errors = []
diff --git a/sdks/python/apache_beam/pipeline.py 
b/sdks/python/apache_beam/pipeline.py
index 35d38a7..0c6ef15 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -243,6 +243,11 @@ class Pipeline(object):
     # type: () -> PipelineOptions
     return self._options
 
+  @property
+  def allow_unsafe_triggers(self):
+    # type: () -> bool
+    return self._options.view_as(TypeOptions).allow_unsafe_triggers
+
   def _current_transform(self):
     # type: () -> AppliedPTransform
 
diff --git a/sdks/python/apache_beam/testing/test_stream_test.py 
b/sdks/python/apache_beam/testing/test_stream_test.py
index ec6309d..94445dd 100644
--- a/sdks/python/apache_beam/testing/test_stream_test.py
+++ b/sdks/python/apache_beam/testing/test_stream_test.py
@@ -24,6 +24,7 @@ import unittest
 import apache_beam as beam
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.portability import common_urns
 from apache_beam.portability.api.beam_interactive_api_pb2 import 
TestStreamFileHeader
 from apache_beam.portability.api.beam_interactive_api_pb2 import 
TestStreamFileRecord
@@ -427,6 +428,7 @@ class TestStreamTest(unittest.TestCase):
 
     options = PipelineOptions()
     options.view_as(StandardOptions).streaming = True
+    options.view_as(TypeOptions).allow_unsafe_triggers = True
     p = TestPipeline(options=options)
     records = (
         p
diff --git a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py 
b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
index d3d177f..4e32324 100644
--- a/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
+++ b/sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
@@ -26,6 +26,7 @@ from nose.plugins.attrib import attr
 from parameterized import parameterized_class
 
 from apache_beam.options.pipeline_options import DebugOptions
+from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.runners.direct import direct_runner
 from apache_beam.runners.portability import fn_api_runner
@@ -88,7 +89,10 @@ class LocalCombineFnLifecycleTest(unittest.TestCase):
     self._assert_teardown_called()
 
   def test_non_liftable_combine(self):
-    run_combine(TestPipeline(runner=self.runner()), lift_combiners=False)
+    test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
+    run_combine(
+        TestPipeline(runner=self.runner(), options=test_options),
+        lift_combiners=False)
     self._assert_teardown_called()
 
   def test_combining_value_state(self):
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 54fd139..45ecedd 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -2314,15 +2314,30 @@ class GroupByKey(PTransform):
           key_type, typehints.WindowedValue[value_type]]]  # type: ignore[misc]
 
   def expand(self, pcoll):
+    from apache_beam.transforms.trigger import DataLossReason
     from apache_beam.transforms.trigger import DefaultTrigger
     windowing = pcoll.windowing
+    trigger = windowing.triggerfn
     if not pcoll.is_bounded and isinstance(
-        windowing.windowfn, GlobalWindows) and isinstance(windowing.triggerfn,
+        windowing.windowfn, GlobalWindows) and isinstance(trigger,
                                                           DefaultTrigger):
       raise ValueError(
           'GroupByKey cannot be applied to an unbounded ' +
           'PCollection with global windowing and a default trigger')
 
+    if not pcoll.pipeline.allow_unsafe_triggers:
+      unsafe_reason = trigger.may_lose_data(windowing)
+      if unsafe_reason != DataLossReason.NO_POTENTIAL_LOSS:
+        msg = 'Unsafe trigger: `{}` may lose data. '.format(trigger)
+        msg += 'Reason: {}. '.format(
+            str(unsafe_reason).replace('DataLossReason.', ''))
+        msg += 'This can be overriden with the --allow_unsafe_triggers flag.'
+        raise ValueError(msg)
+    else:
+      _LOGGER.warning(
+          'Skipping trigger safety check. '
+          'This could lead to incomplete or missing groups.')
+
     return pvalue.PCollection.from_(pcoll)
 
   def infer_output_type(self, input_type):
diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py 
b/sdks/python/apache_beam/transforms/ptransform_test.py
index 61e8a0a..a9b0486 100644
--- a/sdks/python/apache_beam/transforms/ptransform_test.py
+++ b/sdks/python/apache_beam/transforms/ptransform_test.py
@@ -41,13 +41,16 @@ import apache_beam.typehints as typehints
 from apache_beam.io.iobase import Read
 from apache_beam.metrics import Metrics
 from apache_beam.metrics.metric import MetricsFilter
+from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.portability import common_urns
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.testing.util import is_empty
 from apache_beam.transforms import WindowInto
+from apache_beam.transforms import trigger
 from apache_beam.transforms import window
 from apache_beam.transforms.display import DisplayData
 from apache_beam.transforms.display import DisplayDataItem
@@ -480,6 +483,33 @@ class PTransformTest(unittest.TestCase):
       with TestPipeline() as pipeline:
         pipeline | TestStream() | beam.GroupByKey()
 
+  def test_group_by_key_unsafe_trigger(self):
+    with self.assertRaisesRegex(ValueError, 'Unsafe trigger'):
+      with TestPipeline() as pipeline:
+        _ = (
+            pipeline
+            | beam.Create([(None, None)])
+            | WindowInto(
+                window.GlobalWindows(),
+                trigger=trigger.AfterCount(5),
+                accumulation_mode=trigger.AccumulationMode.ACCUMULATING)
+            | beam.GroupByKey())
+
+  def test_group_by_key_allow_unsafe_triggers(self):
+    test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
+    with TestPipeline(options=test_options) as pipeline:
+      pcoll = (
+          pipeline
+          | beam.Create([(1, 1), (1, 2), (1, 3), (1, 4)])
+          | WindowInto(
+              window.GlobalWindows(),
+              trigger=trigger.AfterCount(5),
+              accumulation_mode=trigger.AccumulationMode.ACCUMULATING)
+          | beam.GroupByKey())
+      # We need five, but it only has four - Displays how this option is
+      # dangerous.
+      assert_that(pcoll, is_empty())
+
   def test_group_by_key_reiteration(self):
     class MyDoFn(beam.DoFn):
       def process(self, gbk_result):
diff --git a/sdks/python/apache_beam/transforms/trigger.py 
b/sdks/python/apache_beam/transforms/trigger.py
index 5567895..6569d3f 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -28,7 +28,11 @@ import logging
 import numbers
 from abc import ABCMeta
 from abc import abstractmethod
+from enum import Flag
+from enum import auto
+from functools import reduce
 from itertools import zip_longest
+from operator import or_
 
 from apache_beam.coders import coder_impl
 from apache_beam.coders import observable
@@ -156,6 +160,13 @@ class _WatermarkHoldStateTag(_StateTag):
         prefix + self.tag, self.timestamp_combiner_impl)
 
 
+class DataLossReason(Flag):
+  """Enum defining potential reasons that a trigger may cause data loss."""
+  NO_POTENTIAL_LOSS = 0
+  MAY_FINISH = auto()
+  CONDITION_NOT_GUARANTEED = auto()
+
+
 # pylint: disable=unused-argument
 # TODO(robertwb): Provisional API, Java likely to change as well.
 class TriggerFn(metaclass=ABCMeta):
@@ -237,6 +248,43 @@ class TriggerFn(metaclass=ABCMeta):
     """Clear any state and timers used by this TriggerFn."""
     pass
 
+  @abstractmethod
+  def may_lose_data(self, windowing):
+    # type: (core.Windowing) -> DataLossReason
+
+    """Returns whether or not this trigger could cause data loss.
+
+    A trigger can cause data loss in the following scenarios:
+
+        * The trigger has a chance to finish. For instance, AfterWatermark()
+          without a late trigger would cause all late data to be lost. This
+          scenario is only accounted for if the windowing strategy allows
+          late data. Otherwise, the trigger is not responsible for the data
+          loss.
+        * The trigger condition may not be met. For instance,
+          Repeatedly(AfterCount(N)) may not fire due to N not being met. This
+          is only accounted for if the condition itself led to data loss.
+          Repeatedly(AfterCount(1)) is safe, since it would only not fire if
+          there is no data to lose, but Repeatedly(AfterCount(2)) can cause
+          data loss if there is only one record.
+
+    Note that this only returns the potential for loss. It does not mean that
+    there will be data loss. It also only accounts for loss related to the
+    trigger, not other potential causes.
+
+    Args:
+      windowing: The Windowing that this trigger belongs to. It does not need
+        to be the top-level trigger.
+
+    Returns:
+      The DataLossReason. If there is no potential loss,
+        DataLossReason.NO_POTENTIAL_LOSS is returned. Otherwise, all the
+        potential reasons are returned as a single value. For instance, if
+        data loss can result from finishing or not having the condition met,
+        the result will be DataLossReason.MAY_FINISH|CONDITION_NOT_GUARANTEED.
+    """
+    pass
+
 
 # pylint: enable=unused-argument
 
@@ -290,6 +338,9 @@ class DefaultTrigger(TriggerFn):
   def reset(self, window, context):
     context.clear_timer(str(window), TimeDomain.WATERMARK)
 
+  def may_lose_data(self, unused_windowing):
+    return DataLossReason.NO_POTENTIAL_LOSS
+
   def __eq__(self, other):
     return type(self) == type(other)
 
@@ -338,6 +389,9 @@ class AfterProcessingTime(TriggerFn):
   def reset(self, window, context):
     pass
 
+  def may_lose_data(self, unused_windowing):
+    return DataLossReason.MAY_FINISH
+
   @staticmethod
   def from_runner_api(proto, context):
     return AfterProcessingTime(
@@ -389,6 +443,9 @@ class Always(TriggerFn):
   def on_fire(self, watermark, window, context):
     return False
 
+  def may_lose_data(self, unused_windowing):
+    return DataLossReason.NO_POTENTIAL_LOSS
+
   @staticmethod
   def from_runner_api(proto, context):
     return Always()
@@ -433,6 +490,14 @@ class _Never(TriggerFn):
   def on_fire(self, watermark, window, context):
     return True
 
+  def may_lose_data(self, unused_windowing):
+    """No potential data loss.
+
+    Though Never doesn't explicitly trigger, it still collects data on
+    windowing closing, so any data loss is due to windowing closing.
+    """
+    return DataLossReason.NO_POTENTIAL_LOSS
+
   @staticmethod
   def from_runner_api(proto, context):
     return _Never()
@@ -454,6 +519,7 @@ class AfterWatermark(TriggerFn):
   LATE_TAG = _CombiningValueStateTag('is_late', any)
 
   def __init__(self, early=None, late=None):
+    # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly
     self.early = Repeatedly(early) if early else None
     self.late = Repeatedly(late) if late else None
 
@@ -524,6 +590,20 @@ class AfterWatermark(TriggerFn):
     if self.late:
       self.late.reset(window, NestedContext(context, 'late'))
 
+  def may_lose_data(self, windowing):
+    """May cause data loss if the windowing allows lateness and either:
+
+      * The late trigger is not set
+      * The late trigger may cause data loss.
+
+    The second case is equivalent to Repeatedly(late).may_lose_data(windowing)
+    """
+    if windowing.allowed_lateness == 0:
+      return DataLossReason.NO_POTENTIAL_LOSS
+    if self.late is None:
+      return DataLossReason.MAY_FINISH
+    return self.late.may_lose_data(windowing)
+
   def __eq__(self, other):
     return (
         type(self) == type(other) and self.early == other.early and
@@ -593,6 +673,12 @@ class AfterCount(TriggerFn):
   def reset(self, window, context):
     context.clear_state(self.COUNT_TAG)
 
+  def may_lose_data(self, unused_windowing):
+    reason = DataLossReason.MAY_FINISH
+    if self.count > 1:
+      reason |= DataLossReason.CONDITION_NOT_GUARANTEED
+    return reason
+
   @staticmethod
   def from_runner_api(proto, unused_context):
     return AfterCount(proto.element_count.element_count)
@@ -637,6 +723,17 @@ class Repeatedly(TriggerFn):
   def reset(self, window, context):
     self.underlying.reset(window, context)
 
+  def may_lose_data(self, windowing):
+    """Repeatedly may only lose data if the underlying trigger may not have
+    its condition met.
+
+    For underlying triggers that may finish, Repeatedly overrides that
+    behavior.
+    """
+    return (
+        self.underlying.may_lose_data(windowing)
+        & DataLossReason.CONDITION_NOT_GUARANTEED)
+
   @staticmethod
   def from_runner_api(proto, context):
     return Repeatedly(
@@ -742,6 +839,15 @@ class AfterAny(_ParallelTriggerFn):
   """
   combine_op = any
 
+  def may_lose_data(self, windowing):
+    reason = DataLossReason.NO_POTENTIAL_LOSS
+    for trigger in self.triggers:
+      t_reason = trigger.may_lose_data(windowing)
+      if t_reason == DataLossReason.NO_POTENTIAL_LOSS:
+        return t_reason
+      reason |= t_reason
+    return reason
+
 
 class AfterAll(_ParallelTriggerFn):
   """Fires when all subtriggers have fired.
@@ -750,6 +856,9 @@ class AfterAll(_ParallelTriggerFn):
   """
   combine_op = all
 
+  def may_lose_data(self, windowing):
+    return reduce(or_, (t.may_lose_data(windowing) for t in self.triggers))
+
 
 class AfterEach(TriggerFn):
 
@@ -805,6 +914,9 @@ class AfterEach(TriggerFn):
     for ix, trigger in enumerate(self.triggers):
       trigger.reset(window, self._sub_context(context, ix))
 
+  def may_lose_data(self, windowing):
+    return reduce(or_, (t.may_lose_data(windowing) for t in self.triggers))
+
   @staticmethod
   def _sub_context(context, index):
     return NestedContext(context, '%d/' % index)
diff --git a/sdks/python/apache_beam/transforms/trigger_test.py 
b/sdks/python/apache_beam/transforms/trigger_test.py
index a3dd438..9e1a569 100644
--- a/sdks/python/apache_beam/transforms/trigger_test.py
+++ b/sdks/python/apache_beam/transforms/trigger_test.py
@@ -39,6 +39,7 @@ from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import WindowInto
 from apache_beam.transforms import ptransform
 from apache_beam.transforms import trigger
 from apache_beam.transforms.core import Windowing
@@ -50,6 +51,7 @@ from apache_beam.transforms.trigger import AfterEach
 from apache_beam.transforms.trigger import AfterProcessingTime
 from apache_beam.transforms.trigger import AfterWatermark
 from apache_beam.transforms.trigger import Always
+from apache_beam.transforms.trigger import DataLossReason
 from apache_beam.transforms.trigger import DefaultTrigger
 from apache_beam.transforms.trigger import GeneralTriggerDriver
 from apache_beam.transforms.trigger import InMemoryUnmergedState
@@ -57,6 +59,7 @@ from apache_beam.transforms.trigger import Repeatedly
 from apache_beam.transforms.trigger import TriggerFn
 from apache_beam.transforms.trigger import _Never
 from apache_beam.transforms.window import FixedWindows
+from apache_beam.transforms.window import GlobalWindows
 from apache_beam.transforms.window import IntervalWindow
 from apache_beam.transforms.window import Sessions
 from apache_beam.transforms.window import TimestampCombiner
@@ -433,6 +436,128 @@ class TriggerTest(unittest.TestCase):
           pickle.loads(pickle.dumps(unwindowed)).value, list(range(10)))
 
 
+class MayLoseDataTest(unittest.TestCase):
+  def _test(self, trigger, lateness, expected):
+    windowing = WindowInto(
+        GlobalWindows(),
+        trigger=trigger,
+        accumulation_mode=AccumulationMode.ACCUMULATING,
+        allowed_lateness=lateness).windowing
+    self.assertEqual(trigger.may_lose_data(windowing), expected)
+
+  def test_default_trigger(self):
+    self._test(DefaultTrigger(), 0, DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_processing_time(self):
+    self._test(AfterProcessingTime(), 0, DataLossReason.MAY_FINISH)
+
+  def test_always(self):
+    self._test(Always(), 0, DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_never(self):
+    self._test(_Never(), 0, DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_no_allowed_lateness(self):
+    self._test(AfterWatermark(), 0, DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_late_none(self):
+    self._test(AfterWatermark(), 60, DataLossReason.MAY_FINISH)
+
+  def test_after_watermark_no_allowed_lateness_safe_late(self):
+    self._test(
+        AfterWatermark(late=DefaultTrigger()),
+        0,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_safe_late(self):
+    self._test(
+        AfterWatermark(late=DefaultTrigger()),
+        60,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_no_allowed_lateness_may_finish_late(self):
+    self._test(
+        AfterWatermark(late=AfterProcessingTime()),
+        0,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_may_finish_late(self):
+    self._test(
+        AfterWatermark(late=AfterProcessingTime()),
+        60,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_no_allowed_lateness_condition_late(self):
+    self._test(
+        AfterWatermark(late=AfterCount(5)), 0, 
DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_watermark_condition_late(self):
+    self._test(
+        AfterWatermark(late=AfterCount(5)),
+        60,
+        DataLossReason.CONDITION_NOT_GUARANTEED)
+
+  def test_after_count_one(self):
+    self._test(AfterCount(1), 0, DataLossReason.MAY_FINISH)
+
+  def test_after_count_gt_one(self):
+    self._test(
+        AfterCount(2),
+        0,
+        DataLossReason.MAY_FINISH | DataLossReason.CONDITION_NOT_GUARANTEED)
+
+  def test_repeatedly_safe_underlying(self):
+    self._test(
+        Repeatedly(DefaultTrigger()), 0, DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_repeatedly_may_finish_underlying(self):
+    self._test(Repeatedly(AfterCount(1)), 0, DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_repeatedly_condition_underlying(self):
+    self._test(
+        Repeatedly(AfterCount(2)), 0, DataLossReason.CONDITION_NOT_GUARANTEED)
+
+  def test_after_any_some_unsafe(self):
+    self._test(
+        AfterAny(AfterCount(1), DefaultTrigger()),
+        0,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_any_same_reason(self):
+    self._test(
+        AfterAny(AfterCount(1), AfterProcessingTime()),
+        0,
+        DataLossReason.MAY_FINISH)
+
+  def test_after_any_different_reasons(self):
+    self._test(
+        AfterAny(Repeatedly(AfterCount(2)), AfterProcessingTime()),
+        0,
+        DataLossReason.MAY_FINISH | DataLossReason.CONDITION_NOT_GUARANTEED)
+
+  def test_after_all_some_unsafe(self):
+    self._test(
+        AfterAll(AfterCount(1), DefaultTrigger()), 0, 
DataLossReason.MAY_FINISH)
+
+  def test_after_all_safe(self):
+    self._test(
+        AfterAll(Repeatedly(AfterCount(1)), DefaultTrigger()),
+        0,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+  def test_after_each_some_unsafe(self):
+    self._test(
+        AfterEach(AfterCount(1), DefaultTrigger()),
+        0,
+        DataLossReason.MAY_FINISH)
+
+  def test_after_each_all_safe(self):
+    self._test(
+        AfterEach(Repeatedly(AfterCount(1)), DefaultTrigger()),
+        0,
+        DataLossReason.NO_POTENTIAL_LOSS)
+
+
 class RunnerApiTest(unittest.TestCase):
   def test_trigger_encoding(self):
     for trigger_fn in (DefaultTrigger(),
@@ -451,7 +576,8 @@ class RunnerApiTest(unittest.TestCase):
 
 class TriggerPipelineTest(unittest.TestCase):
   def test_after_count(self):
-    with TestPipeline() as p:
+    test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
+    with TestPipeline(options=test_options) as p:
 
       def construct_timestamped(k_t):
         return TimestampedValue((k_t[0], k_t[1]), k_t[1])

Reply via email to