Added: hama/trunk/core/src/main/java/org/apache/hama/ipc/RPC.java URL: http://svn.apache.org/viewvc/hama/trunk/core/src/main/java/org/apache/hama/ipc/RPC.java?rev=1514580&view=auto ============================================================================== --- hama/trunk/core/src/main/java/org/apache/hama/ipc/RPC.java (added) +++ hama/trunk/core/src/main/java/org/apache/hama/ipc/RPC.java Fri Aug 16 05:20:15 2013 @@ -0,0 +1,650 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hama.ipc; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.lang.reflect.Array; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; +import java.util.HashMap; +import java.util.Map; + +import javax.net.SocketFactory; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.ObjectWritable; +import org.apache.hadoop.io.UTF8; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.security.SaslRpcServer; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.SecretManager; +import org.apache.hadoop.security.token.TokenIdentifier; +import org.apache.hama.util.BSPNetUtils; + +/** + * A simple RPC mechanism. + * + * A <i>protocol</i> is a Java interface. All parameters and return types must + * be one of: + * + * <ul> + * <li>a primitive type, <code>boolean</code>, <code>byte</code>, + * <code>char</code>, <code>short</code>, <code>int</code>, <code>long</code>, + * <code>float</code>, <code>double</code>, or <code>void</code>; or</li> + * + * <li>a {@link String}; or</li> + * + * <li>a {@link Writable}; or</li> + * + * <li>an array of the above types</li> + * </ul> + * + * All methods in the protocol should throw only IOException. No field data of + * the protocol instance is transmitted. + */ +public class RPC { + private static final Log LOG = LogFactory.getLog(RPC.class); + + private RPC() { + } // no public ctor + + /** A method invocation, including the method name and its parameters. */ + private static class Invocation implements Writable, Configurable { + private String methodName; + private Class[] parameterClasses; + private Object[] parameters; + private Configuration conf; + + public Invocation() { + } + + public Invocation(Method method, Object[] parameters) { + this.methodName = method.getName(); + this.parameterClasses = method.getParameterTypes(); + this.parameters = parameters; + } + + /** The name of the method invoked. */ + public String getMethodName() { + return methodName; + } + + /** The parameter classes. */ + public Class[] getParameterClasses() { + return parameterClasses; + } + + /** The parameter instances. */ + public Object[] getParameters() { + return parameters; + } + + public void readFields(DataInput in) throws IOException { + methodName = UTF8.readString(in); + parameters = new Object[in.readInt()]; + parameterClasses = new Class[parameters.length]; + ObjectWritable objectWritable = new ObjectWritable(); + for (int i = 0; i < parameters.length; i++) { + parameters[i] = ObjectWritable + .readObject(in, objectWritable, this.conf); + parameterClasses[i] = objectWritable.getDeclaredClass(); + } + } + + public void write(DataOutput out) throws IOException { + UTF8.writeString(out, methodName); + out.writeInt(parameterClasses.length); + for (int i = 0; i < parameterClasses.length; i++) { + ObjectWritable.writeObject(out, parameters[i], parameterClasses[i], + conf); + } + } + + public String toString() { + StringBuffer buffer = new StringBuffer(); + buffer.append(methodName); + buffer.append("("); + for (int i = 0; i < parameters.length; i++) { + if (i != 0) + buffer.append(", "); + buffer.append(parameters[i]); + } + buffer.append(")"); + return buffer.toString(); + } + + public void setConf(Configuration conf) { + this.conf = conf; + } + + public Configuration getConf() { + return this.conf; + } + + } + + /* Cache a client using its socket factory as the hash key */ + static private class ClientCache { + private Map<SocketFactory, Client> clients = new HashMap<SocketFactory, Client>(); + + /** + * Construct & cache an IPC client with the user-provided SocketFactory if + * no cached client exists. + * + * @param conf Configuration + * @return an IPC client + */ + private synchronized Client getClient(Configuration conf, + SocketFactory factory) { + // Construct & cache client. The configuration is only used for timeout, + // and Clients have connection pools. So we can either (a) lose some + // connection pooling and leak sockets, or (b) use the same timeout for + // all + // configurations. Since the IPC is usually intended globally, not + // per-job, we choose (a). + Client client = clients.get(factory); + if (client == null) { + client = new Client(ObjectWritable.class, conf, factory); + clients.put(factory, client); + } else { + client.incCount(); + } + return client; + } + + /** + * Construct & cache an IPC client with the default SocketFactory if no + * cached client exists. + * + * @param conf Configuration + * @return an IPC client + */ + private synchronized Client getClient(Configuration conf) { + return getClient(conf, SocketFactory.getDefault()); + } + + /** + * Stop a RPC client connection A RPC client is closed only when its + * reference count becomes zero. + */ + private void stopClient(Client client) { + synchronized (this) { + client.decCount(); + if (client.isZeroReference()) { + clients.remove(client.getSocketFactory()); + } + } + if (client.isZeroReference()) { + client.stop(); + } + } + } + + private static ClientCache CLIENTS = new ClientCache(); + + // for unit testing only + static Client getClient(Configuration conf) { + return CLIENTS.getClient(conf); + } + + private static class Invoker implements InvocationHandler { + private Client.ConnectionId remoteId; + private Client client; + private boolean isClosed = false; + + private Invoker(Class<? extends VersionedProtocol> protocol, + InetSocketAddress address, UserGroupInformation ticket, + Configuration conf, SocketFactory factory, int rpcTimeout, + RetryPolicy connectionRetryPolicy) throws IOException { + this.remoteId = Client.ConnectionId.getConnectionId(address, protocol, + ticket, rpcTimeout, connectionRetryPolicy, conf); + this.client = CLIENTS.getClient(conf, factory); + } + + public Object invoke(Object proxy, Method method, Object[] args) + throws Throwable { + final boolean logDebug = LOG.isDebugEnabled(); + long startTime = 0; + if (logDebug) { + startTime = System.currentTimeMillis(); + } + + ObjectWritable value = (ObjectWritable) client.call(new Invocation( + method, args), remoteId); + if (logDebug) { + long callTime = System.currentTimeMillis() - startTime; + LOG.debug("Call: " + method.getName() + " " + callTime); + } + return value.get(); + } + + /* close the IPC client that's responsible for this invoker's RPCs */ + synchronized private void close() { + if (!isClosed) { + isClosed = true; + CLIENTS.stopClient(client); + } + } + } + + /** + * A version mismatch for the RPC protocol. + */ + public static class VersionMismatch extends IOException { + private String interfaceName; + private long clientVersion; + private long serverVersion; + + /** + * Create a version mismatch exception + * + * @param interfaceName the name of the protocol mismatch + * @param clientVersion the client's version of the protocol + * @param serverVersion the server's version of the protocol + */ + public VersionMismatch(String interfaceName, long clientVersion, + long serverVersion) { + super("Protocol " + interfaceName + " version mismatch. (client = " + + clientVersion + ", server = " + serverVersion + ")"); + this.interfaceName = interfaceName; + this.clientVersion = clientVersion; + this.serverVersion = serverVersion; + } + + /** + * Get the interface name + * + * @return the java class name (eg. + * org.apache.hadoop.mapred.InterTrackerProtocol) + */ + public String getInterfaceName() { + return interfaceName; + } + + /** + * Get the client's preferred version + */ + public long getClientVersion() { + return clientVersion; + } + + /** + * Get the server's agreed to version. + */ + public long getServerVersion() { + return serverVersion; + } + } + + public static VersionedProtocol waitForProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf) throws IOException { + return waitForProxy(protocol, clientVersion, addr, conf, 0, Long.MAX_VALUE); + } + + /** + * Get a proxy connection to a remote server + * + * @param protocol protocol class + * @param clientVersion client version + * @param addr remote address + * @param conf configuration to use + * @param connTimeout time in milliseconds before giving up + * @return the proxy + * @throws IOException if the far end through a RemoteException + */ + static VersionedProtocol waitForProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, long connTimeout) + throws IOException { + return waitForProxy(protocol, clientVersion, addr, conf, 0, connTimeout); + } + + static VersionedProtocol waitForProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, int rpcTimeout, + long connTimeout) throws IOException { + long startTime = System.currentTimeMillis(); + IOException ioe; + while (true) { + try { + return getProxy(protocol, clientVersion, addr, conf, rpcTimeout); + } catch (ConnectException se) { // namenode has not been started + LOG.info("Server at " + addr + " not available yet, Zzzzz..."); + ioe = se; + } catch (SocketTimeoutException te) { // namenode is busy + LOG.info("Problem connecting to server: " + addr); + ioe = te; + } + // check if timed out + if (System.currentTimeMillis() - connTimeout >= startTime) { + throw ioe; + } + + // wait for retry + try { + Thread.sleep(1000); + } catch (InterruptedException ie) { + // IGNORE + } + } + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + */ + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, SocketFactory factory) + throws IOException { + UserGroupInformation ugi = UserGroupInformation.getCurrentUser(); + return getProxy(protocol, clientVersion, addr, ugi, conf, factory, 0); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + */ + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, SocketFactory factory, + int rpcTimeout) throws IOException { + UserGroupInformation ugi = UserGroupInformation.getCurrentUser(); + return getProxy(protocol, clientVersion, addr, ugi, conf, factory, + rpcTimeout); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + */ + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, + SocketFactory factory) throws IOException { + return getProxy(protocol, clientVersion, addr, ticket, conf, factory, 0); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + */ + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, + SocketFactory factory, int rpcTimeout) throws IOException { + return getProxy(protocol, clientVersion, addr, ticket, conf, factory, + rpcTimeout, null, true); + } + + /** + * Construct a client-side proxy object that implements the named protocol, + * talking to a server at the named address. + */ + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, UserGroupInformation ticket, Configuration conf, + SocketFactory factory, int rpcTimeout, RetryPolicy connectionRetryPolicy, + boolean checkVersion) throws IOException { + + if (UserGroupInformation.isSecurityEnabled()) { + SaslRpcServer.init(conf); + } + final Invoker invoker = new Invoker(protocol, addr, ticket, conf, factory, + rpcTimeout, connectionRetryPolicy); + VersionedProtocol proxy = (VersionedProtocol) Proxy.newProxyInstance( + protocol.getClassLoader(), new Class[] { protocol }, invoker); + + if (checkVersion) { + checkVersion(protocol, clientVersion, proxy); + } + return proxy; + } + + /** Get server version and then compare it with client version. */ + public static void checkVersion(Class<? extends VersionedProtocol> protocol, + long clientVersion, VersionedProtocol proxy) throws IOException { + long serverVersion = proxy.getProtocolVersion(protocol.getName(), + clientVersion); + if (serverVersion != clientVersion) { + throw new VersionMismatch(protocol.getName(), clientVersion, + serverVersion); + } + } + + /** + * Construct a client-side proxy object with the default SocketFactory + * + * @param protocol + * @param clientVersion + * @param addr + * @param conf + * @return a proxy instance + * @throws IOException + */ + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf) throws IOException { + return getProxy(protocol, clientVersion, addr, conf, + BSPNetUtils.getDefaultSocketFactory(conf), 0); + } + + public static VersionedProtocol getProxy( + Class<? extends VersionedProtocol> protocol, long clientVersion, + InetSocketAddress addr, Configuration conf, int rpcTimeout) + throws IOException { + + return getProxy(protocol, clientVersion, addr, conf, + BSPNetUtils.getDefaultSocketFactory(conf), rpcTimeout); + } + + /** + * Stop this proxy and release its invoker's resource + * + * @param proxy the proxy to be stopped + */ + public static void stopProxy(VersionedProtocol proxy) { + if (proxy != null) { + ((Invoker) Proxy.getInvocationHandler(proxy)).close(); + } + } + + /** + * Expert: Make multiple, parallel calls to a set of servers. + * + * @deprecated Use + * {@link #call(Method, Object[][], InetSocketAddress[], UserGroupInformation, Configuration)} + * instead + */ + public static Object[] call(Method method, Object[][] params, + InetSocketAddress[] addrs, Configuration conf) throws IOException, + InterruptedException { + return call(method, params, addrs, null, conf); + } + + /** Expert: Make multiple, parallel calls to a set of servers. */ + public static Object[] call(Method method, Object[][] params, + InetSocketAddress[] addrs, UserGroupInformation ticket, Configuration conf) + throws IOException, InterruptedException { + + Invocation[] invocations = new Invocation[params.length]; + for (int i = 0; i < params.length; i++) + invocations[i] = new Invocation(method, params[i]); + Client client = CLIENTS.getClient(conf); + try { + Writable[] wrappedValues = client.call(invocations, addrs, + method.getDeclaringClass(), ticket, conf); + + if (method.getReturnType() == Void.TYPE) { + return null; + } + + Object[] values = (Object[]) Array.newInstance(method.getReturnType(), + wrappedValues.length); + for (int i = 0; i < values.length; i++) + if (wrappedValues[i] != null) + values[i] = ((ObjectWritable) wrappedValues[i]).get(); + + return values; + } finally { + CLIENTS.stopClient(client); + } + } + + /** + * Construct a server for a protocol implementation instance listening on a + * port and address. + */ + public static Server getServer(final Object instance, + final String bindAddress, final int port, Configuration conf) + throws IOException { + return getServer(instance, bindAddress, port, 1, false, conf); + } + + /** + * Construct a server for a protocol implementation instance listening on a + * port and address. + */ + public static Server getServer(final Object instance, + final String bindAddress, final int port, final int numHandlers, + final boolean verbose, Configuration conf) throws IOException { + return getServer(instance, bindAddress, port, numHandlers, verbose, conf, + null); + } + + /** + * Construct a server for a protocol implementation instance listening on a + * port and address, with a secret manager. + */ + public static Server getServer(final Object instance, + final String bindAddress, final int port, final int numHandlers, + final boolean verbose, Configuration conf, + SecretManager<? extends TokenIdentifier> secretManager) + throws IOException { + return new Server(instance, conf, bindAddress, port, numHandlers, verbose, + secretManager); + } + + /** An RPC Server. */ + public static class Server extends org.apache.hama.ipc.Server { + private Object instance; + private boolean verbose; + + /** + * Construct an RPC server. + * + * @param instance the instance whose methods will be called + * @param conf the configuration to use + * @param bindAddress the address to bind on to listen for connection + * @param port the port to listen for connections on + */ + public Server(Object instance, Configuration conf, String bindAddress, + int port) throws IOException { + this(instance, conf, bindAddress, port, 1, false, null); + } + + private static String classNameBase(String className) { + String[] names = className.split("\\.", -1); + if (names == null || names.length == 0) { + return className; + } + return names[names.length - 1]; + } + + /** + * Construct an RPC server. + * + * @param instance the instance whose methods will be called + * @param conf the configuration to use + * @param bindAddress the address to bind on to listen for connection + * @param port the port to listen for connections on + * @param numHandlers the number of method handler threads to run + * @param verbose whether each call should be logged + */ + public Server(Object instance, Configuration conf, String bindAddress, + int port, int numHandlers, boolean verbose, + SecretManager<? extends TokenIdentifier> secretManager) + throws IOException { + super(bindAddress, port, Invocation.class, numHandlers, conf, + classNameBase(instance.getClass().getName()), secretManager); + this.instance = instance; + this.verbose = verbose; + } + + public Writable call(Class<?> protocol, Writable param, long receivedTime) + throws IOException { + try { + Invocation call = (Invocation) param; + if (verbose) + log("Call: " + call); + + Method method = protocol.getMethod(call.getMethodName(), + call.getParameterClasses()); + method.setAccessible(true); + + long startTime = System.currentTimeMillis(); + Object value = method.invoke(instance, call.getParameters()); + int processingTime = (int) (System.currentTimeMillis() - startTime); + int qTime = (int) (startTime - receivedTime); + if (LOG.isDebugEnabled()) { + LOG.debug("Served: " + call.getMethodName() + " queueTime= " + qTime + + " procesingTime= " + processingTime); + } + if (verbose) + log("Return: " + value); + + return new ObjectWritable(method.getReturnType(), value); + + } catch (InvocationTargetException e) { + Throwable target = e.getTargetException(); + if (target instanceof IOException) { + throw (IOException) target; + } else { + IOException ioe = new IOException(target.toString()); + ioe.setStackTrace(target.getStackTrace()); + throw ioe; + } + } catch (Throwable e) { + if (!(e instanceof IOException)) { + LOG.error("Unexpected throwable object ", e); + } + IOException ioe = new IOException(e.toString()); + ioe.setStackTrace(e.getStackTrace()); + throw ioe; + } + } + } + + private static void log(String value) { + if (value != null && value.length() > 55) + value = value.substring(0, 55) + "..."; + LOG.info(value); + } +}
Added: hama/trunk/core/src/main/java/org/apache/hama/ipc/RemoteException.java URL: http://svn.apache.org/viewvc/hama/trunk/core/src/main/java/org/apache/hama/ipc/RemoteException.java?rev=1514580&view=auto ============================================================================== --- hama/trunk/core/src/main/java/org/apache/hama/ipc/RemoteException.java (added) +++ hama/trunk/core/src/main/java/org/apache/hama/ipc/RemoteException.java Fri Aug 16 05:20:15 2013 @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hama.ipc; + +import java.io.IOException; +import java.lang.reflect.Constructor; + +import org.xml.sax.Attributes; + +public class RemoteException extends IOException { + /** For java.io.Serializable */ + private static final long serialVersionUID = 1L; + + private String className; + + public RemoteException(String className, String msg) { + super(msg); + this.className = className; + } + + public String getClassName() { + return className; + } + + /** + * If this remote exception wraps up one of the lookupTypes + * then return this exception. + * <p> + * Unwraps any IOException. + * + * @param lookupTypes the desired exception class. + * @return IOException, which is either the lookupClass exception or this. + */ + public IOException unwrapRemoteException(Class<?>... lookupTypes) { + if(lookupTypes == null) + return this; + for(Class<?> lookupClass : lookupTypes) { + if(!lookupClass.getName().equals(getClassName())) + continue; + try { + return instantiateException(lookupClass.asSubclass(IOException.class)); + } catch(Exception e) { + // cannot instantiate lookupClass, just return this + return this; + } + } + // wrapped up exception is not in lookupTypes, just return this + return this; + } + + /** + * Instantiate and return the exception wrapped up by this remote exception. + * + * <p> This unwraps any <code>Throwable</code> that has a constructor taking + * a <code>String</code> as a parameter. + * Otherwise it returns this. + * + * @return <code>Throwable + */ + public IOException unwrapRemoteException() { + try { + Class<?> realClass = Class.forName(getClassName()); + return instantiateException(realClass.asSubclass(IOException.class)); + } catch(Exception e) { + // cannot instantiate the original exception, just return this + } + return this; + } + + private IOException instantiateException(Class<? extends IOException> cls) + throws Exception { + Constructor<? extends IOException> cn = cls.getConstructor(String.class); + cn.setAccessible(true); + String firstLine = this.getMessage(); + int eol = firstLine.indexOf('\n'); + if (eol>=0) { + firstLine = firstLine.substring(0, eol); + } + IOException ex = cn.newInstance(firstLine); + ex.initCause(this); + return ex; + } + + /** Create RemoteException from attributes */ + public static RemoteException valueOf(Attributes attrs) { + return new RemoteException(attrs.getValue("class"), + attrs.getValue("message")); + } +} Added: hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicies.java URL: http://svn.apache.org/viewvc/hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicies.java?rev=1514580&view=auto ============================================================================== --- hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicies.java (added) +++ hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicies.java Fri Aug 16 05:20:15 2013 @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.ipc.RemoteException; + +/** + * <p> + * A collection of useful implementations of {@link RetryPolicy}. + * </p> + */ +public class RetryPolicies { + private static final Log LOG = LogFactory.getLog(RetryPolicies.class); + + private static ThreadLocal<Random> RANDOM = new ThreadLocal<Random>() { + @Override + protected Random initialValue() { + return new Random(); + } + }; + + /** + * <p> + * Try once, and fail by re-throwing the exception. + * This corresponds to having no retry mechanism in place. + * </p> + */ + public static final RetryPolicy TRY_ONCE_THEN_FAIL = new TryOnceThenFail(); + + /** + * <p> + * Try once, and fail silently for <code>void</code> methods, or by + * re-throwing the exception for non-<code>void</code> methods. + * </p> + */ + public static final RetryPolicy TRY_ONCE_DONT_FAIL = new TryOnceDontFail(); + + /** + * <p> + * Keep trying forever. + * </p> + */ + public static final RetryPolicy RETRY_FOREVER = new RetryForever(); + + /** + * <p> + * Keep trying a limited number of times, waiting a fixed time between attempts, + * and then fail by re-throwing the exception. + * </p> + */ + public static final RetryPolicy retryUpToMaximumCountWithFixedSleep(int maxRetries, long sleepTime, TimeUnit timeUnit) { + return new RetryUpToMaximumCountWithFixedSleep(maxRetries, sleepTime, timeUnit); + } + + /** + * <p> + * Keep trying for a maximum time, waiting a fixed time between attempts, + * and then fail by re-throwing the exception. + * </p> + */ + public static final RetryPolicy retryUpToMaximumTimeWithFixedSleep(long maxTime, long sleepTime, TimeUnit timeUnit) { + return new RetryUpToMaximumTimeWithFixedSleep(maxTime, sleepTime, timeUnit); + } + + /** + * <p> + * Keep trying a limited number of times, waiting a growing amount of time between attempts, + * and then fail by re-throwing the exception. + * The time between attempts is <code>sleepTime</code> mutliplied by the number of tries so far. + * </p> + */ + public static final RetryPolicy retryUpToMaximumCountWithProportionalSleep(int maxRetries, long sleepTime, TimeUnit timeUnit) { + return new RetryUpToMaximumCountWithProportionalSleep(maxRetries, sleepTime, timeUnit); + } + + /** + * <p> + * Keep trying a limited number of times, waiting a growing amount of time between attempts, + * and then fail by re-throwing the exception. + * The time between attempts is <code>sleepTime</code> mutliplied by a random + * number in the range of [0, 2 to the number of retries) + * </p> + */ + public static final RetryPolicy exponentialBackoffRetry( + int maxRetries, long sleepTime, TimeUnit timeUnit) { + return new ExponentialBackoffRetry(maxRetries, sleepTime, timeUnit); + } + + /** + * <p> + * Set a default policy with some explicit handlers for specific exceptions. + * </p> + */ + public static final RetryPolicy retryByException(RetryPolicy defaultPolicy, + Map<Class<? extends Exception>, RetryPolicy> exceptionToPolicyMap) { + return new ExceptionDependentRetry(defaultPolicy, exceptionToPolicyMap); + } + + /** + * <p> + * A retry policy for RemoteException + * Set a default policy with some explicit handlers for specific exceptions. + * </p> + */ + public static final RetryPolicy retryByRemoteException( + RetryPolicy defaultPolicy, + Map<Class<? extends Exception>, RetryPolicy> exceptionToPolicyMap) { + return new RemoteExceptionDependentRetry(defaultPolicy, exceptionToPolicyMap); + } + + static class TryOnceThenFail implements RetryPolicy { + public boolean shouldRetry(Exception e, int retries) throws Exception { + throw e; + } + } + static class TryOnceDontFail implements RetryPolicy { + public boolean shouldRetry(Exception e, int retries) throws Exception { + return false; + } + } + + static class RetryForever implements RetryPolicy { + public boolean shouldRetry(Exception e, int retries) throws Exception { + return true; + } + } + + /** + * Retry up to maxRetries. + * The actual sleep time of the n-th retry is f(n, sleepTime), + * where f is a function provided by the subclass implementation. + * + * The object of the subclasses should be immutable; + * otherwise, the subclass must override hashCode(), equals(..) and toString(). + */ + static abstract class RetryLimited implements RetryPolicy { + final int maxRetries; + final long sleepTime; + final TimeUnit timeUnit; + + private String myString; + + RetryLimited(int maxRetries, long sleepTime, TimeUnit timeUnit) { + if (maxRetries < 0) { + throw new IllegalArgumentException("maxRetries = " + maxRetries+" < 0"); + } + if (sleepTime < 0) { + throw new IllegalArgumentException("sleepTime = " + sleepTime + " < 0"); + } + + this.maxRetries = maxRetries; + this.sleepTime = sleepTime; + this.timeUnit = timeUnit; + } + + @Override + public boolean shouldRetry(Exception e, int retries) throws Exception { + if (retries >= maxRetries) { + throw e; + } + try { + timeUnit.sleep(calculateSleepTime(retries)); + } catch (InterruptedException ie) { + // retry + } + return true; + } + + protected abstract long calculateSleepTime(int retries); + + @Override + public int hashCode() { + return toString().hashCode(); + } + + @Override + public boolean equals(final Object that) { + if (this == that) { + return true; + } else if (that == null || this.getClass() != that.getClass()) { + return false; + } + return this.toString().equals(that.toString()); + } + + @Override + public String toString() { + if (myString == null) { + myString = getClass().getSimpleName() + "(maxRetries=" + maxRetries + + ", sleepTime=" + sleepTime + " " + timeUnit + ")"; + } + return myString; + } + } + + static class RetryUpToMaximumCountWithFixedSleep extends RetryLimited { + public RetryUpToMaximumCountWithFixedSleep(int maxRetries, long sleepTime, TimeUnit timeUnit) { + super(maxRetries, sleepTime, timeUnit); + } + + @Override + protected long calculateSleepTime(int retries) { + return sleepTime; + } + } + + static class RetryUpToMaximumTimeWithFixedSleep extends RetryUpToMaximumCountWithFixedSleep { + public RetryUpToMaximumTimeWithFixedSleep(long maxTime, long sleepTime, TimeUnit timeUnit) { + super((int) (maxTime / sleepTime), sleepTime, timeUnit); + } + } + + static class RetryUpToMaximumCountWithProportionalSleep extends RetryLimited { + public RetryUpToMaximumCountWithProportionalSleep(int maxRetries, long sleepTime, TimeUnit timeUnit) { + super(maxRetries, sleepTime, timeUnit); + } + + @Override + protected long calculateSleepTime(int retries) { + return sleepTime * (retries + 1); + } + } + + /** + * Given pairs of number of retries and sleep time (n0, t0), (n1, t1), ..., + * the first n0 retries sleep t0 milliseconds on average, + * the following n1 retries sleep t1 milliseconds on average, and so on. + * + * For all the sleep, the actual sleep time is randomly uniform distributed + * in the close interval [0.5t, 1.5t], where t is the sleep time specified. + * + * The objects of this class are immutable. + */ + public static class MultipleLinearRandomRetry implements RetryPolicy { + /** Pairs of numRetries and sleepSeconds */ + public static class Pair { + final int numRetries; + final int sleepMillis; + + public Pair(final int numRetries, final int sleepMillis) { + if (numRetries < 0) { + throw new IllegalArgumentException("numRetries = " + numRetries+" < 0"); + } + if (sleepMillis < 0) { + throw new IllegalArgumentException("sleepMillis = " + sleepMillis + " < 0"); + } + + this.numRetries = numRetries; + this.sleepMillis = sleepMillis; + } + + @Override + public String toString() { + return numRetries + "x" + sleepMillis + "ms"; + } + } + + private final List<Pair> pairs; + private String myString; + + public MultipleLinearRandomRetry(List<Pair> pairs) { + if (pairs == null || pairs.isEmpty()) { + throw new IllegalArgumentException("pairs must be neither null nor empty."); + } + this.pairs = Collections.unmodifiableList(pairs); + } + + @Override + public boolean shouldRetry(Exception e, int curRetry) throws Exception { + final Pair p = searchPair(curRetry); + if (p == null) { + //no more retries, re-throw the original exception. + throw e; + } + + //sleep and return true. + //If the sleep is interrupted, throw the InterruptedException out. + final double ratio = RANDOM.get().nextDouble() + 0.5;//0.5 <= ratio <=1.5 + Thread.sleep(Math.round(p.sleepMillis * ratio)); + return true; + } + + /** + * Given the current number of retry, search the corresponding pair. + * @return the corresponding pair, + * or null if the current number of retry > maximum number of retry. + */ + private Pair searchPair(int curRetry) { + int i = 0; + for(; i < pairs.size() && curRetry > pairs.get(i).numRetries; i++) { + curRetry -= pairs.get(i).numRetries; + } + return i == pairs.size()? null: pairs.get(i); + } + + @Override + public int hashCode() { + return toString().hashCode(); + } + + @Override + public boolean equals(final Object that) { + if (this == that) { + return true; + } else if (that == null || this.getClass() != that.getClass()) { + return false; + } + return this.toString().equals(that.toString()); + } + + @Override + public String toString() { + if (myString == null) { + myString = getClass().getSimpleName() + pairs; + } + return myString; + } + + /** + * Parse the given string as a MultipleLinearRandomRetry object. + * The format of the string is "t_1, n_1, t_2, n_2, ...", + * where t_i and n_i are the i-th pair of sleep time and number of retires. + * Note that the white spaces in the string are ignored. + * + * @return the parsed object, or null if the parsing fails. + */ + public static MultipleLinearRandomRetry parseCommaSeparatedString(String s) { + final String[] elements = s.split(","); + if (elements.length == 0) { + LOG.warn("Illegal value: there is no element in \"" + s + "\"."); + return null; + } + if (elements.length % 2 != 0) { + LOG.warn("Illegal value: the number of elements in \"" + s + "\" is " + + elements.length + " but an even number of elements is expected."); + return null; + } + + final List<RetryPolicies.MultipleLinearRandomRetry.Pair> pairs + = new ArrayList<RetryPolicies.MultipleLinearRandomRetry.Pair>(); + + for(int i = 0; i < elements.length; ) { + //parse the i-th sleep-time + final int sleep = parsePositiveInt(elements, i++, s); + if (sleep == -1) { + return null; //parse fails + } + + //parse the i-th number-of-retries + final int retries = parsePositiveInt(elements, i++, s); + if (retries == -1) { + return null; //parse fails + } + + pairs.add(new RetryPolicies.MultipleLinearRandomRetry.Pair(retries, sleep)); + } + return new RetryPolicies.MultipleLinearRandomRetry(pairs); + } + + /** + * Parse the i-th element as an integer. + * @return -1 if the parsing fails or the parsed value <= 0; + * otherwise, return the parsed value. + */ + private static int parsePositiveInt(final String[] elements, + final int i, final String originalString) { + final String s = elements[i].trim(); + final int n; + try { + n = Integer.parseInt(s); + } catch(NumberFormatException nfe) { + LOG.warn("Failed to parse \"" + s + "\", which is the index " + i + + " element in \"" + originalString + "\"", nfe); + return -1; + } + + if (n <= 0) { + LOG.warn("The value " + n + " <= 0: it is parsed from the string \"" + + s + "\" which is the index " + i + " element in \"" + + originalString + "\""); + return -1; + } + return n; + } + } + + static class ExceptionDependentRetry implements RetryPolicy { + + RetryPolicy defaultPolicy; + Map<Class<? extends Exception>, RetryPolicy> exceptionToPolicyMap; + + public ExceptionDependentRetry(RetryPolicy defaultPolicy, + Map<Class<? extends Exception>, RetryPolicy> exceptionToPolicyMap) { + this.defaultPolicy = defaultPolicy; + this.exceptionToPolicyMap = exceptionToPolicyMap; + } + + public boolean shouldRetry(Exception e, int retries) throws Exception { + RetryPolicy policy = exceptionToPolicyMap.get(e.getClass()); + if (policy == null) { + policy = defaultPolicy; + } + return policy.shouldRetry(e, retries); + } + + } + + static class RemoteExceptionDependentRetry implements RetryPolicy { + + RetryPolicy defaultPolicy; + Map<String, RetryPolicy> exceptionNameToPolicyMap; + + public RemoteExceptionDependentRetry(RetryPolicy defaultPolicy, + Map<Class<? extends Exception>, + RetryPolicy> exceptionToPolicyMap) { + this.defaultPolicy = defaultPolicy; + this.exceptionNameToPolicyMap = new HashMap<String, RetryPolicy>(); + for (Entry<Class<? extends Exception>, RetryPolicy> e : + exceptionToPolicyMap.entrySet()) { + exceptionNameToPolicyMap.put(e.getKey().getName(), e.getValue()); + } + } + + public boolean shouldRetry(Exception e, int retries) throws Exception { + RetryPolicy policy = null; + if (e instanceof RemoteException) { + policy = exceptionNameToPolicyMap.get( + ((RemoteException) e).getClassName()); + } + if (policy == null) { + policy = defaultPolicy; + } + return policy.shouldRetry(e, retries); + } + } + + static class ExponentialBackoffRetry extends RetryLimited { + private Random r = new Random(); + public ExponentialBackoffRetry( + int maxRetries, long sleepTime, TimeUnit timeUnit) { + super(maxRetries, sleepTime, timeUnit); + + if (maxRetries < 0) { + throw new IllegalArgumentException("maxRetries = " + maxRetries + " < 0"); + } else if (maxRetries > 30) { + //if maxRetries > 30, calculateSleepTime will overflow. + throw new IllegalArgumentException("maxRetries = " + maxRetries + " > 30"); + } + } + + @Override + protected long calculateSleepTime(int retries) { + return sleepTime*r.nextInt(1<<(retries+1)); + } + } +} Added: hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicy.java URL: http://svn.apache.org/viewvc/hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicy.java?rev=1514580&view=auto ============================================================================== --- hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicy.java (added) +++ hama/trunk/core/src/main/java/org/apache/hama/ipc/RetryPolicy.java Fri Aug 16 05:20:15 2013 @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hama.ipc; + +/** + * <p> + * Specifies a policy for retrying method failures. + * Implementations of this interface should be immutable. + * </p> + */ +public interface RetryPolicy { + /** + * <p> + * Determines whether the framework should retry a + * method for the given exception, and the number + * of retries that have been made for that operation + * so far. + * </p> + * @param e The exception that caused the method to fail. + * @param retries The number of times the method has been retried. + * @return <code>true</code> if the method should be retried, + * <code>false</code> if the method should not be retried + * but shouldn't fail with an exception (only for void methods). + * @throws Exception The re-thrown exception <code>e</code> indicating + * that the method failed and should not be retried further. + */ + public boolean shouldRetry(Exception e, int retries) throws Exception; +}
