This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c8d56c  [BEAM-2926] Add support for side inputs to the runner harness.
6c8d56c is described below

commit 6c8d56cacfcf8a44f6c8c029706905eb2748b44a
Author: Luke Cwik <[email protected]>
AuthorDate: Wed Jan 31 09:58:57 2018 -0800

    [BEAM-2926] Add support for side inputs to the runner harness.
---
 .../construction/PCollectionViewTranslation.java   |  12 +-
 .../core/construction/ParDoTranslation.java        |   4 +-
 .../PCollectionViewTranslationTest.java            |  74 +++
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 527 +++++++++++++--------
 .../apache/beam/fn/harness/state/BagUserState.java |  66 +--
 .../state/LazyCachingIteratorToIterable.java       |  17 +
 .../beam/fn/harness/state/MultimapSideInput.java   |  85 ++++
 .../fn/harness/state/StateFetchingIterators.java   |  28 +-
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       | 273 ++++++++++-
 .../beam/fn/harness/state/BagUserStateTest.java    |  59 ++-
 .../state/LazyCachingIteratorToIterableTest.java   |  14 +
 .../fn/harness/state/MultimapSideInputTest.java    |  73 +++
 .../harness/state/StateFetchingIteratorsTest.java  |   2 +-
 13 files changed, 957 insertions(+), 277 deletions(-)

diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
index 25361ed..ade7229 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PCollectionViewTranslation.java
@@ -73,7 +73,11 @@ public class PCollectionViewTranslation {
     return view;
   }
 
-  private static ViewFn<?, ?> viewFnFromProto(RunnerApi.SdkFunctionSpec viewFn)
+  /**
+   * Converts a {@link 
org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec} into
+   * a {@link ViewFn} using the URN.
+   */
+  public static ViewFn<?, ?> viewFnFromProto(RunnerApi.SdkFunctionSpec viewFn)
       throws InvalidProtocolBufferException {
     RunnerApi.FunctionSpec spec = viewFn.getSpec();
     checkArgument(
@@ -86,7 +90,11 @@ public class PCollectionViewTranslation {
             spec.getPayload().toByteArray(), "Custom ViewFn");
   }
 
-  private static WindowMappingFn<?> windowMappingFnFromProto(
+  /**
+   * Converts a {@link 
org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec} into
+   * a {@link WindowMappingFn} using the URN.
+   */
+  public static WindowMappingFn<?> windowMappingFnFromProto(
       RunnerApi.SdkFunctionSpec windowMappingFn)
       throws InvalidProtocolBufferException {
     RunnerApi.FunctionSpec spec = windowMappingFn.getSpec();
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index a9b3c56..6365d77 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -498,7 +498,7 @@ public class ParDoTranslation {
     return builder.build();
   }
 
-  private static SdkFunctionSpec translateViewFn(ViewFn<?, ?> viewFn, 
SdkComponents components) {
+  public static SdkFunctionSpec translateViewFn(ViewFn<?, ?> viewFn, 
SdkComponents components) {
     return SdkFunctionSpec.newBuilder()
         
.setEnvironmentId(components.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT))
         .setSpec(
@@ -526,7 +526,7 @@ public class ParDoTranslation {
     return payload.getSplittable();
   }
 
-  private static SdkFunctionSpec translateWindowMappingFn(
+  public static SdkFunctionSpec translateWindowMappingFn(
       WindowMappingFn<?> windowMappingFn, SdkComponents components) {
     return SdkFunctionSpec.newBuilder()
         
.setEnvironmentId(components.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT))
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java
new file mode 100644
index 0000000..85156a9
--- /dev/null
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.beam.sdk.transforms.Materialization;
+import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link PCollectionViewTranslation}.
+ */
+@RunWith(JUnit4.class)
+public class PCollectionViewTranslationTest {
+  @Test
+  public void testViewFnTranslation() throws Exception {
+    assertEquals(new TestViewFn(),
+        PCollectionViewTranslation.viewFnFromProto(
+            ParDoTranslation.translateViewFn(new TestViewFn(),
+                SdkComponents.create())));
+  }
+
+  @Test
+  public void testWindowMappingFnTranslation() throws Exception {
+    assertEquals(new GlobalWindows().getDefaultWindowMappingFn(),
+        PCollectionViewTranslation.windowMappingFnFromProto(
+            ParDoTranslation.translateWindowMappingFn(
+                new GlobalWindows().getDefaultWindowMappingFn(),
+                SdkComponents.create())));
+  }
+
+  /** Test implementation to check for equality. */
+  private static class TestViewFn extends ViewFn<Object, Object> {
+    @Override
+    public Materialization<Object> getMaterialization() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public Object apply(Object o) {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      return obj instanceof TestViewFn;
+    }
+
+    @Override
+    public int hashCode() {
+      return 0;
+    }
+  }
+}
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 721207a..cf3a227 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -17,17 +17,20 @@
  */
 package org.apache.beam.fn.harness;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.service.AutoService;
-import com.google.common.base.Suppliers;
+import com.google.auto.value.AutoValue;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableListMultimap;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.ListMultimap;
 import com.google.common.collect.Multimap;
+import com.google.common.collect.Sets;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.InvalidProtocolBufferException;
 import java.io.IOException;
@@ -36,6 +39,7 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
@@ -43,14 +47,14 @@ import org.apache.beam.fn.harness.data.BeamFnDataClient;
 import org.apache.beam.fn.harness.fn.ThrowingRunnable;
 import org.apache.beam.fn.harness.state.BagUserState;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
+import org.apache.beam.fn.harness.state.MultimapSideInput;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
 import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.construction.PCollectionViewTranslation;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.RehydratedComponents;
@@ -77,6 +81,8 @@ import 
org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFn.OnTimerContext;
 import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
+import org.apache.beam.sdk.transforms.Materializations;
+import org.apache.beam.sdk.transforms.ViewFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
@@ -86,6 +92,7 @@ import 
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.DoFnInfo;
 import org.apache.beam.sdk.util.SerializableUtils;
@@ -167,10 +174,16 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
               (Collection<FnDataReceiver<WindowedValue<OutputT>>>)
                   (Collection) tagToOutputMap.get(doFnInfo.getMainOutput()),
               tagToOutputMap,
+              ImmutableMap.of(),
               doFnInfo.getWindowingStrategy());
 
       registerHandlers(
-          runner, pTransform, addStartFunction, addFinishFunction, 
pCollectionIdsToConsumers);
+          runner,
+          pTransform,
+          ImmutableSet.of(),
+          addStartFunction,
+          addFinishFunction,
+          pCollectionIdsToConsumers);
       return runner;
     }
   }
@@ -198,23 +211,53 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
       Coder<InputT> inputCoder;
       WindowingStrategy<InputT, ?> windowingStrategy;
 
+      ImmutableMap.Builder<TupleTag<?>, SideInputSpec> tagToSideInputSpecMap =
+          ImmutableMap.builder();
+      ParDoPayload parDoPayload;
       try {
         RehydratedComponents rehydratedComponents = 
RehydratedComponents.forComponents(
             RunnerApi.Components.newBuilder()
                 
.putAllCoders(coders).putAllWindowingStrategies(windowingStrategies).build());
-        ParDoPayload parDoPayload = 
ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
-        if (parDoPayload.getSideInputsCount() != 0) {
-          throw new UnsupportedOperationException("Side inputs not yet 
supported.");
-        }
+        parDoPayload = 
ParDoPayload.parseFrom(pTransform.getSpec().getPayload());
         doFn = (DoFn) ParDoTranslation.getDoFn(parDoPayload);
         mainOutputTag = (TupleTag) 
ParDoTranslation.getMainOutputTag(parDoPayload);
-        // There will only be one due to the check above.
-        RunnerApi.PCollection mainInput = pCollections.get(
-            Iterables.getOnlyElement(pTransform.getInputsMap().values()));
-        inputCoder = (Coder<InputT>) 
rehydratedComponents.getCoder(mainInput.getCoderId());
-        windowingStrategy =
-            (WindowingStrategy)
-                
rehydratedComponents.getWindowingStrategy(mainInput.getWindowingStrategyId());
+        String mainInputTag = Iterables.getOnlyElement(Sets.difference(
+            pTransform.getInputsMap().keySet(), 
parDoPayload.getSideInputsMap().keySet()));
+        RunnerApi.PCollection mainInput =
+            pCollections.get(pTransform.getInputsOrThrow(mainInputTag));
+        inputCoder = (Coder<InputT>) rehydratedComponents.getCoder(
+            mainInput.getCoderId());
+        windowingStrategy = (WindowingStrategy) 
rehydratedComponents.getWindowingStrategy(
+            mainInput.getWindowingStrategyId());
+
+        // Build the map from tag id to side input specification
+        for (Map.Entry<String, RunnerApi.SideInput> entry
+            : parDoPayload.getSideInputsMap().entrySet()) {
+          String sideInputTag = entry.getKey();
+          RunnerApi.SideInput sideInput = entry.getValue();
+          checkArgument(
+              Materializations.MULTIMAP_MATERIALIZATION_URN.equals(
+                  sideInput.getAccessPattern().getUrn()),
+              "This SDK is only capable of dealing with %s materializations "
+                  + "but was asked to handle %s for PCollectionView with tag 
%s.",
+              Materializations.MULTIMAP_MATERIALIZATION_URN,
+              sideInput.getAccessPattern().getUrn(),
+              sideInputTag);
+
+          RunnerApi.PCollection sideInputPCollection =
+              pCollections.get(pTransform.getInputsOrThrow(sideInputTag));
+          WindowingStrategy sideInputWindowingStrategy =
+              rehydratedComponents.getWindowingStrategy(
+                  sideInputPCollection.getWindowingStrategyId());
+          tagToSideInputSpecMap.put(
+              new TupleTag<>(entry.getKey()),
+              SideInputSpec.create(
+                  
rehydratedComponents.getCoder(sideInputPCollection.getCoderId()),
+                  sideInputWindowingStrategy.getWindowFn().windowCoder(),
+                  
PCollectionViewTranslation.viewFnFromProto(entry.getValue().getViewFn()),
+                  PCollectionViewTranslation.windowMappingFnFromProto(
+                      entry.getValue().getWindowMappingFn())));
+        }
       } catch (InvalidProtocolBufferException exn) {
         throw new IllegalArgumentException("Malformed ParDoPayload", exn);
       } catch (IOException exn) {
@@ -241,9 +284,15 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
           (Collection<FnDataReceiver<WindowedValue<OutputT>>>) (Collection)
               tagToConsumer.get(mainOutputTag),
           tagToConsumer,
+          tagToSideInputSpecMap.build(),
           windowingStrategy);
       registerHandlers(
-          runner, pTransform, addStartFunction, addFinishFunction, 
pCollectionIdsToConsumers);
+          runner,
+          pTransform,
+          parDoPayload.getSideInputsMap().keySet(),
+          addStartFunction,
+          addFinishFunction,
+          pCollectionIdsToConsumers);
       return runner;
     }
   }
