This is an automated email from the ASF dual-hosted git repository.
amaliujia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9c9903d [BEAM-10212] Integrate caching client (#15214)
9c9903d is described below
commit 9c9903d50b59a6ca956b9d43809dc26c490cb849
Author: anthonyqzhu <[email protected]>
AuthorDate: Fri Aug 6 13:03:28 2021 -0400
[BEAM-10212] Integrate caching client (#15214)
* [BEAM-10212] Add state cache to ProcessBundleHandler
---
.../fnexecution/control/RemoteExecutionTest.java | 407 +++++++++++++++++++++
.../fn/harness/control/ProcessBundleHandler.java | 52 ++-
2 files changed, 451 insertions(+), 8 deletions(-)
diff --git
a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 5acb87d..f238601 100644
---
a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++
b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -42,6 +42,7 @@ import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.UUID;
+import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
@@ -55,6 +56,7 @@ import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.beam.fn.harness.FnHarness;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse;
@@ -68,6 +70,7 @@ import
org.apache.beam.runners.core.construction.graph.FusedPipeline;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import
org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode;
import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SideInputReference;
import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
import org.apache.beam.runners.core.metrics.DistributionData;
import org.apache.beam.runners.core.metrics.ExecutionStateSampler;
@@ -532,6 +535,160 @@ public class RemoteExecutionTest implements Serializable {
}
}
+ @Test
+ public void testExecutionWithSideInputCaching() throws Exception {
+ Pipeline p = Pipeline.create();
+ addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+ // TODO(BEAM-10097): Remove experiment once all portable runners support
this view type
+ addExperiment(p.getOptions().as(ExperimentalOptions.class),
"use_runner_v2");
+ PCollection<String> input =
+ p.apply("impulse", Impulse.create())
+ .apply(
+ "create",
+ ParDo.of(
+ new DoFn<byte[], String>() {
+ @ProcessElement
+ public void process(ProcessContext ctxt) {
+ ctxt.output("zero");
+ ctxt.output("one");
+ ctxt.output("two");
+ }
+ }))
+ .setCoder(StringUtf8Coder.of());
+ PCollectionView<Iterable<String>> view = input.apply("createSideInput",
View.asIterable());
+
+ input
+ .apply(
+ "readSideInput",
+ ParDo.of(
+ new DoFn<String, KV<String, String>>() {
+ @ProcessElement
+ public void processElement(ProcessContext context) {
+ for (String value : context.sideInput(view)) {
+ context.output(KV.of(context.element(), value));
+ }
+ }
+ })
+ .withSideInputs(view))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+ // Force the output to be materialized
+ .apply("gbk", GroupByKey.create());
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+ Optional<ExecutableStage> optionalStage =
+ Iterables.tryFind(
+ fused.getFusedStages(), (ExecutableStage stage) ->
!stage.getSideInputs().isEmpty());
+ checkState(optionalStage.isPresent(), "Expected a stage with side
inputs.");
+ ExecutableStage stage = optionalStage.get();
+
+ ExecutableProcessBundleDescriptor descriptor =
+ ProcessBundleDescriptors.fromExecutableStage(
+ "test_stage",
+ stage,
+ dataServer.getApiServiceDescriptor(),
+ stateServer.getApiServiceDescriptor());
+
+ BundleProcessor processor =
+ controlClient.getProcessor(
+ descriptor.getProcessBundleDescriptor(),
+ descriptor.getRemoteInputDestinations(),
+ stateDelegator);
+ Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
+ Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
+ Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+ for (Entry<String, Coder> remoteOutputCoder :
remoteOutputCoders.entrySet()) {
+ List<WindowedValue<?>> outputContents = Collections.synchronizedList(new
ArrayList<>());
+ outputValues.put(remoteOutputCoder.getKey(), outputContents);
+ outputReceivers.put(
+ remoteOutputCoder.getKey(),
+ RemoteOutputReceiver.of(
+ (Coder<WindowedValue<?>>) remoteOutputCoder.getValue(),
outputContents::add));
+ }
+
+ StoringStateRequestHandler stateRequestHandler =
+ new StoringStateRequestHandler(
+ StateRequestHandlers.forSideInputHandlerFactory(
+ descriptor.getSideInputSpecs(),
+ new SideInputHandlerFactory() {
+ @Override
+ public <V, W extends BoundedWindow>
+ IterableSideInputHandler<V, W> forIterableSideInput(
+ String pTransformId,
+ String sideInputId,
+ Coder<V> elementCoder,
+ Coder<W> windowCoder) {
+ return new IterableSideInputHandler<V, W>() {
+ @Override
+ public Iterable<V> get(W window) {
+ return (Iterable) Arrays.asList("A", "B", "C");
+ }
+
+ @Override
+ public Coder<V> elementCoder() {
+ return elementCoder;
+ }
+ };
+ }
+
+ @Override
+ public <K, V, W extends BoundedWindow>
+ MultimapSideInputHandler<K, V, W> forMultimapSideInput(
+ String pTransformId,
+ String sideInputId,
+ KvCoder<K, V> elementCoder,
+ Coder<W> windowCoder) {
+ throw new UnsupportedOperationException();
+ }
+ }));
+ SideInputReference sideInputReference =
stage.getSideInputs().iterator().next();
+ String transformId = sideInputReference.transform().getId();
+ String sideInputId = sideInputReference.localName();
+ stateRequestHandler.addCacheToken(
+ BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder()
+ .setSideInput(
+
BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder()
+ .setSideInputId(sideInputId)
+ .setTransformId(transformId)
+ .build())
+ .setToken(ByteString.copyFromUtf8("SideInputToken"))
+ .build());
+ BundleProgressHandler progressHandler = BundleProgressHandler.ignored();
+
+ try (RemoteBundle bundle =
+ processor.newBundle(outputReceivers, stateRequestHandler,
progressHandler)) {
+ Iterables.getOnlyElement(bundle.getInputReceivers().values())
+ .accept(valueInGlobalWindow("X"));
+ }
+
+ try (RemoteBundle bundle =
+ processor.newBundle(outputReceivers, stateRequestHandler,
progressHandler)) {
+ Iterables.getOnlyElement(bundle.getInputReceivers().values())
+ .accept(valueInGlobalWindow("X"));
+ }
+ for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
+ assertThat(
+ windowedValues,
+ containsInAnyOrder(
+ valueInGlobalWindow(KV.of("X", "A")),
+ valueInGlobalWindow(KV.of("X", "B")),
+ valueInGlobalWindow(KV.of("X", "C")),
+ valueInGlobalWindow(KV.of("X", "A")),
+ valueInGlobalWindow(KV.of("X", "B")),
+ valueInGlobalWindow(KV.of("X", "C"))));
+ }
+
+ // Only expect one read to the sideInput
+ assertEquals(1, stateRequestHandler.receivedRequests.size());
+ BeamFnApi.StateRequest receivedRequest =
stateRequestHandler.receivedRequests.get(0);
+ assertEquals(
+ receivedRequest.getStateKey().getIterableSideInput(),
+ BeamFnApi.StateKey.IterableSideInput.newBuilder()
+ .setSideInputId(sideInputId)
+ .setTransformId(transformId)
+ .build());
+ }
+
/**
* A {@link DoFn} that uses static maps of {@link CountDownLatch}es to block
execution allowing
* for synchronization during test execution. The expected flow is:
@@ -1041,6 +1198,256 @@ public class RemoteExecutionTest implements
Serializable {
}
@Test
+ public void testExecutionWithUserStateCaching() throws Exception {
+ Pipeline p = Pipeline.create();
+ final String stateId = "foo";
+ final String stateId2 = "bar";
+
+ p.apply("impulse", Impulse.create())
+ .apply(
+ "create",
+ ParDo.of(
+ new DoFn<byte[], KV<String, String>>() {
+ @ProcessElement
+ public void process(ProcessContext ctxt) {}
+ }))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+ .apply(
+ "userState",
+ ParDo.of(
+ new DoFn<KV<String, String>, KV<String, String>>() {
+
+ @StateId(stateId)
+ private final StateSpec<BagState<String>> bufferState =
+ StateSpecs.bag(StringUtf8Coder.of());
+
+ @StateId(stateId2)
+ private final StateSpec<BagState<String>> bufferState2 =
+ StateSpecs.bag(StringUtf8Coder.of());
+
+ @ProcessElement
+ public void processElement(
+ @Element KV<String, String> element,
+ @StateId(stateId) BagState<String> state,
+ @StateId(stateId2) BagState<String> state2,
+ OutputReceiver<KV<String, String>> r) {
+ for (String value : state.read()) {
+ r.output(KV.of(element.getKey(), value));
+ }
+ ReadableState<Boolean> isEmpty = state2.isEmpty();
+ if (isEmpty.read()) {
+ r.output(KV.of(element.getKey(), "Empty"));
+ } else {
+ state2.clear();
+ }
+ }
+ }))
+ // Force the output to be materialized
+ .apply("gbk", GroupByKey.create());
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+ Optional<ExecutableStage> optionalStage =
+ Iterables.tryFind(
+ fused.getFusedStages(), (ExecutableStage stage) ->
!stage.getUserStates().isEmpty());
+ checkState(optionalStage.isPresent(), "Expected a stage with user state.");
+ ExecutableStage stage = optionalStage.get();
+
+ ExecutableProcessBundleDescriptor descriptor =
+ ProcessBundleDescriptors.fromExecutableStage(
+ "test_stage",
+ stage,
+ dataServer.getApiServiceDescriptor(),
+ stateServer.getApiServiceDescriptor());
+
+ BundleProcessor processor =
+ controlClient.getProcessor(
+ descriptor.getProcessBundleDescriptor(),
+ descriptor.getRemoteInputDestinations(),
+ stateDelegator);
+ Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
+ Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
+ Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+ for (Entry<String, Coder> remoteOutputCoder :
remoteOutputCoders.entrySet()) {
+ List<WindowedValue<?>> outputContents = Collections.synchronizedList(new
ArrayList<>());
+ outputValues.put(remoteOutputCoder.getKey(), outputContents);
+ outputReceivers.put(
+ remoteOutputCoder.getKey(),
+ RemoteOutputReceiver.of(
+ (Coder<WindowedValue<?>>) remoteOutputCoder.getValue(),
outputContents::add));
+ }
+
+ Map<String, List<ByteString>> userStateData =
+ ImmutableMap.of(
+ stateId,
+ new ArrayList(
+ Arrays.asList(
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(
+ StringUtf8Coder.of(), "A", Coder.Context.NESTED)),
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(
+ StringUtf8Coder.of(), "B", Coder.Context.NESTED)),
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(
+ StringUtf8Coder.of(), "C",
Coder.Context.NESTED)))),
+ stateId2,
+ new ArrayList(
+ Arrays.asList(
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(
+ StringUtf8Coder.of(), "D",
Coder.Context.NESTED)))));
+
+ StoringStateRequestHandler stateRequestHandler =
+ new StoringStateRequestHandler(
+ StateRequestHandlers.forBagUserStateHandlerFactory(
+ descriptor,
+ new BagUserStateHandlerFactory<ByteString, Object,
BoundedWindow>() {
+ @Override
+ public BagUserStateHandler<ByteString, Object,
BoundedWindow> forUserState(
+ String pTransformId,
+ String userStateId,
+ Coder<ByteString> keyCoder,
+ Coder<Object> valueCoder,
+ Coder<BoundedWindow> windowCoder) {
+ return new BagUserStateHandler<ByteString, Object,
BoundedWindow>() {
+ @Override
+ public Iterable<Object> get(ByteString key,
BoundedWindow window) {
+ return (Iterable) userStateData.get(userStateId);
+ }
+
+ @Override
+ public void append(
+ ByteString key, BoundedWindow window,
Iterator<Object> values) {
+ Iterators.addAll(userStateData.get(userStateId),
(Iterator) values);
+ }
+
+ @Override
+ public void clear(ByteString key, BoundedWindow window) {
+ userStateData.get(userStateId).clear();
+ }
+ };
+ }
+ }));
+
+ try (RemoteBundle bundle =
+ processor.newBundle(
+ outputReceivers, stateRequestHandler,
BundleProgressHandler.ignored())) {
+ Iterables.getOnlyElement(bundle.getInputReceivers().values())
+ .accept(valueInGlobalWindow(KV.of("X", "Y")));
+ }
+ try (RemoteBundle bundle2 =
+ processor.newBundle(
+ outputReceivers, stateRequestHandler,
BundleProgressHandler.ignored())) {
+ Iterables.getOnlyElement(bundle2.getInputReceivers().values())
+ .accept(valueInGlobalWindow(KV.of("X", "Z")));
+ }
+ for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
+ assertThat(
+ windowedValues,
+ containsInAnyOrder(
+ valueInGlobalWindow(KV.of("X", "A")),
+ valueInGlobalWindow(KV.of("X", "B")),
+ valueInGlobalWindow(KV.of("X", "C")),
+ valueInGlobalWindow(KV.of("X", "A")),
+ valueInGlobalWindow(KV.of("X", "B")),
+ valueInGlobalWindow(KV.of("X", "C")),
+ valueInGlobalWindow(KV.of("X", "Empty"))));
+ }
+ assertThat(
+ userStateData.get(stateId),
+ IsIterableContainingInOrder.contains(
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A",
Coder.Context.NESTED)),
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B",
Coder.Context.NESTED)),
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C",
Coder.Context.NESTED))));
+ assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
+
+ // 3 Requests expected: state read, state2 read, and state2 clear
+ assertEquals(3, stateRequestHandler.getRequestCount());
+ ByteString.Output out = ByteString.newOutput();
+ StringUtf8Coder.of().encode("X", out);
+
+ assertEquals(
+ stateId,
+ stateRequestHandler
+ .receivedRequests
+ .get(0)
+ .getStateKey()
+ .getBagUserState()
+ .getUserStateId());
+ assertEquals(
+
stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(),
+ out.toByteString());
+ assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
+
+ assertEquals(
+ stateId2,
+ stateRequestHandler
+ .receivedRequests
+ .get(1)
+ .getStateKey()
+ .getBagUserState()
+ .getUserStateId());
+ assertEquals(
+
stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(),
+ out.toByteString());
+ assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
+
+ assertEquals(
+ stateId2,
+ stateRequestHandler
+ .receivedRequests
+ .get(2)
+ .getStateKey()
+ .getBagUserState()
+ .getUserStateId());
+ assertEquals(
+
stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(),
+ out.toByteString());
+ assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
+ }
+
+ /**
+ * A state handler that stores each state request made - used to validate
that cached requests are
+ * not forwarded to the state client.
+ */
+ private static class StoringStateRequestHandler implements
StateRequestHandler {
+
+ private StateRequestHandler stateRequestHandler;
+ private ArrayList<BeamFnApi.StateRequest> receivedRequests;
+ private ArrayList<BeamFnApi.ProcessBundleRequest.CacheToken> cacheTokens;
+
+ StoringStateRequestHandler(StateRequestHandler delegate) {
+ stateRequestHandler = delegate;
+ receivedRequests = new ArrayList<>();
+ cacheTokens = new ArrayList<>();
+ }
+
+ @Override
+ public CompletionStage<BeamFnApi.StateResponse.Builder>
handle(BeamFnApi.StateRequest request)
+ throws Exception {
+ receivedRequests.add(request);
+ return stateRequestHandler.handle(request);
+ }
+
+ @Override
+ public Iterable<BeamFnApi.ProcessBundleRequest.CacheToken>
getCacheTokens() {
+ return Iterables.concat(stateRequestHandler.getCacheTokens(),
cacheTokens);
+ }
+
+ public int getRequestCount() {
+ return receivedRequests.size();
+ }
+
+ public void addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken token)
{
+ cacheTokens.add(token);
+ }
+ }
+
+ @Test
public void testExecutionWithTimer() throws Exception {
Pipeline p = Pipeline.create();
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index b40c9d5..e1d6ad4 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -24,6 +24,7 @@ import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -50,6 +51,7 @@ import
org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
import org.apache.beam.fn.harness.logging.BeamFnLoggingMDC;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
+import org.apache.beam.fn.harness.state.CachingBeamFnStateClient;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
@@ -74,6 +76,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
import org.apache.beam.sdk.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
@@ -140,10 +143,29 @@ public class ProcessBundleHandler {
REGISTERED_RUNNER_FACTORIES = builder.build();
}
+ // Creates a new map of state data for newly encountered state keys
+ private CacheLoader<
+ BeamFnApi.StateKey,
+ Map<CachingBeamFnStateClient.StateCacheKey,
BeamFnApi.StateGetResponse>>
+ stateKeyMapCacheLoader =
+ new CacheLoader<
+ BeamFnApi.StateKey,
+ Map<CachingBeamFnStateClient.StateCacheKey,
BeamFnApi.StateGetResponse>>() {
+ @Override
+ public Map<CachingBeamFnStateClient.StateCacheKey,
BeamFnApi.StateGetResponse> load(
+ BeamFnApi.StateKey key) {
+ return new HashMap<>();
+ }
+ };
+
private final PipelineOptions options;
private final Function<String, Message> fnApiRegistry;
private final BeamFnDataClient beamFnDataClient;
private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache;
+ private final LoadingCache<
+ BeamFnApi.StateKey,
+ Map<CachingBeamFnStateClient.StateCacheKey,
BeamFnApi.StateGetResponse>>
+ stateCache;
private final FinalizeBundleHandler finalizeBundleHandler;
private final ShortIdMap shortIds;
private final boolean runnerAcceptsShortIds;
@@ -186,6 +208,7 @@ public class ProcessBundleHandler {
this.fnApiRegistry = fnApiRegistry;
this.beamFnDataClient = beamFnDataClient;
this.beamFnStateGrpcClientCache = beamFnStateGrpcClientCache;
+ this.stateCache = CacheBuilder.newBuilder().build(stateKeyMapCacheLoader);
this.finalizeBundleHandler = finalizeBundleHandler;
this.shortIds = shortIds;
this.runnerAcceptsShortIds =
@@ -491,14 +514,27 @@ public class ProcessBundleHandler {
}
}
- // Instantiate a State API call handler depending on whether a State
ApiServiceDescriptor
- // was specified.
- HandleStateCallsForBundle beamFnStateClient =
- bundleDescriptor.hasStateApiServiceDescriptor()
- ? new BlockTillStateCallsFinish(
- beamFnStateGrpcClientCache.forApiServiceDescriptor(
- bundleDescriptor.getStateApiServiceDescriptor()))
- : new FailAllStateCallsForBundle(processBundleRequest);
+ // Instantiate a State API call handler depending on whether a State
ApiServiceDescriptor was
+ // specified.
+ HandleStateCallsForBundle beamFnStateClient;
+ if (bundleDescriptor.hasStateApiServiceDescriptor()) {
+ BeamFnStateClient underlyingClient =
+ beamFnStateGrpcClientCache.forApiServiceDescriptor(
+ bundleDescriptor.getStateApiServiceDescriptor());
+
+ // If pipeline is batch, use a CachingBeamFnStateClient to store state
responses.
+ // Once streaming is supported, always use CachingBeamFnStateClient as
the arg
+ // to BlockTillStateCallsFinish
+ beamFnStateClient =
+ new BlockTillStateCallsFinish(
+ options.as(StreamingOptions.class).isStreaming()
+ ? underlyingClient
+ : new CachingBeamFnStateClient(
+ underlyingClient, stateCache,
processBundleRequest.getCacheTokensList()));
+
+ } else {
+ beamFnStateClient = new FailAllStateCallsForBundle(processBundleRequest);
+ }
// Instantiate a Timer client registration handler depending on whether a
Timer
// ApiServiceDescriptor was specified.