lukecwik commented on a change in pull request #11342:
URL: https://github.com/apache/beam/pull/11342#discussion_r413289389



##########
File path: 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java
##########
@@ -0,0 +1,606 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.artifact;
+
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.FutureTask;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
+import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.fnexecution.FnService;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.ResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArtifactStagingService
+    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase 
implements FnService {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ArtifactStagingService.class);
+
+  private final ArtifactDestinationProvider destinationProvider;
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> toStage =
+      new ConcurrentHashMap<>();
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> staged =
+      new ConcurrentHashMap<>();
+
+  public ArtifactStagingService(ArtifactDestinationProvider 
destinationProvider) {
+    this.destinationProvider = destinationProvider;
+  }
+
+  /**
+   * Registers a set of artifacts to be staged with this service.
+   *
+   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with 
the given staging
+   * token and offer resolution and retrieval of this set of artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   * @param artifacts all artifacts to stage, keyed by environment
+   */
+  public void registerJob(
+      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> 
artifacts) {
+    assert !toStage.containsKey(stagingToken);
+    toStage.put(stagingToken, artifacts);
+  }
+
+  /**
+   * Returns the rewritten artifacts associated with this job, keyed by 
environment.
+   *
+   * <p>This should be called after the client has finished offering artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   */
+  public Map<String, List<RunnerApi.ArtifactInformation>> 
getStagedArtifacts(String stagingToken) {
+    toStage.remove(stagingToken);
+    return staged.remove(stagingToken);
+  }
+
+  /** Provides a concrete location to which artifacts can be staged on 
retrieval. */
+  public interface ArtifactDestinationProvider {
+    ArtifactDestination getDestination(String stagingToken, String name) 
throws IOException;
+  }
+
+  /**
+   * A pairing of a newly created artifact type and an ouptut stream that will 
be readable at that
+   * type.
+   */
+  @AutoValue
+  public abstract static class ArtifactDestination {
+    public static ArtifactDestination create(
+        String typeUrn, ByteString typePayload, OutputStream out) {
+      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, 
typePayload, out);
+    }
+
+    public static ArtifactDestination fromFile(String path) throws IOException 
{
+      return fromFile(
+          path,
+          Channels.newOutputStream(
+              FileSystems.create(
+                  FileSystems.matchNewResource(path, false /* isDirectory */), 
MimeTypes.BINARY)));
+    }
+
+    public static ArtifactDestination fromFile(String path, OutputStream out) {
+      return create(
+          ArtifactRetrievalService.FILE_ARTIFACT_URN,
+          
RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
+          out);
+    }
+
+    public abstract String getTypeUrn();
+
+    public abstract ByteString getTypePayload();
+
+    public abstract OutputStream getOutputStream();
+  }
+
+  /**
+   * An ArtifactDestinationProvider that places new artifacts as files in a 
Beam filesystem.
+   *
+   * @param root the directory in which to place all artifacts
+   */
+  public static ArtifactDestinationProvider 
beamFilesystemArtifactDestinationProvider(String root) {
+    return (statingToken, name) -> {
+      ResourceId path =
+          FileSystems.matchNewResource(root, true)
+              .resolve(statingToken, 
ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
+              .resolve(name, 
ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
+      return ArtifactDestination.fromFile(path.toString());
+    };
+  }
+
+  private enum State {
+    START,
+    RESOLVE,
+    GET,
+    GETCHUNK,
+    DONE,
+    ERROR,
+  }
+
+  /**
+   * Like the standard Semaphore, but allows an aquire to go over the limit if 
there is any room.
+   *
+   * <p>Also allows setting an error, to avoid issues with un-released aquires 
after error.
+   */
+  private static class OverflowingSemaphore {
+    private int totalPermits;
+    private int usedPermits;
+    private Exception exception;
+
+    public OverflowingSemaphore(int totalPermits) {
+      this.totalPermits = totalPermits;
+      this.usedPermits = 0;
+    }
+
+    synchronized void aquire(int permits) throws Exception {
+      while (usedPermits >= totalPermits) {
+        if (exception != null) {
+          throw exception;
+        }
+        this.wait();
+      }
+      usedPermits += permits;
+    }
+
+    synchronized void release(int permits) {
+      usedPermits -= permits;
+      this.notifyAll();
+    }
+
+    synchronized void setException(Exception exception) {
+      this.exception = exception;
+      this.notifyAll();
+    }
+  }
+
+  /** A task that pulls bytes off a queue and actually writes them to a 
staging location. */
+  private class StoreArtifact implements 
Callable<RunnerApi.ArtifactInformation> {
+
+    private String stagingToken;
+    private String name;
+    private RunnerApi.ArtifactInformation originalArtifact;
+    private BlockingQueue<ByteString> bytesQueue;
+    private OverflowingSemaphore totalPendingBytes;
+
+    public StoreArtifact(
+        String stagingToken,
+        String name,
+        RunnerApi.ArtifactInformation originalArtifact,
+        BlockingQueue<ByteString> bytesQueue,
+        OverflowingSemaphore totalPendingBytes) {
+      this.stagingToken = stagingToken;
+      this.name = name;
+      this.originalArtifact = originalArtifact;
+      this.bytesQueue = bytesQueue;
+      this.totalPendingBytes = totalPendingBytes;
+    }
+
+    @Override
+    public RunnerApi.ArtifactInformation call() throws IOException {
+      try {
+        ArtifactDestination dest = 
destinationProvider.getDestination(stagingToken, name);
+        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, 
dest);
+        ByteString chunk = bytesQueue.take();
+        while (chunk.size() > 0) {
+          totalPendingBytes.release(chunk.size());
+          dest.getOutputStream().write(chunk.toByteArray());
+          chunk = bytesQueue.take();
+        }
+        dest.getOutputStream().close();
+        return originalArtifact
+            .toBuilder()
+            .setTypeUrn(dest.getTypeUrn())
+            .setTypePayload(dest.getTypePayload())
+            .build();
+      } catch (IOException | InterruptedException exn) {
+        // As this thread will no longer be draining the queue, we don't want 
to get stuck writing
+        // to it.
+        totalPendingBytes.setException(exn);
+        LOG.error("Exception staging artifacts", exn);
+        if (exn instanceof IOException) {
+          throw (IOException) exn;
+        } else {
+          throw new RuntimeException(exn);
+        }
+      }
+    }
+  }
+
+  @Override
+  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
reverseArtifactRetrievalService(
+      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
+
+    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {
+
+      /** The maximum number of parallel threads to use to stage. */
+      public static final int THREAD_POOL_SIZE = 10;
+
+      /** The maximum number of bytes to buffer across all writes before 
throttling. */
+      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB
+
+      String stagingToken;
+      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
+      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
+      ExecutorService stagingExecutor;
+      OverflowingSemaphore totalPendingBytes;
+
+      State state = State.START;
+      Queue<String> pendingResolves;
+      String currentEnvironment;
+      int nameIndex;
+      Queue<RunnerApi.ArtifactInformation> pendingGets;
+      BlockingQueue<ByteString> currentOutput;
+
+      @Override
+      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = 
"fallthrough intended")
+      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper 
responseWrapper) {
+        switch (state) {
+          case START:
+            stagingToken = responseWrapper.getStagingToken();
+            LOG.info("Staging artifacts for {}.", stagingToken);
+            toResolve = toStage.get(stagingToken);
+            stagedFutures = new ConcurrentHashMap<>();
+            pendingResolves = new ArrayDeque<>();
+            pendingResolves.addAll(toResolve.keySet());
+            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
+            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
+            resolveNextEnvironment(responseObserver);
+            break;
+
+          case RESOLVE:
+            {
+              currentEnvironment = pendingResolves.remove();
+              stagedFutures.put(currentEnvironment, new ArrayList<>());
+              pendingGets = new ArrayDeque<>();
+              for (RunnerApi.ArtifactInformation artifact :
+                  
responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
+                Optional<RunnerApi.ArtifactInformation> fetched = 
getLocal(artifact);
+                if (fetched.isPresent()) {
+                  stagedFutures
+                      .get(currentEnvironment)
+                      .add(new FutureTask<RunnerApi.ArtifactInformation>(() -> 
fetched.get()));
+                } else {
+                  pendingGets.add(artifact);
+                  responseObserver.onNext(
+                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                          .setGetArtifact(
+                              
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
+                          .build());
+                }
+              }
+              LOG.info(
+                  "Getting {} artifacts for {}.{}.",
+                  pendingGets.size(),
+                  stagingToken,
+                  pendingResolves.peek());
+              if (pendingGets.isEmpty()) {
+                resolveNextEnvironment(responseObserver);
+              } else {
+                state = State.GET;
+              }
+              break;
+            }
+
+          case GET:
+            RunnerApi.ArtifactInformation currentArtifact = 
pendingGets.remove();
+            String name = createFilename(nameIndex++, currentEnvironment, 
currentArtifact);
+            try {
+              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
+              currentOutput = new ArrayBlockingQueue<ByteString>(100);
+              stagedFutures
+                  .get(currentEnvironment)
+                  .add(
+                      stagingExecutor.submit(
+                          new StoreArtifact(
+                              stagingToken,
+                              name,
+                              currentArtifact,
+                              currentOutput,
+                              totalPendingBytes)));
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              responseObserver.onError(exn);
+            }
+            state = State.GETCHUNK;
+            // fall through
+
+          case GETCHUNK:
+            try {
+              ByteString chunk = 
responseWrapper.getGetArtifactResponse().getData();
+              if (chunk.size() > 0) {
+                totalPendingBytes.aquire(chunk.size());
+                currentOutput.put(chunk);
+              }
+              if (responseWrapper.getIsLast()) {
+                currentOutput.put(ByteString.EMPTY); // The EOF value.

Review comment:
       Makes sense, just have to be careful is all.

##########
File path: 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java
##########
@@ -0,0 +1,606 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.artifact;
+
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.FutureTask;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
+import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.fnexecution.FnService;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.ResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArtifactStagingService
+    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase 
implements FnService {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ArtifactStagingService.class);
+
+  private final ArtifactDestinationProvider destinationProvider;
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> toStage =
+      new ConcurrentHashMap<>();
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> staged =
+      new ConcurrentHashMap<>();
+
+  public ArtifactStagingService(ArtifactDestinationProvider 
destinationProvider) {
+    this.destinationProvider = destinationProvider;
+  }
+
+  /**
+   * Registers a set of artifacts to be staged with this service.
+   *
+   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with 
the given staging
+   * token and offer resolution and retrieval of this set of artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   * @param artifacts all artifacts to stage, keyed by environment
+   */
+  public void registerJob(
+      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> 
artifacts) {
+    assert !toStage.containsKey(stagingToken);
+    toStage.put(stagingToken, artifacts);
+  }
+
+  /**
+   * Returns the rewritten artifacts associated with this job, keyed by 
environment.
+   *
+   * <p>This should be called after the client has finished offering artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   */
+  public Map<String, List<RunnerApi.ArtifactInformation>> 
getStagedArtifacts(String stagingToken) {
+    toStage.remove(stagingToken);
+    return staged.remove(stagingToken);
+  }
+
+  /** Provides a concrete location to which artifacts can be staged on 
retrieval. */
+  public interface ArtifactDestinationProvider {
+    ArtifactDestination getDestination(String stagingToken, String name) 
throws IOException;
+  }
+
+  /**
+   * A pairing of a newly created artifact type and an ouptut stream that will 
be readable at that
+   * type.
+   */
+  @AutoValue
+  public abstract static class ArtifactDestination {
+    public static ArtifactDestination create(
+        String typeUrn, ByteString typePayload, OutputStream out) {
+      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, 
typePayload, out);
+    }
+
+    public static ArtifactDestination fromFile(String path) throws IOException 
{
+      return fromFile(
+          path,
+          Channels.newOutputStream(
+              FileSystems.create(
+                  FileSystems.matchNewResource(path, false /* isDirectory */), 
MimeTypes.BINARY)));
+    }
+
+    public static ArtifactDestination fromFile(String path, OutputStream out) {
+      return create(
+          ArtifactRetrievalService.FILE_ARTIFACT_URN,
+          
RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
+          out);
+    }
+
+    public abstract String getTypeUrn();
+
+    public abstract ByteString getTypePayload();
+
+    public abstract OutputStream getOutputStream();
+  }
+
+  /**
+   * An ArtifactDestinationProvider that places new artifacts as files in a 
Beam filesystem.
+   *
+   * @param root the directory in which to place all artifacts
+   */
+  public static ArtifactDestinationProvider 
beamFilesystemArtifactDestinationProvider(String root) {
+    return (statingToken, name) -> {
+      ResourceId path =
+          FileSystems.matchNewResource(root, true)
+              .resolve(statingToken, 
ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
+              .resolve(name, 
ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
+      return ArtifactDestination.fromFile(path.toString());
+    };
+  }
+
+  private enum State {
+    START,
+    RESOLVE,
+    GET,
+    GETCHUNK,
+    DONE,
+    ERROR,
+  }
+
+  /**
+   * Like the standard Semaphore, but allows an aquire to go over the limit if 
there is any room.
+   *
+   * <p>Also allows setting an error, to avoid issues with un-released aquires 
after error.
+   */
+  private static class OverflowingSemaphore {
+    private int totalPermits;
+    private int usedPermits;
+    private Exception exception;
+
+    public OverflowingSemaphore(int totalPermits) {
+      this.totalPermits = totalPermits;
+      this.usedPermits = 0;
+    }
+
+    synchronized void aquire(int permits) throws Exception {
+      while (usedPermits >= totalPermits) {
+        if (exception != null) {
+          throw exception;
+        }
+        this.wait();
+      }
+      usedPermits += permits;
+    }
+
+    synchronized void release(int permits) {
+      usedPermits -= permits;
+      this.notifyAll();
+    }
+
+    synchronized void setException(Exception exception) {
+      this.exception = exception;
+      this.notifyAll();
+    }
+  }
+
+  /** A task that pulls bytes off a queue and actually writes them to a 
staging location. */
+  private class StoreArtifact implements 
Callable<RunnerApi.ArtifactInformation> {
+
+    private String stagingToken;
+    private String name;
+    private RunnerApi.ArtifactInformation originalArtifact;
+    private BlockingQueue<ByteString> bytesQueue;
+    private OverflowingSemaphore totalPendingBytes;
+
+    public StoreArtifact(
+        String stagingToken,
+        String name,
+        RunnerApi.ArtifactInformation originalArtifact,
+        BlockingQueue<ByteString> bytesQueue,
+        OverflowingSemaphore totalPendingBytes) {
+      this.stagingToken = stagingToken;
+      this.name = name;
+      this.originalArtifact = originalArtifact;
+      this.bytesQueue = bytesQueue;
+      this.totalPendingBytes = totalPendingBytes;
+    }
+
+    @Override
+    public RunnerApi.ArtifactInformation call() throws IOException {
+      try {
+        ArtifactDestination dest = 
destinationProvider.getDestination(stagingToken, name);
+        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, 
dest);
+        ByteString chunk = bytesQueue.take();
+        while (chunk.size() > 0) {
+          totalPendingBytes.release(chunk.size());
+          dest.getOutputStream().write(chunk.toByteArray());
+          chunk = bytesQueue.take();
+        }
+        dest.getOutputStream().close();
+        return originalArtifact
+            .toBuilder()
+            .setTypeUrn(dest.getTypeUrn())
+            .setTypePayload(dest.getTypePayload())
+            .build();
+      } catch (IOException | InterruptedException exn) {
+        // As this thread will no longer be draining the queue, we don't want 
to get stuck writing
+        // to it.
+        totalPendingBytes.setException(exn);
+        LOG.error("Exception staging artifacts", exn);
+        if (exn instanceof IOException) {
+          throw (IOException) exn;
+        } else {
+          throw new RuntimeException(exn);
+        }
+      }
+    }
+  }
+
+  @Override
+  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
reverseArtifactRetrievalService(
+      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
+
+    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {
+
+      /** The maximum number of parallel threads to use to stage. */
+      public static final int THREAD_POOL_SIZE = 10;
+
+      /** The maximum number of bytes to buffer across all writes before 
throttling. */
+      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB
+
+      String stagingToken;
+      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
+      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
+      ExecutorService stagingExecutor;
+      OverflowingSemaphore totalPendingBytes;
+
+      State state = State.START;
+      Queue<String> pendingResolves;
+      String currentEnvironment;
+      int nameIndex;
+      Queue<RunnerApi.ArtifactInformation> pendingGets;
+      BlockingQueue<ByteString> currentOutput;
+
+      @Override
+      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = 
"fallthrough intended")
+      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper 
responseWrapper) {
+        switch (state) {
+          case START:
+            stagingToken = responseWrapper.getStagingToken();
+            LOG.info("Staging artifacts for {}.", stagingToken);
+            toResolve = toStage.get(stagingToken);
+            stagedFutures = new ConcurrentHashMap<>();
+            pendingResolves = new ArrayDeque<>();
+            pendingResolves.addAll(toResolve.keySet());
+            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
+            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
+            resolveNextEnvironment(responseObserver);
+            break;
+
+          case RESOLVE:
+            {
+              currentEnvironment = pendingResolves.remove();
+              stagedFutures.put(currentEnvironment, new ArrayList<>());
+              pendingGets = new ArrayDeque<>();
+              for (RunnerApi.ArtifactInformation artifact :
+                  
responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
+                Optional<RunnerApi.ArtifactInformation> fetched = 
getLocal(artifact);
+                if (fetched.isPresent()) {
+                  stagedFutures
+                      .get(currentEnvironment)
+                      .add(new FutureTask<RunnerApi.ArtifactInformation>(() -> 
fetched.get()));
+                } else {
+                  pendingGets.add(artifact);
+                  responseObserver.onNext(
+                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                          .setGetArtifact(
+                              
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
+                          .build());
+                }
+              }
+              LOG.info(
+                  "Getting {} artifacts for {}.{}.",
+                  pendingGets.size(),
+                  stagingToken,
+                  pendingResolves.peek());
+              if (pendingGets.isEmpty()) {
+                resolveNextEnvironment(responseObserver);
+              } else {
+                state = State.GET;
+              }
+              break;
+            }
+
+          case GET:
+            RunnerApi.ArtifactInformation currentArtifact = 
pendingGets.remove();
+            String name = createFilename(nameIndex++, currentEnvironment, 
currentArtifact);
+            try {
+              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
+              currentOutput = new ArrayBlockingQueue<ByteString>(100);
+              stagedFutures
+                  .get(currentEnvironment)
+                  .add(
+                      stagingExecutor.submit(
+                          new StoreArtifact(
+                              stagingToken,
+                              name,
+                              currentArtifact,
+                              currentOutput,
+                              totalPendingBytes)));
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              responseObserver.onError(exn);
+            }
+            state = State.GETCHUNK;
+            // fall through
+
+          case GETCHUNK:
+            try {
+              ByteString chunk = 
responseWrapper.getGetArtifactResponse().getData();
+              if (chunk.size() > 0) {
+                totalPendingBytes.aquire(chunk.size());
+                currentOutput.put(chunk);
+              }
+              if (responseWrapper.getIsLast()) {
+                currentOutput.put(ByteString.EMPTY); // The EOF value.
+                if (pendingGets.isEmpty()) {
+                  resolveNextEnvironment(responseObserver);
+                } else {
+                  state = State.GET;
+                  LOG.debug("Waiting for {}", pendingGets.peek());
+                }
+              }
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              onError(exn);
+            }
+            break;
+
+          default:
+            responseObserver.onError(
+                new StatusException(
+                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + 
state)));
+        }
+      }
+
+      private void resolveNextEnvironment(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        if (pendingResolves.isEmpty()) {
+          finishStaging(responseObserver);
+        } else {
+          state = State.RESOLVE;
+          LOG.info("Resolving artifacts for {}.{}.", stagingToken, 
pendingResolves.peek());
+          responseObserver.onNext(
+              ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                  .setResolveArtifact(
+                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
+                          
.addAllArtifacts(toResolve.get(pendingResolves.peek())))
+                  .build());
+        }
+      }
+
+      private void finishStaging(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        LOG.debug("Finishing staging for {}.", stagingToken);
+        Map<String, List<RunnerApi.ArtifactInformation>> staged = new 
HashMap<>();
+        try {
+          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> 
entry :
+              stagedFutures.entrySet()) {
+            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
+            for (Future<RunnerApi.ArtifactInformation> future : 
entry.getValue()) {
+              envStaged.add(future.get());
+            }
+            staged.put(entry.getKey(), envStaged);
+          }
+          ArtifactStagingService.this.staged.put(stagingToken, staged);
+          stagingExecutor.shutdown();
+          state = State.DONE;
+          LOG.info("Artifacts fully staged for {}.", stagingToken);
+          responseObserver.onCompleted();
+        } catch (Exception exn) {
+          LOG.error("Error staging artifacts", exn);
+          responseObserver.onError(exn);
+          state = State.ERROR;
+          return;
+        }
+      }
+
+      /**
+       * Return an alternative artifact if we do not need to get this over the 
artifact API, or
+       * possibly at all.
+       */
+      private Optional<RunnerApi.ArtifactInformation> getLocal(
+          RunnerApi.ArtifactInformation artifact) {
+        return Optional.empty();
+      }
+
+      /**
+       * Attempts to provide a reasonable filename for the artifact.
+       *
+       * @param index a monotonically increasing index, which provides 
uniqueness
+       * @param environment the environment id
+       * @param artifact the artifact itself
+       */
+      private String createFilename(
+          int index, String environment, RunnerApi.ArtifactInformation 
artifact) {
+        String path;
+        try {
+          if 
(artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN))
 {
+            path =
+                
RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
+                    .getStagedName();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
+          } else {
+            path = "artifact";
+          }
+        } catch (InvalidProtocolBufferException exn) {
+          throw new RuntimeException(exn);
+        }
+        // Limit to the last contiguous alpha-numeric sequence. In particular, 
this will exclude
+        // all path separators.
+        List<String> components = 
Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);

Review comment:
       I was suggesting using the file separator for escaping so you could 
create a file name that doesn't contain it and not for figuring out the 
basename of the file.
   
   As you say its optional and this can be improved later by choosing a better 
file name. Also, the stage_to payload will likely have the original file name + 
hash.

##########
File path: 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java
##########
@@ -0,0 +1,602 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.artifact;
+
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
+import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.fnexecution.FnService;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.ResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArtifactStagingService
+    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase 
implements FnService {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ArtifactStagingService.class);
+
+  private final ArtifactDestinationProvider destinationProvider;
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> toStage =
+      new ConcurrentHashMap<>();
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> staged =
+      new ConcurrentHashMap<>();
+
+  public ArtifactStagingService(ArtifactDestinationProvider 
destinationProvider) {
+    this.destinationProvider = destinationProvider;
+  }
+
+  /**
+   * Registers a set of artifacts to be staged with this service.
+   *
+   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with 
the given staging
+   * token and offer resolution and retrieval of this set of artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   * @param artifacts all artifacts to stage, keyed by environment
+   */
+  public void registerJob(
+      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> 
artifacts) {
+    assert !toStage.containsKey(stagingToken);
+    toStage.put(stagingToken, artifacts);
+  }
+
+  /**
+   * Returns the rewritten artifacts associated with this job, keyed by 
environment.
+   *
+   * <p>This should be called after the client has finished offering artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   */
+  public Map<String, List<RunnerApi.ArtifactInformation>> 
getStagedArtifacts(String stagingToken) {
+    toStage.remove(stagingToken);
+    return staged.remove(stagingToken);
+  }
+
+  /** Provides a concrete location to which artifacts can be staged on 
retrieval. */
+  public interface ArtifactDestinationProvider {
+    ArtifactDestination getDestination(String stagingToken, String name) 
throws IOException;
+  }
+
+  /**
+   * A pairing of a newly created artifact type and an ouptut stream that will 
be readable at that
+   * type.
+   */
+  @AutoValue
+  public abstract static class ArtifactDestination {
+    public static ArtifactDestination create(
+        String typeUrn, ByteString typePayload, OutputStream out) {
+      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, 
typePayload, out);
+    }
+
+    public static ArtifactDestination fromFile(String path) throws IOException 
{
+      return fromFile(
+          path,
+          Channels.newOutputStream(
+              FileSystems.create(
+                  FileSystems.matchNewResource(path, false /* isDirectory */), 
MimeTypes.BINARY)));
+    }
+
+    public static ArtifactDestination fromFile(String path, OutputStream out) {
+      return create(
+          ArtifactRetrievalService.FILE_ARTIFACT_URN,
+          
RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
+          out);
+    }
+
+    public abstract String getTypeUrn();
+
+    public abstract ByteString getTypePayload();
+
+    public abstract OutputStream getOutputStream();
+  }
+
+  /**
+   * An ArtifactDestinationProvider that places new artifacts as files in a 
Beam filesystem.
+   *
+   * @param root the directory in which to place all artifacts
+   */
+  public static ArtifactDestinationProvider 
beamFilesystemArtifactDestinationProvider(String root) {
+    return (stagingToken, name) -> {
+      ResourceId path =
+          FileSystems.matchNewResource(root, true)
+              .resolve(stagingToken, 
ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
+              .resolve(name, 
ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
+      return ArtifactDestination.fromFile(path.toString());
+    };
+  }
+
+  private enum State {
+    START,
+    RESOLVE,
+    GET,
+    GETCHUNK,
+    DONE,
+    ERROR,
+  }
+
+  /**
+   * Like the standard Semaphore, but allows an aquire to go over the limit if 
there is any room.
+   *
+   * <p>Also allows setting an error, to avoid issues with un-released aquires 
after error.
+   */
+  private static class OverflowingSemaphore {
+    private int totalPermits;
+    private int usedPermits;
+    private Exception exception;
+
+    public OverflowingSemaphore(int totalPermits) {
+      this.totalPermits = totalPermits;
+      this.usedPermits = 0;
+    }
+
+    synchronized void aquire(int permits) throws Exception {
+      while (usedPermits >= totalPermits) {
+        if (exception != null) {
+          throw exception;
+        }
+        this.wait();
+      }
+      usedPermits += permits;
+    }
+
+    synchronized void release(int permits) {
+      usedPermits -= permits;
+      this.notifyAll();
+    }
+
+    synchronized void setException(Exception exception) {
+      this.exception = exception;
+      this.notifyAll();
+    }
+  }
+
+  /** A task that pulls bytes off a queue and actually writes them to a 
staging location. */
+  private class StoreArtifact implements 
Callable<RunnerApi.ArtifactInformation> {
+
+    private String stagingToken;
+    private String name;
+    private RunnerApi.ArtifactInformation originalArtifact;
+    private BlockingQueue<ByteString> bytesQueue;
+    private OverflowingSemaphore totalPendingBytes;
+
+    public StoreArtifact(
+        String stagingToken,
+        String name,
+        RunnerApi.ArtifactInformation originalArtifact,
+        BlockingQueue<ByteString> bytesQueue,
+        OverflowingSemaphore totalPendingBytes) {
+      this.stagingToken = stagingToken;
+      this.name = name;
+      this.originalArtifact = originalArtifact;
+      this.bytesQueue = bytesQueue;
+      this.totalPendingBytes = totalPendingBytes;
+    }
+
+    @Override
+    public RunnerApi.ArtifactInformation call() throws IOException {
+      try {
+        ArtifactDestination dest = 
destinationProvider.getDestination(stagingToken, name);
+        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, 
dest);
+        ByteString chunk = bytesQueue.take();
+        while (chunk.size() > 0) {
+          totalPendingBytes.release(chunk.size());
+          chunk.writeTo(dest.getOutputStream());
+          chunk = bytesQueue.take();
+        }
+        dest.getOutputStream().close();
+        return originalArtifact
+            .toBuilder()
+            .setTypeUrn(dest.getTypeUrn())
+            .setTypePayload(dest.getTypePayload())
+            .build();
+      } catch (IOException | InterruptedException exn) {
+        // As this thread will no longer be draining the queue, we don't want 
to get stuck writing
+        // to it.
+        totalPendingBytes.setException(exn);
+        LOG.error("Exception staging artifacts", exn);
+        if (exn instanceof IOException) {
+          throw (IOException) exn;
+        } else {
+          throw new RuntimeException(exn);
+        }
+      }
+    }
+  }
+
+  @Override
+  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
reverseArtifactRetrievalService(
+      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
+
+    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {
+
+      /** The maximum number of parallel threads to use to stage. */
+      public static final int THREAD_POOL_SIZE = 10;
+
+      /** The maximum number of bytes to buffer across all writes before 
throttling. */
+      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB
+
+      IdGenerator idGenerator = IdGenerators.incrementingLongs();
+
+      String stagingToken;
+      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
+      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
+      ExecutorService stagingExecutor;
+      OverflowingSemaphore totalPendingBytes;
+
+      State state = State.START;
+      Queue<String> pendingResolves;
+      String currentEnvironment;
+      Queue<RunnerApi.ArtifactInformation> pendingGets;
+      BlockingQueue<ByteString> currentOutput;
+
+      @Override
+      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = 
"fallthrough intended")
+      // May be called by different threads for the same request; synchronized 
for memory
+      // synchronization.
+      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper 
responseWrapper) {
+        switch (state) {
+          case START:
+            stagingToken = responseWrapper.getStagingToken();
+            LOG.info("Staging artifacts for {}.", stagingToken);
+            toResolve = toStage.get(stagingToken);
+            stagedFutures = new ConcurrentHashMap<>();
+            pendingResolves = new ArrayDeque<>();
+            pendingResolves.addAll(toResolve.keySet());
+            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
+            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
+            resolveNextEnvironment(responseObserver);
+            break;
+
+          case RESOLVE:
+            {
+              currentEnvironment = pendingResolves.remove();
+              stagedFutures.put(currentEnvironment, new ArrayList<>());
+              pendingGets = new ArrayDeque<>();
+              for (RunnerApi.ArtifactInformation artifact :
+                  
responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
+                Optional<RunnerApi.ArtifactInformation> fetched = 
getLocal(artifact);
+                if (fetched.isPresent()) {
+                  stagedFutures
+                      .get(currentEnvironment)
+                      .add(CompletableFuture.completedFuture(fetched.get()));
+                } else {
+                  pendingGets.add(artifact);
+                  responseObserver.onNext(
+                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                          .setGetArtifact(
+                              
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
+                          .build());
+                }
+              }
+              LOG.info(
+                  "Getting {} artifacts for {}.{}.",
+                  pendingGets.size(),
+                  stagingToken,
+                  pendingResolves.peek());
+              if (pendingGets.isEmpty()) {
+                resolveNextEnvironment(responseObserver);
+              } else {
+                state = State.GET;
+              }
+              break;
+            }
+
+          case GET:
+            RunnerApi.ArtifactInformation currentArtifact = 
pendingGets.remove();
+            String name = createFilename(currentEnvironment, currentArtifact);
+            try {
+              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
+              currentOutput = new ArrayBlockingQueue<ByteString>(100);
+              stagedFutures
+                  .get(currentEnvironment)
+                  .add(
+                      stagingExecutor.submit(
+                          new StoreArtifact(
+                              stagingToken,
+                              name,
+                              currentArtifact,
+                              currentOutput,
+                              totalPendingBytes)));
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              responseObserver.onError(exn);
+            }
+            state = State.GETCHUNK;
+            // fall through
+
+          case GETCHUNK:
+            try {
+              ByteString chunk = 
responseWrapper.getGetArtifactResponse().getData();
+              if (chunk.size() > 0) { // Make sure we don't accidentally send 
the EOF value.
+                totalPendingBytes.aquire(chunk.size());
+                currentOutput.put(chunk);
+              }
+              if (responseWrapper.getIsLast()) {
+                currentOutput.put(ByteString.EMPTY); // The EOF value.
+                if (pendingGets.isEmpty()) {
+                  resolveNextEnvironment(responseObserver);
+                } else {
+                  state = State.GET;
+                  LOG.debug("Waiting for {}", pendingGets.peek());
+                }
+              }
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              onError(exn);
+            }
+            break;
+
+          default:
+            responseObserver.onError(
+                new StatusException(
+                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + 
state)));
+        }
+      }
+
+      private void resolveNextEnvironment(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        if (pendingResolves.isEmpty()) {
+          finishStaging(responseObserver);
+        } else {
+          state = State.RESOLVE;
+          LOG.info("Resolving artifacts for {}.{}.", stagingToken, 
pendingResolves.peek());
+          responseObserver.onNext(
+              ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                  .setResolveArtifact(
+                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
+                          
.addAllArtifacts(toResolve.get(pendingResolves.peek())))
+                  .build());
+        }
+      }
+
+      private void finishStaging(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        LOG.debug("Finishing staging for {}.", stagingToken);
+        Map<String, List<RunnerApi.ArtifactInformation>> staged = new 
HashMap<>();
+        try {
+          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> 
entry :
+              stagedFutures.entrySet()) {
+            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
+            for (Future<RunnerApi.ArtifactInformation> future : 
entry.getValue()) {
+              envStaged.add(future.get());
+            }
+            staged.put(entry.getKey(), envStaged);
+          }
+          ArtifactStagingService.this.staged.put(stagingToken, staged);
+          stagingExecutor.shutdown();
+          state = State.DONE;
+          LOG.info("Artifacts fully staged for {}.", stagingToken);
+          responseObserver.onCompleted();
+        } catch (Exception exn) {
+          LOG.error("Error staging artifacts", exn);
+          responseObserver.onError(exn);
+          state = State.ERROR;
+          return;
+        }
+      }
+
+      /**
+       * Return an alternative artifact if we do not need to get this over the 
artifact API, or
+       * possibly at all.
+       */
+      private Optional<RunnerApi.ArtifactInformation> getLocal(
+          RunnerApi.ArtifactInformation artifact) {
+        return Optional.empty();
+      }
+
+      /**
+       * Attempts to provide a reasonable filename for the artifact.
+       *
+       * @param index a monotonically increasing index, which provides 
uniqueness
+       * @param environment the environment id
+       * @param artifact the artifact itself
+       */
+      private String createFilename(String environment, 
RunnerApi.ArtifactInformation artifact) {
+        String path;
+        try {
+          if 
(artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN))
 {
+            path =
+                
RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
+                    .getStagedName();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
+          } else {
+            path = "artifact";
+          }
+        } catch (InvalidProtocolBufferException exn) {
+          throw new RuntimeException(exn);
+        }
+        // Limit to the last contiguous alpha-numeric sequence. In particular, 
this will exclude
+        // all path separators.
+        List<String> components = 
Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
+        String base = components.get(components.size() - 1);
+        return clip(
+            String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 
25), base), 100);
+      }
+
+      private String clip(String s, int maxLength) {
+        return s.length() < maxLength ? s : s.substring(0, maxLength);
+      }
+
+      @Override
+      public void onError(Throwable throwable) {
+        stagingExecutor.shutdownNow();
+        LOG.error("Error staging artifacts", throwable);
+        state = State.ERROR;
+      }
+
+      @Override
+      public void onCompleted() {
+        Preconditions.checkArgument(state == State.DONE);
+      }
+    };
+  }
+
+  @Override
+  public void close() throws Exception {
+    // Nothing to close.
+  }
+
+  /**
+   * Lazily stages artifacts by letting an ArtifactStagingService resolve and 
request artifacts.
+   *
+   * @param retrievalService an ArtifactRetrievalService used to resolve and 
retrieve artifacts
+   * @param stagingService an ArtifactStagingService stub which will request 
artifacts
+   * @param stagingToken the staging token of the job whose artifacts will be 
retrieved
+   * @throws InterruptedException
+   * @throws IOException
+   */
+  public static void offer(
+      ArtifactRetrievalService retrievalService,
+      ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+      String stagingToken)
+      throws ExecutionException, InterruptedException {
+    new StagingDriver(retrievalService, stagingService, 
stagingToken).getCompletionFuture().get();
+  }
+
+  /** Actually implements the reverse retrieval protocol. */
+  private static class StagingDriver implements 
StreamObserver<ArtifactApi.ArtifactRequestWrapper> {
+
+    private final ArtifactRetrievalService retrievalService;
+    private final StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
responseObserver;
+    private final CompletableFuture<Void> completionFuture;
+
+    public StagingDriver(
+        ArtifactRetrievalService retrievalService,
+        ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+        String stagingToken) {
+      this.retrievalService = retrievalService;
+      responseObserver = stagingService.reverseArtifactRetrievalService(this);
+      responseObserver.onNext(
+          
ArtifactApi.ArtifactResponseWrapper.newBuilder().setStagingToken(stagingToken).build());
+      completionFuture = new CompletableFuture<Void>();
+    }
+
+    public CompletableFuture<?> getCompletionFuture() {
+      return completionFuture;
+    }
+
+    @Override
+    public void onNext(ArtifactApi.ArtifactRequestWrapper requestWrapper) {
+      if (requestWrapper.hasResolveArtifact()) {
+        retrievalService.resolveArtifacts(
+            requestWrapper.getResolveArtifact(),
+            new StreamObserver<ArtifactApi.ResolveArtifactsResponse>() {
+
+              @Override
+              public void onNext(ArtifactApi.ResolveArtifactsResponse 
resolveArtifactsResponse) {
+                responseObserver.onNext(
+                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
+                        .setResolveArtifactResponse(resolveArtifactsResponse)
+                        .build());
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                responseObserver.onError(throwable);
+              }
+
+              @Override
+              public void onCompleted() {}
+            });
+      } else if (requestWrapper.hasGetArtifact()) {
+        retrievalService.getArtifact(
+            requestWrapper.getGetArtifact(),
+            new StreamObserver<ArtifactApi.GetArtifactResponse>() {
+
+              @Override
+              public void onNext(ArtifactApi.GetArtifactResponse 
getArtifactResponse) {
+                responseObserver.onNext(
+                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
+                        .setGetArtifactResponse(getArtifactResponse)
+                        .build());
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                responseObserver.onError(throwable);

Review comment:
       complete the future with the error?

##########
File path: 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java
##########
@@ -0,0 +1,602 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.artifact;
+
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
+import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.fnexecution.FnService;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.ResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArtifactStagingService
+    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase 
implements FnService {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ArtifactStagingService.class);
+
+  private final ArtifactDestinationProvider destinationProvider;
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> toStage =
+      new ConcurrentHashMap<>();
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> staged =
+      new ConcurrentHashMap<>();
+
+  public ArtifactStagingService(ArtifactDestinationProvider 
destinationProvider) {
+    this.destinationProvider = destinationProvider;
+  }
+
+  /**
+   * Registers a set of artifacts to be staged with this service.
+   *
+   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with 
the given staging
+   * token and offer resolution and retrieval of this set of artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   * @param artifacts all artifacts to stage, keyed by environment
+   */
+  public void registerJob(
+      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> 
artifacts) {
+    assert !toStage.containsKey(stagingToken);
+    toStage.put(stagingToken, artifacts);
+  }
+
+  /**
+   * Returns the rewritten artifacts associated with this job, keyed by 
environment.
+   *
+   * <p>This should be called after the client has finished offering artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   */
+  public Map<String, List<RunnerApi.ArtifactInformation>> 
getStagedArtifacts(String stagingToken) {
+    toStage.remove(stagingToken);
+    return staged.remove(stagingToken);
+  }
+
+  /** Provides a concrete location to which artifacts can be staged on 
retrieval. */
+  public interface ArtifactDestinationProvider {
+    ArtifactDestination getDestination(String stagingToken, String name) 
throws IOException;
+  }
+
+  /**
+   * A pairing of a newly created artifact type and an ouptut stream that will 
be readable at that
+   * type.
+   */
+  @AutoValue
+  public abstract static class ArtifactDestination {
+    public static ArtifactDestination create(
+        String typeUrn, ByteString typePayload, OutputStream out) {
+      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, 
typePayload, out);
+    }
+
+    public static ArtifactDestination fromFile(String path) throws IOException 
{
+      return fromFile(
+          path,
+          Channels.newOutputStream(
+              FileSystems.create(
+                  FileSystems.matchNewResource(path, false /* isDirectory */), 
MimeTypes.BINARY)));
+    }
+
+    public static ArtifactDestination fromFile(String path, OutputStream out) {
+      return create(
+          ArtifactRetrievalService.FILE_ARTIFACT_URN,
+          
RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
+          out);
+    }
+
+    public abstract String getTypeUrn();
+
+    public abstract ByteString getTypePayload();
+
+    public abstract OutputStream getOutputStream();
+  }
+
+  /**
+   * An ArtifactDestinationProvider that places new artifacts as files in a 
Beam filesystem.
+   *
+   * @param root the directory in which to place all artifacts
+   */
+  public static ArtifactDestinationProvider 
beamFilesystemArtifactDestinationProvider(String root) {
+    return (stagingToken, name) -> {
+      ResourceId path =
+          FileSystems.matchNewResource(root, true)
+              .resolve(stagingToken, 
ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
+              .resolve(name, 
ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
+      return ArtifactDestination.fromFile(path.toString());
+    };
+  }
+
+  private enum State {
+    START,
+    RESOLVE,
+    GET,
+    GETCHUNK,
+    DONE,
+    ERROR,
+  }
+
+  /**
+   * Like the standard Semaphore, but allows an aquire to go over the limit if 
there is any room.
+   *
+   * <p>Also allows setting an error, to avoid issues with un-released aquires 
after error.
+   */
+  private static class OverflowingSemaphore {
+    private int totalPermits;
+    private int usedPermits;
+    private Exception exception;
+
+    public OverflowingSemaphore(int totalPermits) {
+      this.totalPermits = totalPermits;
+      this.usedPermits = 0;
+    }
+
+    synchronized void aquire(int permits) throws Exception {
+      while (usedPermits >= totalPermits) {
+        if (exception != null) {
+          throw exception;
+        }
+        this.wait();
+      }
+      usedPermits += permits;
+    }
+
+    synchronized void release(int permits) {
+      usedPermits -= permits;
+      this.notifyAll();
+    }
+
+    synchronized void setException(Exception exception) {
+      this.exception = exception;
+      this.notifyAll();
+    }
+  }
+
+  /** A task that pulls bytes off a queue and actually writes them to a 
staging location. */
+  private class StoreArtifact implements 
Callable<RunnerApi.ArtifactInformation> {
+
+    private String stagingToken;
+    private String name;
+    private RunnerApi.ArtifactInformation originalArtifact;
+    private BlockingQueue<ByteString> bytesQueue;
+    private OverflowingSemaphore totalPendingBytes;
+
+    public StoreArtifact(
+        String stagingToken,
+        String name,
+        RunnerApi.ArtifactInformation originalArtifact,
+        BlockingQueue<ByteString> bytesQueue,
+        OverflowingSemaphore totalPendingBytes) {
+      this.stagingToken = stagingToken;
+      this.name = name;
+      this.originalArtifact = originalArtifact;
+      this.bytesQueue = bytesQueue;
+      this.totalPendingBytes = totalPendingBytes;
+    }
+
+    @Override
+    public RunnerApi.ArtifactInformation call() throws IOException {
+      try {
+        ArtifactDestination dest = 
destinationProvider.getDestination(stagingToken, name);
+        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, 
dest);
+        ByteString chunk = bytesQueue.take();
+        while (chunk.size() > 0) {
+          totalPendingBytes.release(chunk.size());
+          chunk.writeTo(dest.getOutputStream());
+          chunk = bytesQueue.take();
+        }
+        dest.getOutputStream().close();
+        return originalArtifact
+            .toBuilder()
+            .setTypeUrn(dest.getTypeUrn())
+            .setTypePayload(dest.getTypePayload())
+            .build();
+      } catch (IOException | InterruptedException exn) {
+        // As this thread will no longer be draining the queue, we don't want 
to get stuck writing
+        // to it.
+        totalPendingBytes.setException(exn);
+        LOG.error("Exception staging artifacts", exn);
+        if (exn instanceof IOException) {
+          throw (IOException) exn;
+        } else {
+          throw new RuntimeException(exn);
+        }
+      }
+    }
+  }
+
+  @Override
+  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
reverseArtifactRetrievalService(
+      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
+
+    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {
+
+      /** The maximum number of parallel threads to use to stage. */
+      public static final int THREAD_POOL_SIZE = 10;
+
+      /** The maximum number of bytes to buffer across all writes before 
throttling. */
+      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB
+
+      IdGenerator idGenerator = IdGenerators.incrementingLongs();
+
+      String stagingToken;
+      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
+      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
+      ExecutorService stagingExecutor;
+      OverflowingSemaphore totalPendingBytes;
+
+      State state = State.START;
+      Queue<String> pendingResolves;
+      String currentEnvironment;
+      Queue<RunnerApi.ArtifactInformation> pendingGets;
+      BlockingQueue<ByteString> currentOutput;
+
+      @Override
+      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = 
"fallthrough intended")
+      // May be called by different threads for the same request; synchronized 
for memory
+      // synchronization.
+      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper 
responseWrapper) {
+        switch (state) {
+          case START:
+            stagingToken = responseWrapper.getStagingToken();
+            LOG.info("Staging artifacts for {}.", stagingToken);
+            toResolve = toStage.get(stagingToken);
+            stagedFutures = new ConcurrentHashMap<>();
+            pendingResolves = new ArrayDeque<>();
+            pendingResolves.addAll(toResolve.keySet());
+            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
+            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
+            resolveNextEnvironment(responseObserver);
+            break;
+
+          case RESOLVE:
+            {
+              currentEnvironment = pendingResolves.remove();
+              stagedFutures.put(currentEnvironment, new ArrayList<>());
+              pendingGets = new ArrayDeque<>();
+              for (RunnerApi.ArtifactInformation artifact :
+                  
responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
+                Optional<RunnerApi.ArtifactInformation> fetched = 
getLocal(artifact);
+                if (fetched.isPresent()) {
+                  stagedFutures
+                      .get(currentEnvironment)
+                      .add(CompletableFuture.completedFuture(fetched.get()));
+                } else {
+                  pendingGets.add(artifact);
+                  responseObserver.onNext(
+                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                          .setGetArtifact(
+                              
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
+                          .build());
+                }
+              }
+              LOG.info(
+                  "Getting {} artifacts for {}.{}.",
+                  pendingGets.size(),
+                  stagingToken,
+                  pendingResolves.peek());
+              if (pendingGets.isEmpty()) {
+                resolveNextEnvironment(responseObserver);
+              } else {
+                state = State.GET;
+              }
+              break;
+            }
+
+          case GET:
+            RunnerApi.ArtifactInformation currentArtifact = 
pendingGets.remove();
+            String name = createFilename(currentEnvironment, currentArtifact);
+            try {
+              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
+              currentOutput = new ArrayBlockingQueue<ByteString>(100);
+              stagedFutures
+                  .get(currentEnvironment)
+                  .add(
+                      stagingExecutor.submit(
+                          new StoreArtifact(
+                              stagingToken,
+                              name,
+                              currentArtifact,
+                              currentOutput,
+                              totalPendingBytes)));
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              responseObserver.onError(exn);
+            }
+            state = State.GETCHUNK;
+            // fall through
+
+          case GETCHUNK:
+            try {
+              ByteString chunk = 
responseWrapper.getGetArtifactResponse().getData();
+              if (chunk.size() > 0) { // Make sure we don't accidentally send 
the EOF value.
+                totalPendingBytes.aquire(chunk.size());
+                currentOutput.put(chunk);
+              }
+              if (responseWrapper.getIsLast()) {
+                currentOutput.put(ByteString.EMPTY); // The EOF value.
+                if (pendingGets.isEmpty()) {
+                  resolveNextEnvironment(responseObserver);
+                } else {
+                  state = State.GET;
+                  LOG.debug("Waiting for {}", pendingGets.peek());
+                }
+              }
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              onError(exn);
+            }
+            break;
+
+          default:
+            responseObserver.onError(
+                new StatusException(
+                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + 
state)));
+        }
+      }
+
+      private void resolveNextEnvironment(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        if (pendingResolves.isEmpty()) {
+          finishStaging(responseObserver);
+        } else {
+          state = State.RESOLVE;
+          LOG.info("Resolving artifacts for {}.{}.", stagingToken, 
pendingResolves.peek());
+          responseObserver.onNext(
+              ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                  .setResolveArtifact(
+                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
+                          
.addAllArtifacts(toResolve.get(pendingResolves.peek())))
+                  .build());
+        }
+      }
+
+      private void finishStaging(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        LOG.debug("Finishing staging for {}.", stagingToken);
+        Map<String, List<RunnerApi.ArtifactInformation>> staged = new 
HashMap<>();
+        try {
+          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> 
entry :
+              stagedFutures.entrySet()) {
+            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
+            for (Future<RunnerApi.ArtifactInformation> future : 
entry.getValue()) {
+              envStaged.add(future.get());
+            }
+            staged.put(entry.getKey(), envStaged);
+          }
+          ArtifactStagingService.this.staged.put(stagingToken, staged);
+          stagingExecutor.shutdown();
+          state = State.DONE;
+          LOG.info("Artifacts fully staged for {}.", stagingToken);
+          responseObserver.onCompleted();
+        } catch (Exception exn) {
+          LOG.error("Error staging artifacts", exn);
+          responseObserver.onError(exn);
+          state = State.ERROR;
+          return;
+        }
+      }
+
+      /**
+       * Return an alternative artifact if we do not need to get this over the 
artifact API, or
+       * possibly at all.
+       */
+      private Optional<RunnerApi.ArtifactInformation> getLocal(
+          RunnerApi.ArtifactInformation artifact) {
+        return Optional.empty();
+      }
+
+      /**
+       * Attempts to provide a reasonable filename for the artifact.
+       *
+       * @param index a monotonically increasing index, which provides 
uniqueness
+       * @param environment the environment id
+       * @param artifact the artifact itself
+       */
+      private String createFilename(String environment, 
RunnerApi.ArtifactInformation artifact) {
+        String path;
+        try {
+          if 
(artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN))
 {
+            path =
+                
RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
+                    .getStagedName();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
+          } else {
+            path = "artifact";
+          }
+        } catch (InvalidProtocolBufferException exn) {
+          throw new RuntimeException(exn);
+        }
+        // Limit to the last contiguous alpha-numeric sequence. In particular, 
this will exclude
+        // all path separators.
+        List<String> components = 
Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
+        String base = components.get(components.size() - 1);
+        return clip(
+            String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 
25), base), 100);
+      }
+
+      private String clip(String s, int maxLength) {
+        return s.length() < maxLength ? s : s.substring(0, maxLength);
+      }
+
+      @Override
+      public void onError(Throwable throwable) {
+        stagingExecutor.shutdownNow();
+        LOG.error("Error staging artifacts", throwable);
+        state = State.ERROR;
+      }
+
+      @Override
+      public void onCompleted() {
+        Preconditions.checkArgument(state == State.DONE);
+      }
+    };
+  }
+
+  @Override
+  public void close() throws Exception {
+    // Nothing to close.
+  }
+
+  /**
+   * Lazily stages artifacts by letting an ArtifactStagingService resolve and 
request artifacts.
+   *
+   * @param retrievalService an ArtifactRetrievalService used to resolve and 
retrieve artifacts
+   * @param stagingService an ArtifactStagingService stub which will request 
artifacts
+   * @param stagingToken the staging token of the job whose artifacts will be 
retrieved
+   * @throws InterruptedException
+   * @throws IOException
+   */
+  public static void offer(
+      ArtifactRetrievalService retrievalService,
+      ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+      String stagingToken)
+      throws ExecutionException, InterruptedException {
+    new StagingDriver(retrievalService, stagingService, 
stagingToken).getCompletionFuture().get();
+  }
+
+  /** Actually implements the reverse retrieval protocol. */
+  private static class StagingDriver implements 
StreamObserver<ArtifactApi.ArtifactRequestWrapper> {
+
+    private final ArtifactRetrievalService retrievalService;
+    private final StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
responseObserver;
+    private final CompletableFuture<Void> completionFuture;
+
+    public StagingDriver(
+        ArtifactRetrievalService retrievalService,
+        ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+        String stagingToken) {
+      this.retrievalService = retrievalService;
+      responseObserver = stagingService.reverseArtifactRetrievalService(this);
+      responseObserver.onNext(
+          
ArtifactApi.ArtifactResponseWrapper.newBuilder().setStagingToken(stagingToken).build());
+      completionFuture = new CompletableFuture<Void>();
+    }
+
+    public CompletableFuture<?> getCompletionFuture() {
+      return completionFuture;
+    }
+
+    @Override
+    public void onNext(ArtifactApi.ArtifactRequestWrapper requestWrapper) {
+      if (requestWrapper.hasResolveArtifact()) {
+        retrievalService.resolveArtifacts(
+            requestWrapper.getResolveArtifact(),
+            new StreamObserver<ArtifactApi.ResolveArtifactsResponse>() {
+
+              @Override
+              public void onNext(ArtifactApi.ResolveArtifactsResponse 
resolveArtifactsResponse) {
+                responseObserver.onNext(
+                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
+                        .setResolveArtifactResponse(resolveArtifactsResponse)
+                        .build());
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                responseObserver.onError(throwable);

Review comment:
       complete the future with the error?

##########
File path: 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java
##########
@@ -0,0 +1,602 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.artifact;
+
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
+import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.fnexecution.FnService;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.ResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArtifactStagingService
+    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase 
implements FnService {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ArtifactStagingService.class);
+
+  private final ArtifactDestinationProvider destinationProvider;
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> toStage =
+      new ConcurrentHashMap<>();
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> staged =
+      new ConcurrentHashMap<>();
+
+  public ArtifactStagingService(ArtifactDestinationProvider 
destinationProvider) {
+    this.destinationProvider = destinationProvider;
+  }
+
+  /**
+   * Registers a set of artifacts to be staged with this service.
+   *
+   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with 
the given staging
+   * token and offer resolution and retrieval of this set of artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   * @param artifacts all artifacts to stage, keyed by environment
+   */
+  public void registerJob(
+      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> 
artifacts) {
+    assert !toStage.containsKey(stagingToken);
+    toStage.put(stagingToken, artifacts);
+  }
+
+  /**
+   * Returns the rewritten artifacts associated with this job, keyed by 
environment.
+   *
+   * <p>This should be called after the client has finished offering artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   */
+  public Map<String, List<RunnerApi.ArtifactInformation>> 
getStagedArtifacts(String stagingToken) {
+    toStage.remove(stagingToken);
+    return staged.remove(stagingToken);
+  }
+
+  /** Provides a concrete location to which artifacts can be staged on 
retrieval. */
+  public interface ArtifactDestinationProvider {
+    ArtifactDestination getDestination(String stagingToken, String name) 
throws IOException;
+  }
+
+  /**
+   * A pairing of a newly created artifact type and an ouptut stream that will 
be readable at that
+   * type.
+   */
+  @AutoValue
+  public abstract static class ArtifactDestination {
+    public static ArtifactDestination create(
+        String typeUrn, ByteString typePayload, OutputStream out) {
+      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, 
typePayload, out);
+    }
+
+    public static ArtifactDestination fromFile(String path) throws IOException 
{
+      return fromFile(
+          path,
+          Channels.newOutputStream(
+              FileSystems.create(
+                  FileSystems.matchNewResource(path, false /* isDirectory */), 
MimeTypes.BINARY)));
+    }
+
+    public static ArtifactDestination fromFile(String path, OutputStream out) {
+      return create(
+          ArtifactRetrievalService.FILE_ARTIFACT_URN,
+          
RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
+          out);
+    }
+
+    public abstract String getTypeUrn();
+
+    public abstract ByteString getTypePayload();
+
+    public abstract OutputStream getOutputStream();
+  }
+
+  /**
+   * An ArtifactDestinationProvider that places new artifacts as files in a 
Beam filesystem.
+   *
+   * @param root the directory in which to place all artifacts
+   */
+  public static ArtifactDestinationProvider 
beamFilesystemArtifactDestinationProvider(String root) {
+    return (stagingToken, name) -> {
+      ResourceId path =
+          FileSystems.matchNewResource(root, true)
+              .resolve(stagingToken, 
ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
+              .resolve(name, 
ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
+      return ArtifactDestination.fromFile(path.toString());
+    };
+  }
+
+  private enum State {
+    START,
+    RESOLVE,
+    GET,
+    GETCHUNK,
+    DONE,
+    ERROR,
+  }
+
+  /**
+   * Like the standard Semaphore, but allows an aquire to go over the limit if 
there is any room.
+   *
+   * <p>Also allows setting an error, to avoid issues with un-released aquires 
after error.
+   */
+  private static class OverflowingSemaphore {
+    private int totalPermits;
+    private int usedPermits;
+    private Exception exception;
+
+    public OverflowingSemaphore(int totalPermits) {
+      this.totalPermits = totalPermits;
+      this.usedPermits = 0;
+    }
+
+    synchronized void aquire(int permits) throws Exception {
+      while (usedPermits >= totalPermits) {
+        if (exception != null) {
+          throw exception;
+        }
+        this.wait();
+      }
+      usedPermits += permits;
+    }
+
+    synchronized void release(int permits) {
+      usedPermits -= permits;
+      this.notifyAll();
+    }
+
+    synchronized void setException(Exception exception) {
+      this.exception = exception;
+      this.notifyAll();
+    }
+  }
+
+  /** A task that pulls bytes off a queue and actually writes them to a 
staging location. */
+  private class StoreArtifact implements 
Callable<RunnerApi.ArtifactInformation> {
+
+    private String stagingToken;
+    private String name;
+    private RunnerApi.ArtifactInformation originalArtifact;
+    private BlockingQueue<ByteString> bytesQueue;
+    private OverflowingSemaphore totalPendingBytes;
+
+    public StoreArtifact(
+        String stagingToken,
+        String name,
+        RunnerApi.ArtifactInformation originalArtifact,
+        BlockingQueue<ByteString> bytesQueue,
+        OverflowingSemaphore totalPendingBytes) {
+      this.stagingToken = stagingToken;
+      this.name = name;
+      this.originalArtifact = originalArtifact;
+      this.bytesQueue = bytesQueue;
+      this.totalPendingBytes = totalPendingBytes;
+    }
+
+    @Override
+    public RunnerApi.ArtifactInformation call() throws IOException {
+      try {
+        ArtifactDestination dest = 
destinationProvider.getDestination(stagingToken, name);
+        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, 
dest);
+        ByteString chunk = bytesQueue.take();
+        while (chunk.size() > 0) {
+          totalPendingBytes.release(chunk.size());
+          chunk.writeTo(dest.getOutputStream());
+          chunk = bytesQueue.take();
+        }
+        dest.getOutputStream().close();
+        return originalArtifact
+            .toBuilder()
+            .setTypeUrn(dest.getTypeUrn())
+            .setTypePayload(dest.getTypePayload())
+            .build();
+      } catch (IOException | InterruptedException exn) {
+        // As this thread will no longer be draining the queue, we don't want 
to get stuck writing
+        // to it.
+        totalPendingBytes.setException(exn);
+        LOG.error("Exception staging artifacts", exn);
+        if (exn instanceof IOException) {
+          throw (IOException) exn;
+        } else {
+          throw new RuntimeException(exn);
+        }
+      }
+    }
+  }
+
+  @Override
+  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
reverseArtifactRetrievalService(
+      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
+
+    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {
+
+      /** The maximum number of parallel threads to use to stage. */
+      public static final int THREAD_POOL_SIZE = 10;
+
+      /** The maximum number of bytes to buffer across all writes before 
throttling. */
+      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB
+
+      IdGenerator idGenerator = IdGenerators.incrementingLongs();
+
+      String stagingToken;
+      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
+      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
+      ExecutorService stagingExecutor;
+      OverflowingSemaphore totalPendingBytes;
+
+      State state = State.START;
+      Queue<String> pendingResolves;
+      String currentEnvironment;
+      Queue<RunnerApi.ArtifactInformation> pendingGets;
+      BlockingQueue<ByteString> currentOutput;
+
+      @Override
+      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = 
"fallthrough intended")
+      // May be called by different threads for the same request; synchronized 
for memory
+      // synchronization.
+      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper 
responseWrapper) {
+        switch (state) {
+          case START:
+            stagingToken = responseWrapper.getStagingToken();
+            LOG.info("Staging artifacts for {}.", stagingToken);
+            toResolve = toStage.get(stagingToken);
+            stagedFutures = new ConcurrentHashMap<>();
+            pendingResolves = new ArrayDeque<>();
+            pendingResolves.addAll(toResolve.keySet());
+            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
+            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
+            resolveNextEnvironment(responseObserver);
+            break;
+
+          case RESOLVE:
+            {
+              currentEnvironment = pendingResolves.remove();
+              stagedFutures.put(currentEnvironment, new ArrayList<>());
+              pendingGets = new ArrayDeque<>();
+              for (RunnerApi.ArtifactInformation artifact :
+                  
responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
+                Optional<RunnerApi.ArtifactInformation> fetched = 
getLocal(artifact);
+                if (fetched.isPresent()) {
+                  stagedFutures
+                      .get(currentEnvironment)
+                      .add(CompletableFuture.completedFuture(fetched.get()));
+                } else {
+                  pendingGets.add(artifact);
+                  responseObserver.onNext(
+                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                          .setGetArtifact(
+                              
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
+                          .build());
+                }
+              }
+              LOG.info(
+                  "Getting {} artifacts for {}.{}.",
+                  pendingGets.size(),
+                  stagingToken,
+                  pendingResolves.peek());
+              if (pendingGets.isEmpty()) {
+                resolveNextEnvironment(responseObserver);
+              } else {
+                state = State.GET;
+              }
+              break;
+            }
+
+          case GET:
+            RunnerApi.ArtifactInformation currentArtifact = 
pendingGets.remove();
+            String name = createFilename(currentEnvironment, currentArtifact);
+            try {
+              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
+              currentOutput = new ArrayBlockingQueue<ByteString>(100);
+              stagedFutures
+                  .get(currentEnvironment)
+                  .add(
+                      stagingExecutor.submit(
+                          new StoreArtifact(
+                              stagingToken,
+                              name,
+                              currentArtifact,
+                              currentOutput,
+                              totalPendingBytes)));
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              responseObserver.onError(exn);
+            }
+            state = State.GETCHUNK;
+            // fall through
+
+          case GETCHUNK:
+            try {
+              ByteString chunk = 
responseWrapper.getGetArtifactResponse().getData();
+              if (chunk.size() > 0) { // Make sure we don't accidentally send 
the EOF value.
+                totalPendingBytes.aquire(chunk.size());
+                currentOutput.put(chunk);
+              }
+              if (responseWrapper.getIsLast()) {
+                currentOutput.put(ByteString.EMPTY); // The EOF value.
+                if (pendingGets.isEmpty()) {
+                  resolveNextEnvironment(responseObserver);
+                } else {
+                  state = State.GET;
+                  LOG.debug("Waiting for {}", pendingGets.peek());
+                }
+              }
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              onError(exn);
+            }
+            break;
+
+          default:
+            responseObserver.onError(
+                new StatusException(
+                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + 
state)));
+        }
+      }
+
+      private void resolveNextEnvironment(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        if (pendingResolves.isEmpty()) {
+          finishStaging(responseObserver);
+        } else {
+          state = State.RESOLVE;
+          LOG.info("Resolving artifacts for {}.{}.", stagingToken, 
pendingResolves.peek());
+          responseObserver.onNext(
+              ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                  .setResolveArtifact(
+                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
+                          
.addAllArtifacts(toResolve.get(pendingResolves.peek())))
+                  .build());
+        }
+      }
+
+      private void finishStaging(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        LOG.debug("Finishing staging for {}.", stagingToken);
+        Map<String, List<RunnerApi.ArtifactInformation>> staged = new 
HashMap<>();
+        try {
+          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> 
entry :
+              stagedFutures.entrySet()) {
+            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
+            for (Future<RunnerApi.ArtifactInformation> future : 
entry.getValue()) {
+              envStaged.add(future.get());
+            }
+            staged.put(entry.getKey(), envStaged);
+          }
+          ArtifactStagingService.this.staged.put(stagingToken, staged);
+          stagingExecutor.shutdown();
+          state = State.DONE;
+          LOG.info("Artifacts fully staged for {}.", stagingToken);
+          responseObserver.onCompleted();
+        } catch (Exception exn) {
+          LOG.error("Error staging artifacts", exn);
+          responseObserver.onError(exn);
+          state = State.ERROR;
+          return;
+        }
+      }
+
+      /**
+       * Return an alternative artifact if we do not need to get this over the 
artifact API, or
+       * possibly at all.
+       */
+      private Optional<RunnerApi.ArtifactInformation> getLocal(
+          RunnerApi.ArtifactInformation artifact) {
+        return Optional.empty();
+      }
+
+      /**
+       * Attempts to provide a reasonable filename for the artifact.
+       *
+       * @param index a monotonically increasing index, which provides 
uniqueness
+       * @param environment the environment id
+       * @param artifact the artifact itself
+       */
+      private String createFilename(String environment, 
RunnerApi.ArtifactInformation artifact) {
+        String path;
+        try {
+          if 
(artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN))
 {
+            path =
+                
RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
+                    .getStagedName();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
+          } else {
+            path = "artifact";
+          }
+        } catch (InvalidProtocolBufferException exn) {
+          throw new RuntimeException(exn);
+        }
+        // Limit to the last contiguous alpha-numeric sequence. In particular, 
this will exclude
+        // all path separators.
+        List<String> components = 
Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
+        String base = components.get(components.size() - 1);
+        return clip(
+            String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 
25), base), 100);
+      }
+
+      private String clip(String s, int maxLength) {
+        return s.length() < maxLength ? s : s.substring(0, maxLength);
+      }
+
+      @Override
+      public void onError(Throwable throwable) {
+        stagingExecutor.shutdownNow();
+        LOG.error("Error staging artifacts", throwable);
+        state = State.ERROR;
+      }
+
+      @Override
+      public void onCompleted() {
+        Preconditions.checkArgument(state == State.DONE);
+      }
+    };
+  }
+
+  @Override
+  public void close() throws Exception {
+    // Nothing to close.
+  }
+
+  /**
+   * Lazily stages artifacts by letting an ArtifactStagingService resolve and 
request artifacts.
+   *
+   * @param retrievalService an ArtifactRetrievalService used to resolve and 
retrieve artifacts
+   * @param stagingService an ArtifactStagingService stub which will request 
artifacts
+   * @param stagingToken the staging token of the job whose artifacts will be 
retrieved
+   * @throws InterruptedException
+   * @throws IOException
+   */
+  public static void offer(
+      ArtifactRetrievalService retrievalService,
+      ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+      String stagingToken)
+      throws ExecutionException, InterruptedException {
+    new StagingDriver(retrievalService, stagingService, 
stagingToken).getCompletionFuture().get();
+  }
+
+  /** Actually implements the reverse retrieval protocol. */
+  private static class StagingDriver implements 
StreamObserver<ArtifactApi.ArtifactRequestWrapper> {
+
+    private final ArtifactRetrievalService retrievalService;
+    private final StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
responseObserver;
+    private final CompletableFuture<Void> completionFuture;
+
+    public StagingDriver(
+        ArtifactRetrievalService retrievalService,
+        ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+        String stagingToken) {
+      this.retrievalService = retrievalService;
+      responseObserver = stagingService.reverseArtifactRetrievalService(this);
+      responseObserver.onNext(
+          
ArtifactApi.ArtifactResponseWrapper.newBuilder().setStagingToken(stagingToken).build());
+      completionFuture = new CompletableFuture<Void>();
+    }
+
+    public CompletableFuture<?> getCompletionFuture() {
+      return completionFuture;
+    }
+
+    @Override
+    public void onNext(ArtifactApi.ArtifactRequestWrapper requestWrapper) {
+      if (requestWrapper.hasResolveArtifact()) {
+        retrievalService.resolveArtifacts(
+            requestWrapper.getResolveArtifact(),
+            new StreamObserver<ArtifactApi.ResolveArtifactsResponse>() {
+
+              @Override
+              public void onNext(ArtifactApi.ResolveArtifactsResponse 
resolveArtifactsResponse) {
+                responseObserver.onNext(
+                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
+                        .setResolveArtifactResponse(resolveArtifactsResponse)
+                        .build());
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                responseObserver.onError(throwable);
+              }
+
+              @Override
+              public void onCompleted() {}
+            });
+      } else if (requestWrapper.hasGetArtifact()) {
+        retrievalService.getArtifact(
+            requestWrapper.getGetArtifact(),
+            new StreamObserver<ArtifactApi.GetArtifactResponse>() {
+
+              @Override
+              public void onNext(ArtifactApi.GetArtifactResponse 
getArtifactResponse) {
+                responseObserver.onNext(
+                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
+                        .setGetArtifactResponse(getArtifactResponse)
+                        .build());
+              }
+
+              @Override
+              public void onError(Throwable throwable) {
+                responseObserver.onError(throwable);
+              }
+
+              @Override
+              public void onCompleted() {
+                responseObserver.onNext(
+                    ArtifactApi.ArtifactResponseWrapper.newBuilder()
+                        .setGetArtifactResponse(
+                            
ArtifactApi.GetArtifactResponse.newBuilder().build())
+                        .setIsLast(true)
+                        .build());
+              }
+            });
+      } else {
+        responseObserver.onError(

Review comment:
       complete the future with the error?

##########
File path: 
runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/ArtifactStagingService.java
##########
@@ -0,0 +1,602 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.artifact;
+
+import com.google.auto.value.AutoValue;
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.Channels;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
+import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.fnexecution.FnService;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.io.FileSystems;
+import org.apache.beam.sdk.io.fs.ResolveOptions;
+import org.apache.beam.sdk.io.fs.ResourceId;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.Status;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
+import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArtifactStagingService
+    extends ArtifactStagingServiceGrpc.ArtifactStagingServiceImplBase 
implements FnService {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ArtifactStagingService.class);
+
+  private final ArtifactDestinationProvider destinationProvider;
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> toStage =
+      new ConcurrentHashMap<>();
+
+  private final ConcurrentMap<String, Map<String, 
List<RunnerApi.ArtifactInformation>>> staged =
+      new ConcurrentHashMap<>();
+
+  public ArtifactStagingService(ArtifactDestinationProvider 
destinationProvider) {
+    this.destinationProvider = destinationProvider;
+  }
+
+  /**
+   * Registers a set of artifacts to be staged with this service.
+   *
+   * <p>A client (e.g. a Beam SDK) is expected to connect to this service with 
the given staging
+   * token and offer resolution and retrieval of this set of artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   * @param artifacts all artifacts to stage, keyed by environment
+   */
+  public void registerJob(
+      String stagingToken, Map<String, List<RunnerApi.ArtifactInformation>> 
artifacts) {
+    assert !toStage.containsKey(stagingToken);
+    toStage.put(stagingToken, artifacts);
+  }
+
+  /**
+   * Returns the rewritten artifacts associated with this job, keyed by 
environment.
+   *
+   * <p>This should be called after the client has finished offering artifacts.
+   *
+   * @param stagingToken a staging token for this job
+   */
+  public Map<String, List<RunnerApi.ArtifactInformation>> 
getStagedArtifacts(String stagingToken) {
+    toStage.remove(stagingToken);
+    return staged.remove(stagingToken);
+  }
+
+  /** Provides a concrete location to which artifacts can be staged on 
retrieval. */
+  public interface ArtifactDestinationProvider {
+    ArtifactDestination getDestination(String stagingToken, String name) 
throws IOException;
+  }
+
+  /**
+   * A pairing of a newly created artifact type and an ouptut stream that will 
be readable at that
+   * type.
+   */
+  @AutoValue
+  public abstract static class ArtifactDestination {
+    public static ArtifactDestination create(
+        String typeUrn, ByteString typePayload, OutputStream out) {
+      return new AutoValue_ArtifactStagingService_ArtifactDestination(typeUrn, 
typePayload, out);
+    }
+
+    public static ArtifactDestination fromFile(String path) throws IOException 
{
+      return fromFile(
+          path,
+          Channels.newOutputStream(
+              FileSystems.create(
+                  FileSystems.matchNewResource(path, false /* isDirectory */), 
MimeTypes.BINARY)));
+    }
+
+    public static ArtifactDestination fromFile(String path, OutputStream out) {
+      return create(
+          ArtifactRetrievalService.FILE_ARTIFACT_URN,
+          
RunnerApi.ArtifactFilePayload.newBuilder().setPath(path).build().toByteString(),
+          out);
+    }
+
+    public abstract String getTypeUrn();
+
+    public abstract ByteString getTypePayload();
+
+    public abstract OutputStream getOutputStream();
+  }
+
+  /**
+   * An ArtifactDestinationProvider that places new artifacts as files in a 
Beam filesystem.
+   *
+   * @param root the directory in which to place all artifacts
+   */
+  public static ArtifactDestinationProvider 
beamFilesystemArtifactDestinationProvider(String root) {
+    return (stagingToken, name) -> {
+      ResourceId path =
+          FileSystems.matchNewResource(root, true)
+              .resolve(stagingToken, 
ResolveOptions.StandardResolveOptions.RESOLVE_DIRECTORY)
+              .resolve(name, 
ResolveOptions.StandardResolveOptions.RESOLVE_FILE);
+      return ArtifactDestination.fromFile(path.toString());
+    };
+  }
+
+  private enum State {
+    START,
+    RESOLVE,
+    GET,
+    GETCHUNK,
+    DONE,
+    ERROR,
+  }
+
+  /**
+   * Like the standard Semaphore, but allows an aquire to go over the limit if 
there is any room.
+   *
+   * <p>Also allows setting an error, to avoid issues with un-released aquires 
after error.
+   */
+  private static class OverflowingSemaphore {
+    private int totalPermits;
+    private int usedPermits;
+    private Exception exception;
+
+    public OverflowingSemaphore(int totalPermits) {
+      this.totalPermits = totalPermits;
+      this.usedPermits = 0;
+    }
+
+    synchronized void aquire(int permits) throws Exception {
+      while (usedPermits >= totalPermits) {
+        if (exception != null) {
+          throw exception;
+        }
+        this.wait();
+      }
+      usedPermits += permits;
+    }
+
+    synchronized void release(int permits) {
+      usedPermits -= permits;
+      this.notifyAll();
+    }
+
+    synchronized void setException(Exception exception) {
+      this.exception = exception;
+      this.notifyAll();
+    }
+  }
+
+  /** A task that pulls bytes off a queue and actually writes them to a 
staging location. */
+  private class StoreArtifact implements 
Callable<RunnerApi.ArtifactInformation> {
+
+    private String stagingToken;
+    private String name;
+    private RunnerApi.ArtifactInformation originalArtifact;
+    private BlockingQueue<ByteString> bytesQueue;
+    private OverflowingSemaphore totalPendingBytes;
+
+    public StoreArtifact(
+        String stagingToken,
+        String name,
+        RunnerApi.ArtifactInformation originalArtifact,
+        BlockingQueue<ByteString> bytesQueue,
+        OverflowingSemaphore totalPendingBytes) {
+      this.stagingToken = stagingToken;
+      this.name = name;
+      this.originalArtifact = originalArtifact;
+      this.bytesQueue = bytesQueue;
+      this.totalPendingBytes = totalPendingBytes;
+    }
+
+    @Override
+    public RunnerApi.ArtifactInformation call() throws IOException {
+      try {
+        ArtifactDestination dest = 
destinationProvider.getDestination(stagingToken, name);
+        LOG.debug("Storing artifact for {}.{} at {}", stagingToken, name, 
dest);
+        ByteString chunk = bytesQueue.take();
+        while (chunk.size() > 0) {
+          totalPendingBytes.release(chunk.size());
+          chunk.writeTo(dest.getOutputStream());
+          chunk = bytesQueue.take();
+        }
+        dest.getOutputStream().close();
+        return originalArtifact
+            .toBuilder()
+            .setTypeUrn(dest.getTypeUrn())
+            .setTypePayload(dest.getTypePayload())
+            .build();
+      } catch (IOException | InterruptedException exn) {
+        // As this thread will no longer be draining the queue, we don't want 
to get stuck writing
+        // to it.
+        totalPendingBytes.setException(exn);
+        LOG.error("Exception staging artifacts", exn);
+        if (exn instanceof IOException) {
+          throw (IOException) exn;
+        } else {
+          throw new RuntimeException(exn);
+        }
+      }
+    }
+  }
+
+  @Override
+  public StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
reverseArtifactRetrievalService(
+      StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) {
+
+    return new StreamObserver<ArtifactApi.ArtifactResponseWrapper>() {
+
+      /** The maximum number of parallel threads to use to stage. */
+      public static final int THREAD_POOL_SIZE = 10;
+
+      /** The maximum number of bytes to buffer across all writes before 
throttling. */
+      public static final int MAX_PENDING_BYTES = 100 << 20; // 100 MB
+
+      IdGenerator idGenerator = IdGenerators.incrementingLongs();
+
+      String stagingToken;
+      Map<String, List<RunnerApi.ArtifactInformation>> toResolve;
+      Map<String, List<Future<RunnerApi.ArtifactInformation>>> stagedFutures;
+      ExecutorService stagingExecutor;
+      OverflowingSemaphore totalPendingBytes;
+
+      State state = State.START;
+      Queue<String> pendingResolves;
+      String currentEnvironment;
+      Queue<RunnerApi.ArtifactInformation> pendingGets;
+      BlockingQueue<ByteString> currentOutput;
+
+      @Override
+      @SuppressFBWarnings(value = "SF_SWITCH_FALLTHROUGH", justification = 
"fallthrough intended")
+      // May be called by different threads for the same request; synchronized 
for memory
+      // synchronization.
+      public synchronized void onNext(ArtifactApi.ArtifactResponseWrapper 
responseWrapper) {
+        switch (state) {
+          case START:
+            stagingToken = responseWrapper.getStagingToken();
+            LOG.info("Staging artifacts for {}.", stagingToken);
+            toResolve = toStage.get(stagingToken);
+            stagedFutures = new ConcurrentHashMap<>();
+            pendingResolves = new ArrayDeque<>();
+            pendingResolves.addAll(toResolve.keySet());
+            stagingExecutor = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
+            totalPendingBytes = new OverflowingSemaphore(MAX_PENDING_BYTES);
+            resolveNextEnvironment(responseObserver);
+            break;
+
+          case RESOLVE:
+            {
+              currentEnvironment = pendingResolves.remove();
+              stagedFutures.put(currentEnvironment, new ArrayList<>());
+              pendingGets = new ArrayDeque<>();
+              for (RunnerApi.ArtifactInformation artifact :
+                  
responseWrapper.getResolveArtifactResponse().getReplacementsList()) {
+                Optional<RunnerApi.ArtifactInformation> fetched = 
getLocal(artifact);
+                if (fetched.isPresent()) {
+                  stagedFutures
+                      .get(currentEnvironment)
+                      .add(CompletableFuture.completedFuture(fetched.get()));
+                } else {
+                  pendingGets.add(artifact);
+                  responseObserver.onNext(
+                      ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                          .setGetArtifact(
+                              
ArtifactApi.GetArtifactRequest.newBuilder().setArtifact(artifact))
+                          .build());
+                }
+              }
+              LOG.info(
+                  "Getting {} artifacts for {}.{}.",
+                  pendingGets.size(),
+                  stagingToken,
+                  pendingResolves.peek());
+              if (pendingGets.isEmpty()) {
+                resolveNextEnvironment(responseObserver);
+              } else {
+                state = State.GET;
+              }
+              break;
+            }
+
+          case GET:
+            RunnerApi.ArtifactInformation currentArtifact = 
pendingGets.remove();
+            String name = createFilename(currentEnvironment, currentArtifact);
+            try {
+              LOG.debug("Storing artifacts for {} as {}", stagingToken, name);
+              currentOutput = new ArrayBlockingQueue<ByteString>(100);
+              stagedFutures
+                  .get(currentEnvironment)
+                  .add(
+                      stagingExecutor.submit(
+                          new StoreArtifact(
+                              stagingToken,
+                              name,
+                              currentArtifact,
+                              currentOutput,
+                              totalPendingBytes)));
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              responseObserver.onError(exn);
+            }
+            state = State.GETCHUNK;
+            // fall through
+
+          case GETCHUNK:
+            try {
+              ByteString chunk = 
responseWrapper.getGetArtifactResponse().getData();
+              if (chunk.size() > 0) { // Make sure we don't accidentally send 
the EOF value.
+                totalPendingBytes.aquire(chunk.size());
+                currentOutput.put(chunk);
+              }
+              if (responseWrapper.getIsLast()) {
+                currentOutput.put(ByteString.EMPTY); // The EOF value.
+                if (pendingGets.isEmpty()) {
+                  resolveNextEnvironment(responseObserver);
+                } else {
+                  state = State.GET;
+                  LOG.debug("Waiting for {}", pendingGets.peek());
+                }
+              }
+            } catch (Exception exn) {
+              LOG.error("Error submitting.", exn);
+              onError(exn);
+            }
+            break;
+
+          default:
+            responseObserver.onError(
+                new StatusException(
+                    Status.INVALID_ARGUMENT.withDescription("Illegal state " + 
state)));
+        }
+      }
+
+      private void resolveNextEnvironment(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        if (pendingResolves.isEmpty()) {
+          finishStaging(responseObserver);
+        } else {
+          state = State.RESOLVE;
+          LOG.info("Resolving artifacts for {}.{}.", stagingToken, 
pendingResolves.peek());
+          responseObserver.onNext(
+              ArtifactApi.ArtifactRequestWrapper.newBuilder()
+                  .setResolveArtifact(
+                      ArtifactApi.ResolveArtifactsRequest.newBuilder()
+                          
.addAllArtifacts(toResolve.get(pendingResolves.peek())))
+                  .build());
+        }
+      }
+
+      private void finishStaging(
+          StreamObserver<ArtifactApi.ArtifactRequestWrapper> responseObserver) 
{
+        LOG.debug("Finishing staging for {}.", stagingToken);
+        Map<String, List<RunnerApi.ArtifactInformation>> staged = new 
HashMap<>();
+        try {
+          for (Map.Entry<String, List<Future<RunnerApi.ArtifactInformation>>> 
entry :
+              stagedFutures.entrySet()) {
+            List<RunnerApi.ArtifactInformation> envStaged = new ArrayList<>();
+            for (Future<RunnerApi.ArtifactInformation> future : 
entry.getValue()) {
+              envStaged.add(future.get());
+            }
+            staged.put(entry.getKey(), envStaged);
+          }
+          ArtifactStagingService.this.staged.put(stagingToken, staged);
+          stagingExecutor.shutdown();
+          state = State.DONE;
+          LOG.info("Artifacts fully staged for {}.", stagingToken);
+          responseObserver.onCompleted();
+        } catch (Exception exn) {
+          LOG.error("Error staging artifacts", exn);
+          responseObserver.onError(exn);
+          state = State.ERROR;
+          return;
+        }
+      }
+
+      /**
+       * Return an alternative artifact if we do not need to get this over the 
artifact API, or
+       * possibly at all.
+       */
+      private Optional<RunnerApi.ArtifactInformation> getLocal(
+          RunnerApi.ArtifactInformation artifact) {
+        return Optional.empty();
+      }
+
+      /**
+       * Attempts to provide a reasonable filename for the artifact.
+       *
+       * @param index a monotonically increasing index, which provides 
uniqueness
+       * @param environment the environment id
+       * @param artifact the artifact itself
+       */
+      private String createFilename(String environment, 
RunnerApi.ArtifactInformation artifact) {
+        String path;
+        try {
+          if 
(artifact.getRoleUrn().equals(ArtifactRetrievalService.STAGING_TO_ARTIFACT_URN))
 {
+            path =
+                
RunnerApi.ArtifactStagingToRolePayload.parseFrom(artifact.getRolePayload())
+                    .getStagedName();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.FILE_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactFilePayload.parseFrom(artifact.getTypePayload()).getPath();
+          } else if 
(artifact.getTypeUrn().equals(ArtifactRetrievalService.URL_ARTIFACT_URN)) {
+            path = 
RunnerApi.ArtifactUrlPayload.parseFrom(artifact.getTypePayload()).getUrl();
+          } else {
+            path = "artifact";
+          }
+        } catch (InvalidProtocolBufferException exn) {
+          throw new RuntimeException(exn);
+        }
+        // Limit to the last contiguous alpha-numeric sequence. In particular, 
this will exclude
+        // all path separators.
+        List<String> components = 
Splitter.onPattern("[^A-Za-z-_.]]").splitToList(path);
+        String base = components.get(components.size() - 1);
+        return clip(
+            String.format("%s-%s-%s", idGenerator.getId(), clip(environment, 
25), base), 100);
+      }
+
+      private String clip(String s, int maxLength) {
+        return s.length() < maxLength ? s : s.substring(0, maxLength);
+      }
+
+      @Override
+      public void onError(Throwable throwable) {
+        stagingExecutor.shutdownNow();
+        LOG.error("Error staging artifacts", throwable);
+        state = State.ERROR;
+      }
+
+      @Override
+      public void onCompleted() {
+        Preconditions.checkArgument(state == State.DONE);
+      }
+    };
+  }
+
+  @Override
+  public void close() throws Exception {
+    // Nothing to close.
+  }
+
+  /**
+   * Lazily stages artifacts by letting an ArtifactStagingService resolve and 
request artifacts.
+   *
+   * @param retrievalService an ArtifactRetrievalService used to resolve and 
retrieve artifacts
+   * @param stagingService an ArtifactStagingService stub which will request 
artifacts
+   * @param stagingToken the staging token of the job whose artifacts will be 
retrieved
+   * @throws InterruptedException
+   * @throws IOException
+   */
+  public static void offer(
+      ArtifactRetrievalService retrievalService,
+      ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+      String stagingToken)
+      throws ExecutionException, InterruptedException {
+    new StagingDriver(retrievalService, stagingService, 
stagingToken).getCompletionFuture().get();
+  }
+
+  /** Actually implements the reverse retrieval protocol. */
+  private static class StagingDriver implements 
StreamObserver<ArtifactApi.ArtifactRequestWrapper> {
+
+    private final ArtifactRetrievalService retrievalService;
+    private final StreamObserver<ArtifactApi.ArtifactResponseWrapper> 
responseObserver;
+    private final CompletableFuture<Void> completionFuture;
+
+    public StagingDriver(
+        ArtifactRetrievalService retrievalService,
+        ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stagingService,
+        String stagingToken) {
+      this.retrievalService = retrievalService;
+      responseObserver = stagingService.reverseArtifactRetrievalService(this);
+      responseObserver.onNext(
+          
ArtifactApi.ArtifactResponseWrapper.newBuilder().setStagingToken(stagingToken).build());
+      completionFuture = new CompletableFuture<Void>();
+    }
+
+    public CompletableFuture<?> getCompletionFuture() {
+      return completionFuture;
+    }
+
+    @Override
+    public void onNext(ArtifactApi.ArtifactRequestWrapper requestWrapper) {
+      if (requestWrapper.hasResolveArtifact()) {

Review comment:
       if the future is completed, we should call responseObserver.onError with 
the exception
   
   This allows us to propagate the root cause further along.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to