@@ -251,14 +300,16 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
   private static <InputT, OutputT> void registerHandlers(
       DoFnRunner<InputT, OutputT> runner,
       RunnerApi.PTransform pTransform,
+      Set<String> sideInputLocalNames,
       Consumer<ThrowingRunnable> addStartFunction,
       Consumer<ThrowingRunnable> addFinishFunction,
       Multimap<String, FnDataReceiver<WindowedValue<?>>> 
pCollectionIdsToConsumers) {
     // Register the appropriate handlers.
     addStartFunction.accept(runner::startBundle);
-    for (String pcollectionId : pTransform.getInputsMap().values()) {
+    for (String localInputName
+        : Sets.difference(pTransform.getInputsMap().keySet(), 
sideInputLocalNames)) {
       pCollectionIdsToConsumers.put(
-          pcollectionId,
+          pTransform.getInputsOrThrow(localInputName),
           (FnDataReceiver) (FnDataReceiver<WindowedValue<InputT>>) 
runner::processElement);
     }
     addFinishFunction.accept(runner::finishBundle);
@@ -274,6 +325,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
   private final Coder<InputT> inputCoder;
   private final Collection<FnDataReceiver<WindowedValue<OutputT>>> 
mainOutputConsumers;
   private final Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> 
outputMap;
+  private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
+  private final Map<StateKey, Object> stateKeyObjectCache;
   private final WindowingStrategy windowingStrategy;
   private final DoFnSignature doFnSignature;
   private final DoFnInvoker<InputT, OutputT> doFnInvoker;
@@ -296,12 +349,16 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
   private BoundedWindow currentWindow;
 
   /**
-   * This member should only be accessed indirectly by calling
-   * {@link #createOrUseCachedBagUserStateKey} and is only valid during {@link 
#processElement}
-   * and is null otherwise.
+   * The lifetime of this member is only valid during {@link #processElement}
+   * and only when processing a {@link KV} and is null otherwise.
    */
-  private StateKey.BagUserState cachedPartialBagUserStateKey;
+  private ByteString encodedCurrentKey;
 
+  /**
+   * The lifetime of this member is only valid during {@link #processElement}
+   * and is null otherwise.
+   */
+  private ByteString encodedCurrentWindow;
 
   FnApiDoFnRunner(
       PipelineOptions pipelineOptions,
@@ -312,6 +369,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
       Coder<InputT> inputCoder,
       Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers,
       Multimap<TupleTag<?>, FnDataReceiver<WindowedValue<?>>> outputMap,
+      Map<TupleTag<?>, SideInputSpec> sideInputSpecMap,
       WindowingStrategy windowingStrategy) {
     this.pipelineOptions = pipelineOptions;
     this.beamFnStateClient = beamFnStateClient;
@@ -321,6 +379,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
     this.inputCoder = inputCoder;
     this.mainOutputConsumers = mainOutputConsumers;
     this.outputMap = outputMap;
+    this.sideInputSpecMap = sideInputSpecMap;
+    this.stateKeyObjectCache = new HashMap<>();
     this.windowingStrategy = windowingStrategy;
     this.doFnSignature = DoFnSignatures.signatureForDoFn(doFn);
     this.doFnInvoker = DoFnInvokers.invokerFor(doFn);
@@ -349,7 +409,8 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
     } finally {
       currentElement = null;
       currentWindow = null;
-      cachedPartialBagUserStateKey = null;
+      encodedCurrentKey = null;
+      encodedCurrentWindow = null;
     }
   }
 
@@ -377,6 +438,9 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
     } catch (Exception e) {
       throw new IllegalStateException(e);
     }
