This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bcca67d286b140610215a3222baa23aa8237a035
Author: Dian Fu <[email protected]>
AuthorDate: Fri Nov 24 14:46:35 2023 +0800

    [FLINK-33613][python] Port Beam DefaultJobBundleFactory class to 
flink-python module
---
 .../control/DefaultJobBundleFactory.java           | 777 +++++++++++++++++++++
 1 file changed, 777 insertions(+)

diff --git 
a/flink-python/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
 
b/flink-python/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
new file mode 100644
index 00000000000..c22175b1742
--- /dev/null
+++ 
b/flink-python/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
@@ -0,0 +1,777 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.ProvisionApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
+import org.apache.beam.model.pipeline.v1.RunnerApi.StandardEnvironments;
+import org.apache.beam.runners.core.construction.BeamUrns;
+import org.apache.beam.runners.core.construction.Environments;
+import org.apache.beam.runners.core.construction.PipelineOptionsTranslation;
+import org.apache.beam.runners.core.construction.Timer;
+import org.apache.beam.runners.core.construction.graph.ExecutableStage;
+import org.apache.beam.runners.fnexecution.artifact.ArtifactRetrievalService;
+import 
org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor;
+import 
org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.TimerSpec;
+import 
org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor;
+import org.apache.beam.runners.fnexecution.data.GrpcDataService;
+import 
org.apache.beam.runners.fnexecution.environment.DockerEnvironmentFactory;
+import 
org.apache.beam.runners.fnexecution.environment.EmbeddedEnvironmentFactory;
+import org.apache.beam.runners.fnexecution.environment.EnvironmentFactory;
+import 
org.apache.beam.runners.fnexecution.environment.ExternalEnvironmentFactory;
+import 
org.apache.beam.runners.fnexecution.environment.ProcessEnvironmentFactory;
+import org.apache.beam.runners.fnexecution.environment.RemoteEnvironment;
+import org.apache.beam.runners.fnexecution.logging.GrpcLoggingService;
+import org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter;
+import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
+import 
org.apache.beam.runners.fnexecution.provisioning.StaticGrpcProvisionService;
+import org.apache.beam.runners.fnexecution.state.GrpcStateService;
+import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.server.GrpcContextHeaderAccessorProvider;
+import org.apache.beam.sdk.fn.server.GrpcFnServer;
+import org.apache.beam.sdk.fn.server.ServerFactory;
+import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
+import org.apache.beam.sdk.function.ThrowingFunction;
+import org.apache.beam.sdk.options.ExperimentalOptions;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PortablePipelineOptions;
+import org.apache.beam.sdk.util.NoopLock;
+import org.apache.beam.sdk.values.KV;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheBuilder;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.CacheLoader;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.io.IOException;
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * A {@link JobBundleFactory} for which the implementation can specify a 
custom {@link
+ * EnvironmentFactory} for environment management. Note that returned {@link 
StageBundleFactory
+ * stage bundle factories} are not thread-safe. Instead, a new stage factory 
should be created for
+ * each client. {@link DefaultJobBundleFactory} initializes the Environment 
lazily when the forStage
+ * is called for a stage.
+ */
+@ThreadSafe
+@SuppressWarnings({
+    "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
+    "nullness" // TODO(https://github.com/apache/beam/issues/20497)
+})
+public class DefaultJobBundleFactory implements JobBundleFactory {
+    private static final Logger LOG = 
LoggerFactory.getLogger(DefaultJobBundleFactory.class);
+    private static final IdGenerator factoryIdGenerator = 
IdGenerators.incrementingLongs();
+
+    private final String factoryId = factoryIdGenerator.getId();
+    private final ImmutableList<EnvironmentCacheAndLock> environmentCaches;
+    private final AtomicInteger stageBundleFactoryCount = new AtomicInteger();
+    private final Map<String, EnvironmentFactory.Provider> 
environmentFactoryProviderMap;
+    private final ExecutorService executor;
+    private final MapControlClientPool clientPool;
+    private final IdGenerator stageIdGenerator;
+    private final int environmentExpirationMillis;
+    private final Semaphore availableCachesSemaphore;
+    private final LinkedBlockingDeque<EnvironmentCacheAndLock> availableCaches;
+    private final boolean loadBalanceBundles;
+    /**
+     * Clients which were evicted due to environment expiration but still had 
pending references.
+     */
+    private final Set<WrappedSdkHarnessClient> evictedActiveClients;
+
+    private boolean closed;
+
+    public static DefaultJobBundleFactory create(JobInfo jobInfo) {
+        PipelineOptions pipelineOptions =
+                
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
+        Map<String, EnvironmentFactory.Provider> environmentFactoryProviderMap 
=
+                ImmutableMap.of(
+                        
BeamUrns.getUrn(StandardEnvironments.Environments.DOCKER),
+                        new DockerEnvironmentFactory.Provider(pipelineOptions),
+                        
BeamUrns.getUrn(StandardEnvironments.Environments.PROCESS),
+                        new 
ProcessEnvironmentFactory.Provider(pipelineOptions),
+                        
BeamUrns.getUrn(StandardEnvironments.Environments.EXTERNAL),
+                        new ExternalEnvironmentFactory.Provider(),
+                        Environments.ENVIRONMENT_EMBEDDED, // Non Public urn 
for testing.
+                        new 
EmbeddedEnvironmentFactory.Provider(pipelineOptions));
+        return new DefaultJobBundleFactory(jobInfo, 
environmentFactoryProviderMap);
+    }
+
+    public static DefaultJobBundleFactory create(
+            JobInfo jobInfo,
+            Map<String, EnvironmentFactory.Provider> 
environmentFactoryProviderMap) {
+        return new DefaultJobBundleFactory(jobInfo, 
environmentFactoryProviderMap);
+    }
+
+    DefaultJobBundleFactory(
+            JobInfo jobInfo, Map<String, EnvironmentFactory.Provider> 
environmentFactoryMap) {
+        IdGenerator stageIdSuffixGenerator = IdGenerators.incrementingLongs();
+        this.environmentFactoryProviderMap = environmentFactoryMap;
+        this.executor = Executors.newCachedThreadPool();
+        this.clientPool = MapControlClientPool.create();
+        this.stageIdGenerator = () -> factoryId + "-" + 
stageIdSuffixGenerator.getId();
+        this.environmentExpirationMillis = 
getEnvironmentExpirationMillis(jobInfo);
+        this.loadBalanceBundles = shouldLoadBalanceBundles(jobInfo);
+        this.environmentCaches =
+                createEnvironmentCaches(
+                        serverFactory -> createServerInfo(jobInfo, 
serverFactory),
+                        getMaxEnvironmentClients(jobInfo));
+        this.availableCachesSemaphore = new 
Semaphore(environmentCaches.size(), true);
+        this.availableCaches = new LinkedBlockingDeque<>(environmentCaches);
+        this.evictedActiveClients = Sets.newConcurrentHashSet();
+    }
+
+    @VisibleForTesting
+    DefaultJobBundleFactory(
+            JobInfo jobInfo,
+            Map<String, EnvironmentFactory.Provider> environmentFactoryMap,
+            IdGenerator stageIdGenerator,
+            ServerInfo serverInfo) {
+        this.environmentFactoryProviderMap = environmentFactoryMap;
+        this.executor = Executors.newCachedThreadPool();
+        this.clientPool = MapControlClientPool.create();
+        this.stageIdGenerator = stageIdGenerator;
+        this.environmentExpirationMillis = 
getEnvironmentExpirationMillis(jobInfo);
+        this.loadBalanceBundles = shouldLoadBalanceBundles(jobInfo);
+        this.environmentCaches =
+                createEnvironmentCaches(
+                        serverFactory -> serverInfo, 
getMaxEnvironmentClients(jobInfo));
+        this.availableCachesSemaphore = new 
Semaphore(environmentCaches.size(), true);
+        this.availableCaches = new LinkedBlockingDeque<>(environmentCaches);
+        this.evictedActiveClients = Sets.newConcurrentHashSet();
+    }
+
+    private static class EnvironmentCacheAndLock {
+        final Lock lock;
+        final LoadingCache<Environment, WrappedSdkHarnessClient> cache;
+
+        EnvironmentCacheAndLock(
+                LoadingCache<Environment, WrappedSdkHarnessClient> cache, Lock 
lock) {
+            this.lock = lock;
+            this.cache = cache;
+        }
+    }
+
+    private ImmutableList<EnvironmentCacheAndLock> createEnvironmentCaches(
+            ThrowingFunction<ServerFactory, ServerInfo> serverInfoCreator, int 
count) {
+
+        ImmutableList.Builder<EnvironmentCacheAndLock> caches = 
ImmutableList.builder();
+        for (int i = 0; i < count; i++) {
+
+            final Lock refLock;
+            if (environmentExpirationMillis > 0) {
+                // The lock ensures there is no race condition between 
expiring an environment and a
+                // client
+                // still attempting to use it, hence referencing it.
+                refLock = new ReentrantLock(true);
+            } else {
+                refLock = NoopLock.get();
+            }
+
+            CacheBuilder<Environment, WrappedSdkHarnessClient> cacheBuilder =
+                    CacheBuilder.newBuilder()
+                            .removalListener(
+                                    notification -> {
+                                        WrappedSdkHarnessClient client = 
notification.getValue();
+                                        final int refCount;
+                                        // We need to use a lock here to 
ensure we are not causing
+                                        // the environment to
+                                        // be removed if beforehand a 
StageBundleFactory has
+                                        // retrieved it but not yet
+                                        // issued ref() on it.
+                                        refLock.lock();
+                                        try {
+                                            refCount = client.unref();
+                                        } finally {
+                                            refLock.unlock();
+                                        }
+                                        if (refCount > 0) {
+                                            LOG.warn(
+                                                    "Expiring environment {} 
with {} remaining bundle references. Taking note to clean it up during shutdown 
if the references are not removed by then.",
+                                                    notification.getKey(),
+                                                    refCount);
+                                            evictedActiveClients.add(client);
+                                        }
+                                    });
+
+            if (environmentExpirationMillis > 0) {
+                cacheBuilder.expireAfterWrite(environmentExpirationMillis, 
TimeUnit.MILLISECONDS);
+            }
+
+            LoadingCache<Environment, WrappedSdkHarnessClient> cache =
+                    cacheBuilder.build(
+                            new CacheLoader<Environment, 
WrappedSdkHarnessClient>() {
+                                @Override
+                                public WrappedSdkHarnessClient 
load(Environment environment)
+                                        throws Exception {
+                                    EnvironmentFactory.Provider 
environmentFactoryProvider =
+                                            
environmentFactoryProviderMap.get(environment.getUrn());
+                                    ServerFactory serverFactory =
+                                            
environmentFactoryProvider.getServerFactory();
+                                    ServerInfo serverInfo = 
serverInfoCreator.apply(serverFactory);
+                                    String workerId = stageIdGenerator.getId();
+                                    serverInfo
+                                            .getProvisioningServer()
+                                            .getService()
+                                            .registerEnvironment(workerId, 
environment);
+                                    EnvironmentFactory environmentFactory =
+                                            
environmentFactoryProvider.createEnvironmentFactory(
+                                                    
serverInfo.getControlServer(),
+                                                    
serverInfo.getLoggingServer(),
+                                                    
serverInfo.getRetrievalServer(),
+                                                    
serverInfo.getProvisioningServer(),
+                                                    clientPool,
+                                                    stageIdGenerator);
+                                    return WrappedSdkHarnessClient.wrapping(
+                                            
environmentFactory.createEnvironment(
+                                                    environment, workerId),
+                                            serverInfo);
+                                }
+                            });
+
+            caches.add(new EnvironmentCacheAndLock(cache, refLock));
+        }
+        return caches.build();
+    }
+
+    private static int getEnvironmentExpirationMillis(JobInfo jobInfo) {
+        PipelineOptions pipelineOptions =
+                
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
+        return 
pipelineOptions.as(PortablePipelineOptions.class).getEnvironmentExpirationMillis();
+    }
+
+    private static int getMaxEnvironmentClients(JobInfo jobInfo) {
+        PortablePipelineOptions pipelineOptions =
+                PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())
+                        .as(PortablePipelineOptions.class);
+        int maxEnvironments =
+                
MoreObjects.firstNonNull(pipelineOptions.getSdkWorkerParallelism(), 1);
+        Preconditions.checkArgument(maxEnvironments >= 0, 
"sdk_worker_parallelism must be >= 0");
+        if (maxEnvironments == 0) {
+            // if this is 0, use the auto behavior of num_cores - 1 so that we 
leave some resources
+            // available for the java process
+            maxEnvironments = 
Math.max(Runtime.getRuntime().availableProcessors() - 1, 1);
+        }
+        return maxEnvironments;
+    }
+
+    private static boolean shouldLoadBalanceBundles(JobInfo jobInfo) {
+        PipelineOptions pipelineOptions =
+                
PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions());
+        boolean loadBalanceBundles =
+                
pipelineOptions.as(PortablePipelineOptions.class).getLoadBalanceBundles();
+        if (loadBalanceBundles) {
+            int stateCacheSize =
+                    Integer.parseInt(
+                            MoreObjects.firstNonNull(
+                                    ExperimentalOptions.getExperimentValue(
+                                            pipelineOptions, 
ExperimentalOptions.STATE_CACHE_SIZE),
+                                    "0"));
+            Preconditions.checkArgument(
+                    stateCacheSize == 0,
+                    "%s must be 0 when using bundle load balancing",
+                    ExperimentalOptions.STATE_CACHE_SIZE);
+        }
+        return loadBalanceBundles;
+    }
+
+    @Override
+    public StageBundleFactory forStage(ExecutableStage executableStage) {
+        return new SimpleStageBundleFactory(executableStage);
+    }
+
+    @Override
+    public synchronized void close() throws Exception {
+        if (closed) {
+            return;
+        }
+        // The following code is written defensively to guard against any 
exceptions occurring
+        // during shutdown. It is not visually appealing but unless it can be 
written more
+        // defensively, there is no reason to change it.
+        Exception exception = null;
+        for (EnvironmentCacheAndLock environmentCache : environmentCaches) {
+            try {
+                // Clear the cache. This closes all active environments.
+                // note this may cause open calls to be cancelled by the peer
+                environmentCache.cache.invalidateAll();
+                environmentCache.cache.cleanUp();
+            } catch (Exception e) {
+                if (exception != null) {
+                    exception.addSuppressed(e);
+                } else {
+                    exception = e;
+                }
+            }
+        }
+        // Cleanup any left-over environments which were not properly 
dereferenced by the Runner,
+        // e.g.
+        // when the bundle was not closed properly. This ensures we do not 
leak resources.
+        for (WrappedSdkHarnessClient client : evictedActiveClients) {
+            try {
+                //noinspection StatementWithEmptyBody
+                while (client.unref() > 0) {
+                    // Remove any pending references from the client to force 
closing the
+                    // environment
+                }
+            } catch (Exception e) {
+                if (exception != null) {
+                    exception.addSuppressed(e);
+                } else {
+                    exception = e;
+                }
+            }
+        }
+        try {
+            executor.shutdown();
+        } catch (Exception e) {
+            if (exception != null) {
+                exception.addSuppressed(e);
+            } else {
+                exception = e;
+            }
+        }
+        closed = true;
+        if (exception != null) {
+            throw exception;
+        }
+    }
+
+    private static Map<String, RemoteOutputReceiver<?>> getOutputReceivers(
+            ExecutableProcessBundleDescriptor processBundleDescriptor,
+            OutputReceiverFactory outputReceiverFactory) {
+        ImmutableMap.Builder<String, RemoteOutputReceiver<?>> outputReceivers =
+                ImmutableMap.builder();
+        for (Map.Entry<String, Coder> remoteOutputCoder :
+                processBundleDescriptor.getRemoteOutputCoders().entrySet()) {
+            String outputTransform = remoteOutputCoder.getKey();
+            Coder coder = remoteOutputCoder.getValue();
+            String bundleOutputPCollection =
+                    Iterables.getOnlyElement(
+                            processBundleDescriptor
+                                    .getProcessBundleDescriptor()
+                                    .getTransformsOrThrow(outputTransform)
+                                    .getInputsMap()
+                                    .values());
+            FnDataReceiver outputReceiver = 
outputReceiverFactory.create(bundleOutputPCollection);
+            outputReceivers.put(outputTransform, 
RemoteOutputReceiver.of(coder, outputReceiver));
+        }
+        return outputReceivers.build();
+    }
+
+    private static Map<KV<String, String>, RemoteOutputReceiver<Timer<?>>> 
getTimerReceivers(
+            ExecutableProcessBundleDescriptor processBundleDescriptor,
+            TimerReceiverFactory timerReceiverFactory) {
+        ImmutableMap.Builder<KV<String, String>, 
RemoteOutputReceiver<Timer<?>>> timerReceivers =
+                ImmutableMap.builder();
+        for (Map.Entry<String, Map<String, TimerSpec>> transformTimerSpecs :
+                processBundleDescriptor.getTimerSpecs().entrySet()) {
+            for (TimerSpec timerSpec : 
transformTimerSpecs.getValue().values()) {
+                FnDataReceiver<Timer<?>> receiver =
+                        (FnDataReceiver)
+                                timerReceiverFactory.create(
+                                        timerSpec.transformId(), 
timerSpec.timerId());
+                timerReceivers.put(
+                        KV.of(timerSpec.transformId(), timerSpec.timerId()),
+                        RemoteOutputReceiver.of(timerSpec.coder(), receiver));
+            }
+        }
+        return timerReceivers.build();
+    }
+
+    private static class PreparedClient {
+        private BundleProcessor processor;
+        private ExecutableProcessBundleDescriptor processBundleDescriptor;
+        private WrappedSdkHarnessClient wrappedClient;
+    }
+
+    private PreparedClient prepare(
+            WrappedSdkHarnessClient wrappedClient, ExecutableStage 
executableStage) {
+        PreparedClient preparedClient = new PreparedClient();
+        try {
+            preparedClient.wrappedClient = wrappedClient;
+            preparedClient.processBundleDescriptor =
+                    ProcessBundleDescriptors.fromExecutableStage(
+                            stageIdGenerator.getId(),
+                            executableStage,
+                            
wrappedClient.getServerInfo().getDataServer().getApiServiceDescriptor(),
+                            wrappedClient
+                                    .getServerInfo()
+                                    .getStateServer()
+                                    .getApiServiceDescriptor());
+        } catch (IOException e) {
+            throw new RuntimeException("Failed to create 
ProcessBundleDescriptor.", e);
+        }
+
+        preparedClient.processor =
+                wrappedClient
+                        .getClient()
+                        .getProcessor(
+                                
preparedClient.processBundleDescriptor.getProcessBundleDescriptor(),
+                                
preparedClient.processBundleDescriptor.getRemoteInputDestinations(),
+                                
wrappedClient.getServerInfo().getStateServer().getService(),
+                                
preparedClient.processBundleDescriptor.getTimerSpecs());
+        return preparedClient;
+    }
+
+    /**
+     * A {@link StageBundleFactory} for remotely processing bundles that 
supports environment
+     * expiration.
+     */
+    private class SimpleStageBundleFactory implements StageBundleFactory {
+
+        private final ExecutableStage executableStage;
+        private final int environmentIndex;
+        private final IdentityHashMap<WrappedSdkHarnessClient, PreparedClient> 
preparedClients =
+                new IdentityHashMap<>();
+        private volatile PreparedClient currentClient;
+
+        private SimpleStageBundleFactory(ExecutableStage executableStage) {
+            this.executableStage = executableStage;
+            this.environmentIndex =
+                    stageBundleFactoryCount.getAndIncrement() % 
environmentCaches.size();
+            WrappedSdkHarnessClient client =
+                    environmentCaches
+                            .get(environmentIndex)
+                            .cache
+                            .getUnchecked(executableStage.getEnvironment());
+            this.currentClient = prepare(client, executableStage);
+            this.preparedClients.put(client, currentClient);
+        }
+
+        @Override
+        public RemoteBundle getBundle(
+                OutputReceiverFactory outputReceiverFactory,
+                TimerReceiverFactory timerReceiverFactory,
+                StateRequestHandler stateRequestHandler,
+                BundleProgressHandler progressHandler,
+                BundleFinalizationHandler finalizationHandler,
+                BundleCheckpointHandler checkpointHandler)
+                throws Exception {
+            // TODO: Consider having BundleProcessor#newBundle take in an 
OutputReceiverFactory
+            // rather
+            // than constructing the receiver map here. Every bundle factory 
will need this.
+
+            final EnvironmentCacheAndLock currentCache;
+            final WrappedSdkHarnessClient client;
+            if (loadBalanceBundles) {
+                // The semaphore is used to ensure fairness, i.e. first stop 
first go.
+                availableCachesSemaphore.acquire();
+                // The blocking queue of caches for serving multiple bundles 
concurrently.
+                currentCache = availableCaches.take();
+                // Lock because the environment expiration can remove the ref 
for the client
+                // which would close the underlying environment before we can 
ref it.
+                currentCache.lock.lock();
+                try {
+                    client = 
currentCache.cache.getUnchecked(executableStage.getEnvironment());
+                    client.ref();
+                } finally {
+                    currentCache.lock.unlock();
+                }
+
+                currentClient = preparedClients.get(client);
+                if (currentClient == null) {
+                    // we are using this client for the first time
+                    preparedClients.put(client, currentClient = 
prepare(client, executableStage));
+                    // cleanup any expired clients
+                    preparedClients.keySet().removeIf(c -> 
c.bundleRefCount.get() == 0);
+                }
+
+            } else {
+                currentCache = environmentCaches.get(environmentIndex);
+                // Lock because the environment expiration can remove the ref 
for the client which
+                // would
+                // close the underlying environment before we can ref it.
+                currentCache.lock.lock();
+                try {
+                    client = 
currentCache.cache.getUnchecked(executableStage.getEnvironment());
+                    client.ref();
+                } finally {
+                    currentCache.lock.unlock();
+                }
+
+                if (currentClient.wrappedClient != client) {
+                    // reset after environment expired
+                    preparedClients.clear();
+                    currentClient = prepare(client, executableStage);
+                    preparedClients.put(client, currentClient);
+                }
+            }
+
+            if (environmentExpirationMillis > 0) {
+                // Cleanup list of clients which were active during eviction 
but now do not hold
+                // references
+                evictedActiveClients.removeIf(c -> c.bundleRefCount.get() == 
0);
+            }
+
+            final RemoteBundle bundle =
+                    currentClient.processor.newBundle(
+                            getOutputReceivers(
+                                    currentClient.processBundleDescriptor, 
outputReceiverFactory),
+                            getTimerReceivers(
+                                    currentClient.processBundleDescriptor, 
timerReceiverFactory),
+                            stateRequestHandler,
+                            progressHandler,
+                            finalizationHandler,
+                            checkpointHandler);
+            return new RemoteBundle() {
+                @Override
+                public String getId() {
+                    return bundle.getId();
+                }
+
+                @Override
+                public Map<String, FnDataReceiver> getInputReceivers() {
+                    return bundle.getInputReceivers();
+                }
+
+                @Override
+                public Map<KV<String, String>, FnDataReceiver<Timer>> 
getTimerReceivers() {
+                    return bundle.getTimerReceivers();
+                }
+
+                @Override
+                public void requestProgress() {
+                    throw new UnsupportedOperationException();
+                }
+
+                @Override
+                public void split(double fractionOfRemainder) {
+                    bundle.split(fractionOfRemainder);
+                }
+
+                @Override
+                public void close() throws Exception {
+                    try {
+                        bundle.close();
+                    } finally {
+                        client.unref();
+                        if (loadBalanceBundles) {
+                            availableCaches.offer(currentCache);
+                            availableCachesSemaphore.release();
+                        }
+                    }
+                }
+            };
+        }
+
+        @Override
+        public ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
+            return currentClient.processBundleDescriptor;
+        }
+
+        @Override
+        public InstructionRequestHandler getInstructionRequestHandler() {
+            return 
currentClient.wrappedClient.getClient().getInstructionRequestHandler();
+        }
+
+        @Override
+        public void close() throws Exception {
+            // Clear reference to encourage cache eviction. Values are weakly 
referenced.
+            preparedClients.clear();
+        }
+    }
+
+    /**
+     * Holder for an {@link SdkHarnessClient} along with its associated state 
and data servers. As
+     * of now, there is a 1:1 relationship between data services and harness 
clients. The servers
+     * are packaged here to tie server lifetimes to harness client lifetimes.
+     */
+    protected static class WrappedSdkHarnessClient {
+
+        private final RemoteEnvironment environment;
+        private final SdkHarnessClient client;
+        private final ServerInfo serverInfo;
+        private final AtomicInteger bundleRefCount = new AtomicInteger();
+
+        private boolean closed;
+
+        static WrappedSdkHarnessClient wrapping(
+                RemoteEnvironment environment, ServerInfo serverInfo) {
+            SdkHarnessClient client =
+                    SdkHarnessClient.usingFnApiClient(
+                            environment.getInstructionRequestHandler(),
+                            serverInfo.getDataServer().getService());
+            return new WrappedSdkHarnessClient(environment, client, 
serverInfo);
+        }
+
+        private WrappedSdkHarnessClient(
+                RemoteEnvironment environment, SdkHarnessClient client, 
ServerInfo serverInfo) {
+            this.environment = environment;
+            this.client = client;
+            this.serverInfo = serverInfo;
+            ref();
+        }
+
+        SdkHarnessClient getClient() {
+            return client;
+        }
+
+        ServerInfo getServerInfo() {
+            return serverInfo;
+        }
+
+        public synchronized void close() {
+            if (closed) {
+                return;
+            }
+            // DO NOT ADD ANYTHING HERE WHICH MIGHT CAUSE THE BLOCK BELOW TO 
NOT BE EXECUTED.
+            // If we exit prematurely (e.g. due to an exception), resources 
won't be cleaned up
+            // properly.
+            // Please make an AutoCloseable and add it to the try statement 
below.
+            // These will be closed in the reverse creation order:
+            try (AutoCloseable envCloser = environment;
+                    AutoCloseable provisioningServer = 
serverInfo.getProvisioningServer();
+                    AutoCloseable retrievalServer = 
serverInfo.getRetrievalServer();
+                    AutoCloseable stateServer = serverInfo.getStateServer();
+                    AutoCloseable dataServer = serverInfo.getDataServer();
+                    AutoCloseable controlServer = 
serverInfo.getControlServer();
+                    // Close the logging server first to prevent spaming the 
logs with error
+                    // messages
+                    AutoCloseable loggingServer = 
serverInfo.getLoggingServer()) {
+                // Wrap resources in try-with-resources to ensure all are 
cleaned up.
+                // This will close _all_ of these even in the presence of 
exceptions.
+                // The first exception encountered will be the base exception,
+                // the next one will be added via Throwable#addSuppressed.
+                closed = true;
+            } catch (Exception e) {
+                LOG.warn("Error cleaning up servers {}", 
environment.getEnvironment(), e);
+            }
+            // TODO: Wait for executor shutdown?
+        }
+
+        private int ref() {
+            return bundleRefCount.incrementAndGet();
+        }
+
+        private int unref() {
+            int refCount = bundleRefCount.decrementAndGet();
+            Preconditions.checkState(refCount >= 0, "Reference count must not 
be negative.");
+            if (refCount == 0) {
+                // Close environment after it was removed from cache and all 
bundles finished.
+                LOG.info("Closing environment {}", 
environment.getEnvironment());
+                close();
+            }
+            return refCount;
+        }
+    }
+
+    private ServerInfo createServerInfo(JobInfo jobInfo, ServerFactory 
serverFactory)
+            throws IOException {
+        Preconditions.checkNotNull(serverFactory, "serverFactory can not be 
null");
+
+        PortablePipelineOptions portableOptions =
+                PipelineOptionsTranslation.fromProto(jobInfo.pipelineOptions())
+                        .as(PortablePipelineOptions.class);
+
+        GrpcFnServer<FnApiControlClientPoolService> controlServer =
+                GrpcFnServer.allocatePortAndCreateFor(
+                        FnApiControlClientPoolService.offeringClientsToPool(
+                                clientPool.getSink(),
+                                
GrpcContextHeaderAccessorProvider.getHeaderAccessor()),
+                        serverFactory);
+        GrpcFnServer<GrpcLoggingService> loggingServer =
+                GrpcFnServer.allocatePortAndCreateFor(
+                        
GrpcLoggingService.forWriter(Slf4jLogWriter.getDefault()), serverFactory);
+        GrpcFnServer<ArtifactRetrievalService> retrievalServer =
+                GrpcFnServer.allocatePortAndCreateFor(
+                        new ArtifactRetrievalService(), serverFactory);
+        ProvisionApi.ProvisionInfo.Builder provisionInfo = 
jobInfo.toProvisionInfo().toBuilder();
+        
provisionInfo.setLoggingEndpoint(loggingServer.getApiServiceDescriptor());
+        
provisionInfo.setArtifactEndpoint(retrievalServer.getApiServiceDescriptor());
+        
provisionInfo.setControlEndpoint(controlServer.getApiServiceDescriptor());
+        GrpcFnServer<StaticGrpcProvisionService> provisioningServer =
+                GrpcFnServer.allocatePortAndCreateFor(
+                        StaticGrpcProvisionService.create(
+                                provisionInfo.build(),
+                                
GrpcContextHeaderAccessorProvider.getHeaderAccessor()),
+                        serverFactory);
+        GrpcFnServer<GrpcDataService> dataServer =
+                GrpcFnServer.allocatePortAndCreateFor(
+                        GrpcDataService.create(
+                                portableOptions, executor, 
OutboundObserverFactory.serverDirect()),
+                        serverFactory);
+        GrpcFnServer<GrpcStateService> stateServer =
+                
GrpcFnServer.allocatePortAndCreateFor(GrpcStateService.create(), serverFactory);
+
+        ServerInfo serverInfo =
+                new AutoValue_DefaultJobBundleFactory_ServerInfo.Builder()
+                        .setControlServer(controlServer)
+                        .setLoggingServer(loggingServer)
+                        .setRetrievalServer(retrievalServer)
+                        .setProvisioningServer(provisioningServer)
+                        .setDataServer(dataServer)
+                        .setStateServer(stateServer)
+                        .build();
+        return serverInfo;
+    }
+
+    /** A container for EnvironmentFactory and its corresponding Grpc servers. 
*/
+    public abstract static class ServerInfo {
+        abstract GrpcFnServer<FnApiControlClientPoolService> 
getControlServer();
+
+        abstract GrpcFnServer<GrpcLoggingService> getLoggingServer();
+
+        abstract GrpcFnServer<ArtifactRetrievalService> getRetrievalServer();
+
+        abstract GrpcFnServer<StaticGrpcProvisionService> 
getProvisioningServer();
+
+        abstract GrpcFnServer<GrpcDataService> getDataServer();
+
+        abstract GrpcFnServer<GrpcStateService> getStateServer();
+
+        abstract Builder toBuilder();
+
+        abstract static class Builder {
+            abstract Builder 
setControlServer(GrpcFnServer<FnApiControlClientPoolService> server);
+
+            abstract Builder setLoggingServer(GrpcFnServer<GrpcLoggingService> 
server);
+
+            abstract Builder 
setRetrievalServer(GrpcFnServer<ArtifactRetrievalService> server);
+
+            abstract Builder 
setProvisioningServer(GrpcFnServer<StaticGrpcProvisionService> server);
+
+            abstract Builder setDataServer(GrpcFnServer<GrpcDataService> 
server);
+
+            abstract Builder setStateServer(GrpcFnServer<GrpcStateService> 
server);
+
+            abstract ServerInfo build();
+        }
+    }
+}


Reply via email to