http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/ThriftFactory.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/ThriftFactory.java b/commons/src/main/java/com/twitter/common/thrift/ThriftFactory.java new file mode 100644 index 0000000..be6a1c4 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/ThriftFactory.java @@ -0,0 +1,657 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import org.apache.thrift.async.TAsyncClient; +import org.apache.thrift.async.TAsyncClientManager; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransport; + +import com.twitter.common.base.Closure; +import com.twitter.common.base.Closures; +import com.twitter.common.base.MorePreconditions; +import com.twitter.common.net.loadbalancing.LeastConnectedStrategy; +import com.twitter.common.net.loadbalancing.LoadBalancer; +import com.twitter.common.net.loadbalancing.LoadBalancerImpl; +import com.twitter.common.net.loadbalancing.LoadBalancingStrategy; +import com.twitter.common.net.loadbalancing.MarkDeadStrategyWithHostCheck; +import com.twitter.common.net.loadbalancing.TrafficMonitorAdapter; +import com.twitter.common.net.monitoring.TrafficMonitor; +import com.twitter.common.net.pool.Connection; +import com.twitter.common.net.pool.ConnectionPool; +import com.twitter.common.net.pool.DynamicHostSet; +import com.twitter.common.net.pool.DynamicPool; +import com.twitter.common.net.pool.MetaPool; +import com.twitter.common.net.pool.ObjectPool; +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import com.twitter.common.stats.Stats; +import com.twitter.common.stats.StatsProvider; +import com.twitter.common.thrift.ThriftConnectionFactory.TransportType; +import com.twitter.common.util.BackoffDecider; +import com.twitter.common.util.BackoffStrategy; +import com.twitter.common.util.TruncatedBinaryBackoff; +import com.twitter.common.util.concurrent.ForwardingExecutorService; +import com.twitter.thrift.ServiceInstance; + +/** + * A utility that provides convenience methods to build common {@link Thrift}s. + * + * The thrift factory allows you to specify parameters that define how the client connects to + * and communicates with servers, such as the transport type, connection settings, and load + * balancing. Request-level settings like sync/async and retries should be set on the + * {@link Thrift} instance that this factory will create. + * + * The factory will attempt to provide reasonable defaults to allow the caller to minimize the + * amount of necessary configuration. Currently, the default behavior includes: + * + * <ul> + * <li> A test lease/release for each host will be performed every second + * {@link #withDeadConnectionRestoreInterval(Amount)} + * <li> At most 50 connections will be established to each host + * {@link #withMaxConnectionsPerEndpoint(int)} + * <li> Unframed transport {@link #useFramedTransport(boolean)} + * <li> A load balancing strategy that will mark hosts dead and prefer least-connected hosts. + * Hosts are marked dead if the most recent connection attempt was a failure or else based on + * the windowed error rate of attempted RPCs. If the error rate for a connected host exceeds + * 20% over the last second, the host will be disabled for 2 seconds ascending up to 10 seconds + * if the elevated error rate persists. + * {@link #withLoadBalancingStrategy(LoadBalancingStrategy)} + * <li> Statistics are reported through {@link Stats} + * {@link #withStatsProvider(StatsProvider)} + * <li> A service name matching the thrift interface name {@link #withServiceName(String)} + * </ul> + * + * @author John Sirois + */ +public class ThriftFactory<T> { + private static final Amount<Long,Time> DEFAULT_DEAD_TARGET_RESTORE_INTERVAL = + Amount.of(1L, Time.SECONDS); + + private static final int DEFAULT_MAX_CONNECTIONS_PER_ENDPOINT = 50; + + private Class<T> serviceInterface; + private Function<TTransport, T> clientFactory; + private int maxConnectionsPerEndpoint; + private Amount<Long,Time> connectionRestoreInterval; + private boolean framedTransport; + private LoadBalancingStrategy<InetSocketAddress> loadBalancingStrategy = null; + private final TrafficMonitor<InetSocketAddress> monitor; + private Amount<Long,Time> socketTimeout = null; + private Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback = Closures.noop(); + private StatsProvider statsProvider = Stats.STATS_PROVIDER; + private Optional<String> endpointName = Optional.absent(); + private String serviceName; + private boolean sslTransport; + + public static <T> ThriftFactory<T> create(Class<T> serviceInterface) { + return new ThriftFactory<T>(serviceInterface); + } + + /** + * Creates a default factory that will use unframed blocking transport. + * + * @param serviceInterface The interface of the thrift service to make a client for. + */ + private ThriftFactory(Class<T> serviceInterface) { + this.serviceInterface = Thrift.checkServiceInterface(serviceInterface); + this.maxConnectionsPerEndpoint = DEFAULT_MAX_CONNECTIONS_PER_ENDPOINT; + this.connectionRestoreInterval = DEFAULT_DEAD_TARGET_RESTORE_INTERVAL; + this.framedTransport = false; + this.monitor = new TrafficMonitor<InetSocketAddress>(serviceInterface.getName()); + this.serviceName = serviceInterface.getEnclosingClass().getSimpleName(); + this.sslTransport = false; + } + + private void checkBaseState() { + Preconditions.checkArgument(maxConnectionsPerEndpoint > 0, + "Must allow at least 1 connection per endpoint; %s specified", maxConnectionsPerEndpoint); + } + + public TrafficMonitor<InetSocketAddress> getMonitor() { + return monitor; + } + + /** + * Creates the thrift client, and initializes connection pools. + * + * @param backends Backends to connect to. + * @return A new thrift client. + */ + public Thrift<T> build(Set<InetSocketAddress> backends) { + checkBaseState(); + MorePreconditions.checkNotBlank(backends); + + ManagedThreadPool managedThreadPool = createManagedThreadpool(backends.size()); + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Function<TTransport, T> clientFactory = getClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(backends, loadBalancer, managedThreadPool, false); + + return new Thrift<T>(managedThreadPool, connectionPool, loadBalancer, serviceName, + serviceInterface, clientFactory, false, sslTransport); + } + + /** + * Creates a synchronous thrift client that will communicate with a dynamic host set. + * + * @param hostSet The host set to use as a backend. + * @return A thrift client. + * @throws ThriftFactoryException If an error occurred while creating the client. + */ + public Thrift<T> build(DynamicHostSet<ServiceInstance> hostSet) throws ThriftFactoryException { + checkBaseState(); + Preconditions.checkNotNull(hostSet); + + ManagedThreadPool managedThreadPool = createManagedThreadpool(1); + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Function<TTransport, T> clientFactory = getClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(hostSet, loadBalancer, managedThreadPool, false, endpointName); + + return new Thrift<T>(managedThreadPool, connectionPool, loadBalancer, serviceName, + serviceInterface, clientFactory, false, sslTransport); + } + + private ManagedThreadPool createManagedThreadpool(int initialEndpointCount) { + return new ManagedThreadPool(serviceName, initialEndpointCount, maxConnectionsPerEndpoint); + } + + /** + * A finite thread pool that monitors backend choice events to dynamically resize. This + * {@link java.util.concurrent.ExecutorService} implementation immediately rejects requests when + * there are no more available worked threads (requests are not queued). + */ + private static class ManagedThreadPool extends ForwardingExecutorService<ThreadPoolExecutor> + implements Closure<Collection<InetSocketAddress>> { + + private static final Logger LOG = Logger.getLogger(ManagedThreadPool.class.getName()); + + private static ThreadPoolExecutor createThreadPool(String serviceName, int initialSize) { + ThreadFactory threadFactory = + new ThreadFactoryBuilder() + .setNameFormat("Thrift[" +serviceName + "][%d]") + .setDaemon(true) + .build(); + return new ThreadPoolExecutor(initialSize, initialSize, 0, TimeUnit.MILLISECONDS, + new SynchronousQueue<Runnable>(), threadFactory); + } + + private final int maxConnectionsPerEndpoint; + + public ManagedThreadPool(String serviceName, int initialEndpointCount, + int maxConnectionsPerEndpoint) { + + super(createThreadPool(serviceName, initialEndpointCount * maxConnectionsPerEndpoint)); + this.maxConnectionsPerEndpoint = maxConnectionsPerEndpoint; + setRejectedExecutionHandler(initialEndpointCount); + } + + private void setRejectedExecutionHandler(int endpointCount) { + final String message = + String.format("All %d x %d connections in use", endpointCount, maxConnectionsPerEndpoint); + delegate.setRejectedExecutionHandler(new RejectedExecutionHandler() { + @Override public void rejectedExecution(Runnable runnable, ThreadPoolExecutor executor) { + throw new RejectedExecutionException(message); + } + }); + } + + @Override + public void execute(Collection<InetSocketAddress> chosenBackends) { + int previousPoolSize = delegate.getMaximumPoolSize(); + /* + * In the case of no available backends, we need to make sure we pass in a positive pool + * size to our delegate. In particular, java.util.concurrent.ThreadPoolExecutor does not + * accept zero as a valid core or max pool size. + */ + int backendCount = Math.max(chosenBackends.size(), 1); + int newPoolSize = backendCount * maxConnectionsPerEndpoint; + + if (previousPoolSize != newPoolSize) { + LOG.info(String.format("Re-sizing deadline thread pool from: %d to: %d", + previousPoolSize, newPoolSize)); + if (previousPoolSize < newPoolSize) { // Don't cross the beams! + delegate.setMaximumPoolSize(newPoolSize); + delegate.setCorePoolSize(newPoolSize); + } else { + delegate.setCorePoolSize(newPoolSize); + delegate.setMaximumPoolSize(newPoolSize); + } + setRejectedExecutionHandler(backendCount); + } + } + } + + /** + * Creates an asynchronous thrift client that will communicate with a fixed set of backends. + * + * @param backends Backends to connect to. + * @return A thrift client. + * @throws ThriftFactoryException If an error occurred while creating the client. + */ + public Thrift<T> buildAsync(Set<InetSocketAddress> backends) throws ThriftFactoryException { + checkBaseState(); + MorePreconditions.checkNotBlank(backends); + + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Closure<Collection<InetSocketAddress>> noop = Closures.noop(); + Function<TTransport, T> asyncClientFactory = getAsyncClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(backends, loadBalancer, noop, true); + + return new Thrift<T>(connectionPool, loadBalancer, + serviceName, serviceInterface, asyncClientFactory, true); + } + + /** + * Creates an asynchronous thrift client that will communicate with a dynamic host set. + * + * @param hostSet The host set to use as a backend. + * @return A thrift client. + * @throws ThriftFactoryException If an error occurred while creating the client. + */ + public Thrift<T> buildAsync(DynamicHostSet<ServiceInstance> hostSet) + throws ThriftFactoryException { + checkBaseState(); + Preconditions.checkNotNull(hostSet); + + LoadBalancer<InetSocketAddress> loadBalancer = createLoadBalancer(); + Closure<Collection<InetSocketAddress>> noop = Closures.noop(); + Function<TTransport, T> asyncClientFactory = getAsyncClientFactory(); + + ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool = + createConnectionPool(hostSet, loadBalancer, noop, true, endpointName); + + return new Thrift<T>(connectionPool, loadBalancer, + serviceName, serviceInterface, asyncClientFactory, true); + } + + /** + * Prepare the client factory, which will create client class instances from transports. + * + * @return The client factory to use. + */ + private Function<TTransport, T> getClientFactory() { + return clientFactory == null ? createClientFactory(serviceInterface) : clientFactory; + } + + /** + * Prepare the async client factory, which will create client class instances from transports. + * + * @return The client factory to use. + * @throws ThriftFactoryException If there was a problem creating the factory. + */ + private Function<TTransport, T> getAsyncClientFactory() throws ThriftFactoryException { + try { + return clientFactory == null ? createAsyncClientFactory(serviceInterface) : clientFactory; + } catch (IOException e) { + throw new ThriftFactoryException("Failed to create async client factory.", e); + } + } + + private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool( + Set<InetSocketAddress> backends, LoadBalancer<InetSocketAddress> loadBalancer, + Closure<Collection<InetSocketAddress>> onBackendsChosen, boolean nonblocking) { + + ImmutableMap.Builder<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>> + backendBuilder = ImmutableMap.builder(); + for (InetSocketAddress backend : backends) { + backendBuilder.put(backend, createConnectionPool(backend, nonblocking)); + } + + return new MetaPool<TTransport, InetSocketAddress>(backendBuilder.build(), + loadBalancer, onBackendsChosen, connectionRestoreInterval); + } + + private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool( + DynamicHostSet<ServiceInstance> hostSet, LoadBalancer<InetSocketAddress> loadBalancer, + Closure<Collection<InetSocketAddress>> onBackendsChosen, + final boolean nonblocking, Optional<String> serviceEndpointName) + throws ThriftFactoryException { + + Function<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>> + endpointPoolFactory = + new Function<InetSocketAddress, ObjectPool<Connection<TTransport, InetSocketAddress>>>() { + @Override public ObjectPool<Connection<TTransport, InetSocketAddress>> apply( + InetSocketAddress endpoint) { + return createConnectionPool(endpoint, nonblocking); + } + }; + + try { + return new DynamicPool<ServiceInstance, TTransport, InetSocketAddress>(hostSet, + endpointPoolFactory, loadBalancer, onBackendsChosen, connectionRestoreInterval, + Util.getAddress(serviceEndpointName), Util.IS_ALIVE); + } catch (DynamicHostSet.MonitorException e) { + throw new ThriftFactoryException("Failed to monitor host set.", e); + } + } + + private ObjectPool<Connection<TTransport, InetSocketAddress>> createConnectionPool( + InetSocketAddress backend, boolean nonblocking) { + + ThriftConnectionFactory connectionFactory = new ThriftConnectionFactory( + backend, maxConnectionsPerEndpoint, TransportType.get(framedTransport, nonblocking), + socketTimeout, postCreateCallback, sslTransport); + + return new ConnectionPool<Connection<TTransport, InetSocketAddress>>(connectionFactory, + statsProvider); + } + + @VisibleForTesting + public ThriftFactory<T> withClientFactory(Function<TTransport, T> clientFactory) { + this.clientFactory = Preconditions.checkNotNull(clientFactory); + + return this; + } + + public ThriftFactory<T> withSslEnabled() { + this.sslTransport = true; + return this; + } + + /** + * Specifies the maximum number of connections that should be made to any single endpoint. + * + * @param maxConnectionsPerEndpoint Maximum number of connections per endpoint. + * @return A reference to the factory. + */ + public ThriftFactory<T> withMaxConnectionsPerEndpoint(int maxConnectionsPerEndpoint) { + Preconditions.checkArgument(maxConnectionsPerEndpoint > 0); + this.maxConnectionsPerEndpoint = maxConnectionsPerEndpoint; + + return this; + } + + /** + * Specifies the interval at which dead endpoint connections should be checked and revived. + * + * @param connectionRestoreInterval the time interval to check. + * @return A reference to the factory. + */ + public ThriftFactory<T> withDeadConnectionRestoreInterval( + Amount<Long, Time> connectionRestoreInterval) { + Preconditions.checkNotNull(connectionRestoreInterval); + Preconditions.checkArgument(connectionRestoreInterval.getValue() >= 0, + "A negative interval is invalid: %s", connectionRestoreInterval); + this.connectionRestoreInterval = connectionRestoreInterval; + + return this; + } + + /** + * Instructs the factory whether framed transport should be used. + * + * @param framedTransport Whether to use framed transport. + * @return A reference to the factory. + */ + public ThriftFactory<T> useFramedTransport(boolean framedTransport) { + this.framedTransport = framedTransport; + + return this; + } + + /** + * Specifies the load balancer to use when interacting with multiple backends. + * + * @param strategy Load balancing strategy. + * @return A reference to the factory. + */ + public ThriftFactory<T> withLoadBalancingStrategy( + LoadBalancingStrategy<InetSocketAddress> strategy) { + this.loadBalancingStrategy = Preconditions.checkNotNull(strategy); + + return this; + } + + private LoadBalancer<InetSocketAddress> createLoadBalancer() { + if (loadBalancingStrategy == null) { + loadBalancingStrategy = createDefaultLoadBalancingStrategy(); + } + + return LoadBalancerImpl.create(TrafficMonitorAdapter.create(loadBalancingStrategy, monitor)); + } + + private LoadBalancingStrategy<InetSocketAddress> createDefaultLoadBalancingStrategy() { + Function<InetSocketAddress, BackoffDecider> backoffFactory = + new Function<InetSocketAddress, BackoffDecider>() { + @Override public BackoffDecider apply(InetSocketAddress socket) { + BackoffStrategy backoffStrategy = new TruncatedBinaryBackoff( + Amount.of(2L, Time.SECONDS), Amount.of(10L, Time.SECONDS)); + + return BackoffDecider.builder(socket.toString()) + .withTolerateFailureRate(0.2) + .withRequestWindow(Amount.of(1L, Time.SECONDS)) + .withSeedSize(5) + .withStrategy(backoffStrategy) + .withRecoveryType(BackoffDecider.RecoveryType.FULL_CAPACITY) + .withStatsProvider(statsProvider) + .build(); + } + }; + + return new MarkDeadStrategyWithHostCheck<InetSocketAddress>( + new LeastConnectedStrategy<InetSocketAddress>(), backoffFactory); + } + + /** + * Specifies the net read/write timeout to set via SO_TIMEOUT on the thrift blocking client + * or AsyncClient.setTimeout on the thrift async client. Defaults to the connectTimeout on + * the blocking client if not set. + * + * @param socketTimeout timeout on thrift i/o operations + * @return A reference to the factory. + */ + public ThriftFactory<T> withSocketTimeout(Amount<Long, Time> socketTimeout) { + this.socketTimeout = Preconditions.checkNotNull(socketTimeout); + Preconditions.checkArgument(socketTimeout.as(Time.MILLISECONDS) >= 0); + + return this; + } + + /** + * Specifies the callback to notify when a connection has been created. The callback may + * be used to make thrift calls to the connection, but must not invalidate it. + * Defaults to a no-op closure. + * + * @param postCreateCallback function to setup new connections + * @return A reference to the factory. + */ + public ThriftFactory<T> withPostCreateCallback( + Closure<Connection<TTransport, InetSocketAddress>> postCreateCallback) { + this.postCreateCallback = Preconditions.checkNotNull(postCreateCallback); + + return this; + } + + /** + * Registers a custom stats provider to use to track various client stats. + * + * @param statsProvider the {@code StatsProvider} to use + * @return A reference to the factory. + */ + public ThriftFactory<T> withStatsProvider(StatsProvider statsProvider) { + this.statsProvider = Preconditions.checkNotNull(statsProvider); + + return this; + } + + /** + * Name to be passed to Thrift constructor, used in stats. + * + * @param serviceName string to use + * @return A reference to the factory. + */ + public ThriftFactory<T> withServiceName(String serviceName) { + this.serviceName = MorePreconditions.checkNotBlank(serviceName); + + return this; + } + + /** + * Set the end-point to use from {@link ServiceInstance#getAdditionalEndpoints()}. + * If not set, the default behavior is to use {@link ServiceInstance#getServiceEndpoint()}. + * + * @param endpointName the (optional) name of the end-point, if unset - the + * default/primary end-point is selected + * @return a reference to the factory for chaining + */ + public ThriftFactory<T> withEndpointName(String endpointName) { + this.endpointName = Optional.of(endpointName); + return this; + } + + private static <T> Function<TTransport, T> createClientFactory(Class<T> serviceInterface) { + final Constructor<? extends T> implementationConstructor = + findImplementationConstructor(serviceInterface); + + return new Function<TTransport, T>() { + @Override public T apply(TTransport transport) { + try { + return implementationConstructor.newInstance(new TBinaryProtocol(transport)); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } + } + }; + } + + private <T> Function<TTransport, T> createAsyncClientFactory( + final Class<T> serviceInterface) throws IOException { + + final TAsyncClientManager clientManager = new TAsyncClientManager(); + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override public void run() { + clientManager.stop(); + } + }); + + final Constructor<? extends T> implementationConstructor = + findAsyncImplementationConstructor(serviceInterface); + + return new Function<TTransport, T>() { + @Override public T apply(TTransport transport) { + Preconditions.checkNotNull(transport); + Preconditions.checkArgument(transport instanceof TNonblockingTransport, + "Invalid transport provided to client factory: " + transport.getClass()); + + try { + T client = implementationConstructor.newInstance(new TBinaryProtocol.Factory(), + clientManager, transport); + + if (socketTimeout != null) { + ((TAsyncClient) client).setTimeout(socketTimeout.as(Time.MILLISECONDS)); + } + + return client; + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } + } + }; + } + + private static <T> Constructor<? extends T> findImplementationConstructor( + final Class<T> serviceInterface) { + Class<? extends T> implementationClass = findImplementationClass(serviceInterface); + try { + return implementationClass.getConstructor(TProtocol.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Failed to find a single argument TProtocol constructor " + + "in service client class: " + implementationClass); + } + } + + private static <T> Constructor<? extends T> findAsyncImplementationConstructor( + final Class<T> serviceInterface) { + Class<? extends T> implementationClass = findImplementationClass(serviceInterface); + try { + return implementationClass.getConstructor(TProtocolFactory.class, TAsyncClientManager.class, + TNonblockingTransport.class); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Failed to find expected constructor " + + "in service client class: " + implementationClass); + } + } + + @SuppressWarnings("unchecked") + private static <T> Class<? extends T> findImplementationClass(final Class<T> serviceInterface) { + try { + return (Class<? extends T>) + Iterables.find(ImmutableList.copyOf(serviceInterface.getEnclosingClass().getClasses()), + new Predicate<Class<?>>() { + @Override public boolean apply(Class<?> inner) { + return !serviceInterface.equals(inner) + && serviceInterface.isAssignableFrom(inner); + } + }); + } catch (NoSuchElementException e) { + throw new IllegalArgumentException("Could not find a sibling enclosed implementation of " + + "service interface: " + serviceInterface); + } + } + + public static class ThriftFactoryException extends Exception { + public ThriftFactoryException(String msg) { + super(msg); + } + + public ThriftFactoryException(String msg, Throwable t) { + super(msg, t); + } + } +}
http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/Util.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/Util.java b/commons/src/main/java/com/twitter/common/thrift/Util.java new file mode 100644 index 0000000..92eba98 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/Util.java @@ -0,0 +1,237 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.annotation.Nullable; + +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.base.Strings; +import com.google.common.collect.Lists; + +import org.apache.thrift.TBase; +import org.apache.thrift.TFieldIdEnum; +import org.apache.thrift.meta_data.FieldMetaData; + +import com.twitter.thrift.Endpoint; +import com.twitter.thrift.ServiceInstance; + +/** + * Utility functions for thrift. + * + * @author William Farner + */ +public class Util { + + /** + * Maps a {@link ServiceInstance} to an {@link InetSocketAddress} given the {@code endpointName}. + * + * @param optionalEndpointName the name of the end-point on the service's additional end-points, + * if not set, maps to the primary service end-point + */ + public static Function<ServiceInstance, InetSocketAddress> getAddress( + final Optional<String> optionalEndpointName) { + if (!optionalEndpointName.isPresent()) { + return GET_ADDRESS; + } + + final String endpointName = optionalEndpointName.get(); + return getAddress( + new Function<ServiceInstance, Endpoint>() { + @Override public Endpoint apply(@Nullable ServiceInstance serviceInstance) { + Map<String, Endpoint> endpoints = serviceInstance.getAdditionalEndpoints(); + Preconditions.checkArgument(endpoints.containsKey(endpointName), + "Did not find end-point %s on %s", endpointName, serviceInstance); + return endpoints.get(endpointName); + } + }); + } + + private static Function<ServiceInstance, InetSocketAddress> getAddress( + final Function<ServiceInstance, Endpoint> serviceToEndpoint) { + return new Function<ServiceInstance, InetSocketAddress>() { + @Override public InetSocketAddress apply(ServiceInstance serviceInstance) { + Endpoint endpoint = serviceToEndpoint.apply(serviceInstance); + return InetSocketAddress.createUnresolved(endpoint.getHost(), endpoint.getPort()); + } + }; + } + + private static Function<ServiceInstance, Endpoint> GET_PRIMARY_ENDPOINT = + new Function<ServiceInstance, Endpoint>() { + @Override public Endpoint apply(ServiceInstance input) { + return input.getServiceEndpoint(); + } + }; + + public static Function<ServiceInstance, InetSocketAddress> GET_ADDRESS = + getAddress(GET_PRIMARY_ENDPOINT); + + public static final Predicate<ServiceInstance> IS_ALIVE = new Predicate<ServiceInstance>() { + @Override public boolean apply(ServiceInstance serviceInstance) { + switch (serviceInstance.getStatus()) { + case ALIVE: + return true; + + // We'll be optimistic here and let MTCP's ranking deal with + // unhealthy services in a WARNING state. + case WARNING: + return true; + + // Services which are just starting up, on the other hand... are much easier to just not + // send requests to. The STARTING state is useful to distinguish from WARNING or ALIVE: + // you exist in ZooKeeper, but don't yet serve traffic. + case STARTING: + default: + return false; + } + } + }; + + /** + * Pretty-prints a thrift object contents. + * + * @param t The thrift object to print. + * @return The pretty-printed version of the thrift object. + */ + public static String prettyPrint(TBase t) { + return t == null ? "null" : printTbase(t, 0); + } + + /** + * Prints an object contained in a thrift message. + * + * @param o The object to print. + * @param depth The print nesting level. + * @return The pretty-printed version of the thrift field. + */ + private static String printValue(Object o, int depth) { + if (o == null) { + return "null"; + } else if (TBase.class.isAssignableFrom(o.getClass())) { + return "\n" + printTbase((TBase) o, depth + 1); + } else if (Map.class.isAssignableFrom(o.getClass())) { + return printMap((Map) o, depth + 1); + } else if (List.class.isAssignableFrom(o.getClass())) { + return printList((List) o, depth + 1); + } else if (Set.class.isAssignableFrom(o.getClass())) { + return printSet((Set) o, depth + 1); + } else if (String.class == o.getClass()) { + return '"' + o.toString() + '"'; + } else { + return o.toString(); + } + } + + private static final String METADATA_MAP_FIELD_NAME = "metaDataMap"; + + /** + * Prints a TBase. + * + * @param t The object to print. + * @param depth The print nesting level. + * @return The pretty-printed version of the TBase. + */ + private static String printTbase(TBase t, int depth) { + List<String> fields = Lists.newArrayList(); + for (Map.Entry<? extends TFieldIdEnum, FieldMetaData> entry : + FieldMetaData.getStructMetaDataMap(t.getClass()).entrySet()) { + @SuppressWarnings("unchecked") + boolean fieldSet = t.isSet(entry.getKey()); + String strValue; + if (fieldSet) { + @SuppressWarnings("unchecked") + Object value = t.getFieldValue(entry.getKey()); + strValue = printValue(value, depth); + } else { + strValue = "not set"; + } + fields.add(tabs(depth) + entry.getValue().fieldName + ": " + strValue); + } + + return Joiner.on("\n").join(fields); + } + + /** + * Prints a map in a style that is consistent with TBase pretty printing. + * + * @param map The map to print + * @param depth The print nesting level. + * @return The pretty-printed version of the map. + */ + private static String printMap(Map<?, ?> map, int depth) { + List<String> entries = Lists.newArrayList(); + for (Map.Entry entry : map.entrySet()) { + entries.add(tabs(depth) + printValue(entry.getKey(), depth) + + " = " + printValue(entry.getValue(), depth)); + } + + return entries.isEmpty() ? "{}" + : String.format("{\n%s\n%s}", Joiner.on(",\n").join(entries), tabs(depth - 1)); + } + + /** + * Prints a list in a style that is consistent with TBase pretty printing. + * + * @param list The list to print + * @param depth The print nesting level. + * @return The pretty-printed version of the list + */ + private static String printList(List<?> list, int depth) { + List<String> entries = Lists.newArrayList(); + for (int i = 0; i < list.size(); i++) { + entries.add( + String.format("%sItem[%d] = %s", tabs(depth), i, printValue(list.get(i), depth))); + } + + return entries.isEmpty() ? "[]" + : String.format("[\n%s\n%s]", Joiner.on(",\n").join(entries), tabs(depth - 1)); + } + /** + * Prints a set in a style that is consistent with TBase pretty printing. + * + * @param set The set to print + * @param depth The print nesting level. + * @return The pretty-printed version of the set + */ + private static String printSet(Set<?> set, int depth) { + List<String> entries = Lists.newArrayList(); + for (Object item : set) { + entries.add( + String.format("%sItem = %s", tabs(depth), printValue(item, depth))); + } + + return entries.isEmpty() ? "{}" + : String.format("{\n%s\n%s}", Joiner.on(",\n").join(entries), tabs(depth - 1)); + } + + private static String tabs(int n) { + return Strings.repeat(" ", n); + } + + private Util() { + // Utility class. + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/Caller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/Caller.java b/commons/src/main/java/com/twitter/common/thrift/callers/Caller.java new file mode 100644 index 0000000..80c9e67 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/Caller.java @@ -0,0 +1,102 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import com.google.common.base.Preconditions; +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import org.apache.thrift.async.AsyncMethodCallback; + +import javax.annotation.Nullable; +import java.lang.reflect.Method; + +/** +* A caller that invokes a method on an object. +* +* @author William Farner +*/ +public interface Caller { + + /** + * Invokes a method on an object, using the given arguments. The method call may be + * asynchronous, in which case {@code callback} will be non-null. + * + * @param method The method being invoked. + * @param args The arguments to call {@code method} with. + * @param callback The callback to use if the method is asynchronous. + * @param connectTimeoutOverride Optional override for the default connection timeout. + * @return The return value from invoking the method. + * @throws Throwable Exception, as prescribed by the method's contract. + */ + public Object call(Method method, Object[] args, @Nullable AsyncMethodCallback callback, + @Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable; + + /** + * Captures the result of a request, whether synchronous or asynchronous. It should be expected + * that for every request made, exactly one of these methods will be called. + */ + static interface ResultCapture { + /** + * Called when the request completed successfully. + */ + void success(); + + /** + * Called when the request failed. + * + * @param t Throwable that was caught. Must never be null. + * @return {@code true} if a wrapped callback should be notified of the failure, + * {@code false} otherwise. + */ + boolean fail(Throwable t); + } + + /** + * A callback that adapts a {@link ResultCapture} with an {@link AsyncMethodCallback} while + * maintaining the AsyncMethodCallback interface. The wrapped callback will handle invocation + * of the underlying callback based on the return values from the ResultCapture. + */ + static class WrappedMethodCallback implements AsyncMethodCallback { + private final AsyncMethodCallback wrapped; + private final ResultCapture capture; + + private boolean callbackTriggered = false; + + public WrappedMethodCallback(AsyncMethodCallback wrapped, ResultCapture capture) { + this.wrapped = wrapped; + this.capture = capture; + } + + private void callbackTriggered() { + Preconditions.checkState(!callbackTriggered, "Each callback may only be triggered once."); + callbackTriggered = true; + } + + @Override @SuppressWarnings("unchecked") public void onComplete(Object o) { + capture.success(); + wrapped.onComplete(o); + callbackTriggered(); + } + + @Override public void onError(Exception t) { + if (capture.fail(t)) { + wrapped.onError(t); + callbackTriggered(); + } + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/CallerDecorator.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/CallerDecorator.java b/commons/src/main/java/com/twitter/common/thrift/callers/CallerDecorator.java new file mode 100644 index 0000000..fc85b8a --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/CallerDecorator.java @@ -0,0 +1,81 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import org.apache.thrift.async.AsyncMethodCallback; + +import javax.annotation.Nullable; +import java.lang.reflect.Method; + +/** +* A caller that decorates another caller. +* +* @author William Farner +*/ +abstract class CallerDecorator implements Caller { + private final Caller decoratedCaller; + private final boolean async; + + CallerDecorator(Caller decoratedCaller, boolean async) { + this.decoratedCaller = decoratedCaller; + this.async = async; + } + + /** + * Convenience method for invoking the method and shunting the capture into the callback if + * the call is asynchronous. + * + * @param method The method being invoked. + * @param args The arguments to call {@code method} with. + * @param callback The callback to use if the method is asynchronous. + * @param capture The result capture to notify of the call result. + * @param connectTimeoutOverride Optional override for the default connection timeout. + * @return The return value from invoking the method. + * @throws Throwable Exception, as prescribed by the method's contract. + */ + protected final Object invoke(Method method, Object[] args, + @Nullable AsyncMethodCallback callback, @Nullable final ResultCapture capture, + @Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable { + + // Swap the wrapped callback out for ours. + if (callback != null) { + callback = new WrappedMethodCallback(callback, capture); + } + + try { + Object result = decoratedCaller.call(method, args, callback, connectTimeoutOverride); + if (callback == null && capture != null) capture.success(); + + return result; + } catch (Exception t) { + // We allow this one to go to both sync and async captures. + if (callback != null) { + callback.onError(t); + return null; + } else { + if (capture != null) capture.fail(t); + throw t; + } + } + } + + boolean isAsync() { + return async; + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/DeadlineCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/DeadlineCaller.java b/commons/src/main/java/com/twitter/common/thrift/callers/DeadlineCaller.java new file mode 100644 index 0000000..63f8f4d --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/DeadlineCaller.java @@ -0,0 +1,96 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeoutException; + +import javax.annotation.Nullable; + +import com.google.common.base.Throwables; + +import org.apache.thrift.async.AsyncMethodCallback; + +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import com.twitter.common.thrift.TResourceExhaustedException; +import com.twitter.common.thrift.TTimeoutException; + +/** + * A caller that imposes a time deadline on the underlying caller. If the underlying calls fail + * to meet the deadline {@link TTimeoutException} is thrown. If the executor service rejects + * execution of a task, {@link TResourceExhaustedException} is thrown. + * + * @author William Farner + */ +public class DeadlineCaller extends CallerDecorator { + private final ExecutorService executorService; + private final Amount<Long, Time> timeout; + + /** + * Creates a new deadline caller. + * + * @param decoratedCaller The caller to decorate with a deadline. + * @param async Whether the caller is asynchronous. + * @param executorService The executor service to use for performing calls. + * @param timeout The timeout by which the underlying call should complete in. + */ + public DeadlineCaller(Caller decoratedCaller, boolean async, ExecutorService executorService, + Amount<Long, Time> timeout) { + super(decoratedCaller, async); + + this.executorService = executorService; + this.timeout = timeout; + } + + @Override + public Object call(final Method method, final Object[] args, + @Nullable final AsyncMethodCallback callback, + @Nullable final Amount<Long, Time> connectTimeoutOverride) throws Throwable { + try { + Future<Object> result = executorService.submit(new Callable<Object>() { + @Override public Object call() throws Exception { + try { + return invoke(method, args, callback, null, connectTimeoutOverride); + } catch (Throwable t) { + Throwables.propagateIfInstanceOf(t, Exception.class); + throw new RuntimeException(t); + } + } + }); + + try { + return result.get(timeout.getValue(), timeout.getUnit().getTimeUnit()); + } catch (TimeoutException e) { + result.cancel(true); + throw new TTimeoutException(e); + } catch (ExecutionException e) { + throw e.getCause(); + } + } catch (RejectedExecutionException e) { + throw new TResourceExhaustedException(e); + } catch (InvocationTargetException e) { + throw e.getCause(); + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/DebugCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/DebugCaller.java b/commons/src/main/java/com/twitter/common/thrift/callers/DebugCaller.java new file mode 100644 index 0000000..17929cd --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/DebugCaller.java @@ -0,0 +1,76 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import com.google.common.base.Joiner; +import com.google.common.base.Throwables; +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import org.apache.thrift.async.AsyncMethodCallback; + +import javax.annotation.Nullable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.logging.Logger; + +/** + * A caller that reports debugging information about calls. + * + * @author William Farner + */ +public class DebugCaller extends CallerDecorator { + private static final Logger LOG = Logger.getLogger(DebugCaller.class.getName()); + private static final Joiner ARG_JOINER = Joiner.on(", "); + + /** + * Creates a new debug caller. + * + * @param decoratedCaller The caller to decorate with debug information. + * @param async Whether the caller is asynchronous. + */ + public DebugCaller(Caller decoratedCaller, boolean async) { + super(decoratedCaller, async); + } + + @Override + public Object call(final Method method, final Object[] args, + @Nullable AsyncMethodCallback callback, @Nullable Amount<Long, Time> connectTimeoutOverride) + throws Throwable { + ResultCapture capture = new ResultCapture() { + @Override public void success() { + // No-op. + } + + @Override public boolean fail(Throwable t) { + StringBuilder message = new StringBuilder("Thrift call failed: "); + message.append(method.getName()).append("("); + ARG_JOINER.appendTo(message, args); + message.append(")"); + LOG.warning(message.toString()); + + return true; + } + }; + + try { + return invoke(method, args, callback, capture, connectTimeoutOverride); + } catch (Throwable t) { + capture.fail(t); + throw t; + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/RetryingCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/RetryingCaller.java b/commons/src/main/java/com/twitter/common/thrift/callers/RetryingCaller.java new file mode 100644 index 0000000..a04dffc --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/RetryingCaller.java @@ -0,0 +1,227 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import java.lang.reflect.Method; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Function; +import com.google.common.base.Joiner; +import com.google.common.base.Predicate; +import com.google.common.base.Throwables; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +import org.apache.thrift.async.AsyncMethodCallback; + +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import com.twitter.common.stats.StatsProvider; +import com.twitter.common.thrift.TResourceExhaustedException; + +/** +* A caller that will retry calls to the wrapped caller. +* +* @author William Farner +*/ +public class RetryingCaller extends CallerDecorator { + private static final Logger LOG = Logger.getLogger(RetryingCaller.class.getName()); + + @VisibleForTesting + public static final Amount<Long, Time> NONBLOCKING_TIMEOUT = Amount.of(-1L, Time.MILLISECONDS); + + private final StatsProvider statsProvider; + private final String serviceName; + private final int retries; + private final ImmutableSet<Class<? extends Exception>> retryableExceptions; + private final boolean debug; + + /** + * Creates a new retrying caller. The retrying caller will attempt to call invoked methods on the + * underlying caller at most {@code retries} times. A retry will be performed only when one of + * the {@code retryableExceptions} is caught. + * + * @param decoratedCall The caller to decorate with retries. + * @param async Whether the caller is asynchronous. + * @param statsProvider The stat provider to export retry statistics through. + * @param serviceName The service name that calls are being invoked on. + * @param retries The maximum number of retries to perform. + * @param retryableExceptions The exceptions that can be retried. + * @param debug Whether to include debugging information when retries are being performed. + */ + public RetryingCaller(Caller decoratedCall, boolean async, StatsProvider statsProvider, + String serviceName, int retries, ImmutableSet<Class<? extends Exception>> retryableExceptions, + boolean debug) { + super(decoratedCall, async); + this.statsProvider = statsProvider; + this.serviceName = serviceName; + this.retries = retries; + this.retryableExceptions = retryableExceptions; + this.debug = debug; + } + + private final LoadingCache<Method, AtomicLong> stats = + CacheBuilder.newBuilder().build(new CacheLoader<Method, AtomicLong>() { + @Override public AtomicLong load(Method method) { + // Thrift does not support overloads - so just the name disambiguates all calls. + return statsProvider.makeCounter(serviceName + "_" + method.getName() + "_retries"); + } + }); + + @Override public Object call(final Method method, final Object[] args, + @Nullable final AsyncMethodCallback callback, + @Nullable final Amount<Long, Time> connectTimeoutOverride) throws Throwable { + final AtomicLong retryCounter = stats.get(method); + final AtomicInteger attempts = new AtomicInteger(); + final List<Throwable> exceptions = Lists.newArrayList(); + + final ResultCapture capture = new ResultCapture() { + @Override public void success() { + // No-op. + } + + @Override public boolean fail(Throwable t) { + if (!isRetryable(t)) { + if (debug) { + LOG.warning(String.format( + "Call failed with un-retryable exception of [%s]: %s, previous exceptions: %s", + t.getClass().getName(), t.getMessage(), combineStackTraces(exceptions))); + } + + return true; + } else if (attempts.get() >= retries) { + exceptions.add(t); + + if (debug) { + LOG.warning(String.format("Retried %d times, last error: %s, exceptions: %s", + attempts.get(), t, combineStackTraces(exceptions))); + } + + return true; + } else { + exceptions.add(t); + + if (isAsync() && attempts.incrementAndGet() <= retries) { + try { + retryCounter.incrementAndGet(); + // override connect timeout in ThriftCaller to prevent blocking for a connection + // for async retries (since this is within the callback in the selector thread) + invoke(method, args, callback, this, NONBLOCKING_TIMEOUT); + } catch (Throwable throwable) { + return fail(throwable); + } + } + + return false; + } + } + }; + + boolean continueLoop; + do { + try { + // If this is an async call, the looping will be handled within the capture. + return invoke(method, args, callback, capture, connectTimeoutOverride); + } catch (Throwable t) { + if (!isRetryable(t)) { + Throwable propagated = t; + + if (!exceptions.isEmpty() && (t instanceof TResourceExhaustedException)) { + // If we've been trucking along through retries that have had remote call failures + // and we suddenly can't immediately get a connection on the next retry, throw the + // previous remote call failure - the idea here is that the remote call failure is + // more interesting than a transient inability to get an immediate connection. + propagated = exceptions.remove(exceptions.size() - 1); + } + + if (isAsync()) { + callback.onError((Exception) propagated); + } else { + throw propagated; + } + } + } + + continueLoop = !isAsync() && attempts.incrementAndGet() <= retries; + if (continueLoop) retryCounter.incrementAndGet(); + } while (continueLoop); + + Throwable lastRetriedException = Iterables.getLast(exceptions); + if (debug) { + if (!exceptions.isEmpty()) { + LOG.warning( + String.format("Retried %d times, last error: %s, previous exceptions: %s", + attempts.get(), lastRetriedException, combineStackTraces(exceptions))); + } else { + LOG.warning( + String.format("Retried 1 time, last error: %s", lastRetriedException)); + } + } + + if (!isAsync()) throw lastRetriedException; + return null; + } + + private boolean isRetryable(Throwable throwable) { + return isRetryable.getUnchecked(throwable.getClass()); + } + + private final LoadingCache<Class<? extends Throwable>, Boolean> isRetryable = + CacheBuilder.newBuilder().build(new CacheLoader<Class<? extends Throwable>, Boolean>() { + @Override public Boolean load(Class<? extends Throwable> exceptionClass) { + return isRetryable(exceptionClass); + } + }); + + private boolean isRetryable(final Class<? extends Throwable> exceptionClass) { + if (retryableExceptions.contains(exceptionClass)) { + return true; + } + return Iterables.any(retryableExceptions, new Predicate<Class<? extends Exception>>() { + @Override public boolean apply(Class<? extends Exception> retryableExceptionClass) { + return retryableExceptionClass.isAssignableFrom(exceptionClass); + } + }); + } + + private static final Joiner STACK_TRACE_JOINER = Joiner.on('\n'); + + private static String combineStackTraces(List<Throwable> exceptions) { + if (exceptions.isEmpty()) { + return "none"; + } else { + return STACK_TRACE_JOINER.join(Iterables.transform(exceptions, + new Function<Throwable, String>() { + private int index = 1; + @Override public String apply(Throwable exception) { + return String.format("[%d] %s", + index++, Throwables.getStackTraceAsString(exception)); + } + })); + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/StatTrackingCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/StatTrackingCaller.java b/commons/src/main/java/com/twitter/common/thrift/callers/StatTrackingCaller.java new file mode 100644 index 0000000..60bf709 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/StatTrackingCaller.java @@ -0,0 +1,106 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import java.lang.reflect.Method; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.annotation.Nullable; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; + +import org.apache.thrift.async.AsyncMethodCallback; + +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import com.twitter.common.stats.StatsProvider; +import com.twitter.common.stats.StatsProvider.RequestTimer; +import com.twitter.common.thrift.TResourceExhaustedException; +import com.twitter.common.thrift.TTimeoutException; + +/** + * A caller that exports statistics about calls made to the wrapped caller. + * + * @author William Farner + */ +public class StatTrackingCaller extends CallerDecorator { + + private final StatsProvider statsProvider; + private final String serviceName; + + private final LoadingCache<Method, RequestTimer> stats = + CacheBuilder.newBuilder().build(new CacheLoader<Method, RequestTimer>() { + @Override public RequestTimer load(Method method) { + // Thrift does not support overloads - so just the name disambiguates all calls. + return statsProvider.makeRequestTimer(serviceName + "_" + method.getName()); + } + }); + + /** + * Creates a new stat tracking caller, which will export stats to the given {@link StatsProvider}. + * + * @param decoratedCaller The caller to decorate with a deadline. + * @param async Whether the caller is asynchronous. + * @param statsProvider The stat provider to export statistics to. + * @param serviceName The name of the service that methods are being called on. + */ + public StatTrackingCaller(Caller decoratedCaller, boolean async, StatsProvider statsProvider, + String serviceName) { + super(decoratedCaller, async); + + this.statsProvider = statsProvider; + this.serviceName = serviceName; + } + + @Override + public Object call(Method method, Object[] args, @Nullable AsyncMethodCallback callback, + @Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable { + final RequestTimer requestStats = stats.get(method); + final long startTime = System.nanoTime(); + + ResultCapture capture = new ResultCapture() { + @Override public void success() { + requestStats.requestComplete(TimeUnit.NANOSECONDS.toMicros( + System.nanoTime() - startTime)); + } + + @Override public boolean fail(Throwable t) { + // TODO(John Sirois): the ruby client reconnects for timeouts too - this provides a natural + // backoff mechanism - consider how to plumb something similar. + if (t instanceof TTimeoutException || t instanceof TimeoutException) { + requestStats.incTimeouts(); + return true; + } + + // TODO(John Sirois): consider ditching reconnects since its nearly redundant with errors as + // it stands. + if (!(t instanceof TResourceExhaustedException)) { + requestStats.incReconnects(); + } + // TODO(John Sirois): provide more detailed stats: track counts for distinct exceptions types, + // track retries-per-method, etc... + requestStats.incErrors(); + return true; + } + }; + + return invoke(method, args, callback, capture, connectTimeoutOverride); + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/callers/ThriftCaller.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/callers/ThriftCaller.java b/commons/src/main/java/com/twitter/common/thrift/callers/ThriftCaller.java new file mode 100644 index 0000000..9e112f5 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/callers/ThriftCaller.java @@ -0,0 +1,160 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.callers; + +import com.google.common.base.Function; +import com.google.common.collect.Lists; +import com.twitter.common.net.pool.Connection; +import com.twitter.common.net.pool.ObjectPool; +import com.twitter.common.quantity.Amount; +import com.twitter.common.quantity.Time; +import com.twitter.common.net.pool.ResourceExhaustedException; +import com.twitter.common.thrift.TResourceExhaustedException; +import com.twitter.common.thrift.TTimeoutException; +import com.twitter.common.net.loadbalancing.RequestTracker; +import org.apache.thrift.async.AsyncMethodCallback; +import org.apache.thrift.transport.TTransport; + +import javax.annotation.Nullable; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.List; +import java.util.concurrent.TimeoutException; +import java.util.logging.Logger; + +/** + * A caller that issues calls to a target that is assumed to be a client to a thrift service. + * + * @author William Farner + */ +public class ThriftCaller<T> implements Caller { + private static final Logger LOG = Logger.getLogger(ThriftCaller.class.getName()); + + private final ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool; + private final RequestTracker<InetSocketAddress> requestTracker; + private final Function<TTransport, T> clientFactory; + private final Amount<Long, Time> timeout; + private final boolean debug; + + /** + * Creates a new thrift caller. + * + * @param connectionPool The connection pool to use. + * @param requestTracker The request tracker to nofify of request results. + * @param clientFactory Factory to use for building client object instances. + * @param timeout The timeout to use when requesting objects from the connection pool. + * @param debug Whether to use the caller in debug mode. + */ + public ThriftCaller(ObjectPool<Connection<TTransport, InetSocketAddress>> connectionPool, + RequestTracker<InetSocketAddress> requestTracker, Function<TTransport, T> clientFactory, + Amount<Long, Time> timeout, boolean debug) { + + this.connectionPool = connectionPool; + this.requestTracker = requestTracker; + this.clientFactory = clientFactory; + this.timeout = timeout; + this.debug = debug; + } + + @Override + public Object call(Method method, Object[] args, @Nullable AsyncMethodCallback callback, + @Nullable Amount<Long, Time> connectTimeoutOverride) throws Throwable { + + final Connection<TTransport, InetSocketAddress> connection = getConnection(connectTimeoutOverride); + final long startNanos = System.nanoTime(); + + ResultCapture capture = new ResultCapture() { + @Override public void success() { + try { + requestTracker.requestResult(connection.getEndpoint(), + RequestTracker.RequestResult.SUCCESS, System.nanoTime() - startNanos); + } finally { + connectionPool.release(connection); + } + } + + @Override public boolean fail(Throwable t) { + if (debug) { + LOG.warning(String.format("Call to endpoint: %s failed: %s", connection, t)); + } + + try { + requestTracker.requestResult(connection.getEndpoint(), + RequestTracker.RequestResult.FAILED, System.nanoTime() - startNanos); + } finally { + connectionPool.remove(connection); + } + return true; + } + }; + + return invokeMethod(clientFactory.apply(connection.get()), method, args, callback, capture); + } + + private static Object invokeMethod(Object target, Method method, Object[] args, + AsyncMethodCallback callback, final ResultCapture capture) throws Throwable { + + // Swap the wrapped callback out for ours. + if (callback != null) { + callback = new WrappedMethodCallback(callback, capture); + + List<Object> argsList = Lists.newArrayList(args); + argsList.add(callback); + args = argsList.toArray(); + } + + try { + Object result = method.invoke(target, args); + if (callback == null) capture.success(); + + return result; + } catch (InvocationTargetException t) { + // We allow this one to go to both sync and async captures. + if (callback != null) { + callback.onError((Exception) t.getCause()); + return null; + } else { + capture.fail(t.getCause()); + throw t.getCause(); + } + } + } + + private Connection<TTransport, InetSocketAddress> getConnection( + Amount<Long, Time> connectTimeoutOverride) + throws TResourceExhaustedException, TTimeoutException { + try { + Connection<TTransport, InetSocketAddress> connection; + if (connectTimeoutOverride != null) { + connection = connectionPool.get(connectTimeoutOverride); + } else { + connection = (timeout.getValue() > 0) + ? connectionPool.get(timeout) : connectionPool.get(); + } + + if (connection == null) { + throw new TResourceExhaustedException("no connection was available"); + } + return connection; + } catch (ResourceExhaustedException e) { + throw new TResourceExhaustedException(e); + } catch (TimeoutException e) { + throw new TTimeoutException(e); + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredNonblockingServerSocket.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredNonblockingServerSocket.java b/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredNonblockingServerSocket.java new file mode 100644 index 0000000..4e47d99 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredNonblockingServerSocket.java @@ -0,0 +1,83 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.monitoring; + +import com.google.common.base.Preconditions; +import com.twitter.common.net.monitoring.ConnectionMonitor; +import org.apache.thrift.transport.TNonblockingServerSocket; +import org.apache.thrift.transport.TNonblockingSocket; +import org.apache.thrift.transport.TTransportException; + +import java.net.InetSocketAddress; + +/** + * Extension of TNonblockingServerSocket that allows for tracking of connected clients. + * + * @author William Farner + */ +public class TMonitoredNonblockingServerSocket extends TNonblockingServerSocket { + private final ConnectionMonitor monitor; + + public TMonitoredNonblockingServerSocket(int port, ConnectionMonitor monitor) + throws TTransportException { + super(port); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredNonblockingServerSocket(int port, int clientTimeout, ConnectionMonitor monitor) + throws TTransportException { + super(port, clientTimeout); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredNonblockingServerSocket(InetSocketAddress bindAddr, ConnectionMonitor monitor) + throws TTransportException { + super(bindAddr); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredNonblockingServerSocket(InetSocketAddress bindAddr, int clientTimeout, + ConnectionMonitor monitor) throws TTransportException { + super(bindAddr, clientTimeout); + this.monitor = Preconditions.checkNotNull(monitor); + } + + @Override + protected TNonblockingSocket acceptImpl() throws TTransportException { + /* TODO(William Farner): Finish implementing...may require an object proxy. + final TNonblockingSocket socket = super.acceptImpl(); + + TNonblockingSocket wrappedSocket = new TNonblockingSocket(socket.get) { + @Override public void close() { + super.close(); + monitor.disconnected(this); + } + }; + + monitor.connected(wrappedSocket, socket.getSocket().getInetAddress()); + + return wrappedSocket; + + */ + return super.acceptImpl(); + } + + @Override + public void close() { + super.close(); + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredProcessor.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredProcessor.java b/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredProcessor.java new file mode 100644 index 0000000..b89e689 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredProcessor.java @@ -0,0 +1,65 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.monitoring; + +import com.google.common.base.Preconditions; +import com.twitter.common.net.loadbalancing.RequestTracker; +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.transport.TSocket; + +import java.net.InetSocketAddress; + +import static com.twitter.common.net.loadbalancing.RequestTracker.RequestResult.*; + +/** + * A TProcessor that joins a wrapped TProcessor with a monitor. + * + * @author William Farner + */ +public class TMonitoredProcessor implements TProcessor { + private final TProcessor wrapped; + private final TMonitoredServerSocket monitoredServerSocket; + private final RequestTracker<InetSocketAddress> monitor; + + public TMonitoredProcessor(TProcessor wrapped, TMonitoredServerSocket monitoredServerSocket, + RequestTracker<InetSocketAddress> monitor) { + this.wrapped = Preconditions.checkNotNull(wrapped); + this.monitoredServerSocket = Preconditions.checkNotNull(monitoredServerSocket); + this.monitor = Preconditions.checkNotNull(monitor); + } + + @Override + public boolean process(TProtocol in, TProtocol out) throws TException { + long startNanos = System.nanoTime(); + boolean exceptionThrown = false; + try { + return wrapped.process(in, out); + } catch (TException e) { + exceptionThrown = true; + throw e; + } finally { + InetSocketAddress address = monitoredServerSocket.getAddress((TSocket) in.getTransport()); + Preconditions.checkState(address != null, + "Address unknown for transport " + in.getTransport()); + + monitor.requestResult(address, exceptionThrown ? FAILED : SUCCESS, + System.nanoTime() - startNanos); + } + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredServerSocket.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredServerSocket.java b/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredServerSocket.java new file mode 100644 index 0000000..ebc37b9 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/monitoring/TMonitoredServerSocket.java @@ -0,0 +1,114 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.monitoring; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; +import com.twitter.common.net.monitoring.ConnectionMonitor; +import org.apache.thrift.transport.TServerSocket; +import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.TTransportException; + +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.util.Collections; +import java.util.Map; + +/** + * Extension of TServerSocket that allows for tracking of connected clients. + * + * @author William Farner + */ +public class TMonitoredServerSocket extends TServerSocket { + private ConnectionMonitor<InetSocketAddress> monitor; + + public TMonitoredServerSocket(ServerSocket serverSocket, + ConnectionMonitor<InetSocketAddress> monitor) { + super(serverSocket); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredServerSocket(ServerSocket serverSocket, int clientTimeout, + ConnectionMonitor<InetSocketAddress> monitor) { + super(serverSocket, clientTimeout); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredServerSocket(int port, ConnectionMonitor<InetSocketAddress> monitor) + throws TTransportException { + super(port); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredServerSocket(int port, int clientTimeout, + ConnectionMonitor<InetSocketAddress> monitor) throws TTransportException { + super(port, clientTimeout); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredServerSocket(InetSocketAddress bindAddr, + ConnectionMonitor<InetSocketAddress> monitor) throws TTransportException { + super(bindAddr); + this.monitor = Preconditions.checkNotNull(monitor); + } + + public TMonitoredServerSocket(InetSocketAddress bindAddr, int clientTimeout, + ConnectionMonitor<InetSocketAddress> monitor) throws TTransportException { + super(bindAddr, clientTimeout); + this.monitor = Preconditions.checkNotNull(monitor); + } + + private final Map<TSocket, InetSocketAddress> addressMap = + Collections.synchronizedMap(Maps.<TSocket, InetSocketAddress>newHashMap()); + + public InetSocketAddress getAddress(TSocket socket) { + return addressMap.get(socket); + } + + @Override + protected TSocket acceptImpl() throws TTransportException { + final TSocket socket = super.acceptImpl(); + final InetSocketAddress remoteAddress = + (InetSocketAddress) socket.getSocket().getRemoteSocketAddress(); + + TSocket monitoredSocket = new TSocket(socket.getSocket()) { + boolean closed = false; + + @Override public void close() { + try { + super.close(); + } finally { + if (!closed) { + monitor.released(remoteAddress); + addressMap.remove(this); + } + closed = true; + } + } + }; + + addressMap.put(monitoredSocket, remoteAddress); + + monitor.connected(remoteAddress); + return monitoredSocket; + } + + @Override + public void close() { + super.close(); + } +} http://git-wip-us.apache.org/repos/asf/aurora/blob/86a547b9/commons/src/main/java/com/twitter/common/thrift/testing/MockTSocket.java ---------------------------------------------------------------------- diff --git a/commons/src/main/java/com/twitter/common/thrift/testing/MockTSocket.java b/commons/src/main/java/com/twitter/common/thrift/testing/MockTSocket.java new file mode 100644 index 0000000..5dcc4a1 --- /dev/null +++ b/commons/src/main/java/com/twitter/common/thrift/testing/MockTSocket.java @@ -0,0 +1,48 @@ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= + +package com.twitter.common.thrift.testing; + +import org.apache.thrift.transport.TSocket; + +/** + * @author William Farner + */ +public class MockTSocket extends TSocket { + public static final String HOST = "dummyHost"; + public static final int PORT = 1000; + + private boolean connected = false; + + public MockTSocket() { + super(HOST, PORT); + } + + @Override + public void open() { + connected = true; + // TODO(William Farner): Allow for failure injection here by throwing TTransportException. + } + + @Override + public boolean isOpen() { + return connected; + } + + public void close() { + connected = false; + } +}
