boyuanzz commented on a change in pull request #11922:
URL: https://github.com/apache/beam/pull/11922#discussion_r435521657



##########
File path: 
sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
##########
@@ -1593,6 +1601,391 @@ public void 
testProcessElementForSizedElementAndRestriction() throws Exception {
     assertEquals(stateData, fakeClient.getData());
   }
 
+  @Test
+  public void testProcessElementForWindowedSizedElementAndRestriction() throws 
Exception {
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+    PCollectionView<String> singletonSideInputView = 
valuePCollection.apply(View.asSingleton());
+    TestSplittableDoFn doFn = new TestSplittableDoFn(singletonSideInputView);
+
+    valuePCollection
+        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+        .apply(TEST_TRANSFORM_ID, 
ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+    RunnerApi.Pipeline pProto =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            PipelineTranslation.toProto(p, 
SdkComponents.create(p.getOptions()), true),
+            SplittableParDoExpander.createSizedReplacement());
+    String expandedTransformId =
+        Iterables.find(
+                pProto.getComponents().getTransformsMap().entrySet(),
+                entry ->
+                    entry
+                            .getValue()
+                            .getSpec()
+                            .getUrn()
+                            .equals(
+                                PTransformTranslation
+                                    
.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                        && 
entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+            .getKey();
+    RunnerApi.PTransform pTransform =
+        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+    String inputPCollectionId =
+        
pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+    RunnerApi.PCollection inputPCollection =
+        pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+    RehydratedComponents rehydratedComponents =
+        RehydratedComponents.forComponents(pProto.getComponents());
+    Coder<WindowedValue> inputCoder =
+        WindowedValue.getFullCoder(
+            CoderTranslation.fromProto(
+                
pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                rehydratedComponents,
+                TranslationContext.DEFAULT),
+            (Coder)
+                CoderTranslation.fromProto(
+                    pProto
+                        .getComponents()
+                        .getCodersOrThrow(
+                            pProto
+                                .getComponents()
+                                .getWindowingStrategiesOrThrow(
+                                    inputPCollection.getWindowingStrategyId())
+                                .getWindowCoderId()),
+                    rehydratedComponents,
+                    TranslationContext.DEFAULT));
+    String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+    ImmutableMap<StateKey, ByteString> stateData =
+        ImmutableMap.of(
+            
multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), 
ByteString.EMPTY),
+            encode("8"));
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    consumers.register(
+        outputPCollectionId,
+        TEST_TRANSFORM_ID,
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) 
mainOutputValues::add);
+    PTransformFunctionRegistry startFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), 
mock(ExecutionStateTracker.class), "start");
+    PTransformFunctionRegistry finishFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), 
mock(ExecutionStateTracker.class), "finish");
+    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+    List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
+    BundleSplitListener.InMemory splitListener = 
BundleSplitListener.InMemory.create();
+
+    new FnApiDoFnRunner.Factory<>()
+        .createRunnerForPTransform(
+            PipelineOptionsFactory.create(),
+            null /* beamFnDataClient */,
+            fakeClient,
+            null /* beamFnTimerClient */,
+            TEST_TRANSFORM_ID,
+            pTransform,
+            Suppliers.ofInstance("57L")::get,
+            pProto.getComponents().getPcollectionsMap(),
+            pProto.getComponents().getCodersMap(),
+            pProto.getComponents().getWindowingStrategiesMap(),
+            consumers,
+            startFunctionRegistry,
+            finishFunctionRegistry,
+            teardownFunctions::add,
+            progressRequestCallbacks::add,
+            splitListener,
+            null /* bundleFinalizer */);
+
+    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, 
outputPCollectionId));
+
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        consumers.getMultiplexingConsumer(inputPCollectionId);
+    assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+    BoundedWindow window1 = new IntervalWindow(new Instant(5), new 
Instant(10));
+    BoundedWindow window2 = new IntervalWindow(new Instant(6), new 
Instant(11));
+    {
+      // Check that before processing an element we don't report progress
+      
assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(),
 empty());
+      WindowedValue<?> firstValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(5, 10), 
GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+              window1,
+              window2);
+      mainInput.accept(firstValue);
+      // Check that after processing an element we don't report progress
+      
assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(),
 empty());