+
+    // TODO: Support caching state data across bundle boundaries.
+    stateKeyObjectCache.clear();
   }
 
   /**
@@ -592,7 +656,7 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
 
     @Override
     public <T> T sideInput(PCollectionView<T> view) {
-      throw new UnsupportedOperationException("TODO: Support side inputs");
+      return bindSideInputView(view.getTagInternal());
     }
 
     @Override
@@ -705,87 +769,85 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
    * {@link #bindWatermark} should never be implemented.
    */
   private class BeamFnStateBinder implements StateBinder {
-    private final Map<StateKey.BagUserState, Object> stateObjectCache = new 
HashMap<>();
-
     @Override
     public <T> ValueState<T> bindValue(String id, StateSpec<ValueState<T>> 
spec, Coder<T> coder) {
-      return (ValueState<T>) stateObjectCache.computeIfAbsent(
-          createOrUseCachedBagUserStateKey(id),
-          new Function<StateKey.BagUserState, Object>() {
-        @Override
-        public Object apply(StateKey.BagUserState s) {
-          return new ValueState<T>() {
-            private final BagUserState<T> impl = createBagUserState(id, coder);
-
+      return (ValueState<T>) stateKeyObjectCache.computeIfAbsent(
+          createBagUserStateKey(id),
+          new Function<StateKey, Object>() {
             @Override
-            public void clear() {
-              impl.clear();
-            }
+            public Object apply(StateKey key) {
+              return new ValueState<T>() {
+                private final BagUserState<T> impl = createBagUserState(id, 
coder);
 
-            @Override
-            public void write(T input) {
-              impl.clear();
-              impl.append(input);
-            }
+                @Override
+                public void clear() {
+                  impl.clear();
+                }
 
-            @Override
-            public T read() {
-              Iterator<T> value = impl.get().iterator();
-              if (value.hasNext()) {
-                return value.next();
-              } else {
-                return null;
-              }
-            }
+                @Override
+                public void write(T input) {
+                  impl.clear();
+                  impl.append(input);
+                }
 
-            @Override
-            public ValueState<T> readLater() {
-              // TODO: Support prefetching.
-              return this;
+                @Override
+                public T read() {
+                  Iterator<T> value = impl.get().iterator();
+                  if (value.hasNext()) {
+                    return value.next();
+                  } else {
+                    return null;
+                  }
+                }
+
+                @Override
+                public ValueState<T> readLater() {
+                  // TODO: Support prefetching.
+                  return this;
+                }
+              };
             }
-          };
-        }
-      });
+          });
     }
 
     @Override
     public <T> BagState<T> bindBag(String id, StateSpec<BagState<T>> spec, 
Coder<T> elemCoder) {
-      return (BagState<T>) stateObjectCache.computeIfAbsent(
-          createOrUseCachedBagUserStateKey(id),
-          new Function<StateKey.BagUserState, Object>() {
-        @Override
-        public Object apply(StateKey.BagUserState s) {
-          return new BagState<T>() {
-            private final BagUserState<T> impl = createBagUserState(id, 
elemCoder);
-
+      return (BagState<T>) stateKeyObjectCache.computeIfAbsent(
+          createBagUserStateKey(id),
+          new Function<StateKey, Object>() {
             @Override
-            public void add(T value) {
-              impl.append(value);
-            }
+            public Object apply(StateKey key) {
+              return new BagState<T>() {
+                private final BagUserState<T> impl = createBagUserState(id, 
elemCoder);
 
-            @Override
-            public ReadableState<Boolean> isEmpty() {
-              return 
ReadableStates.immediate(!impl.get().iterator().hasNext());
-            }
+                @Override
+                public void add(T value) {
+                  impl.append(value);
+                }
 
-            @Override
-            public Iterable<T> read() {
-              return impl.get();
-            }
+                @Override
+                public ReadableState<Boolean> isEmpty() {
+                  return 
ReadableStates.immediate(!impl.get().iterator().hasNext());
+                }
 
-            @Override
-            public BagState<T> readLater() {
-              // TODO: Support prefetching.
-              return this;
-            }
+                @Override
+                public Iterable<T> read() {
+                  return impl.get();
+                }
 
-            @Override
-            public void clear() {
-              impl.clear();
+                @Override
+                public BagState<T> readLater() {
+                  // TODO: Support prefetching.
+                  return this;
+                }
+
+                @Override
+                public void clear() {
+                  impl.clear();
+                }
+              };
             }
-          };
-        }
-      });
+          });
     }
 
     @Override
@@ -805,77 +867,77 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
         String id,
         StateSpec<CombiningState<InputT, AccumT, OutputT>> spec, Coder<AccumT> 
accumCoder,
         CombineFn<InputT, AccumT, OutputT> combineFn) {
-      return (CombiningState<InputT, AccumT, OutputT>) 
stateObjectCache.computeIfAbsent(
-          createOrUseCachedBagUserStateKey(id),
-          new Function<StateKey.BagUserState, Object>() {
-        @Override
-        public Object apply(StateKey.BagUserState s) {
-          // TODO: Support squashing accumulators depending on whether we know 
of all
-          // remote accumulators and local accumulators or just local 
accumulators.
-          return new CombiningState<InputT, AccumT, OutputT>() {
-            private final BagUserState<AccumT> impl = createBagUserState(id, 
accumCoder);
-
+      return (CombiningState<InputT, AccumT, OutputT>) 
stateKeyObjectCache.computeIfAbsent(
+          createBagUserStateKey(id),
+          new Function<StateKey, Object>() {
             @Override
-            public AccumT getAccum() {
-              Iterator<AccumT> iterator = impl.get().iterator();
-              if (iterator.hasNext()) {
-                return iterator.next();
-              }
-              return combineFn.createAccumulator();
-            }
+            public Object apply(StateKey key) {
+              // TODO: Support squashing accumulators depending on whether we 
know of all
+              // remote accumulators and local accumulators or just local 
accumulators.
+              return new CombiningState<InputT, AccumT, OutputT>() {
+                private final BagUserState<AccumT> impl = 
createBagUserState(id, accumCoder);
 
-            @Override
-            public void addAccum(AccumT accum) {
-              Iterator<AccumT> iterator = impl.get().iterator();
+                @Override
+                public AccumT getAccum() {
+                  Iterator<AccumT> iterator = impl.get().iterator();
+                  if (iterator.hasNext()) {
+                    return iterator.next();
+                  }
+                  return combineFn.createAccumulator();
+                }
 
-              // Only merge if there was a prior value
-              if (iterator.hasNext()) {
-                accum = 
combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
-                // Since there was a prior value, we need to clear.
-                impl.clear();
-              }
+                @Override
+                public void addAccum(AccumT accum) {
+                  Iterator<AccumT> iterator = impl.get().iterator();
 
-              impl.append(accum);
-            }
+                  // Only merge if there was a prior value
+                  if (iterator.hasNext()) {
+                    accum = 
combineFn.mergeAccumulators(ImmutableList.of(iterator.next(), accum));
+                    // Since there was a prior value, we need to clear.
+                    impl.clear();
+                  }
 
-            @Override
-            public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
-              return combineFn.mergeAccumulators(accumulators);
-            }
+                  impl.append(accum);
+                }
 
-            @Override
-            public CombiningState<InputT, AccumT, OutputT> readLater() {
-              return this;
-            }
+                @Override
+                public AccumT mergeAccumulators(Iterable<AccumT> accumulators) 
{
+                  return combineFn.mergeAccumulators(accumulators);
+                }
 
-            @Override
-            public OutputT read() {
-              Iterator<AccumT> iterator = impl.get().iterator();
-              if (iterator.hasNext()) {
-                return combineFn.extractOutput(iterator.next());
-              }
-              return combineFn.defaultValue();
-            }
+                @Override
+                public CombiningState<InputT, AccumT, OutputT> readLater() {
+                  return this;
+                }
 
-            @Override
-            public void add(InputT value) {
-              AccumT newAccumulator = combineFn.addInput(getAccum(), value);
-              impl.clear();
-              impl.append(newAccumulator);
-            }
+                @Override
+                public OutputT read() {
+                  Iterator<AccumT> iterator = impl.get().iterator();
+                  if (iterator.hasNext()) {
+                    return combineFn.extractOutput(iterator.next());
+                  }
+                  return combineFn.defaultValue();
+                }
 
-            @Override
-            public ReadableState<Boolean> isEmpty() {
-              return 
ReadableStates.immediate(!impl.get().iterator().hasNext());
-            }
+                @Override
+                public void add(InputT value) {
+                  AccumT newAccumulator = combineFn.addInput(getAccum(), 
value);
+                  impl.clear();
+                  impl.append(newAccumulator);
+                }
 
-            @Override
-            public void clear() {
-              impl.clear();
+                @Override
+                public ReadableState<Boolean> isEmpty() {
+                  return 
ReadableStates.immediate(!impl.get().iterator().hasNext());
+                }
+
+                @Override
+                public void clear() {
+                  impl.clear();
+                }
+              };
             }
-          };
-        }
-      });
+          });
     }
 
     @Override
@@ -885,32 +947,25 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
         StateSpec<CombiningState<InputT, AccumT, OutputT>> spec,
         Coder<AccumT> accumCoder,
         CombineFnWithContext<InputT, AccumT, OutputT> combineFn) {
-      return (CombiningState<InputT, AccumT, OutputT>)
-          stateObjectCache.computeIfAbsent(
-              createOrUseCachedBagUserStateKey(id),
-              s ->
-                  bindCombining(
-                      id,
-                      spec,
-                      accumCoder,
-                      CombineFnUtil.bindContext(
-                          combineFn,
-                          new StateContext<BoundedWindow>() {
-                            @Override
-                            public PipelineOptions getPipelineOptions() {
-                              return pipelineOptions;
-                            }
-
-                            @Override
-                            public <T> T sideInput(PCollectionView<T> view) {
-                              return processBundleContext.sideInput(view);
-                            }
-
-                            @Override
-                            public BoundedWindow window() {
-                              return currentWindow;
-                            }
-                          })));
+      return (CombiningState<InputT, AccumT, OutputT>) 
stateKeyObjectCache.computeIfAbsent(
+          createBagUserStateKey(id),
+          key -> bindCombining(id, spec, accumCoder, 
CombineFnUtil.bindContext(combineFn,
+              new StateContext<BoundedWindow>() {
+                @Override
+                public PipelineOptions getPipelineOptions() {
+                  return pipelineOptions;
+                }
+
+                @Override
+                public <T> T sideInput(PCollectionView<T> view) {
+                  return processBundleContext.sideInput(view);
+                }
+
+                @Override
+                public BoundedWindow window() {
+                  return currentWindow;
+                }
+              })));
     }
 
     /**
@@ -924,37 +979,41 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
       throw new UnsupportedOperationException("WatermarkHoldState is 
unsupported by the Fn API.");
     }
 
-    private <T> BagUserState<T> createBagUserState(String id, Coder<T> coder) {
-      BagUserState rval =
-          new BagUserState<>(
-              beamFnStateClient,
-              id,
-              coder,
-              new Supplier<StateRequest.Builder>() {
-                /** Memoizes the partial state key for the lifetime of the 
{@link BagUserState}. */
-                private final Supplier<StateKey.BagUserState> 
memoizingSupplier =
-                    Suppliers.memoize(() -> 
createOrUseCachedBagUserStateKey(id))::get;
-
-                @Override
-                public Builder get() {
-                  return StateRequest.newBuilder()
-                      
.setInstructionReference(processBundleInstructionId.get())
-                      
.setStateKey(StateKey.newBuilder().setBagUserState(memoizingSupplier.get()));
-                }
-              });
+    private <T> BagUserState<T> createBagUserState(
+        String stateId, Coder<T> valueCoder) {
+      BagUserState rval = new BagUserState<T>(
+          beamFnStateClient,
+          processBundleInstructionId.get(),
+          ptransformId,
+          stateId,
+          encodedCurrentWindow,
+          encodedCurrentKey,
+          valueCoder);
       stateFinalizers.add(rval::asyncClose);
       return rval;
     }
   }
 
