http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ClientInvokerImpl.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ClientInvokerImpl.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ClientInvokerImpl.java new file mode 100644 index 0000000..562fea7 --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ClientInvokerImpl.java @@ -0,0 +1,327 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.IOException; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; +import java.util.WeakHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.aries.rsa.provider.fastbin.api.Dispatched; +import org.apache.aries.rsa.provider.fastbin.api.ObjectSerializationStrategy; +import org.apache.aries.rsa.provider.fastbin.api.Serialization; +import org.apache.aries.rsa.provider.fastbin.api.SerializationStrategy; +import org.apache.aries.rsa.provider.fastbin.io.ClientInvoker; +import org.apache.aries.rsa.provider.fastbin.io.ProtocolCodec; +import org.apache.aries.rsa.provider.fastbin.io.Transport; +import org.fusesource.hawtbuf.Buffer; +import org.fusesource.hawtbuf.BufferEditor; +import org.fusesource.hawtbuf.DataByteArrayInputStream; +import org.fusesource.hawtbuf.DataByteArrayOutputStream; +import org.fusesource.hawtbuf.UTF8Buffer; +import org.fusesource.hawtdispatch.DispatchQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClientInvokerImpl implements ClientInvoker, Dispatched { + + public static final long DEFAULT_TIMEOUT = TimeUnit.MINUTES.toMillis(5); + + protected static final Logger LOGGER = LoggerFactory.getLogger(ClientInvokerImpl.class); + + private final static HashMap<Class,String> CLASS_TO_PRIMITIVE = new HashMap<Class, String>(8, 1.0F); + + static { + CLASS_TO_PRIMITIVE.put(boolean.class,"Z"); + CLASS_TO_PRIMITIVE.put(byte.class,"B"); + CLASS_TO_PRIMITIVE.put(char.class,"C"); + CLASS_TO_PRIMITIVE.put(short.class,"S"); + CLASS_TO_PRIMITIVE.put(int.class,"I"); + CLASS_TO_PRIMITIVE.put(long.class,"J"); + CLASS_TO_PRIMITIVE.put(float.class,"F"); + CLASS_TO_PRIMITIVE.put(double.class,"D"); + } + + protected final AtomicLong correlationGenerator = new AtomicLong(); + protected final DispatchQueue queue; + protected final Map<String, TransportPool> transports = new HashMap<String, TransportPool>(); + protected final AtomicBoolean running = new AtomicBoolean(false); + protected final Map<Long, ResponseFuture> requests = new HashMap<Long, ResponseFuture>(); + protected final long timeout; + protected final Map<String, SerializationStrategy> serializationStrategies; + + public ClientInvokerImpl(DispatchQueue queue, Map<String, SerializationStrategy> serializationStrategies) { + this(queue, DEFAULT_TIMEOUT, serializationStrategies); + } + + public ClientInvokerImpl(DispatchQueue queue, long timeout, Map<String, SerializationStrategy> serializationStrategies) { + this.queue = queue; + this.timeout = timeout; + this.serializationStrategies = serializationStrategies; + } + + public DispatchQueue queue() { + return queue; + } + + public void start() throws Exception { + start(null); + } + + public void start(Runnable onComplete) throws Exception { + running.set(true); + if (onComplete != null) { + onComplete.run(); + } + } + + public void stop() { + stop(null); + } + + public void stop(final Runnable onComplete) { + if (running.compareAndSet(true, false)) { + queue().execute(new Runnable() { + public void run() { + final AtomicInteger latch = new AtomicInteger(transports.size()); + final Runnable countDown = new Runnable() { + public void run() { + if (latch.decrementAndGet() == 0) { + if (onComplete != null) { + onComplete.run(); + } + } + } + }; + for (TransportPool pool : transports.values()) { + pool.stop(countDown); + } + } + }); + } else { + if (onComplete != null) { + onComplete.run(); + } + } + } + + public InvocationHandler getProxy(String address, String service, ClassLoader classLoader) { + return new ProxyInvocationHandler(address, service, classLoader); + } + + protected void onCommand(TransportPool pool, Object data) { + try { + DataByteArrayInputStream bais = new DataByteArrayInputStream( (Buffer) data); + int size = bais.readInt(); + long correlation = bais.readVarLong(); + pool.onDone(correlation); + ResponseFuture response = requests.remove(correlation); + if( response!=null ) { + response.set(bais); + } + } catch (Exception e) { + LOGGER.info("Error while reading response", e); + } + } + + protected void onFailure(Object id, Throwable throwable) { + ResponseFuture response = requests.remove(id); + if( response!=null ) { + response.fail(throwable); + } + } + + static final WeakHashMap<Method, MethodData> method_cache = new WeakHashMap<Method, MethodData>(); + + static class MethodData { + private final SerializationStrategy serializationStrategy; + final Buffer signature; + final InvocationStrategy invocationStrategy; + + MethodData(InvocationStrategy invocationStrategy, SerializationStrategy serializationStrategy, Buffer signature) { + this.invocationStrategy = invocationStrategy; + this.serializationStrategy = serializationStrategy; + this.signature = signature; + } + } + + private MethodData getMethodData(Method method) throws IOException { + MethodData rc = null; + synchronized (method_cache) { + rc = method_cache.get(method); + } + if( rc==null ) { + StringBuilder sb = new StringBuilder(); + sb.append(method.getName()); + sb.append(","); + Class<?>[] types = method.getParameterTypes(); + for(int i=0; i < types.length; i++) { + if( i!=0 ) { + sb.append(","); + } + sb.append(encodeClassName(types[i])); + } + Buffer signature = new UTF8Buffer(sb.toString()).buffer(); + + Serialization annotation = method.getAnnotation(Serialization.class); + SerializationStrategy serializationStrategy; + if( annotation!=null ) { + serializationStrategy = serializationStrategies.get(annotation.value()); + if( serializationStrategy==null ) { + throw new RuntimeException("Could not find the serialization strategy named: "+annotation.value()); + } + } else { + serializationStrategy = ObjectSerializationStrategy.INSTANCE; + } + + final InvocationStrategy strategy; + if( AsyncInvocationStrategy.isAsyncMethod(method) ) { + strategy = AsyncInvocationStrategy.INSTANCE; + } else { + strategy = BlockingInvocationStrategy.INSTANCE; + } + + rc = new MethodData(strategy, serializationStrategy, signature); + synchronized (method_cache) { + method_cache.put(method, rc); + } + } + return rc; + } + + String encodeClassName(Class<?> type) { + if( type.getComponentType()!=null ) { + return "["+ encodeClassName(type.getComponentType()); + } + if( type.isPrimitive() ) { + return CLASS_TO_PRIMITIVE.get(type); + } else { + return "L"+type.getName(); + } + } + + protected Object request(ProxyInvocationHandler handler, final String address, final UTF8Buffer service, final ClassLoader classLoader, final Method method, final Object[] args) throws Exception { + if (!running.get()) { + throw new IllegalStateException("DOSGi Client stopped"); + } + + final long correlation = correlationGenerator.incrementAndGet(); + + // Encode the request before we try to pass it onto + // IO layers so that #1 we can report encoding error back to the caller + // and #2 reduce CPU load done in the execution queue since it's + // serially executed. + + DataByteArrayOutputStream baos = new DataByteArrayOutputStream((int) (handler.lastRequestSize*1.10)); + baos.writeInt(0); // we don't know the size yet... + baos.writeVarLong(correlation); + writeBuffer(baos, service); + + MethodData methodData = getMethodData(method); + writeBuffer(baos, methodData.signature); + + final ResponseFuture future = methodData.invocationStrategy.request(methodData.serializationStrategy, classLoader, method, args, baos); + + // toBuffer() is better than toByteArray() since it avoids an + // array copy. + final Buffer command = baos.toBuffer(); + + + // Update the field size. + BufferEditor editor = command.buffer().bigEndianEditor(); + editor.writeInt(command.length); + handler.lastRequestSize = command.length; + + queue().execute(new Runnable() { + public void run() { + try { + TransportPool pool = transports.get(address); + if (pool == null) { + pool = new InvokerTransportPool(address, queue()); + transports.put(address, pool); + pool.start(); + } + requests.put(correlation, future); + pool.offer(command, correlation); + } catch (Exception e) { + LOGGER.info("Error while sending request", e); + future.fail(e); + } + } + }); + + // TODO: make that configurable, that's only for tests + return future.get(timeout, TimeUnit.MILLISECONDS); + } + + private void writeBuffer(DataByteArrayOutputStream baos, Buffer value) throws IOException { + baos.writeVarInt(value.length); + baos.write(value); + } + + protected class ProxyInvocationHandler implements InvocationHandler { + + final String address; + final UTF8Buffer service; + final ClassLoader classLoader; + int lastRequestSize = 250; + + public ProxyInvocationHandler(String address, String service, ClassLoader classLoader) { + this.address = address; + this.service = new UTF8Buffer(service); + this.classLoader = classLoader; + } + + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + return request(this, address, service, classLoader, method, args); + } + + } + + protected class InvokerTransportPool extends TransportPool { + + public InvokerTransportPool(String uri, DispatchQueue queue) { + super(uri, queue, TransportPool.DEFAULT_POOL_SIZE, timeout << 1); + } + + @Override + protected Transport createTransport(String uri) throws Exception { + return new TcpTransportFactory().connect(uri); + } + + @Override + protected ProtocolCodec createCodec() { + return new LengthPrefixedCodec(); + } + + @Override + protected void onCommand(Object command) { + ClientInvokerImpl.this.onCommand(this, command); + } + + @Override + protected void onFailure(Object id, Throwable throwable) { + ClientInvokerImpl.this.onFailure(id, throwable); + } + } + +}
http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/InvocationStrategy.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/InvocationStrategy.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/InvocationStrategy.java new file mode 100644 index 0000000..8b3901b --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/InvocationStrategy.java @@ -0,0 +1,34 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.lang.reflect.Method; + +import org.apache.aries.rsa.provider.fastbin.api.SerializationStrategy; +import org.fusesource.hawtbuf.DataByteArrayInputStream; +import org.fusesource.hawtbuf.DataByteArrayOutputStream; + +/** + * <p> + * </p> + * + */ +public interface InvocationStrategy { + + public ResponseFuture request(SerializationStrategy serializationStrategy, ClassLoader loader, Method method, Object[] args, DataByteArrayOutputStream requestStream) throws Exception; + + void service(SerializationStrategy serializationStrategy, ClassLoader loader, Method method, Object target, DataByteArrayInputStream requestStream, DataByteArrayOutputStream responseStream, Runnable onComplete); +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/LengthPrefixedCodec.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/LengthPrefixedCodec.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/LengthPrefixedCodec.java new file mode 100644 index 0000000..ca8c588 --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/LengthPrefixedCodec.java @@ -0,0 +1,175 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.EOFException; +import java.io.IOException; +import java.net.ProtocolException; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.WritableByteChannel; +import java.util.LinkedList; +import java.util.Queue; + +import org.apache.aries.rsa.provider.fastbin.io.ProtocolCodec; +import org.fusesource.hawtbuf.Buffer; + +public class LengthPrefixedCodec implements ProtocolCodec { + + + final int write_buffer_size = 1024 * 64; + long write_counter = 0L; + WritableByteChannel write_channel; + + final Queue<ByteBuffer> next_write_buffers = new LinkedList<ByteBuffer>(); + int next_write_size = 0; + + public boolean full() { + return false; + } + + protected boolean empty() { + if (next_write_size > 0) { + return false; + } + if (!next_write_buffers.isEmpty()) { + for (ByteBuffer b : next_write_buffers) { + if (b.remaining() > 0) { + return false; + } + } + } + return true; + } + + public void setWritableByteChannel(WritableByteChannel channel) { + this.write_channel = channel; + if (channel instanceof SocketChannel) { + try { + ((SocketChannel) channel).socket().setSendBufferSize(write_buffer_size); + } catch (SocketException e) { + e.printStackTrace(); + } + } + } + + public BufferState write(Object value) throws IOException { + if (full()) { + return BufferState.FULL; + } else { + boolean wasEmpty = empty(); + Buffer buffer = (Buffer) value; + next_write_size += buffer.length; + next_write_buffers.add(buffer.toByteBuffer()); + return wasEmpty ? BufferState.WAS_EMPTY : BufferState.NOT_EMPTY; + } + } + + public BufferState flush() throws IOException { + final long writeCounterBeforeFlush = write_counter; + while(!next_write_buffers.isEmpty()) { + final ByteBuffer nextBuffer = next_write_buffers.peek(); + if (nextBuffer.remaining() < 1) { + next_write_buffers.remove(); + continue; + } + int bytesWritten = write_channel.write(nextBuffer); + write_counter += bytesWritten; + next_write_size -= bytesWritten; + if (nextBuffer.remaining() > 0) { + break; + } + } + if (empty()) { + if (writeCounterBeforeFlush == write_counter) { + return BufferState.WAS_EMPTY; + } else { + return BufferState.EMPTY; + } + } + return BufferState.NOT_EMPTY; + } + + public long getWriteCounter() { + return write_counter; + } + + long read_counter = 0L; + int read_buffer_size = 1024 * 64; + ReadableByteChannel read_channel = null; + ByteBuffer read_buffer = ByteBuffer.allocate(4); + + + public void setReadableByteChannel(ReadableByteChannel channel) { + read_channel = channel; + if (channel instanceof SocketChannel) { + try { + ((SocketChannel) channel).socket().setReceiveBufferSize(read_buffer_size); + } catch (SocketException e) { + e.printStackTrace(); + } + } + } + + public Object read() throws IOException { + while(true) { + if( read_buffer.remaining()!=0 ) { + // keep reading from the channel until we fill the read buffer.. + int count = read_channel.read(read_buffer); + if (count == -1) { + throw new EOFException("Peer disconnected"); + } else if (count == 0) { + return null; + } + read_counter += count; + } else { + //read buffer is full.. interpret it.. + read_buffer.flip(); + + if( read_buffer.capacity() == 4 ) { + // Finding out the + int size = read_buffer.getInt(0); + if( size < 4 ) { + throw new ProtocolException("Expecting a size greater than 3"); + } + if( size == 4 ) { + // weird.. empty frame.. guess it could happen. + Buffer rc = new Buffer(read_buffer); + read_buffer = ByteBuffer.allocate(4); + return rc; + } else { + // Resize to the right size.. this resumes the reads.. + ByteBuffer next = ByteBuffer.allocate(size); + next.putInt(size); + read_buffer = next; + } + } else { + // finish loading the rest of the buffer.. + Buffer rc = new Buffer(read_buffer); + read_buffer = ByteBuffer.allocate(4); + return rc; + } + } + } + } + + public long getReadCounter() { + return read_counter; + } + +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ResponseFuture.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ResponseFuture.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ResponseFuture.java new file mode 100644 index 0000000..2744f9a --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ResponseFuture.java @@ -0,0 +1,31 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.util.concurrent.TimeUnit; + +import org.fusesource.hawtbuf.DataByteArrayInputStream; + +/** + * <p> + * </p> + * + */ +public interface ResponseFuture { + void set(DataByteArrayInputStream responseStream) throws Exception; + Object get(long timeout, TimeUnit unit) throws Exception; + void fail(Throwable throwable); +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ServerInvokerImpl.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ServerInvokerImpl.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ServerInvokerImpl.java new file mode 100644 index 0000000..9a12dbd --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/ServerInvokerImpl.java @@ -0,0 +1,314 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.EOFException; +import java.io.IOException; +import java.lang.reflect.Array; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.aries.rsa.provider.fastbin.api.Dispatched; +import org.apache.aries.rsa.provider.fastbin.api.ObjectSerializationStrategy; +import org.apache.aries.rsa.provider.fastbin.api.Serialization; +import org.apache.aries.rsa.provider.fastbin.api.SerializationStrategy; +import org.apache.aries.rsa.provider.fastbin.io.ServerInvoker; +import org.apache.aries.rsa.provider.fastbin.io.Transport; +import org.apache.aries.rsa.provider.fastbin.io.TransportAcceptListener; +import org.apache.aries.rsa.provider.fastbin.io.TransportListener; +import org.apache.aries.rsa.provider.fastbin.io.TransportServer; +import org.fusesource.hawtbuf.Buffer; +import org.fusesource.hawtbuf.BufferEditor; +import org.fusesource.hawtbuf.DataByteArrayInputStream; +import org.fusesource.hawtbuf.DataByteArrayOutputStream; +import org.fusesource.hawtbuf.UTF8Buffer; +import org.fusesource.hawtdispatch.DispatchQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ServerInvokerImpl implements ServerInvoker, Dispatched { + + protected static final Logger LOGGER = LoggerFactory.getLogger(ServerInvokerImpl.class); + static private final HashMap<String, Class> PRIMITIVE_TO_CLASS = new HashMap<String, Class>(8, 1.0F); + static { + PRIMITIVE_TO_CLASS.put("Z", boolean.class); + PRIMITIVE_TO_CLASS.put("B", byte.class); + PRIMITIVE_TO_CLASS.put("C", char.class); + PRIMITIVE_TO_CLASS.put("S", short.class); + PRIMITIVE_TO_CLASS.put("I", int.class); + PRIMITIVE_TO_CLASS.put("J", long.class); + PRIMITIVE_TO_CLASS.put("F", float.class); + PRIMITIVE_TO_CLASS.put("D", double.class); + } + + protected final ExecutorService blockingExecutor = Executors.newFixedThreadPool(8); + protected final DispatchQueue queue; + private final Map<String, SerializationStrategy> serializationStrategies; + protected final TransportServer server; + protected final Map<UTF8Buffer, ServiceFactoryHolder> holders = new HashMap<UTF8Buffer, ServiceFactoryHolder>(); + + static class MethodData { + + private final SerializationStrategy serializationStrategy; + final InvocationStrategy invocationStrategy; + final Method method; + + MethodData(InvocationStrategy invocationStrategy, SerializationStrategy serializationStrategy, Method method) { + this.invocationStrategy = invocationStrategy; + this.serializationStrategy = serializationStrategy; + this.method = method; + } + } + + class ServiceFactoryHolder { + + private final ServiceFactory factory; + private final ClassLoader loader; + private final Class clazz; + private HashMap<Buffer, MethodData> method_cache = new HashMap<Buffer, MethodData>(); + + public ServiceFactoryHolder(ServiceFactory factory, ClassLoader loader) { + this.factory = factory; + this.loader = loader; + Object o = factory.get(); + clazz = o.getClass(); + factory.unget(); + } + + private MethodData getMethodData(Buffer data) throws IOException, NoSuchMethodException, ClassNotFoundException { + MethodData rc = method_cache.get(data); + if( rc == null ) { + String[] parts = data.utf8().toString().split(","); + String name = parts[0]; + Class params[] = new Class[parts.length-1]; + for( int i=0; i < params.length; i++) { + params[i] = decodeClass(parts[i+1]); + } + Method method = clazz.getMethod(name, params); + + + Serialization annotation = method.getAnnotation(Serialization.class); + SerializationStrategy serializationStrategy; + if( annotation!=null ) { + serializationStrategy = serializationStrategies.get(annotation.value()); + if( serializationStrategy==null ) { + throw new RuntimeException("Could not find the serialization strategy named: "+annotation.value()); + } + } else { + serializationStrategy = ObjectSerializationStrategy.INSTANCE; + } + + + final InvocationStrategy invocationStrategy; + if( AsyncInvocationStrategy.isAsyncMethod(method) ) { + invocationStrategy = AsyncInvocationStrategy.INSTANCE; + } else { + invocationStrategy = BlockingInvocationStrategy.INSTANCE; + } + + rc = new MethodData(invocationStrategy, serializationStrategy, method); + method_cache.put(data, rc); + } + return rc; + } + + private Class<?> decodeClass(String s) throws ClassNotFoundException { + if( s.startsWith("[")) { + Class<?> nested = decodeClass(s.substring(1)); + return Array.newInstance(nested,0).getClass(); + } + String c = s.substring(0,1); + if( c.equals("L") ) { + return loader.loadClass(s.substring(1)); + } else { + return PRIMITIVE_TO_CLASS.get(c); + } + } + + } + + + public ServerInvokerImpl(String address, DispatchQueue queue, Map<String, SerializationStrategy> serializationStrategies) throws Exception { + this.queue = queue; + this.serializationStrategies = serializationStrategies; + this.server = new TcpTransportFactory().bind(address); + this.server.setDispatchQueue(queue); + this.server.setAcceptListener(new InvokerAcceptListener()); + } + + public InetSocketAddress getSocketAddress() { + return this.server.getSocketAddress(); + } + + + public DispatchQueue queue() { + return queue; + } + + public String getConnectAddress() { + return this.server.getConnectAddress(); + } + + public void registerService(final String id, final ServiceFactory service, final ClassLoader classLoader) { + queue().execute(new Runnable() { + public void run() { + holders.put(new UTF8Buffer(id), new ServiceFactoryHolder(service, classLoader)); + } + }); + } + + public void unregisterService(final String id) { + queue().execute(new Runnable() { + public void run() { + holders.remove(new UTF8Buffer(id)); + } + }); + } + + public void start() throws Exception { + start(null); + } + + public void start(Runnable onComplete) throws Exception { + this.server.start(onComplete); + } + + public void stop() { + stop(null); + } + + public void stop(final Runnable onComplete) { + this.server.stop(new Runnable() { + public void run() { + blockingExecutor.shutdown(); + if (onComplete != null) { + onComplete.run(); + } + } + }); + } + + + protected void onCommand(final Transport transport, Object data) { + try { + final DataByteArrayInputStream bais = new DataByteArrayInputStream((Buffer) data); + final int size = bais.readInt(); + final long correlation = bais.readVarLong(); + + // Use UTF8Buffer instead of string to avoid encoding/decoding UTF-8 strings + // for every request. + final UTF8Buffer service = readBuffer(bais).utf8(); + final Buffer encoded_method = readBuffer(bais); + + final ServiceFactoryHolder holder = holders.get(service); + final MethodData methodData = holder.getMethodData(encoded_method); + + final Object svc = holder.factory.get(); + + Runnable task = new Runnable() { + public void run() { + + final DataByteArrayOutputStream baos = new DataByteArrayOutputStream(); + try { + baos.writeInt(0); // make space for the size field. + baos.writeVarLong(correlation); + } catch (IOException e) { // should not happen + throw new RuntimeException(e); + } + + // Lets decode the remaining args on the target's executor + // to take cpu load off the + methodData.invocationStrategy.service(methodData.serializationStrategy, holder.loader, methodData.method, svc, bais, baos, new Runnable() { + public void run() { + holder.factory.unget(); + final Buffer command = baos.toBuffer(); + + // Update the size field. + BufferEditor editor = command.buffer().bigEndianEditor(); + editor.writeInt(command.length); + + queue().execute(new Runnable() { + public void run() { + transport.offer(command); + } + }); + } + }); + } + }; + + Executor executor; + if( svc instanceof Dispatched ) { + executor = ((Dispatched)svc).queue(); + } else { + executor = blockingExecutor; + } + executor.execute(task); + + } catch (Exception e) { + LOGGER.info("Error while reading request", e); + } + } + + private Buffer readBuffer(DataByteArrayInputStream bais) throws IOException { + byte b[] = new byte[bais.readVarInt()]; + bais.readFully(b); + return new Buffer(b); + } + + class InvokerAcceptListener implements TransportAcceptListener { + + public void onAccept(TransportServer transportServer, TcpTransport transport) { + transport.setProtocolCodec(new LengthPrefixedCodec()); + transport.setDispatchQueue(queue()); + transport.setTransportListener(new InvokerTransportListener()); + transport.start(); + } + + public void onAcceptError(TransportServer transportServer, Exception error) { + LOGGER.info("Error accepting incoming connection", error); + } + } + + class InvokerTransportListener implements TransportListener { + + public void onTransportCommand(Transport transport, Object command) { + ServerInvokerImpl.this.onCommand(transport, command); + } + + public void onRefill(Transport transport) { + } + + public void onTransportFailure(Transport transport, IOException error) { + if (!transport.isDisposed() && !(error instanceof EOFException)) { + LOGGER.info("Transport failure", error); + } + } + + public void onTransportConnected(Transport transport) { + transport.resumeRead(); + } + + public void onTransportDisconnected(Transport transport) { + } + } + +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransport.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransport.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransport.java new file mode 100644 index 0000000..ebedf0e --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransport.java @@ -0,0 +1,828 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.UnknownHostException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.nio.channels.WritableByteChannel; +import java.util.LinkedList; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.apache.aries.rsa.provider.fastbin.io.ProtocolCodec; +import org.apache.aries.rsa.provider.fastbin.io.Transport; +import org.apache.aries.rsa.provider.fastbin.io.TransportListener; +import org.fusesource.hawtdispatch.Dispatch; +import org.fusesource.hawtdispatch.DispatchQueue; +import org.fusesource.hawtdispatch.DispatchSource; +import org.fusesource.hawtdispatch.Retained; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class TcpTransport implements Transport { + + private static final Logger LOG = LoggerFactory.getLogger(TcpTransport.class); + + protected State _serviceState = CREATED; + + protected Map<String, Object> socketOptions; + + public static class State { + public String toString() { + return getClass().getSimpleName(); + } + public boolean isStarted() { + return false; + } + } + + static class CallbackSupport extends State { + LinkedList<Runnable> callbacks = new LinkedList<Runnable>(); + + void add(Runnable r) { + if (r != null) { + callbacks.add(r); + } + } + + void done() { + for (Runnable callback : callbacks) { + callback.run(); + } + } + } + + abstract static class SocketState { + void onStop(Runnable onCompleted) { + } + void onCanceled() { + } + boolean is(Class<? extends SocketState> clazz) { + return getClass()==clazz; + } + } + + public static final State CREATED = new State(); + + public static class STARTING extends CallbackSupport { + } + + public static final State STARTED = new State() { + public boolean isStarted() { + return true; + } + }; + public static class STOPPING extends CallbackSupport { + } + + public static final State STOPPED = new State(); + + + final public void start() { + start(null); + } + + final public void stop() { + stop(null); + } + + final public void start(final Runnable onCompleted) { + queue().execute(new Runnable() { + public void run() { + if (_serviceState == CREATED || + _serviceState == STOPPED) { + final STARTING state = new STARTING(); + state.add(onCompleted); + _serviceState = state; + _start(new Runnable() { + public void run() { + _serviceState = STARTED; + state.done(); + } + }); + } else if (_serviceState instanceof STARTING) { + ((STARTING) _serviceState).add(onCompleted); + } else if (_serviceState == STARTED) { + if (onCompleted != null) { + onCompleted.run(); + } + } else { + if (onCompleted != null) { + onCompleted.run(); + } + LOG.error("start should not be called from state: " + _serviceState); + } + } + }); + } + + final public void stop(final Runnable onCompleted) { + queue().execute(new Runnable() { + public void run() { + if (_serviceState == STARTED) { + final STOPPING state = new STOPPING(); + state.add(onCompleted); + _serviceState = state; + _stop(new Runnable() { + public void run() { + _serviceState = STOPPED; + state.done(); + } + }); + } else if (_serviceState instanceof STOPPING) { + ((STOPPING) _serviceState).add(onCompleted); + } else if (_serviceState == STOPPED) { + if (onCompleted != null) { + onCompleted.run(); + } + } else { + if (onCompleted != null) { + onCompleted.run(); + } + LOG.error("stop should not be called from state: " + _serviceState); + } + } + }); + } + + protected State getServiceState() { + return _serviceState; + } + + static class DISCONNECTED extends SocketState{} + + class CONNECTING extends SocketState{ + void onStop(Runnable onCompleted) { + trace("CONNECTING.onStop"); + CANCELING state = new CANCELING(); + socketState = state; + state.onStop(onCompleted); + } + void onCanceled() { + trace("CONNECTING.onCanceled"); + CANCELING state = new CANCELING(); + socketState = state; + state.onCanceled(); + } + } + + class CONNECTED extends SocketState { + void onStop(Runnable onCompleted) { + trace("CONNECTED.onStop"); + CANCELING state = new CANCELING(); + socketState = state; + state.add(createDisconnectTask()); + state.onStop(onCompleted); + } + void onCanceled() { + trace("CONNECTED.onCanceled"); + CANCELING state = new CANCELING(); + socketState = state; + state.add(createDisconnectTask()); + state.onCanceled(); + } + Runnable createDisconnectTask() { + return new Runnable(){ + public void run() { + listener.onTransportDisconnected(TcpTransport.this); + } + }; + } + } + + class CANCELING extends SocketState { + private LinkedList<Runnable> runnables = new LinkedList<Runnable>(); + private int remaining; + private boolean dispose; + + public CANCELING() { + if( readSource!=null ) { + remaining++; + readSource.cancel(); + } + if( writeSource!=null ) { + remaining++; + writeSource.cancel(); + } + } + void onStop(Runnable onCompleted) { + trace("CANCELING.onCompleted"); + add(onCompleted); + dispose = true; + } + void add(Runnable onCompleted) { + if( onCompleted!=null ) { + runnables.add(onCompleted); + } + } + void onCanceled() { + trace("CANCELING.onCanceled"); + remaining--; + if( remaining!=0 ) { + return; + } + try { + channel.close(); + } catch (IOException ignore) { + } + socketState = new CANCELED(dispose); + for (Runnable runnable : runnables) { + runnable.run(); + } + if (dispose) { + dispose(); + } + } + } + + class CANCELED extends SocketState { + private boolean disposed; + + public CANCELED(boolean disposed) { + this.disposed=disposed; + } + + void onStop(Runnable onCompleted) { + trace("CANCELED.onStop"); + if( !disposed ) { + disposed = true; + dispose(); + } + onCompleted.run(); + } + } + + protected URI remoteLocation; + protected URI localLocation; + protected TransportListener listener; + protected String remoteAddress; + protected ProtocolCodec codec; + + protected SocketChannel channel; + + protected SocketState socketState = new DISCONNECTED(); + + protected DispatchQueue dispatchQueue; + private DispatchSource readSource; + private DispatchSource writeSource; + + protected boolean useLocalHost = true; + + int max_read_rate; + int max_write_rate; + protected RateLimitingChannel rateLimitingChannel; + + class RateLimitingChannel implements ReadableByteChannel, WritableByteChannel { + + int read_allowance = max_read_rate; + boolean read_suspended = false; + int read_resume_counter = 0; + int write_allowance = max_write_rate; + boolean write_suspended = false; + + public void resetAllowance() { + if( read_allowance != max_read_rate || write_allowance != max_write_rate) { + read_allowance = max_read_rate; + write_allowance = max_write_rate; + if( write_suspended ) { + write_suspended = false; + resumeWrite(); + } + if( read_suspended ) { + read_suspended = false; + resumeRead(); + for( int i=0; i < read_resume_counter ; i++ ) { + resumeRead(); + } + } + } + } + + public int read(ByteBuffer dst) throws IOException { + if( max_read_rate==0 ) { + return channel.read(dst); + } else { + int remaining = dst.remaining(); + if( read_allowance ==0 || remaining ==0 ) { + return 0; + } + + int reduction = 0; + if( remaining > read_allowance) { + reduction = remaining - read_allowance; + dst.limit(dst.limit() - reduction); + } + int rc=0; + try { + rc = channel.read(dst); + read_allowance -= rc; + } finally { + if( reduction!=0 ) { + if( dst.remaining() == 0 ) { + // we need to suspend the read now until we get + // a new allowance.. + readSource.suspend(); + read_suspended = true; + } + dst.limit(dst.limit() + reduction); + } + } + return rc; + } + } + + public int write(ByteBuffer src) throws IOException { + if( max_write_rate==0 ) { + return channel.write(src); + } else { + int remaining = src.remaining(); + if( write_allowance ==0 || remaining ==0 ) { + return 0; + } + + int reduction = 0; + if( remaining > write_allowance) { + reduction = remaining - write_allowance; + src.limit(src.limit() - reduction); + } + int rc = 0; + try { + rc = channel.write(src); + write_allowance -= rc; + } finally { + if( reduction!=0 ) { + if( src.remaining() == 0 ) { + // we need to suspend the read now until we get + // a new allowance.. + write_suspended = true; + suspendWrite(); + } + src.limit(src.limit() + reduction); + } + } + return rc; + } + } + + public boolean isOpen() { + return channel.isOpen(); + } + + public void close() throws IOException { + channel.close(); + } + + public void resumeRead() { + if( read_suspended ) { + read_resume_counter += 1; + } else { + _resumeRead(); + } + } + + } + + private final Runnable CANCEL_HANDLER = new Runnable() { + public void run() { + socketState.onCanceled(); + } + }; + + static final class OneWay { + final Object command; + final Retained retained; + + public OneWay(Object command, Retained retained) { + this.command = command; + this.retained = retained; + } + } + + public void connected(SocketChannel channel) throws IOException, Exception { + this.channel = channel; + + if( codec !=null ) { + initializeCodec(); + } + + this.channel.configureBlocking(false); + this.remoteAddress = channel.socket().getRemoteSocketAddress().toString(); + channel.socket().setSoLinger(true, 0); + channel.socket().setTcpNoDelay(true); + + this.socketState = new CONNECTED(); + } + + protected void initializeCodec() { + codec.setReadableByteChannel(readChannel()); + codec.setWritableByteChannel(writeChannel()); + } + + public void connecting(URI remoteLocation, URI localLocation) throws IOException, Exception { + this.channel = SocketChannel.open(); + this.channel.configureBlocking(false); + this.remoteLocation = remoteLocation; + this.localLocation = localLocation; + + if (localLocation != null) { + InetSocketAddress localAddress = new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort()); + channel.socket().bind(localAddress); + } + + String host = resolveHostName(remoteLocation.getHost()); + InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort()); + channel.connect(remoteAddress); + this.socketState = new CONNECTING(); + } + + + public DispatchQueue queue() { + return dispatchQueue; + } + + public void setDispatchQueue(DispatchQueue queue) { + this.dispatchQueue = queue; + } + + public void _start(Runnable onCompleted) { + try { + if (socketState.is(CONNECTING.class) ) { + trace("connecting..."); + // this allows the connect to complete.. + readSource = Dispatch.createSource(channel, SelectionKey.OP_CONNECT, dispatchQueue); + readSource.setEventHandler(new Runnable() { + public void run() { + if (getServiceState() != STARTED) { + return; + } + try { + trace("connected."); + channel.finishConnect(); + readSource.setCancelHandler(null); + readSource.cancel(); + readSource=null; + socketState = new CONNECTED(); + onConnected(); + } catch (IOException e) { + onTransportFailure(e); + } + } + }); + readSource.setCancelHandler(CANCEL_HANDLER); + readSource.resume(); + + } else if (socketState.is(CONNECTED.class) ) { + dispatchQueue.execute(new Runnable() { + public void run() { + try { + trace("was connected."); + onConnected(); + } catch (IOException e) { + onTransportFailure(e); + } + } + }); + } else { + System.err.println("cannot be started. socket state is: "+socketState); + } + } finally { + if( onCompleted!=null ) { + onCompleted.run(); + } + } + } + + public void _stop(final Runnable onCompleted) { + trace("stopping.. at state: "+socketState); + socketState.onStop(onCompleted); + } + + protected String resolveHostName(String host) throws UnknownHostException { + String localName = InetAddress.getLocalHost().getHostName(); + if (localName != null && isUseLocalHost()) { + if (localName.equals(host)) { + return "localhost"; + } + } + return host; + } + + protected void onConnected() throws IOException { + + readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue); + writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue); + + readSource.setCancelHandler(CANCEL_HANDLER); + writeSource.setCancelHandler(CANCEL_HANDLER); + + readSource.setEventHandler(new Runnable() { + public void run() { + drainInbound(); + } + }); + writeSource.setEventHandler(new Runnable() { + public void run() { + drainOutbound(); + } + }); + + if( max_read_rate!=0 || max_write_rate!=0 ) { + rateLimitingChannel = new RateLimitingChannel(); + schedualRateAllowanceReset(); + } + + remoteAddress = channel.socket().getRemoteSocketAddress().toString(); + listener.onTransportConnected(this); + } + + private void schedualRateAllowanceReset() { + dispatchQueue.executeAfter(1, TimeUnit.SECONDS, new Runnable(){ + public void run() { + if( !socketState.is(CONNECTED.class) ) { + return; + } + rateLimitingChannel.resetAllowance(); + schedualRateAllowanceReset(); + } + }); + } + + private void dispose() { + if( readSource!=null ) { + readSource.cancel(); + readSource=null; + } + + if( writeSource!=null ) { + writeSource.cancel(); + writeSource=null; + } + this.codec = null; + } + + public void onTransportFailure(IOException error) { + listener.onTransportFailure(this, error); + socketState.onCanceled(); + } + + + public boolean full() { + return codec.full(); + } + + public boolean offer(Object command) { + assert Dispatch.getCurrentQueue() == dispatchQueue; + try { + if (!socketState.is(CONNECTED.class)) { + throw new IOException("Not connected."); + } + if (getServiceState() != STARTED) { + throw new IOException("Not running."); + } + + ProtocolCodec.BufferState rc = codec.write(command); + switch (rc ) { + case FULL: + return false; + default: + if( drained ) { + drained = false; + resumeWrite(); + } + return true; + } + } catch (IOException e) { + onTransportFailure(e); + return false; + } + + } + + + boolean drained = true; + /** + * + */ + protected void drainOutbound() { + assert Dispatch.getCurrentQueue() == dispatchQueue; + if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) { + return; + } + try { + if( codec.flush() == ProtocolCodec.BufferState.WAS_EMPTY && flush() ) { + if( !drained ) { + drained = true; + suspendWrite(); + listener.onRefill(this); + } + } + } catch (IOException e) { + onTransportFailure(e); + } + } + + protected boolean flush() throws IOException { + return true; + } + + protected void drainInbound() { + if (!getServiceState().isStarted() || readSource.isSuspended()) { + return; + } + try { + long initial = codec.getReadCounter(); + // Only process upto 64k worth of data at a time so we can give + // other connections a chance to process their requests. + while( codec.getReadCounter()-initial < 1024*64 ) { + Object command = codec.read(); + if ( command!=null ) { + try { + listener.onTransportCommand(this, command); + } catch (Throwable e) { + onTransportFailure(new IOException("Transport listener failure.")); + } + + // the transport may be suspended after processing a command. + if (getServiceState() == STOPPED || readSource.isSuspended()) { + return; + } + } else { + return; + } + } + } catch (IOException e) { + onTransportFailure(e); + } + } + + + public String getRemoteAddress() { + return remoteAddress; + } + + private boolean assertConnected() { + try { + if ( !isConnected() ) { + throw new IOException("Not connected."); + } + return true; + } catch (IOException e) { + onTransportFailure(e); + } + return false; + } + + public void suspendRead() { + if( isConnected() && readSource!=null ) { + readSource.suspend(); + } + } + + + public void resumeRead() { + if( isConnected() && readSource!=null ) { + if( rateLimitingChannel!=null ) { + rateLimitingChannel.resumeRead(); + } else { + _resumeRead(); + } + } + } + private void _resumeRead() { + readSource.resume(); + dispatchQueue.execute(new Runnable(){ + public void run() { + drainInbound(); + } + }); + } + + protected void suspendWrite() { + if( isConnected() && writeSource!=null ) { + writeSource.suspend(); + } + } + protected void resumeWrite() { + if( isConnected() && writeSource!=null ) { + writeSource.resume(); + dispatchQueue.execute(new Runnable(){ + public void run() { + drainOutbound(); + } + }); + } + } + + public TransportListener getTransportListener() { + return listener; + } + + public void setTransportListener(TransportListener listener) { + this.listener = listener; + } + + public ProtocolCodec getProtocolCodec() { + return codec; + } + + public void setProtocolCodec(ProtocolCodec protocolCodec) { + this.codec = protocolCodec; + if( channel!=null && codec!=null ) { + initializeCodec(); + } + } + + public boolean isConnected() { + return socketState.is(CONNECTED.class); + } + + public boolean isDisposed() { + return getServiceState() == STOPPED; + } + + public void setSocketOptions(Map<String, Object> socketOptions) { + this.socketOptions = socketOptions; + } + + public boolean isUseLocalHost() { + return useLocalHost; + } + + /** + * Sets whether 'localhost' or the actual local host name should be used to + * make local connections. On some operating systems such as Macs its not + * possible to connect as the local host name so localhost is better. + */ + public void setUseLocalHost(boolean useLocalHost) { + this.useLocalHost = useLocalHost; + } + + + private void trace(String message) { + if( LOG.isTraceEnabled() ) { + final String label = dispatchQueue.getLabel(); + if( label !=null ) { + LOG.trace(label +" | "+message); + } else { + LOG.trace(message); + } + } + } + + public SocketChannel getSocketChannel() { + return channel; + } + + public ReadableByteChannel readChannel() { + if(rateLimitingChannel!=null) { + return rateLimitingChannel; + } else { + return channel; + } + } + + public WritableByteChannel writeChannel() { + if(rateLimitingChannel!=null) { + return rateLimitingChannel; + } else { + return channel; + } + } + + public int getMax_read_rate() { + return max_read_rate; + } + + public void setMax_read_rate(int max_read_rate) { + this.max_read_rate = max_read_rate; + } + + public int getMax_write_rate() { + return max_write_rate; + } + + public void setMax_write_rate(int max_write_rate) { + this.max_write_rate = max_write_rate; + } + +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportFactory.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportFactory.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportFactory.java new file mode 100644 index 0000000..2404190 --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportFactory.java @@ -0,0 +1,123 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.security.NoSuchAlgorithmException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.aries.rsa.provider.fastbin.util.IntrospectionSupport; +import org.apache.aries.rsa.provider.fastbin.util.URISupport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + */ +public class TcpTransportFactory { + private static final Logger LOG = LoggerFactory.getLogger(TcpTransportFactory.class); + + public TcpTransportServer bind(String location) throws Exception { + + URI uri = new URI(location); + TcpTransportServer server = createTcpTransportServer(uri); + if (server == null) return null; + + Map<String, String> options = new HashMap<String, String>(URISupport.parseParameters(uri)); + IntrospectionSupport.setProperties(server, options); + Map<String, Object> transportOptions = IntrospectionSupport.extractProperties(options, "transport."); + server.setTransportOption(transportOptions); + return server; + } + + + public TcpTransport connect(String location) throws Exception { + URI uri = new URI(location); + TcpTransport transport = createTransport(uri); + if (transport == null) return null; + + Map<String, String> options = new HashMap<String, String>(URISupport.parseParameters(uri)); + URI localLocation = getLocalLocation(uri); + + transport.connecting(uri, localLocation); + + Map<String, Object> socketOptions = IntrospectionSupport.extractProperties(options, "socket."); + transport.setSocketOptions(socketOptions); + + IntrospectionSupport.setProperties(transport, options); + if (!options.isEmpty()) { + // Release the transport resource as we are erroring out... + try { + transport.stop(); + } catch (Throwable cleanup) { + } + throw new IllegalArgumentException("Invalid connect parameters: " + options); + } + return transport; + } + + /** + * Allows subclasses of TcpTransportFactory to create custom instances of + * TcpTransportServer. + */ + protected TcpTransportServer createTcpTransportServer(final URI location) throws IOException, URISyntaxException, Exception { + if( !location.getScheme().equals("tcp") ) { + return null; + } + return new TcpTransportServer(location); + } + + /** + * Allows subclasses of TcpTransportFactory to create custom instances of + * TcpTransport. + */ + protected TcpTransport createTransport(URI uri) throws NoSuchAlgorithmException, Exception { + if( !uri.getScheme().equals("tcp") ) { + return null; + } + TcpTransport transport = new TcpTransport(); + return transport; + } + + protected URI getLocalLocation(URI location) { + URI localLocation = null; + String path = location.getPath(); + // see if the path is a local URI location + if (path != null && path.length() > 0) { + int localPortIndex = path.indexOf(':'); + try { + Integer.parseInt(path.substring(localPortIndex + 1, path.length())); + String localString = location.getScheme() + ":/" + path; + localLocation = new URI(localString); + } catch (Exception e) { + LOG.warn("path isn't a valid local location for TcpTransport to use", e); + } + } + return localLocation; + } + + protected String getOption(Map options, String key, String def) { + String rc = (String) options.remove(key); + if( rc == null ) { + rc = def; + } + return rc; + } + +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportServer.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportServer.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportServer.java new file mode 100644 index 0000000..3e828b5 --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TcpTransportServer.java @@ -0,0 +1,231 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.nio.channels.SelectionKey; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.HashMap; +import java.util.Map; + +import org.apache.aries.rsa.provider.fastbin.io.TransportAcceptListener; +import org.apache.aries.rsa.provider.fastbin.io.TransportServer; +import org.apache.aries.rsa.provider.fastbin.util.IntrospectionSupport; +import org.fusesource.hawtdispatch.Dispatch; +import org.fusesource.hawtdispatch.DispatchQueue; +import org.fusesource.hawtdispatch.DispatchSource; + +/** + * A TCP based transport server + * + */ + +public class TcpTransportServer implements TransportServer { + + private final String bindScheme; + private final InetSocketAddress bindAddress; + + private int backlog = 100; + private Map<String, Object> transportOptions; + + private ServerSocketChannel channel; + private TransportAcceptListener listener; + private DispatchQueue dispatchQueue; + private DispatchSource acceptSource; + + public TcpTransportServer(URI location) throws UnknownHostException { + bindScheme = location.getScheme(); + String host = location.getHost(); + host = (host == null || host.length() == 0) ? "::" : host; + bindAddress = new InetSocketAddress(InetAddress.getByName(host), location.getPort()); + } + + public void setAcceptListener(TransportAcceptListener listener) { + this.listener = listener; + } + + public InetSocketAddress getSocketAddress() { + return (InetSocketAddress) channel.socket().getLocalSocketAddress(); + } + + public DispatchQueue getDispatchQueue() { + return dispatchQueue; + } + + public void setDispatchQueue(DispatchQueue dispatchQueue) { + this.dispatchQueue = dispatchQueue; + } + + public void suspend() { + acceptSource.suspend(); + } + + public void resume() { + acceptSource.resume(); + } + + public void start() throws Exception { + start(null); + } + public void start(Runnable onCompleted) throws Exception { + + try { + channel = ServerSocketChannel.open(); + channel.configureBlocking(false); + channel.socket().bind(bindAddress, backlog); + } catch (IOException e) { + throw new IOException("Failed to bind to server socket: " + bindAddress + " due to: " + e, e); + } + + acceptSource = Dispatch.createSource(channel, SelectionKey.OP_ACCEPT, dispatchQueue); + acceptSource.setEventHandler(new Runnable() { + public void run() { + try { + SocketChannel client = channel.accept(); + while( client!=null ) { + handleSocket(client); + client = channel.accept(); + } + } catch (Exception e) { + listener.onAcceptError(TcpTransportServer.this, e); + } + } + }); + acceptSource.setCancelHandler(new Runnable() { + public void run() { + try { + channel.close(); + } catch (IOException e) { + } + } + }); + acceptSource.resume(); + if( onCompleted!=null ) { + dispatchQueue.execute(onCompleted); + } + } + + public String getBoundAddress() { + try { + return new URI(bindScheme, null, bindAddress.getAddress().getHostAddress(), channel.socket().getLocalPort(), null, null, null).toString(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public String getConnectAddress() { + try { + return new URI(bindScheme, null, resolveHostName(), channel.socket().getLocalPort(), null, null, null).toString(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + + protected String resolveHostName() { + String result; + if (bindAddress.getAddress().isAnyLocalAddress()) { + // make it more human readable and useful, an alternative to 0.0.0.0 + try { + result = InetAddress.getLocalHost().getCanonicalHostName(); + } catch (UnknownHostException e) { + result = "localhost"; + } + } else { + result = bindAddress.getAddress().getCanonicalHostName(); + } + return result; + } + + public void stop() { + stop(null); + } + public void stop(final Runnable onCompleted) { + if( acceptSource.isCanceled() ) { + onCompleted.run(); + } else { + acceptSource.setCancelHandler(new Runnable() { + public void run() { + try { + channel.close(); + } catch (IOException e) { + } + if( onCompleted!=null ) { + onCompleted.run(); + } + } + }); + acceptSource.cancel(); + } + } + + public int getBacklog() { + return backlog; + } + + public void setBacklog(int backlog) { + this.backlog = backlog; + } + + protected final void handleSocket(SocketChannel socket) throws Exception { + HashMap<String, Object> options = new HashMap<String, Object>(); +// options.put("maxInactivityDuration", Long.valueOf(maxInactivityDuration)); +// options.put("maxInactivityDurationInitalDelay", Long.valueOf(maxInactivityDurationInitalDelay)); +// options.put("trace", Boolean.valueOf(trace)); +// options.put("soTimeout", Integer.valueOf(soTimeout)); +// options.put("socketBufferSize", Integer.valueOf(socketBufferSize)); +// options.put("connectionTimeout", Integer.valueOf(connectionTimeout)); +// options.put("dynamicManagement", Boolean.valueOf(dynamicManagement)); +// options.put("startLogging", Boolean.valueOf(startLogging)); + + TcpTransport transport = createTransport(socket, options); + listener.onAccept(this, transport); + } + + protected TcpTransport createTransport(SocketChannel socketChannel, HashMap<String, Object> options) throws Exception { + TcpTransport transport = createTransport(); + transport.connected(socketChannel); + if( options!=null ) { + IntrospectionSupport.setProperties(transport, options); + } + if (transportOptions != null) { + IntrospectionSupport.setProperties(transport, transportOptions); + } + return transport; + } + + protected TcpTransport createTransport() { + return new TcpTransport(); + } + + public void setTransportOption(Map<String, Object> transportOptions) { + this.transportOptions = transportOptions; + } + + /** + * @return pretty print of this + */ + public String toString() { + return getBoundAddress(); + } + +} http://git-wip-us.apache.org/repos/asf/aries-rsa/blob/62d835de/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TransportPool.java ---------------------------------------------------------------------- diff --git a/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TransportPool.java b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TransportPool.java new file mode 100644 index 0000000..493212f --- /dev/null +++ b/provider/fastbin/src/main/java/org/apache/aries/rsa/provider/fastbin/tcp/TransportPool.java @@ -0,0 +1,256 @@ +/** + * Copyright 2005-2015 Red Hat, Inc. + * + * Red Hat licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +package org.apache.aries.rsa.provider.fastbin.tcp; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.aries.rsa.provider.fastbin.io.ProtocolCodec; +import org.apache.aries.rsa.provider.fastbin.io.Service; +import org.apache.aries.rsa.provider.fastbin.io.Transport; +import org.apache.aries.rsa.provider.fastbin.io.TransportListener; +import org.fusesource.hawtdispatch.DispatchQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class TransportPool implements Service { + + protected static final Logger LOGGER = LoggerFactory.getLogger(TransportPool.class); + + public static final int DEFAULT_POOL_SIZE = 2; + + public static final long DEFAULT_EVICTION_DELAY = TimeUnit.MINUTES.toMillis(5); + + protected final String uri; + protected final DispatchQueue queue; + protected final LinkedList<Pair> pending = new LinkedList<Pair>(); + protected final Map<Transport, TransportState> transports = new HashMap<Transport, TransportState>(); + protected AtomicBoolean running = new AtomicBoolean(false); + + protected int poolSize; + protected long evictionDelay; + + public TransportPool(String uri, DispatchQueue queue) { + this(uri, queue, DEFAULT_POOL_SIZE, DEFAULT_EVICTION_DELAY); + } + + public TransportPool(String uri, DispatchQueue queue, int poolSize, long evictionDelay) { + this.uri = uri; + this.queue = queue; + this.poolSize = poolSize; + this.evictionDelay = evictionDelay; + } + + protected abstract Transport createTransport(String uri) throws Exception; + + protected abstract ProtocolCodec createCodec(); + + protected abstract void onCommand(Object command); + + protected abstract void onFailure(Object id, Throwable throwable); + + protected void onDone(Object id) { + for (TransportState state : transports.values()) { + if (state.inflight.remove(id)) { + break; + } + } + } + + public void offer(final Object data, final Object id) { + if (!running.get()) { + throw new IllegalStateException("Transport pool stopped"); + } + queue.execute(new Runnable() { + public void run() { + Transport transport = getIdleTransport(); + if (transport != null) { + doOffer(transport, data, id); + if( transport.full() ) { + transports.get(transport).time = 0L; + } + } else { + pending.add(new Pair(data, id)); + } + } + }); + } + + protected boolean doOffer(Transport transport, Object command, Object id) { + transports.get(transport).inflight.add(id); + return transport.offer(command); + } + + protected Transport getIdleTransport() { + for (Map.Entry<Transport, TransportState> entry : transports.entrySet()) { + if (entry.getValue().time > 0) { + return entry.getKey(); + } + } + if (transports.size() < poolSize) { + try { + startNewTransport(); + } catch (Exception e) { + LOGGER.info("Unable to start new transport", e); + } + } + return null; + } + + public void start() throws Exception { + start(null); + } + + public void start(Runnable onComplete) throws Exception { + running.set(true); + } + + public void stop() { + stop(null); + } + + public void stop(final Runnable onComplete) { + if (running.compareAndSet(true, false)) { + queue.execute(new Runnable() { + public void run() { + final AtomicInteger latch = new AtomicInteger(transports.size()); + final Runnable coutDown = new Runnable() { + public void run() { + if (latch.decrementAndGet() == 0) { + while (!pending.isEmpty()) { + Pair p = pending.removeFirst(); + onFailure(p.id, new IOException("Transport stopped")); + } + onComplete.run(); + } + } + }; + while (!transports.isEmpty()) { + Transport transport = transports.keySet().iterator().next(); + TransportState state = transports.remove(transport); + if (state != null) { + for (Object id : state.inflight) { + onFailure(id, new IOException("Transport stopped")); + } + } + transport.stop(coutDown); + } + } + }); + } else { + onComplete.run(); + } + } + + protected void startNewTransport() throws Exception { + LOGGER.debug("Creating new transport for: {}", this.uri); + Transport transport = createTransport(this.uri); + transport.setDispatchQueue(queue); + transport.setProtocolCodec(createCodec()); + transport.setTransportListener(new Listener()); + transports.put(transport, new TransportState()); + transport.start(); + } + + protected static class Pair { + Object command; + Object id; + + public Pair(Object command, Object id) { + this.command = command; + this.id = id; + } + } + + protected static class TransportState { + long time; + final Set<Object> inflight; + + public TransportState() { + time = 0; + inflight = new HashSet<Object>(); + } + } + + protected class Listener implements TransportListener { + + public void onTransportCommand(Transport transport, Object command) { + TransportPool.this.onCommand(command); + } + + public void onRefill(final Transport transport) { + while (pending.size() > 0 && !transport.full()) { + Pair pair = pending.removeFirst(); + boolean accepted = doOffer(transport, pair.command, pair.id); + assert accepted: "Should have been accepted since the transport was not full"; + } + + if( transport.full() ) { + transports.get(transport).time = 0L; + } else { + final long time = System.currentTimeMillis(); + transports.get(transport).time = time; + if (evictionDelay > 0) { + queue.executeAfter(evictionDelay, TimeUnit.MILLISECONDS, new Runnable() { + public void run() { + TransportState state = transports.get(transport); + if (state != null && state.time == time) { + transports.remove(transport); + transport.stop(); + } + } + }); + } + } + + } + + public void onTransportFailure(Transport transport, IOException error) { + if (!transport.isDisposed()) { + LOGGER.info("Transport failure", error); + TransportState state = transports.remove(transport); + if (state != null) { + for (Object id : state.inflight) { + onFailure(id, error); + } + } + transport.stop(); + if (transports.isEmpty()) { + while (!pending.isEmpty()) { + Pair p = pending.removeFirst(); + onFailure(p.id, error); + } + } + } + } + + public void onTransportConnected(Transport transport) { + transport.resumeRead(); + onRefill(transport); + } + + public void onTransportDisconnected(Transport transport) { + onTransportFailure(transport, new IOException("Transport disconnected")); + } + } +}