+
+      // Since the side input upperBound is 8 we will process 5, 6, and 7 then 
checkpoint.
+      // We expect that the watermark advances to MIN + 7 and that the primary 
represents [5, 8)
+      // with the original watermark while the residual represents [8, 10) 
with the new MIN + 7
+      // watermark.
+      assertEquals(2, splitListener.getPrimaryRoots().size());
+      assertEquals(2, splitListener.getResidualRoots().size());
+      for (int i = 0; i < splitListener.getPrimaryRoots().size(); ++i) {
+        BundleApplication primaryRoot = splitListener.getPrimaryRoots().get(i);
+        DelayedBundleApplication residualRoot = 
splitListener.getResidualRoots().get(i);
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), 
primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, 
residualRoot.getApplication().getTransformId());
+        Instant expectedOutputWatermark = 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+        assertEquals(
+            ImmutableMap.of(
+                "output",
+                
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                    .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                    .setNanos((int) (expectedOutputWatermark.getMillis() % 
1000) * 1000000)
+                    .build()),
+            residualRoot.getApplication().getOutputWatermarksMap());
+        assertEquals(
+            
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+                .setSeconds(54)
+                .setNanos(321000000)
+                .build(),
+            residualRoot.getRequestedTimeDelay());
+      }
+      assertThat(
+          Collections2.transform(
+              splitListener.getPrimaryRoots(), (root) -> decode(inputCoder, 
root.getElement())),
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(5, 8), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(5, 8), 
GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane())));
+      assertThat(
+          Collections2.transform(
+              splitListener.getResidualRoots(),
+              (root) -> decode(inputCoder, 
root.getApplication().getElement())),
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of(
+                          "5",
+                          KV.of(new OffsetRange(8, 10), 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of(
+                          "5",
+                          KV.of(new OffsetRange(8, 10), 
GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane())));
+      splitListener.clear();
+
+      // Check that before processing an element we don't report progress
+      
assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(),
 empty());
+      WindowedValue<?> secondValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("2", KV.of(new OffsetRange(0, 2), 
GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+              window1,
+              window2);
+      mainInput.accept(secondValue);
+      // Check that after processing an element we don't report progress
+      
assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(),
 empty());
+
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  "5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5), window1, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6), window1, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7), window1, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5), window2, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6), window2, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7), window2, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window2, 
firstValue.getPane()),
+              WindowedValue.of(
+                  "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window2, 
firstValue.getPane())));
+      assertTrue(splitListener.getPrimaryRoots().isEmpty());
+      assertTrue(splitListener.getResidualRoots().isEmpty());
+      mainOutputValues.clear();
+    }
+
+    {
+      // Setup and launch the trySplit thread.
+      ExecutorService executorService = Executors.newSingleThreadExecutor();
+      Future<HandlesSplits.SplitResult> trySplitFuture =
+          executorService.submit(
+              () -> {
+                try {
+                  doFn.waitForSplitElementToBeProcessed();
+                  // Currently processing "3" out of range [0, 5) elements.
+                  assertEquals(0.6, ((HandlesSplits) mainInput).getProgress(), 
0.01);
+
+                  // Check that during progressing of an element we report 
progress
+                  List<MonitoringInfo> mis =
+                      
Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
+                  MonitoringInfo.Builder expectedCompleted = 
MonitoringInfo.newBuilder();
+                  
expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
+                  
expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                  expectedCompleted.putLabels(
+                      MonitoringInfoConstants.Labels.PTRANSFORM, 
TEST_TRANSFORM_ID);
+                  expectedCompleted.setPayload(
+                      ByteString.copyFrom(
+                          CoderUtils.encodeToByteArray(
+                              IterableCoder.of(DoubleCoder.of()), 
Collections.singletonList(3.0))));
+                  MonitoringInfo.Builder expectedRemaining = 
MonitoringInfo.newBuilder();
+                  
expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
+                  
expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                  expectedRemaining.putLabels(
+                      MonitoringInfoConstants.Labels.PTRANSFORM, 
TEST_TRANSFORM_ID);
+                  expectedRemaining.setPayload(
+                      ByteString.copyFrom(
+                          CoderUtils.encodeToByteArray(
+                              IterableCoder.of(DoubleCoder.of()), 
Collections.singletonList(2.0))));
+                  assertThat(
+                      mis,
+                      containsInAnyOrder(expectedCompleted.build(), 
expectedRemaining.build()));
+
+                  return ((HandlesSplits) mainInput).trySplit(0);
+                } finally {
+                  doFn.releaseWaitingProcessElementThread();
+                }
+              });
+
+      // Check that before processing an element we don't report progress
+      
assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(),
 empty());
+      WindowedValue<?> splitValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 5), 
GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+              window1,
+              window2);
+      mainInput.accept(splitValue);
+      HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+      // Check that after processing an element we don't report progress
+      
assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(),
 empty());
+
+      // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split 
on the first window.

Review comment:
       I'm wondering why the split is on the first window?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to