+  private StateKey createBagUserStateKey(String stateId) {
+    cacheEncodedKeyAndWindowForKeyedContext();
+    StateKey.Builder builder = StateKey.newBuilder();
+    builder.getBagUserStateBuilder()
+        .setWindow(encodedCurrentWindow)
+        .setKey(encodedCurrentKey)
+        .setPtransformId(ptransformId)
+        .setUserStateId(stateId);
+    return builder.build();
+  }
+
   /**
-   * Memoizes a partially built {@link StateKey} saving on the encoding cost 
of the key and
-   * window across multiple state cells for the lifetime of {@link 
#processElement}.
+   * Memoizes an encoded key and window for the current element being 
processed saving on the
+   * encoding cost of the key and window across multiple state cells for the 
lifetime of
+   * {@link #processElement}.
    *
    * <p>This should only be called during {@link #processElement}.
    */
-  private <K> StateKey.BagUserState createOrUseCachedBagUserStateKey(String 
id) {
-    if (cachedPartialBagUserStateKey == null) {
+  private <K> void cacheEncodedKeyAndWindowForKeyedContext() {
+    if (encodedCurrentKey == null) {
       checkState(currentElement.getValue() instanceof KV,
           "Accessing state in unkeyed context. Current element is not a KV: 
%s.",
           currentElement);
@@ -976,19 +1035,85 @@ public class FnApiDoFnRunner<InputT, OutputT> implements 
DoFnRunner<InputT, Outp
       } catch (IOException e) {
         throw new IllegalStateException(e);
       }
+      encodedCurrentKey = encodedKeyOut.toByteString();
+    }
 
+    if (encodedCurrentWindow == null) {
       ByteString.Output encodedWindowOut = ByteString.newOutput();
       try {
         windowingStrategy.getWindowFn().windowCoder().encode(currentWindow, 
encodedWindowOut);
       } catch (IOException e) {
         throw new IllegalStateException(e);
       }
+      encodedCurrentWindow = encodedWindowOut.toByteString();
+    }
+  }
 
-      cachedPartialBagUserStateKey = StateKey.BagUserState.newBuilder()
-          .setPtransformId(ptransformId)
-          .setKey(encodedKeyOut.toByteString())
-          .setWindow(encodedWindowOut.toByteString()).buildPartial();
+  /**
+   * A specification for side inputs containing a value {@link Coder},
+   * the window {@link Coder}, {@link ViewFn}, and the {@link WindowMappingFn}.
+   * @param <W>
+   */
+  @AutoValue
+  abstract static class SideInputSpec<W extends BoundedWindow> {
+    static <W extends BoundedWindow> SideInputSpec create(
+        Coder<?> coder,
+        Coder<W> windowCoder,
+        ViewFn<?, ?> viewFn,
+        WindowMappingFn<W> windowMappingFn) {
+      return new AutoValue_FnApiDoFnRunner_SideInputSpec<>(
+          coder, windowCoder, viewFn, windowMappingFn);
+    }
+
+    abstract Coder<?> getCoder();
+
+    abstract Coder<W> getWindowCoder();
+
+    abstract ViewFn<?, ?> getViewFn();
+
+    abstract WindowMappingFn<W> getWindowMappingFn();
+  }
+
+  private <T, K, V> T bindSideInputView(TupleTag<?> view) {
+    SideInputSpec sideInputSpec = sideInputSpecMap.get(view);
+    checkArgument(sideInputSpec != null,
+        "Attempting to access unknown side input %s.",
+        view);
+    KvCoder<K, V> kvCoder = (KvCoder) sideInputSpec.getCoder();
+
+    ByteString.Output encodedWindowOut = ByteString.newOutput();
+    try {
+      sideInputSpec.getWindowCoder().encode(
+          
sideInputSpec.getWindowMappingFn().getSideInputWindow(currentWindow), 
encodedWindowOut);
+    } catch (IOException e) {
+      throw new IllegalStateException(e);
     }
-    return cachedPartialBagUserStateKey.toBuilder().setUserStateId(id).build();
+    ByteString encodedWindow = encodedWindowOut.toByteString();
+
+    StateKey.Builder cacheKeyBuilder = StateKey.newBuilder();
+    cacheKeyBuilder.getMultimapSideInputBuilder()
+        .setPtransformId(ptransformId)
+        .setSideInputId(view.getId())
+        .setWindow(encodedWindow);
+    return (T) stateKeyObjectCache.computeIfAbsent(
+        cacheKeyBuilder.build(),
+        key -> sideInputSpec.getViewFn().apply(createMultimapSideInput(
+            view.getId(), encodedWindow, kvCoder.getKeyCoder(), 
kvCoder.getValueCoder())));
+  }
+
+  private <K, V> MultimapSideInput<K, V> createMultimapSideInput(
+      String sideInputId,
+      ByteString encodedWindow,
+      Coder<K> keyCoder,
+      Coder<V> valueCoder) {
+
+    return new MultimapSideInput<>(
+        beamFnStateClient,
+        processBundleInstructionId.get(),
+        ptransformId,
+        sideInputId,
+        encodedWindow,
+        keyCoder,
+        valueCoder);
   }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
index f2e852c..1b08e58 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/BagUserState.java
@@ -23,12 +23,10 @@ import com.google.common.collect.Iterables;
 import com.google.protobuf.ByteString;
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.List;
 import java.util.concurrent.CompletableFuture;
-import java.util.function.Supplier;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.stream.DataStreams;
 
@@ -46,62 +44,76 @@ import org.apache.beam.sdk.fn.stream.DataStreams;
  */
 public class BagUserState<T> {
   private final BeamFnStateClient beamFnStateClient;
-  private final String stateId;
-  private final Coder<T> coder;
-  private final Supplier<Builder> partialRequestSupplier;
+  private final StateRequest request;
+  private final Coder<T> valueCoder;
   private Iterable<T> oldValues;
   private ArrayList<T> newValues;
-  private List<T> unmodifiableNewValues;
   private boolean isClosed;
 
   public BagUserState(
       BeamFnStateClient beamFnStateClient,
+      String instructionId,
+      String ptransformId,
       String stateId,
-      Coder<T> coder,
-      Supplier<Builder> partialRequestSupplier) {
+      ByteString encodedWindow,
+      ByteString encodedKey,
+      Coder<T> valueCoder) {
     this.beamFnStateClient = beamFnStateClient;
-    this.stateId = stateId;
-    this.coder = coder;
-    this.partialRequestSupplier = partialRequestSupplier;
+    this.valueCoder = valueCoder;
+
+    StateRequest.Builder requestBuilder = StateRequest.newBuilder();
+    requestBuilder
+        .setInstructionReference(instructionId)
+        .getStateKeyBuilder()
+        .getBagUserStateBuilder()
+        .setPtransformId(ptransformId)
+        .setUserStateId(stateId)
+        .setWindow(encodedWindow)
+        .setKey(encodedKey);
+    request = requestBuilder.build();
+
     this.oldValues = new LazyCachingIteratorToIterable<>(
-        new DataStreams.DataStreamDecoder(coder,
+        new DataStreams.DataStreamDecoder(valueCoder,
             DataStreams.inbound(
-                StateFetchingIterators.usingPartialRequestWithStateKey(
+                StateFetchingIterators.forFirstChunk(
                     beamFnStateClient,
-                    partialRequestSupplier))));
+                    request))));
     this.newValues = new ArrayList<>();
-    this.unmodifiableNewValues = Collections.unmodifiableList(newValues);
   }
 
   public Iterable<T> get() {
     checkState(!isClosed,
-        "Bag user state is no longer usable because it is closed for %s", 
stateId);
-    // If we were cleared we should disregard old values.
+        "Bag user state is no longer usable because it is closed for %s", 
request.getStateKey());
     if (oldValues == null) {
-      return unmodifiableNewValues;
+      // If we were cleared we should disregard old values.
+      return Iterables.limit(Collections.unmodifiableList(newValues), 
newValues.size());
+    } else if (newValues.isEmpty()) {
+      // If we have no new values then just return the old values.
+      return oldValues;
     }
