lukecwik commented on a change in pull request #11448: [BEAM-3836] Enable 
dynamic splitting/checkpointing within the Java SDK harness.
URL: https://github.com/apache/beam/pull/11448#discussion_r410556926
 
 

 ##########
 File path: 
runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
 ##########
 @@ -1246,16 +1249,230 @@ public void process(ProcessContext c) {
               StateRequestHandler.unsupported(),
               BundleProgressHandler.ignored())) {
         Iterables.getOnlyElement(bundle.getInputReceivers().values())
-            .accept(
-                WindowedValue.valueInGlobalWindow(
-                    CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
+            
.accept(valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), 
"X")));
       }
     }
     assertThat(
         outputValues,
         containsInAnyOrder(
-            WindowedValue.valueInGlobalWindow(KV.of("stream1X", "")),
-            WindowedValue.valueInGlobalWindow(KV.of("stream2X", ""))));
+            valueInGlobalWindow(KV.of("stream1X", "")),
+            valueInGlobalWindow(KV.of("stream2X", ""))));
+  }
+
+  /**
+   * A restriction tracker that will block making progress on {@link 
#WAIT_TILL_SPLIT} until a try
+   * split is invoked.
+   */
+  private static class WaitingTillSplitRestrictionTracker extends 
RestrictionTracker<String, Void> {
+    private static final String WAIT_TILL_SPLIT = "WaitTillSplit";
+    private static final String PRIMARY = "Primary";
+    private static final String RESIDUAL = "Residual";
+
+    private String currentRestriction;
+
+    private WaitingTillSplitRestrictionTracker(String restriction) {
+      this.currentRestriction = restriction;
+    }
+
+    @Override
+    public boolean tryClaim(Void position) {
+      return needsSplitting();
+    }
+
+    @Override
+    public String currentRestriction() {
+      return currentRestriction;
+    }
+
+    @Override
+    public SplitResult<String> trySplit(double fractionOfRemainder) {
+      if (!needsSplitting()) {
+        return null;
+      }
+      this.currentRestriction = PRIMARY;
+      return SplitResult.of(currentRestriction, RESIDUAL);
+    }
+
+    private boolean needsSplitting() {
+      return WAIT_TILL_SPLIT.equals(currentRestriction);
+    }
+
+    @Override
+    public void checkDone() throws IllegalStateException {
+      checkState(!needsSplitting(), "Expected for this restriction to have 
been split.");
+    }
+  }
+
+  @Test(timeout = 60000L)
+  public void testSplit() throws Exception {
+    Pipeline p = Pipeline.create();
+    p.apply("impulse", Impulse.create())
+        .apply(
+            "create",
+            ParDo.of(
+                new DoFn<byte[], String>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {
+                    ctxt.output("zero");
+                    
ctxt.output(WaitingTillSplitRestrictionTracker.WAIT_TILL_SPLIT);
+                    ctxt.output("two");
+                  }
+                }))
+        .apply(
+            "forceSplit",
+            ParDo.of(
+                new DoFn<String, String>() {
+                  @GetInitialRestriction
+                  public String getInitialRestriction(@Element String element) 
{
+                    return element;
+                  }
+
+                  @NewTracker
+                  public WaitingTillSplitRestrictionTracker newTracker(
+                      @Restriction String restriction) {
+                    return new WaitingTillSplitRestrictionTracker(restriction);
+                  }
+
+                  @ProcessElement
+                  public void process(
+                      RestrictionTracker<String, Void> tracker, ProcessContext 
context) {
+                    while (tracker.tryClaim(null)) {}
+                    context.output(tracker.currentRestriction());
+                  }
+                }))
+        .apply("addKeys", WithKeys.of("foo"))
+        // Use some unknown coders
+        .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+        // Force the output to be materialized
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(p);
+    // Expand any splittable DoFns within the graph to enable sizing and 
splitting of bundles.
+    RunnerApi.Pipeline pipelineWithSdfExpanded =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            pipeline,
+            SplittableParDoExpander.createSizedReplacement());
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded);
+
+    // Find the fused stage with the SDF ProcessSizedElementAndRestriction 
transform
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(
+            fused.getFusedStages(),
+            (ExecutableStage stage) ->
+                Iterables.filter(
+                        stage.getTransforms(),
+                        (PTransformNode node) ->
+                            PTransformTranslation
+                                
.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
+                                
.equals(node.getTransform().getSpec().getUrn()))
+                    .iterator()
+                    .hasNext());
+    checkState(
+        optionalStage.isPresent(), "Expected a stage with SDF 
ProcessSizedElementAndRestriction.");
+    ExecutableStage stage = optionalStage.get();
+
+    ExecutableProcessBundleDescriptor descriptor =
+        ProcessBundleDescriptors.fromExecutableStage(
+            "my_stage", stage, dataServer.getApiServiceDescriptor());
+
+    BundleProcessor processor =
+        controlClient.getProcessor(
+            descriptor.getProcessBundleDescriptor(), 
descriptor.getRemoteInputDestinations());
+    Map<String, ? super Coder<WindowedValue<?>>> remoteOutputCoders =
+        descriptor.getRemoteOutputCoders();
+    Map<String, Collection<? super WindowedValue<?>>> outputValues = new 
HashMap<>();
+    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+    for (Entry<String, ? super Coder<WindowedValue<?>>> remoteOutputCoder :
+        remoteOutputCoders.entrySet()) {
+      List<? super WindowedValue<?>> outputContents =
+          Collections.synchronizedList(new ArrayList<>());
+      outputValues.put(remoteOutputCoder.getKey(), outputContents);
+      outputReceivers.put(
+          remoteOutputCoder.getKey(),
+          RemoteOutputReceiver.of(
+              (Coder) remoteOutputCoder.getValue(),
+              (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
+    }
+
+    List<ProcessBundleSplitResponse> splitResponses = new ArrayList<>();
+    List<ProcessBundleResponse> checkpointResponses = new ArrayList<>();
+    List<String> requestsFinalization = new ArrayList<>();
+
+    ScheduledExecutorService executor = 
Executors.newSingleThreadScheduledExecutor();
+    ScheduledFuture<Object> future;
+
+    // Execute the remote bundle.
+    try (RemoteBundle bundle =
+        processor.newBundle(
+            outputReceivers,
+            Collections.emptyMap(),
+            StateRequestHandler.unsupported(),
+            BundleProgressHandler.ignored(),
+            splitResponses::add,
+            checkpointResponses::add,
+            requestsFinalization::add)) {
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          
.accept(valueInGlobalWindow(sdfSizedElementAndRestrictionForTest("zero")));
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          .accept(
+              valueInGlobalWindow(
+                  sdfSizedElementAndRestrictionForTest(
+                      WaitingTillSplitRestrictionTracker.WAIT_TILL_SPLIT)));
+      Iterables.getOnlyElement(bundle.getInputReceivers().values())
+          
.accept(valueInGlobalWindow(sdfSizedElementAndRestrictionForTest("two")));
+      // Keep sending splits until the bundle terminates, we specifically use 
0.5 so that we will
+      // choose a split point before the end of WAIT_TILL_SPLIT regardless of 
where we are during
+      // processing.
+      future =
+          (ScheduledFuture)
+              executor.scheduleWithFixedDelay(
+                  () -> bundle.split(0.5), 0L, 100L, TimeUnit.MILLISECONDS);
+    }
+    future.cancel(false);
+    executor.shutdown();
+
+    assertTrue(requestsFinalization.isEmpty());
+    assertTrue(checkpointResponses.isEmpty());
+
+    List<WindowedValue<KV<String, String>>> expectedOutputs = new 
ArrayList<>();
+
+    // We only validate the last split response since it is the only one that 
could possibly
+    // contain the SDF split, all others will be a reduction in the 
ChannelSplit
+    assertFalse(splitResponses.isEmpty());
+    ProcessBundleSplitResponse splitResponse = 
splitResponses.get(splitResponses.size() - 1);
+    ChannelSplit channelSplit = 
Iterables.getOnlyElement(splitResponse.getChannelSplitsList());
+
+    // There are only a few outcomes that could happen with splitting due to 
timing:
 
 Review comment:
   Thats a good idea. Updated the PR to reflect it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to