HBASE-16414 Improve performance for RPC encryption with Apache Common Crypto (Colin Ma)
Project: http://git-wip-us.apache.org/repos/asf/hbase/repo Commit: http://git-wip-us.apache.org/repos/asf/hbase/commit/0ae211eb Tree: http://git-wip-us.apache.org/repos/asf/hbase/tree/0ae211eb Diff: http://git-wip-us.apache.org/repos/asf/hbase/diff/0ae211eb Branch: refs/heads/master Commit: 0ae211eb399e5524196d89af8eac1941c8b61b60 Parents: d3decaa Author: Ramkrishna <[email protected]> Authored: Fri Oct 21 16:02:39 2016 +0530 Committer: Ramkrishna <[email protected]> Committed: Fri Oct 21 16:02:39 2016 +0530 ---------------------------------------------------------------------- .../hadoop/hbase/ipc/BlockingRpcConnection.java | 64 +- .../hadoop/hbase/ipc/NettyRpcConnection.java | 49 +- .../apache/hadoop/hbase/ipc/RpcConnection.java | 13 + .../hbase/ipc/UnsupportedCryptoException.java | 38 + .../hbase/security/CryptoAESUnwrapHandler.java | 47 + .../hbase/security/CryptoAESWrapHandler.java | 98 + .../hadoop/hbase/security/EncryptionUtil.java | 27 + .../hbase/security/HBaseSaslRpcClient.java | 124 +- .../NettyHBaseRpcConnectionHeaderHandler.java | 99 + .../hbase/security/NettyHBaseSaslRpcClient.java | 5 + .../NettyHBaseSaslRpcClientHandler.java | 27 +- .../hbase/security/TestHBaseSaslRpcClient.java | 6 +- hbase-common/pom.xml | 4 + .../hadoop/hbase/io/crypto/aes/CryptoAES.java | 241 ++ .../shaded/protobuf/generated/RPCProtos.java | 3419 ++++++++++++++---- .../src/main/protobuf/RPC.proto | 19 + .../org/apache/hadoop/hbase/ipc/RpcServer.java | 164 +- .../hadoop/hbase/security/TestSecureIPC.java | 41 +- pom.xml | 12 + 19 files changed, 3666 insertions(+), 831 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java index 5ae5508..15eb10c 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/BlockingRpcConnection.java @@ -60,12 +60,14 @@ import org.apache.hadoop.hbase.exceptions.ConnectionClosingException; import org.apache.hadoop.hbase.io.ByteArrayOutputStream; import org.apache.hadoop.hbase.ipc.HBaseRpcController.CancellationCallback; import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.CellBlockMeta; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ConnectionHeader; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ExceptionResponse; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ResponseHeader; import org.apache.hadoop.hbase.security.HBaseSaslRpcClient; +import org.apache.hadoop.hbase.security.SaslUtil; import org.apache.hadoop.hbase.security.SaslUtil.QualityOfProtection; import org.apache.hadoop.hbase.util.EnvironmentEdgeManager; import org.apache.hadoop.hbase.util.ExceptionUtil; @@ -111,6 +113,8 @@ class BlockingRpcConnection extends RpcConnection implements Runnable { private byte[] connectionHeaderWithLength; + private boolean waitingConnectionHeaderResponse = false; + /** * If the client wants to interrupt its calls easily (i.e. call Thread#interrupt), it gets into a * java issue: an interruption during a write closes the socket/channel. A way to avoid this is to @@ -349,7 +353,8 @@ class BlockingRpcConnection extends RpcConnection implements Runnable { throws IOException { saslRpcClient = new HBaseSaslRpcClient(authMethod, token, serverPrincipal, this.rpcClient.fallbackAllowed, this.rpcClient.conf.get("hbase.rpc.protection", - QualityOfProtection.AUTHENTICATION.name().toLowerCase(Locale.ROOT))); + QualityOfProtection.AUTHENTICATION.name().toLowerCase(Locale.ROOT)), + this.rpcClient.conf.getBoolean(CRYPTO_AES_ENABLED_KEY, CRYPTO_AES_ENABLED_DEFAULT)); return saslRpcClient.saslConnect(in2, out2); } @@ -462,8 +467,8 @@ class BlockingRpcConnection extends RpcConnection implements Runnable { } if (continueSasl) { // Sasl connect is successful. Let's set up Sasl i/o streams. - inStream = saslRpcClient.getInputStream(inStream); - outStream = saslRpcClient.getOutputStream(outStream); + inStream = saslRpcClient.getInputStream(); + outStream = saslRpcClient.getOutputStream(); } else { // fall back to simple auth because server told us so. // do not change authMethod and useSasl here, we should start from secure when @@ -474,6 +479,9 @@ class BlockingRpcConnection extends RpcConnection implements Runnable { this.out = new DataOutputStream(new BufferedOutputStream(outStream)); // Now write out the connection header writeConnectionHeader(); + // process the response from server for connection header if necessary + processResponseForConnectionHeader(); + break; } } catch (Throwable t) { @@ -511,10 +519,60 @@ class BlockingRpcConnection extends RpcConnection implements Runnable { * Write the connection header. */ private void writeConnectionHeader() throws IOException { + boolean isCryptoAesEnable = false; + // check if Crypto AES is enabled + if (saslRpcClient != null) { + boolean saslEncryptionEnabled = SaslUtil.QualityOfProtection.PRIVACY. + getSaslQop().equalsIgnoreCase(saslRpcClient.getSaslQOP()); + isCryptoAesEnable = saslEncryptionEnabled && conf.getBoolean( + CRYPTO_AES_ENABLED_KEY, CRYPTO_AES_ENABLED_DEFAULT); + } + + // if Crypto AES is enabled, set transformation and negotiate with server + if (isCryptoAesEnable) { + waitingConnectionHeaderResponse = true; + } this.out.write(connectionHeaderWithLength); this.out.flush(); } + private void processResponseForConnectionHeader() throws IOException { + // if no response excepted, return + if (!waitingConnectionHeaderResponse) return; + try { + // read the ConnectionHeaderResponse from server + int len = this.in.readInt(); + byte[] buff = new byte[len]; + int readSize = this.in.read(buff); + if (LOG.isDebugEnabled()) { + LOG.debug("Length of response for connection header:" + readSize); + } + + RPCProtos.ConnectionHeaderResponse connectionHeaderResponse = + RPCProtos.ConnectionHeaderResponse.parseFrom(buff); + + // Get the CryptoCipherMeta, update the HBaseSaslRpcClient for Crypto Cipher + if (connectionHeaderResponse.hasCryptoCipherMeta()) { + negotiateCryptoAes(connectionHeaderResponse.getCryptoCipherMeta()); + } + waitingConnectionHeaderResponse = false; + } catch (SocketTimeoutException ste) { + LOG.fatal("Can't get the connection header response for rpc timeout, please check if" + + " server has the correct configuration to support the additional function.", ste); + // timeout when waiting the connection header response, ignore the additional function + throw new IOException("Timeout while waiting connection header response", ste); + } + } + + private void negotiateCryptoAes(RPCProtos.CryptoCipherMeta cryptoCipherMeta) + throws IOException { + // initilize the Crypto AES with CryptoCipherMeta + saslRpcClient.initCryptoCipher(cryptoCipherMeta, this.rpcClient.conf); + // reset the inputStream/outputStream for Crypto AES encryption + this.in = new DataInputStream(new BufferedInputStream(saslRpcClient.getInputStream())); + this.out = new DataOutputStream(new BufferedOutputStream(saslRpcClient.getOutputStream())); + } + private void tracedWriteRequest(Call call) throws IOException { try (TraceScope ignored = Trace.startSpan("RpcClientImpl.tracedWriteRequest", call.span)) { writeRequest(call); http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java index ce5adda..47d7234 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/NettyRpcConnection.java @@ -22,6 +22,8 @@ import static org.apache.hadoop.hbase.ipc.CallEvent.Type.TIMEOUT; import static org.apache.hadoop.hbase.ipc.IPCUtil.setCancelled; import static org.apache.hadoop.hbase.ipc.IPCUtil.toIOE; +import io.netty.handler.timeout.ReadTimeoutHandler; +import org.apache.hadoop.hbase.security.NettyHBaseRpcConnectionHeaderHandler; import org.apache.hadoop.hbase.shaded.com.google.protobuf.RpcCallback; import io.netty.bootstrap.Bootstrap; @@ -55,7 +57,6 @@ import org.apache.hadoop.hbase.ipc.HBaseRpcController.CancellationCallback; import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ConnectionHeader; import org.apache.hadoop.hbase.security.NettyHBaseSaslRpcClientHandler; import org.apache.hadoop.hbase.security.SaslChallengeDecoder; -import org.apache.hadoop.hbase.security.SaslUtil.QualityOfProtection; import org.apache.hadoop.hbase.util.Threads; import org.apache.hadoop.security.UserGroupInformation; @@ -130,8 +131,7 @@ class NettyRpcConnection extends RpcConnection { } } - private void established(Channel ch) { - ch.write(connectionHeaderWithLength.retainedDuplicate()); + private void established(Channel ch) throws IOException { ChannelPipeline p = ch.pipeline(); String addBeforeHandler = p.context(BufferCallBeforeInitHandler.class).name(); p.addBefore(addBeforeHandler, null, @@ -188,11 +188,10 @@ class NettyRpcConnection extends RpcConnection { return; } Promise<Boolean> saslPromise = ch.eventLoop().newPromise(); - ChannelHandler saslHandler; + final NettyHBaseSaslRpcClientHandler saslHandler; try { saslHandler = new NettyHBaseSaslRpcClientHandler(saslPromise, ticket, authMethod, token, - serverPrincipal, rpcClient.fallbackAllowed, this.rpcClient.conf.get( - "hbase.rpc.protection", QualityOfProtection.AUTHENTICATION.name().toLowerCase())); + serverPrincipal, rpcClient.fallbackAllowed, this.rpcClient.conf); } catch (IOException e) { failInit(ch, e); return; @@ -206,7 +205,41 @@ class NettyRpcConnection extends RpcConnection { ChannelPipeline p = ch.pipeline(); p.remove(SaslChallengeDecoder.class); p.remove(NettyHBaseSaslRpcClientHandler.class); - established(ch); + + // check if negotiate with server for connection header is necessary + if (saslHandler.isNeedProcessConnectionHeader()) { + Promise<Boolean> connectionHeaderPromise = ch.eventLoop().newPromise(); + // create the handler to handle the connection header + ChannelHandler chHandler = new NettyHBaseRpcConnectionHeaderHandler( + connectionHeaderPromise, conf, connectionHeaderWithLength); + + // add ReadTimeoutHandler to deal with server doesn't response connection header + // because of the different configuration in client side and server side + p.addFirst(new ReadTimeoutHandler( + RpcClient.DEFAULT_SOCKET_TIMEOUT_READ, TimeUnit.MILLISECONDS)); + p.addLast(chHandler); + connectionHeaderPromise.addListener(new FutureListener<Boolean>() { + @Override + public void operationComplete(Future<Boolean> future) throws Exception { + if (future.isSuccess()) { + ChannelPipeline p = ch.pipeline(); + p.remove(ReadTimeoutHandler.class); + p.remove(NettyHBaseRpcConnectionHeaderHandler.class); + // don't send connection header, NettyHbaseRpcConnectionHeaderHandler + // sent it already + established(ch); + } else { + final Throwable error = future.cause(); + scheduleRelogin(error); + failInit(ch, toIOE(error)); + } + } + }); + } else { + // send the connection header to server + ch.write(connectionHeaderWithLength.retainedDuplicate()); + established(ch); + } } else { final Throwable error = future.cause(); scheduleRelogin(error); @@ -240,6 +273,8 @@ class NettyRpcConnection extends RpcConnection { if (useSasl) { saslNegotiate(ch); } else { + // send the connection header to server + ch.write(connectionHeaderWithLength.retainedDuplicate()); established(ch); } } http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcConnection.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcConnection.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcConnection.java index a60528e..c9002e7 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcConnection.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/RpcConnection.java @@ -72,6 +72,12 @@ abstract class RpcConnection { protected final HashedWheelTimer timeoutTimer; + protected final Configuration conf; + + protected static String CRYPTO_AES_ENABLED_KEY = "hbase.rpc.crypto.encryption.aes.enabled"; + + protected static boolean CRYPTO_AES_ENABLED_DEFAULT = false; + // the last time we were picked up from connection pool. protected long lastTouched; @@ -84,6 +90,7 @@ abstract class RpcConnection { this.timeoutTimer = timeoutTimer; this.codec = codec; this.compressor = compressor; + this.conf = conf; UserGroupInformation ticket = remoteId.getTicket().getUGI(); SecurityInfo securityInfo = SecurityInfo.getInfo(remoteId.getServiceName()); @@ -224,6 +231,12 @@ abstract class RpcConnection { builder.setCellBlockCompressorClass(this.compressor.getClass().getCanonicalName()); } builder.setVersionInfo(ProtobufUtil.getVersionInfo()); + boolean isCryptoAESEnable = conf.getBoolean(CRYPTO_AES_ENABLED_KEY, CRYPTO_AES_ENABLED_DEFAULT); + // if Crypto AES enable, setup Cipher transformation + if (isCryptoAESEnable) { + builder.setRpcCryptoCipherTransformation( + conf.get("hbase.rpc.crypto.encryption.aes.cipher.transform", "AES/CTR/NoPadding")); + } return builder.build(); } http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/UnsupportedCryptoException.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/UnsupportedCryptoException.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/UnsupportedCryptoException.java new file mode 100644 index 0000000..12e4a7a --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/ipc/UnsupportedCryptoException.java @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.ipc; + +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.classification.InterfaceStability; + [email protected] [email protected] +public class UnsupportedCryptoException extends FatalConnectionException { + public UnsupportedCryptoException() { + super(); + } + + public UnsupportedCryptoException(String msg) { + super(msg); + } + + public UnsupportedCryptoException(String msg, Throwable t) { + super(msg, t); + } +} http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java new file mode 100644 index 0000000..31abeba --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESUnwrapHandler.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; + +/** + * Unwrap messages with Crypto AES. Should be placed after a + * {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder} + */ [email protected] +public class CryptoAESUnwrapHandler extends SimpleChannelInboundHandler<ByteBuf> { + + private final CryptoAES cryptoAES; + + public CryptoAESUnwrapHandler(CryptoAES cryptoAES) { + this.cryptoAES = cryptoAES; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + byte[] bytes = new byte[msg.readableBytes()]; + msg.readBytes(bytes); + ctx.fireChannelRead(Unpooled.wrappedBuffer(cryptoAES.unwrap(bytes, 0, bytes.length))); + } +} http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java new file mode 100644 index 0000000..6c74ed8 --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/CryptoAESWrapHandler.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.CoalescingBufferQueue; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.PromiseCombiner; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; + +import java.io.IOException; + +/** + * wrap messages with Crypto AES. + */ [email protected] +public class CryptoAESWrapHandler extends ChannelOutboundHandlerAdapter { + + private final CryptoAES cryptoAES; + + private CoalescingBufferQueue queue; + + public CryptoAESWrapHandler(CryptoAES cryptoAES) { + this.cryptoAES = cryptoAES; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + queue = new CoalescingBufferQueue(ctx.channel()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (msg instanceof ByteBuf) { + queue.add((ByteBuf) msg, promise); + } else { + ctx.write(msg, promise); + } + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + if (queue.isEmpty()) { + return; + } + ByteBuf buf = null; + try { + ChannelPromise promise = ctx.newPromise(); + int readableBytes = queue.readableBytes(); + buf = queue.remove(readableBytes, promise); + byte[] bytes = new byte[readableBytes]; + buf.readBytes(bytes); + byte[] wrapperBytes = cryptoAES.wrap(bytes, 0, bytes.length); + ChannelPromise lenPromise = ctx.newPromise(); + ctx.write(ctx.alloc().buffer(4).writeInt(wrapperBytes.length), lenPromise); + ChannelPromise contentPromise = ctx.newPromise(); + ctx.write(Unpooled.wrappedBuffer(wrapperBytes), contentPromise); + PromiseCombiner combiner = new PromiseCombiner(); + combiner.addAll(lenPromise, contentPromise); + combiner.finish(promise); + ctx.flush(); + } finally { + if (buf != null) { + ReferenceCountUtil.safeRelease(buf); + } + } + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + if (!queue.isEmpty()) { + queue.releaseAndFailAll(new IOException("Connection closed")); + } + ctx.close(promise); + } +} http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/EncryptionUtil.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/EncryptionUtil.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/EncryptionUtil.java index b5009e0..c7e0be7 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/EncryptionUtil.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/EncryptionUtil.java @@ -23,9 +23,11 @@ import java.io.IOException; import java.security.Key; import java.security.KeyException; import java.security.SecureRandom; +import java.util.Properties; import javax.crypto.spec.SecretKeySpec; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; @@ -37,6 +39,8 @@ import org.apache.hadoop.hbase.io.crypto.Cipher; import org.apache.hadoop.hbase.io.crypto.Encryption; import org.apache.hadoop.hbase.shaded.com.google.protobuf.UnsafeByteOperations; import org.apache.hadoop.hbase.shaded.protobuf.generated.EncryptionProtos; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; +import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; import org.apache.hadoop.hbase.util.Bytes; /** @@ -255,4 +259,27 @@ public final class EncryptionUtil { } return key; } + + /** + * Helper to create an instance of CryptoAES. + * + * @param conf The current configuration. + * @param cryptoCipherMeta The metadata for create CryptoAES. + * @return The instance of CryptoAES. + * @throws IOException if create CryptoAES failed + */ + public static CryptoAES createCryptoAES(RPCProtos.CryptoCipherMeta cryptoCipherMeta, + Configuration conf) throws IOException { + Properties properties = new Properties(); + // the property for cipher class + properties.setProperty(CryptoCipherFactory.CLASSES_KEY, + conf.get("hbase.rpc.crypto.encryption.aes.cipher.class", + "org.apache.commons.crypto.cipher.JceCipher")); + // create SaslAES for client + return new CryptoAES(cryptoCipherMeta.getTransformation(), properties, + cryptoCipherMeta.getInKey().toByteArray(), + cryptoCipherMeta.getOutKey().toByteArray(), + cryptoCipherMeta.getInIv().toByteArray(), + cryptoCipherMeta.getOutIv().toByteArray()); + } } http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java index 3f43f7f..e644cb9 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/HBaseSaslRpcClient.java @@ -22,16 +22,22 @@ import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; +import java.io.FilterInputStream; +import java.io.FilterOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.nio.ByteBuffer; import javax.security.sasl.Sasl; import javax.security.sasl.SaslException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; import org.apache.hadoop.io.WritableUtils; import org.apache.hadoop.ipc.RemoteException; import org.apache.hadoop.security.SaslInputStream; @@ -47,6 +53,13 @@ import org.apache.hadoop.security.token.TokenIdentifier; public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { private static final Log LOG = LogFactory.getLog(HBaseSaslRpcClient.class); + private boolean cryptoAesEnable; + private CryptoAES cryptoAES; + private InputStream saslInputStream; + private InputStream cryptoInputStream; + private OutputStream saslOutputStream; + private OutputStream cryptoOutputStream; + private boolean initStreamForCrypto; public HBaseSaslRpcClient(AuthMethod method, Token<? extends TokenIdentifier> token, String serverPrincipal, boolean fallbackAllowed) throws IOException { @@ -54,8 +67,10 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { } public HBaseSaslRpcClient(AuthMethod method, Token<? extends TokenIdentifier> token, - String serverPrincipal, boolean fallbackAllowed, String rpcProtection) throws IOException { + String serverPrincipal, boolean fallbackAllowed, String rpcProtection, + boolean initStreamForCrypto) throws IOException { super(method, token, serverPrincipal, fallbackAllowed, rpcProtection); + this.initStreamForCrypto = initStreamForCrypto; } private static void readStatus(DataInputStream inStream) throws IOException { @@ -133,6 +148,18 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { LOG.debug("SASL client context established. Negotiated QoP: " + saslClient.getNegotiatedProperty(Sasl.QOP)); } + // initial the inputStream, outputStream for both Sasl encryption + // and Crypto AES encryption if necessary + // if Crypto AES encryption enabled, the saslInputStream/saslOutputStream is + // only responsible for connection header negotiation, + // cryptoInputStream/cryptoOutputStream is responsible for rpc encryption with Crypto AES + saslInputStream = new SaslInputStream(inS, saslClient); + saslOutputStream = new SaslOutputStream(outS, saslClient); + if (initStreamForCrypto) { + cryptoInputStream = new WrappedInputStream(inS); + cryptoOutputStream = new WrappedOutputStream(outS); + } + return true; } catch (IOException e) { try { @@ -144,29 +171,112 @@ public class HBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { } } + public String getSaslQOP() { + return (String) saslClient.getNegotiatedProperty(Sasl.QOP); + } + + public void initCryptoCipher(RPCProtos.CryptoCipherMeta cryptoCipherMeta, + Configuration conf) throws IOException { + // create SaslAES for client + cryptoAES = EncryptionUtil.createCryptoAES(cryptoCipherMeta, conf); + cryptoAesEnable = true; + } + /** * Get a SASL wrapped InputStream. Can be called only after saslConnect() has been called. - * @param in the InputStream to wrap * @return a SASL wrapped InputStream * @throws IOException */ - public InputStream getInputStream(InputStream in) throws IOException { + public InputStream getInputStream() throws IOException { if (!saslClient.isComplete()) { throw new IOException("Sasl authentication exchange hasn't completed yet"); } - return new SaslInputStream(in, saslClient); + // If Crypto AES is enabled, return cryptoInputStream which unwrap the data with Crypto AES. + if (cryptoAesEnable && cryptoInputStream != null) { + return cryptoInputStream; + } + return saslInputStream; + } + + class WrappedInputStream extends FilterInputStream { + private ByteBuffer unwrappedRpcBuffer = ByteBuffer.allocate(0); + public WrappedInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + public int read() throws IOException { + byte[] b = new byte[1]; + int n = read(b, 0, 1); + return (n != -1) ? b[0] : -1; + } + + @Override + public int read(byte b[]) throws IOException { + return read(b, 0, b.length); + } + + @Override + public synchronized int read(byte[] buf, int off, int len) throws IOException { + // fill the buffer with the next RPC message + if (unwrappedRpcBuffer.remaining() == 0) { + readNextRpcPacket(); + } + // satisfy as much of the request as possible + int readLen = Math.min(len, unwrappedRpcBuffer.remaining()); + unwrappedRpcBuffer.get(buf, off, readLen); + return readLen; + } + + // unwrap messages with Crypto AES + private void readNextRpcPacket() throws IOException { + LOG.debug("reading next wrapped RPC packet"); + DataInputStream dis = new DataInputStream(in); + int rpcLen = dis.readInt(); + byte[] rpcBuf = new byte[rpcLen]; + dis.readFully(rpcBuf); + + // unwrap with Crypto AES + rpcBuf = cryptoAES.unwrap(rpcBuf, 0, rpcBuf.length); + if (LOG.isDebugEnabled()) { + LOG.debug("unwrapping token of length:" + rpcBuf.length); + } + unwrappedRpcBuffer = ByteBuffer.wrap(rpcBuf); + } } /** * Get a SASL wrapped OutputStream. Can be called only after saslConnect() has been called. - * @param out the OutputStream to wrap * @return a SASL wrapped OutputStream * @throws IOException */ - public OutputStream getOutputStream(OutputStream out) throws IOException { + public OutputStream getOutputStream() throws IOException { if (!saslClient.isComplete()) { throw new IOException("Sasl authentication exchange hasn't completed yet"); } - return new SaslOutputStream(out, saslClient); + // If Crypto AES is enabled, return cryptoOutputStream which wrap the data with Crypto AES. + if (cryptoAesEnable && cryptoOutputStream != null) { + return cryptoOutputStream; + } + return saslOutputStream; + } + + class WrappedOutputStream extends FilterOutputStream { + public WrappedOutputStream(OutputStream out) throws IOException { + super(out); + } + @Override + public void write(byte[] buf, int off, int len) throws IOException { + if (LOG.isDebugEnabled()) { + LOG.debug("wrapping token of length:" + len); + } + + // wrap with Crypto AES + byte[] wrapped = cryptoAES.wrap(buf, off, len); + DataOutputStream dob = new DataOutputStream(out); + dob.writeInt(wrapped.length); + dob.write(wrapped, 0, wrapped.length); + dob.flush(); + } } } http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java new file mode 100644 index 0000000..5608874 --- /dev/null +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseRpcConnectionHeaderHandler.java @@ -0,0 +1,99 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.security; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.concurrent.Promise; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.io.crypto.aes.CryptoAES; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos; + +/** + * Implement logic to deal with the rpc connection header. + */ [email protected] +public class NettyHBaseRpcConnectionHeaderHandler extends SimpleChannelInboundHandler<ByteBuf> { + + private final Promise<Boolean> saslPromise; + + private final Configuration conf; + + private final ByteBuf connectionHeaderWithLength; + + public NettyHBaseRpcConnectionHeaderHandler(Promise<Boolean> saslPromise, Configuration conf, + ByteBuf connectionHeaderWithLength) { + this.saslPromise = saslPromise; + this.conf = conf; + this.connectionHeaderWithLength = connectionHeaderWithLength; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + // read the ConnectionHeaderResponse from server + int len = msg.readInt(); + byte[] buff = new byte[len]; + msg.readBytes(buff); + + RPCProtos.ConnectionHeaderResponse connectionHeaderResponse = + RPCProtos.ConnectionHeaderResponse.parseFrom(buff); + + // Get the CryptoCipherMeta, update the HBaseSaslRpcClient for Crypto Cipher + if (connectionHeaderResponse.hasCryptoCipherMeta()) { + CryptoAES cryptoAES = EncryptionUtil.createCryptoAES( + connectionHeaderResponse.getCryptoCipherMeta(), conf); + // replace the Sasl handler with Crypto AES handler + setupCryptoAESHandler(ctx.pipeline(), cryptoAES); + } + + saslPromise.setSuccess(true); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + try { + // send the connection header to server first + ctx.writeAndFlush(connectionHeaderWithLength.retainedDuplicate()); + } catch (Exception e) { + // the exception thrown by handlerAdded will not be passed to the exceptionCaught below + // because netty will remove a handler if handlerAdded throws an exception. + exceptionCaught(ctx, e); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + saslPromise.tryFailure(cause); + } + + /** + * Remove handlers for sasl encryption and add handlers for Crypto AES encryption + */ + private void setupCryptoAESHandler(ChannelPipeline p, CryptoAES cryptoAES) { + p.remove(SaslWrapHandler.class); + p.remove(SaslUnwrapHandler.class); + String lengthDecoder = p.context(LengthFieldBasedFrameDecoder.class).name(); + p.addAfter(lengthDecoder, null, new CryptoAESUnwrapHandler(cryptoAES)); + p.addAfter(lengthDecoder, null, new CryptoAESWrapHandler(cryptoAES)); + } +} http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java index f624608..9ae31a4 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClient.java @@ -47,6 +47,7 @@ public class NettyHBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { if (LOG.isDebugEnabled()) { LOG.debug("SASL client context established. Negotiated QoP: " + qop); } + if (qop == null || "auth".equalsIgnoreCase(qop)) { return; } @@ -55,4 +56,8 @@ public class NettyHBaseSaslRpcClient extends AbstractHBaseSaslRpcClient { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4), new SaslUnwrapHandler(saslClient)); } + + public String getSaslQOP() { + return (String) saslClient.getNegotiatedProperty(Sasl.QOP); + } } http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java index 50609b4..4525aef 100644 --- a/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java +++ b/hbase-client/src/main/java/org/apache/hadoop/hbase/security/NettyHBaseSaslRpcClientHandler.java @@ -27,6 +27,7 @@ import java.security.PrivilegedExceptionAction; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.hbase.classification.InterfaceAudience; import org.apache.hadoop.hbase.ipc.FallbackDisallowedException; import org.apache.hadoop.security.UserGroupInformation; @@ -47,17 +48,25 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler< private final NettyHBaseSaslRpcClient saslRpcClient; + private final Configuration conf; + + // flag to mark if Crypto AES encryption is enable + private boolean needProcessConnectionHeader = false; + /** * @param saslPromise {@code true} if success, {@code false} if server tells us to fallback to * simple. */ public NettyHBaseSaslRpcClientHandler(Promise<Boolean> saslPromise, UserGroupInformation ugi, AuthMethod method, Token<? extends TokenIdentifier> token, String serverPrincipal, - boolean fallbackAllowed, String rpcProtection) throws IOException { + boolean fallbackAllowed, Configuration conf) + throws IOException { this.saslPromise = saslPromise; this.ugi = ugi; + this.conf = conf; this.saslRpcClient = new NettyHBaseSaslRpcClient(method, token, serverPrincipal, - fallbackAllowed, rpcProtection); + fallbackAllowed, conf.get( + "hbase.rpc.protection", SaslUtil.QualityOfProtection.AUTHENTICATION.name().toLowerCase())); } private void writeResponse(ChannelHandlerContext ctx, byte[] response) { @@ -72,10 +81,24 @@ public class NettyHBaseSaslRpcClientHandler extends SimpleChannelInboundHandler< if (!saslRpcClient.isComplete()) { return; } + saslRpcClient.setupSaslHandler(ctx.pipeline()); + setCryptoAESOption(); + saslPromise.setSuccess(true); } + private void setCryptoAESOption() { + boolean saslEncryptionEnabled = SaslUtil.QualityOfProtection.PRIVACY. + getSaslQop().equalsIgnoreCase(saslRpcClient.getSaslQOP()); + needProcessConnectionHeader = saslEncryptionEnabled && conf.getBoolean( + "hbase.rpc.crypto.encryption.aes.enabled", false); + } + + public boolean isNeedProcessConnectionHeader() { + return needProcessConnectionHeader; + } + @Override public void handlerAdded(ChannelHandlerContext ctx) { try { http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java ---------------------------------------------------------------------- diff --git a/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java b/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java index 12b3661..7573a78 100644 --- a/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java +++ b/hbase-client/src/test/java/org/apache/hadoop/hbase/security/TestHBaseSaslRpcClient.java @@ -88,7 +88,7 @@ public class TestHBaseSaslRpcClient { DEFAULT_USER_PASSWORD); for (SaslUtil.QualityOfProtection qop : SaslUtil.QualityOfProtection.values()) { String negotiatedQop = new HBaseSaslRpcClient(AuthMethod.DIGEST, token, - "principal/[email protected]", false, qop.name()) { + "principal/[email protected]", false, qop.name(), false) { public String getQop() { return saslProps.get(Sasl.QOP); } @@ -211,14 +211,14 @@ public class TestHBaseSaslRpcClient { }; try { - rpcClient.getInputStream(Mockito.mock(InputStream.class)); + rpcClient.getInputStream(); } catch(IOException ex) { //Sasl authentication exchange hasn't completed yet inState = true; } try { - rpcClient.getOutputStream(Mockito.mock(OutputStream.class)); + rpcClient.getOutputStream(); } catch(IOException ex) { //Sasl authentication exchange hasn't completed yet outState = true; http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-common/pom.xml ---------------------------------------------------------------------- diff --git a/hbase-common/pom.xml b/hbase-common/pom.xml index c5f5a81..03f1682 100644 --- a/hbase-common/pom.xml +++ b/hbase-common/pom.xml @@ -275,6 +275,10 @@ <groupId>org.apache.htrace</groupId> <artifactId>htrace-core</artifactId> </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-crypto</artifactId> + </dependency> </dependencies> <profiles> http://git-wip-us.apache.org/repos/asf/hbase/blob/0ae211eb/hbase-common/src/main/java/org/apache/hadoop/hbase/io/crypto/aes/CryptoAES.java ---------------------------------------------------------------------- diff --git a/hbase-common/src/main/java/org/apache/hadoop/hbase/io/crypto/aes/CryptoAES.java b/hbase-common/src/main/java/org/apache/hadoop/hbase/io/crypto/aes/CryptoAES.java new file mode 100644 index 0000000..57ce2e1 --- /dev/null +++ b/hbase-common/src/main/java/org/apache/hadoop/hbase/io/crypto/aes/CryptoAES.java @@ -0,0 +1,241 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.io.crypto.aes; + +import org.apache.commons.crypto.cipher.CryptoCipher; +import org.apache.commons.crypto.utils.Utils; +import org.apache.hadoop.hbase.classification.InterfaceAudience; +import org.apache.hadoop.hbase.classification.InterfaceStability; + +import javax.crypto.Cipher; +import javax.crypto.Mac; +import javax.crypto.SecretKey; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import javax.security.sasl.SaslException; +import java.io.IOException; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Properties; + +/** + * AES encryption and decryption. + */ [email protected] [email protected] +public class CryptoAES { + + private final CryptoCipher encryptor; + private final CryptoCipher decryptor; + + private final Integrity integrity; + + public CryptoAES(String transformation, Properties properties, + byte[] inKey, byte[] outKey, byte[] inIv, byte[] outIv) throws IOException { + checkTransformation(transformation); + // encryptor + encryptor = Utils.getCipherInstance(transformation, properties); + try { + SecretKeySpec outKEYSpec = new SecretKeySpec(outKey, "AES"); + IvParameterSpec outIVSpec = new IvParameterSpec(outIv); + encryptor.init(Cipher.ENCRYPT_MODE, outKEYSpec, outIVSpec); + } catch (InvalidKeyException | InvalidAlgorithmParameterException e) { + throw new IOException("Failed to initialize encryptor", e); + } + + // decryptor + decryptor = Utils.getCipherInstance(transformation, properties); + try { + SecretKeySpec inKEYSpec = new SecretKeySpec(inKey, "AES"); + IvParameterSpec inIVSpec = new IvParameterSpec(inIv); + decryptor.init(Cipher.DECRYPT_MODE, inKEYSpec, inIVSpec); + } catch (InvalidKeyException | InvalidAlgorithmParameterException e) { + throw new IOException("Failed to initialize decryptor", e); + } + + integrity = new Integrity(outKey, inKey); + } + + /** + * Encrypts input data. The result composes of (msg, padding if needed, mac) and sequence num. + * @param data the input byte array + * @param offset the offset in input where the input starts + * @param len the input length + * @return the new encrypted byte array. + * @throws SaslException if error happens + */ + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + // mac + byte[] mac = integrity.getHMAC(data, offset, len); + integrity.incMySeqNum(); + + // encrypt + byte[] encrypted = new byte[len + 10]; + try { + int n = encryptor.update(data, offset, len, encrypted, 0); + encryptor.update(mac, 0, 10, encrypted, n); + } catch (ShortBufferException sbe) { + // this should not happen + throw new SaslException("Error happens during encrypt data", sbe); + } + + // append seqNum used for mac + byte[] wrapped = new byte[encrypted.length + 4]; + System.arraycopy(encrypted, 0, wrapped, 0, encrypted.length); + System.arraycopy(integrity.getSeqNum(), 0, wrapped, encrypted.length, 4); + + return wrapped; + } + + /** + * Decrypts input data. The input composes of (msg, padding if needed, mac) and sequence num. + * The result is msg. + * @param data the input byte array + * @param offset the offset in input where the input starts + * @param len the input length + * @return the new decrypted byte array. + * @throws SaslException if error happens + */ + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + // get plaintext and seqNum + byte[] decrypted = new byte[len - 4]; + byte[] peerSeqNum = new byte[4]; + try { + decryptor.update(data, offset, len - 4, decrypted, 0); + } catch (ShortBufferException sbe) { + // this should not happen + throw new SaslException("Error happens during decrypt data", sbe); + } + System.arraycopy(data, offset + decrypted.length, peerSeqNum, 0, 4); + + // get msg and mac + byte[] msg = new byte[decrypted.length - 10]; + byte[] mac = new byte[10]; + System.arraycopy(decrypted, 0, msg, 0, msg.length); + System.arraycopy(decrypted, msg.length, mac, 0, 10); + + // check mac integrity and msg sequence + if (!integrity.compareHMAC(mac, peerSeqNum, msg, 0, msg.length)) { + throw new SaslException("Unmatched MAC"); + } + if (!integrity.comparePeerSeqNum(peerSeqNum)) { + throw new SaslException("Out of order sequencing of messages. Got: " + integrity.byteToInt + (peerSeqNum) + " Expected: " + integrity.peerSeqNum); + } + integrity.incPeerSeqNum(); + + return msg; + } + + private void checkTransformation(String transformation) throws IOException { + if ("AES/CTR/NoPadding".equalsIgnoreCase(transformation)) { + return; + } + throw new IOException("AES cipher transformation is not supported: " + transformation); + } + + /** + * Helper class for providing integrity protection. + */ + private static class Integrity { + + private int mySeqNum = 0; + private int peerSeqNum = 0; + private byte[] seqNum = new byte[4]; + + private byte[] myKey; + private byte[] peerKey; + + Integrity(byte[] outKey, byte[] inKey) throws IOException { + myKey = outKey; + peerKey = inKey; + } + + byte[] getHMAC(byte[] msg, int start, int len) throws SaslException { + intToByte(mySeqNum); + return calculateHMAC(myKey, seqNum, msg, start, len); + } + + boolean compareHMAC(byte[] expectedHMAC, byte[] peerSeqNum, byte[] msg, int start, + int len) throws SaslException { + byte[] mac = calculateHMAC(peerKey, peerSeqNum, msg, start, len); + return Arrays.equals(mac, expectedHMAC); + } + + boolean comparePeerSeqNum(byte[] peerSeqNum) { + return this.peerSeqNum == byteToInt(peerSeqNum); + } + + byte[] getSeqNum() { + return seqNum; + } + + void incMySeqNum() { + mySeqNum ++; + } + + void incPeerSeqNum() { + peerSeqNum ++; + } + + private byte[] calculateHMAC(byte[] key, byte[] seqNum, byte[] msg, int start, + int len) throws SaslException { + byte[] seqAndMsg = new byte[4+len]; + System.arraycopy(seqNum, 0, seqAndMsg, 0, 4); + System.arraycopy(msg, start, seqAndMsg, 4, len); + + try { + SecretKey keyKi = new SecretKeySpec(key, "HmacMD5"); + Mac m = Mac.getInstance("HmacMD5"); + m.init(keyKi); + m.update(seqAndMsg); + byte[] hMAC_MD5 = m.doFinal(); + + /* First 10 bytes of HMAC_MD5 digest */ + byte macBuffer[] = new byte[10]; + System.arraycopy(hMAC_MD5, 0, macBuffer, 0, 10); + + return macBuffer; + } catch (InvalidKeyException e) { + throw new SaslException("Invalid bytes used for key of HMAC-MD5 hash.", e); + } catch (NoSuchAlgorithmException e) { + throw new SaslException("Error creating instance of MD5 MAC algorithm", e); + } + } + + private void intToByte(int num) { + for(int i = 3; i >= 0; i --) { + seqNum[i] = (byte)(num & 0xff); + num >>>= 8; + } + } + + private int byteToInt(byte[] seqNum) { + int answer = 0; + for (int i = 0; i < 4; i ++) { + answer <<= 8; + answer |= ((int)seqNum[i] & 0xff); + } + return answer; + } + } +}