-    return Iterables.concat(oldValues, unmodifiableNewValues);
+    return Iterables.concat(oldValues,
+        Iterables.limit(Collections.unmodifiableList(newValues), 
newValues.size()));
   }
 
   public void append(T t) {
     checkState(!isClosed,
-        "Bag user state is no longer usable because it is closed for %s", 
stateId);
+        "Bag user state is no longer usable because it is closed for %s", 
request.getStateKey());
     newValues.add(t);
   }
 
   public void clear() {
     checkState(!isClosed,
-        "Bag user state is no longer usable because it is closed for %s", 
stateId);
+        "Bag user state is no longer usable because it is closed for %s", 
request.getStateKey());
     oldValues = null;
-    newValues.clear();
+    newValues = new ArrayList<>();
   }
 
   public void asyncClose() throws Exception {
     checkState(!isClosed,
-        "Bag user state is no longer usable because it is closed for %s", 
stateId);
+        "Bag user state is no longer usable because it is closed for %s", 
request.getStateKey());
     if (oldValues == null) {
       beamFnStateClient.handle(
-          partialRequestSupplier.get()
+          request.toBuilder()
               .setClear(StateClearRequest.getDefaultInstance()),
           new CompletableFuture<>());
     }
@@ -109,10 +121,10 @@ public class BagUserState<T> {
       ByteString.Output out = ByteString.newOutput();
       for (T newValue : newValues) {
         // TODO: Replace with chunking output stream
-        coder.encode(newValue, out);
+        valueCoder.encode(newValue, out);
       }
       beamFnStateClient.handle(
-          partialRequestSupplier.get()
+          request.toBuilder()
               
.setAppend(StateAppendRequest.newBuilder().setData(out.toByteString())),
           new CompletableFuture<>());
     }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
index 0a43317..0a6232c 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterable.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.fn.harness.state;
 
+import com.google.common.collect.Iterables;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
@@ -69,4 +70,20 @@ class LazyCachingIteratorToIterable<T> implements 
Iterable<T> {
       return rval;
     }
   }
+
+  @Override
+  public int hashCode() {
+    return iterator.hasNext() ? iterator.next().hashCode() : -1789023489;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    return obj instanceof Iterable
+        && Iterables.elementsEqual(this, (Iterable) obj);
+  }
+
+  @Override
+  public String toString() {
+    return Iterables.toString(this);
+  }
 }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
new file mode 100644
index 0000000..874d0fc
--- /dev/null
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapSideInput.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.stream.DataStreams;
+import org.apache.beam.sdk.transforms.Materializations.MultimapView;
+
+/**
+ * An implementation of a multimap side input that utilizes the Beam Fn State 
API to fetch values.
+ *
+ * <p>TODO: Support block level caching and prefetch.
+ */
+public class MultimapSideInput<K, V> implements MultimapView<K, V> {
+
+  private final BeamFnStateClient beamFnStateClient;
+  private final String instructionId;
+  private final String ptransformId;
+  private final String sideInputId;
+  private final ByteString encodedWindow;
+  private final Coder<K> keyCoder;
+  private final Coder<V> valueCoder;
+
+  public MultimapSideInput(
+      BeamFnStateClient beamFnStateClient,
+      String instructionId,
+      String ptransformId,
+      String sideInputId,
+      ByteString encodedWindow,
+      Coder<K> keyCoder,
+      Coder<V> valueCoder) {
+    this.beamFnStateClient = beamFnStateClient;
+    this.instructionId = instructionId;
+    this.ptransformId = ptransformId;
+    this.sideInputId = sideInputId;
+    this.encodedWindow = encodedWindow;
+    this.keyCoder = keyCoder;
+    this.valueCoder = valueCoder;
+  }
+
+  public Iterable<V> get(K k) {
+    ByteString.Output output = ByteString.newOutput();
+    try {
+      keyCoder.encode(k, output);
+    } catch (IOException e) {
+      throw new IllegalStateException(
+          String.format("Failed to encode key %s for side input id %s.", k, 
sideInputId),
+          e);
+    }
+    StateRequest.Builder requestBuilder = StateRequest.newBuilder();
+    requestBuilder
+        .setInstructionReference(instructionId)
+        .getStateKeyBuilder()
+        .getMultimapSideInputBuilder()
+        .setPtransformId(ptransformId)
+        .setSideInputId(sideInputId)
+        .setWindow(encodedWindow)
+        .setKey(output.toByteString());
+
+    return new LazyCachingIteratorToIterable<>(
+        new DataStreams.DataStreamDecoder(valueCoder,
+            DataStreams.inbound(
+                StateFetchingIterators.forFirstChunk(
+                    beamFnStateClient,
+                    requestBuilder.build()))));
+  }
+}
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
index b64c946..683314a 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java
@@ -23,10 +23,8 @@ import java.util.Iterator;
 import java.util.NoSuchElementException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.function.Supplier;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
 
 /**
@@ -40,18 +38,18 @@ public class StateFetchingIterators {
 
   /**
    * This adapter handles using the continuation token to provide iteration 
over all the chunks
-   * returned by the Beam Fn State API using the supplied state client and 
partially filled
-   * out state request containing a state key.
+   * returned by the Beam Fn State API using the supplied state client and 
state request for
+   * the first chunk of the state stream.
    *
    * @param beamFnStateClient A client for handling state requests.
-   * @param partialStateRequestBuilder A {@link StateRequest} with the
-   * {@link StateRequest#getStateKey()} already set.
-   * @return An {@code Iterator<ByteString>} representing all the requested 
data.
+   * @param stateRequestForFirstChunk A fully populated state request for the 
first (and possibly
+   * only) chunk of a state stream. This state request will be populated with 
a continuation token
+   * to request further chunks of the stream if required.
    */
-  public static Iterator<ByteString> usingPartialRequestWithStateKey(
+  public static Iterator<ByteString> forFirstChunk(
       BeamFnStateClient beamFnStateClient,
-      Supplier<StateRequest.Builder> partialStateRequestBuilder) {
-    return new LazyBlockingStateFetchingIterator(beamFnStateClient, 
partialStateRequestBuilder);
+      StateRequest stateRequestForFirstChunk) {
+    return new LazyBlockingStateFetchingIterator(beamFnStateClient, 
stateRequestForFirstChunk);
   }
 
   /**
@@ -63,18 +61,17 @@ public class StateFetchingIterators {
   static class LazyBlockingStateFetchingIterator implements 
Iterator<ByteString> {
     private enum State { READ_REQUIRED, HAS_NEXT, EOF };
     private final BeamFnStateClient beamFnStateClient;
-    /** Allows for the partially built state request to be memoized across 
many requests. */
-    private final Supplier<Builder> stateRequestSupplier;
+    private final StateRequest stateRequestForFirstChunk;
     private State currentState;
     private ByteString continuationToken;
     private ByteString next;
 
     LazyBlockingStateFetchingIterator(
         BeamFnStateClient beamFnStateClient,
-        Supplier<StateRequest.Builder> stateRequestSupplier) {
+        StateRequest stateRequestForFirstChunk) {
       this.currentState = State.READ_REQUIRED;
       this.beamFnStateClient = beamFnStateClient;
-      this.stateRequestSupplier = stateRequestSupplier;
+      this.stateRequestForFirstChunk = stateRequestForFirstChunk;
       this.continuationToken = ByteString.EMPTY;
     }
 
@@ -86,7 +83,7 @@ public class StateFetchingIterators {
         case READ_REQUIRED:
           CompletableFuture<StateResponse> stateResponseFuture = new 
CompletableFuture<>();
           beamFnStateClient.handle(
-              stateRequestSupplier.get().setGet(
+              stateRequestForFirstChunk.toBuilder().setGet(
                   
StateGetRequest.newBuilder().setContinuationToken(continuationToken)),
               stateResponseFuture);
           StateResponse stateResponse;
@@ -122,5 +119,4 @@ public class StateFetchingIterators {
       return next;
     }
   }
-
 }
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 70aca2e..22bcebd 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -23,6 +23,7 @@ import static 
org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.hasSize;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
@@ -45,6 +46,9 @@ import org.apache.beam.fn.harness.state.FakeBeamFnStateClient;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.SdkComponents;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
@@ -57,17 +61,28 @@ import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
 import org.apache.beam.sdk.transforms.CombineWithContext.Context;
+import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.DoFnInfo;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.hamcrest.collection.IsMapContaining;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -291,10 +306,10 @@ public class FnApiDoFnRunnerTest {
         .build();
 
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(ImmutableMap.of(
-        key("value", "X"), encode("X0"),
-        key("bag", "X"), encode("X0"),
-        key("combine", "X"), encode("X0"),
-        key("combineWithContext", "X"), encode("X0")
+        bagUserStateKey("value", "X"), encode("X0"),
+        bagUserStateKey("bag", "X"), encode("X0"),
+        bagUserStateKey("combine", "X"), encode("X0"),
+        bagUserStateKey("combineWithContext", "X"), encode("X0")
     ));
 
     List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
@@ -355,21 +370,21 @@ public class FnApiDoFnRunnerTest {
 
     assertEquals(
         ImmutableMap.<StateKey, ByteString>builder()
-            .put(key("value", "X"), encode("X2"))
-            .put(key("bag", "X"), encode("X0", "X1", "X2"))
-            .put(key("combine", "X"), encode("X0X1X2"))
-            .put(key("combineWithContext", "X"), encode("X0X1X2"))
-            .put(key("value", "Y"), encode("Y2"))
-            .put(key("bag", "Y"), encode("Y1", "Y2"))
-            .put(key("combine", "Y"), encode("Y1Y2"))
-            .put(key("combineWithContext", "Y"), encode("Y1Y2"))
+            .put(bagUserStateKey("value", "X"), encode("X2"))
+            .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
+            .put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
+            .put(bagUserStateKey("combineWithContext", "X"), encode("X0X1X2"))
+            .put(bagUserStateKey("value", "Y"), encode("Y2"))
+            .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
+            .put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
+            .put(bagUserStateKey("combineWithContext", "Y"), encode("Y1Y2"))
             .build(),
         fakeClient.getData());
     mainOutputValues.clear();
   }
 
-  /** Produces a {@link StateKey} for the test PTransform id in the Global 
Window. */
-  private StateKey key(String userStateId, String key) throws IOException {
+  /** Produces a bag user {@link StateKey} for the test PTransform id in the 
global window. */
+  private StateKey bagUserStateKey(String userStateId, String key) throws 
IOException {
     return StateKey.newBuilder().setBagUserState(
         StateKey.BagUserState.newBuilder()
             .setPtransformId(TEST_PTRANSFORM_ID)
@@ -380,6 +395,236 @@ public class FnApiDoFnRunnerTest {
         .build();
   }
 
+  private static class TestSideInputDoFn extends DoFn<String, String> {
+    private final PCollectionView<String> defaultSingletonSideInput;
+    private final PCollectionView<String> singletonSideInput;
+    private final PCollectionView<Iterable<String>> iterableSideInput;
+    private TestSideInputDoFn(
+        PCollectionView<String> defaultSingletonSideInput,
+        PCollectionView<String> singletonSideInput,
+        PCollectionView<Iterable<String>> iterableSideInput) {
+      this.defaultSingletonSideInput = defaultSingletonSideInput;
+      this.singletonSideInput = singletonSideInput;
+      this.iterableSideInput = iterableSideInput;
+    }
+
+    @ProcessElement
+    public void processElement(ProcessContext context) {
+      context.output(context.element() + ":" + 
context.sideInput(defaultSingletonSideInput));
+      context.output(context.element() + ":" + 
context.sideInput(singletonSideInput));
+      for (String sideInputValue : context.sideInput(iterableSideInput)) {
+        context.output(context.element() + ":" + sideInputValue);
+      }
+    }
+  }
+
+  @Test
+  public void testUsingSideInput() throws Exception {
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+    PCollectionView<String> defaultSingletonSideInputView = 
valuePCollection.apply(
+        View.<String>asSingleton().withDefaultValue("defaultSingletonValue"));
+    PCollectionView<String> singletonSideInputView = 
valuePCollection.apply(View.asSingleton());
+    PCollectionView<Iterable<String>> iterableSideInputView =
+        valuePCollection.apply(View.asIterable());
+    PCollection<String> outputPCollection = 
valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
+        new TestSideInputDoFn(
+            defaultSingletonSideInputView,
+            singletonSideInputView,
+            iterableSideInputView))
+        .withSideInputs(
+            defaultSingletonSideInputView, singletonSideInputView, 
iterableSideInputView));
+
+    SdkComponents sdkComponents = SdkComponents.create();
+    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+    String inputPCollectionId = 
sdkComponents.registerPCollection(valuePCollection);
+    String outputPCollectionId = 
sdkComponents.registerPCollection(outputPCollection);
+
+    RunnerApi.PTransform pTransform = 
pProto.getComponents().getTransformsOrThrow(
+        
pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+
+    ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
+        multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), 
ByteString.EMPTY),
+        encode("singletonValue"),
+        multimapSideInputKey(iterableSideInputView.getTagInternal().getId(), 
ByteString.EMPTY),
+        encode("iterableValue1", "iterableValue2", "iterableValue3"));
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = 
HashMultimap.create();
+    
consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) 
mainOutputValues::add);
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+        PipelineOptionsFactory.create(),
+        null /* beamFnDataClient */,
+        fakeClient,
+        TEST_PTRANSFORM_ID,
+        pTransform,
+        Suppliers.ofInstance("57L")::get,
+        pProto.getComponents().getPcollectionsMap(),
+        pProto.getComponents().getCodersMap(),
+        pProto.getComponents().getWindowingStrategiesMap(),
+        consumers,
+        startFunctions::add,
+        finishFunctions::add);
+
+    Iterables.getOnlyElement(startFunctions).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, 
outputPCollectionId));
+
+    // Ensure that bag user state that is initially empty or populated works.
+    // Ensure that the bagUserStateKey order does not matter when we traverse 
over KV pairs.
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        Iterables.getOnlyElement(consumers.get(inputPCollectionId));
+    mainInput.accept(valueInGlobalWindow("X"));
+    mainInput.accept(valueInGlobalWindow("Y"));
+    assertThat(mainOutputValues, contains(
+        valueInGlobalWindow("X:defaultSingletonValue"),
+        valueInGlobalWindow("X:singletonValue"),
+        valueInGlobalWindow("X:iterableValue1"),
+        valueInGlobalWindow("X:iterableValue2"),
+        valueInGlobalWindow("X:iterableValue3"),
+        valueInGlobalWindow("Y:defaultSingletonValue"),
+        valueInGlobalWindow("Y:singletonValue"),
+        valueInGlobalWindow("Y:iterableValue1"),
+        valueInGlobalWindow("Y:iterableValue2"),
+        valueInGlobalWindow("Y:iterableValue3")));
+    mainOutputValues.clear();
+
+    Iterables.getOnlyElement(finishFunctions).run();
+    assertThat(mainOutputValues, empty());
+
+    // Assert that state data did not change
+    assertEquals(stateData, fakeClient.getData());
+    mainOutputValues.clear();
+  }
+
+  private static class TestSideInputIsAccessibleForDownstreamCallersDoFn
+      extends DoFn<String, Iterable<String>> {
+    private final PCollectionView<Iterable<String>> iterableSideInput;
+    private TestSideInputIsAccessibleForDownstreamCallersDoFn(
+        PCollectionView<Iterable<String>> iterableSideInput) {
+      this.iterableSideInput = iterableSideInput;
+    }
+
+    @ProcessElement
+    public void processElement(ProcessContext context) {
+      context.output(context.sideInput(iterableSideInput));
+    }
+  }
+
+  @Test
+  public void testSideInputIsAccessibleForDownstreamCallers() throws Exception 
{
+    FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
+    IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
+    IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
+    ByteString encodedWindowA =
+        
ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), 
windowA));
+    ByteString encodedWindowB =
+        
ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), 
windowB));
+
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection = p.apply(Create.of("unused"))
+        .apply(Window.into(windowFn));
+    PCollectionView<Iterable<String>> iterableSideInputView =
+        valuePCollection.apply(View.asIterable());
+    PCollection<Iterable<String>> outputPCollection =
+        valuePCollection.apply(TEST_PTRANSFORM_ID, ParDo.of(
+            new 
TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
+            .withSideInputs(iterableSideInputView));
+
+    SdkComponents sdkComponents = SdkComponents.create();
+    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+    String inputPCollectionId = 
sdkComponents.registerPCollection(valuePCollection);
+    String outputPCollectionId = 
sdkComponents.registerPCollection(outputPCollection);
+
+    RunnerApi.PTransform pTransform = 
pProto.getComponents().getTransformsOrThrow(
+        
pProto.getComponents().getTransformsOrThrow(TEST_PTRANSFORM_ID).getSubtransforms(0));
+
+    ImmutableMap<StateKey, ByteString> stateData = ImmutableMap.of(
+        multimapSideInputKey(
+            iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY, 
encodedWindowA),
+        encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
+        multimapSideInputKey(
+            iterableSideInputView.getTagInternal().getId(), ByteString.EMPTY, 
encodedWindowB),
+        encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+    List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
+    Multimap<String, FnDataReceiver<WindowedValue<?>>> consumers = 
HashMultimap.create();
+    
consumers.put(Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) 
mainOutputValues::add);
+    List<ThrowingRunnable> startFunctions = new ArrayList<>();
+    List<ThrowingRunnable> finishFunctions = new ArrayList<>();
+
+    new FnApiDoFnRunner.NewFactory<>().createRunnerForPTransform(
+        PipelineOptionsFactory.create(),
+        null /* beamFnDataClient */,
+        fakeClient,
+        TEST_PTRANSFORM_ID,
+        pTransform,
+        Suppliers.ofInstance("57L")::get,
+        pProto.getComponents().getPcollectionsMap(),
+        pProto.getComponents().getCodersMap(),
+        pProto.getComponents().getWindowingStrategiesMap(),
+        consumers,
+        startFunctions::add,
+        finishFunctions::add);
+
+    Iterables.getOnlyElement(startFunctions).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, 
outputPCollectionId));
+
+    // Ensure that bag user state that is initially empty or populated works.
+    // Ensure that the bagUserStateKey order does not matter when we traverse 
over KV pairs.
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        Iterables.getOnlyElement(consumers.get(inputPCollectionId));
+    mainInput.accept(valueInWindow("X", windowA));
+    mainInput.accept(valueInWindow("Y", windowB));
+    assertThat(mainOutputValues, hasSize(2));
+    assertThat(mainOutputValues.get(0).getValue(), contains(
+        "iterableValue1A", "iterableValue2A", "iterableValue3A"));
+    assertThat(mainOutputValues.get(1).getValue(), contains(
+        "iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+    // Assert that state data did not change
+    assertEquals(stateData, fakeClient.getData());
+  }
+
+  private <T> WindowedValue<T> valueInWindow(T value, BoundedWindow window) {
+    return WindowedValue.of(value, window.maxTimestamp(), window, 
PaneInfo.ON_TIME_AND_ONLY_FIRING);
+  }
+
+  /**
+   * Produces a multimap side input {@link StateKey} for the test PTransform 
id in the global
+   * window.
+   */
+  private StateKey multimapSideInputKey(String sideInputId, ByteString key) 
throws IOException {
+    return multimapSideInputKey(sideInputId, key, ByteString.copyFrom(
+        CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, 
GlobalWindow.INSTANCE)));
+  }
+
+  /**
+   * Produces a multimap side input {@link StateKey} for the test PTransform 
id in the supplied
+   * window.
+   */
+  private StateKey multimapSideInputKey(String sideInputId, ByteString key, 
ByteString windowKey) {
+    return StateKey.newBuilder().setMultimapSideInput(
+        StateKey.MultimapSideInput.newBuilder()
+            .setPtransformId(TEST_PTRANSFORM_ID)
+            .setSideInputId(sideInputId)
+            .setKey(key)
+            .setWindow(windowKey))
+        .build();
+  }
+
   private ByteString encode(String ... values) throws IOException {
     ByteString.Output out = ByteString.newOutput();
     for (String value : values) {
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
index 6d3e078..29c4a8a 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/BagUserStateTest.java
@@ -19,6 +19,7 @@ package org.apache.beam.fn.harness.state;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
 
 import com.google.common.collect.ImmutableMap;
@@ -26,7 +27,6 @@ import com.google.common.collect.Iterables;
 import com.google.protobuf.ByteString;
 import java.io.IOException;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.junit.Rule;
 import org.junit.Test;
@@ -44,7 +44,14 @@ public class BagUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(ImmutableMap.of(
         key("A"), encode("A1", "A2", "A3")));
     BagUserState<String> userState =
-        new BagUserState<>(fakeClient, "A", StringUtf8Coder.of(), () -> 
requestForId("A"));
+        new BagUserState<>(
+            fakeClient,
+            "instructionId",
+            "ptransformId",
+            "stateId",
+            ByteString.copyFromUtf8("encodedWindow"),
+            encode("A"),
+            StringUtf8Coder.of());
     assertArrayEquals(new String[]{ "A1", "A2", "A3" },
         Iterables.toArray(userState.get(), String.class));
 
@@ -58,9 +65,23 @@ public class BagUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(ImmutableMap.of(
         key("A"), encode("A1")));
     BagUserState<String> userState =
-        new BagUserState<>(fakeClient, "A", StringUtf8Coder.of(), () -> 
requestForId("A"));
+        new BagUserState<>(
+            fakeClient,
+            "instructionId",
+            "ptransformId",
+            "stateId",
+            ByteString.copyFromUtf8("encodedWindow"),
+            encode("A"),
+            StringUtf8Coder.of());
     userState.append("A2");
+    Iterable<String> stateBeforeA3 = userState.get();
+    assertArrayEquals(new String[]{ "A1", "A2" },
+        Iterables.toArray(stateBeforeA3, String.class));
     userState.append("A3");
+    assertArrayEquals(new String[]{ "A1", "A2" },
+        Iterables.toArray(stateBeforeA3, String.class));
+    assertArrayEquals(new String[]{ "A1", "A2", "A3" },
+        Iterables.toArray(userState.get(), String.class));
     userState.asyncClose();
 
     assertEquals(encode("A1", "A2", "A3"), fakeClient.getData().get(key("A")));
@@ -73,11 +94,23 @@ public class BagUserStateTest {
     FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(ImmutableMap.of(
         key("A"), encode("A1", "A2", "A3")));
     BagUserState<String> userState =
-        new BagUserState<>(fakeClient, "A", StringUtf8Coder.of(), () -> 
requestForId("A"));
-
+        new BagUserState<>(
+            fakeClient,
+            "instructionId",
+            "ptransformId",
+            "stateId",
+            ByteString.copyFromUtf8("encodedWindow"),
+            encode("A"),
+            StringUtf8Coder.of());
+    assertArrayEquals(new String[]{ "A1", "A2", "A3" },
+        Iterables.toArray(userState.get(), String.class));
     userState.clear();
-    userState.append("A1");
+    assertFalse(userState.get().iterator().hasNext());
+    userState.append("A4");
+    assertArrayEquals(new String[]{ "A4" },
+        Iterables.toArray(userState.get(), String.class));
     userState.clear();
+    assertFalse(userState.get().iterator().hasNext());
     userState.asyncClose();
 
     assertNull(fakeClient.getData().get(key("A")));
@@ -85,15 +118,13 @@ public class BagUserStateTest {
     userState.clear();
   }
 
-  private StateRequest.Builder requestForId(String id) {
-    return StateRequest.newBuilder().setStateKey(
-        StateKey.newBuilder().setBagUserState(
-            
StateKey.BagUserState.newBuilder().setKey(ByteString.copyFromUtf8(id))));
-  }
-
-  private StateKey key(String id) {
+  private StateKey key(String id) throws IOException {
     return StateKey.newBuilder().setBagUserState(
-        
StateKey.BagUserState.newBuilder().setKey(ByteString.copyFromUtf8(id))).build();
+        StateKey.BagUserState.newBuilder()
+            .setPtransformId("ptransformId")
+            .setUserStateId("stateId")
+            .setWindow(ByteString.copyFromUtf8("encodedWindow"))
+            .setKey(encode(id))).build();
   }
 
   private ByteString encode(String ... values) throws IOException {
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
index 53eefb4..1e44452 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/LazyCachingIteratorToIterableTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.fn.harness.state;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
 
 import com.google.common.collect.Iterables;
@@ -73,4 +74,17 @@ public class LazyCachingIteratorToIterableTest {
     thrown.expect(NoSuchElementException.class);
     iterator1.next();
   }
+
+  @Test
+  public void testEqualsAndHashCode() {
+    Iterable<String> iterA = new 
LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+    Iterable<String> iterB = new 
LazyCachingIteratorToIterable<>(Iterators.forArray("A", "B", "C"));
+    Iterable<String> iterC = new 
LazyCachingIteratorToIterable<>(Iterators.forArray());
+    Iterable<String> iterD = new 
LazyCachingIteratorToIterable<>(Iterators.forArray());
+    assertEquals(iterA, iterB);
+    assertEquals(iterC, iterD);
+    assertNotEquals(iterA, iterC);
+    assertEquals(iterA.hashCode(), iterB.hashCode());
+    assertEquals(iterC.hashCode(), iterD.hashCode());
+  }
 }
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
new file mode 100644
index 0000000..39c0cbd
--- /dev/null
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapSideInputTest.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Iterables;
+import com.google.protobuf.ByteString;
+import java.io.IOException;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link MultimapSideInput}. */
+@RunWith(JUnit4.class)
+public class MultimapSideInputTest {
+  @Test
+  public void testGet() throws Exception {
+    FakeBeamFnStateClient fakeBeamFnStateClient = new 
FakeBeamFnStateClient(ImmutableMap.of(
+        key("A"), encode("A1", "A2", "A3"),
+        key("B"), encode("B1", "B2")));
+
+    MultimapSideInput<String, String> multimapSideInput = new 
MultimapSideInput<>(
+        fakeBeamFnStateClient,
+        "instructionId",
+        "ptransformId",
+        "sideInputId",
+        ByteString.copyFromUtf8("encodedWindow"),
+        StringUtf8Coder.of(),
+        StringUtf8Coder.of());
+    assertArrayEquals(new String[]{ "A1", "A2", "A3" },
+        Iterables.toArray(multimapSideInput.get("A"), String.class));
+    assertArrayEquals(new String[]{ "B1", "B2" },
+        Iterables.toArray(multimapSideInput.get("B"), String.class));
+    assertArrayEquals(new String[]{ },
+        Iterables.toArray(multimapSideInput.get("unknown"), String.class));
+  }
+
+  private StateKey key(String id) throws IOException {
+    return StateKey.newBuilder().setMultimapSideInput(
+        StateKey.MultimapSideInput.newBuilder()
+            .setPtransformId("ptransformId")
+            .setSideInputId("sideInputId")
+            .setWindow(ByteString.copyFromUtf8("encodedWindow"))
+            .setKey(encode(id))).build();
+  }
+
+  private ByteString encode(String ... values) throws IOException {
+    ByteString.Output out = ByteString.newOutput();
+    for (String value : values) {
+      StringUtf8Coder.of().encode(value, out);
+    }
+    return out.toByteString();
+  }
+}
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
index 6ddec56..b4f37ab 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.java
@@ -91,7 +91,7 @@ public class StateFetchingIteratorsTest {
                     .build());
           };
       Iterator<ByteString> byteStrings =
-          new LazyBlockingStateFetchingIterator(fakeStateClient, 
StateRequest::newBuilder);
+          new LazyBlockingStateFetchingIterator(fakeStateClient, 
StateRequest.getDefaultInstance());
       assertArrayEquals(expected, Iterators.toArray(byteStrings, 
Object.class));
     }
   }

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to