http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/LoadBalanceSession.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/LoadBalanceSession.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/LoadBalanceSession.java new file mode 100644 index 0000000..6386088 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/LoadBalanceSession.java @@ -0,0 +1,641 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.apache.nifi.controller.queue.LoadBalanceCompression; +import org.apache.nifi.controller.queue.clustered.FlowFileContentAccess; +import org.apache.nifi.controller.queue.clustered.TransactionThreshold; +import org.apache.nifi.controller.queue.clustered.client.LoadBalanceFlowFileCodec; +import org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants; +import org.apache.nifi.controller.queue.clustered.server.TransactionAbortedException; +import org.apache.nifi.controller.repository.ContentNotFoundException; +import org.apache.nifi.controller.repository.FlowFileRecord; +import org.apache.nifi.remote.StandardVersionNegotiator; +import org.apache.nifi.remote.VersionNegotiator; +import org.apache.nifi.stream.io.ByteCountingOutputStream; +import org.apache.nifi.stream.io.GZIPOutputStream; +import org.apache.nifi.stream.io.StreamUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.OptionalInt; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.ABORT_PROTOCOL_NEGOTIATION; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.ABORT_TRANSACTION; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.CONFIRM_CHECKSUM; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.CONFIRM_COMPLETE_TRANSACTION; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.QUEUE_FULL; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.REJECT_CHECKSUM; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.REQEUST_DIFFERENT_VERSION; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.SPACE_AVAILABLE; +import static org.apache.nifi.controller.queue.clustered.protocol.LoadBalanceProtocolConstants.VERSION_ACCEPTED; + + +public class LoadBalanceSession { + private static final Logger logger = LoggerFactory.getLogger(LoadBalanceSession.class); + static final int MAX_DATA_FRAME_SIZE = 65535; + private static final long PENALTY_MILLIS = TimeUnit.SECONDS.toMillis(2L); + + private final RegisteredPartition partition; + private final Supplier<FlowFileRecord> flowFileSupplier; + private final FlowFileContentAccess flowFileContentAccess; + private final LoadBalanceFlowFileCodec flowFileCodec; + private final PeerChannel channel; + private final int timeoutMillis; + private final String peerDescription; + private final String connectionId; + private final TransactionThreshold transactionThreshold; + + final VersionNegotiator negotiator = new StandardVersionNegotiator(1); + private int protocolVersion = 1; + + private final Checksum checksum = new CRC32(); + + // guarded by synchronizing on 'this' + private ByteBuffer preparedFrame; + private FlowFileRecord currentFlowFile; + private List<FlowFileRecord> flowFilesSent = new ArrayList<>(); + private TransactionPhase phase = TransactionPhase.RECOMMEND_PROTOCOL_VERSION; + private InputStream flowFileInputStream; + private byte[] byteBuffer = new byte[MAX_DATA_FRAME_SIZE]; + private boolean complete = false; + private long readTimeout; + private long penaltyExpiration = -1L; + + public LoadBalanceSession(final RegisteredPartition partition, final FlowFileContentAccess contentAccess, final LoadBalanceFlowFileCodec flowFileCodec, final PeerChannel peerChannel, + final int timeoutMillis, final TransactionThreshold transactionThreshold) { + this.partition = partition; + this.flowFileSupplier = partition.getFlowFileRecordSupplier(); + this.connectionId = partition.getConnectionId(); + this.flowFileContentAccess = contentAccess; + this.flowFileCodec = flowFileCodec; + this.channel = peerChannel; + this.peerDescription = peerChannel.getPeerDescription(); + + if (timeoutMillis < 1) { + throw new IllegalArgumentException(); + } + this.timeoutMillis = timeoutMillis; + this.transactionThreshold = transactionThreshold; + } + + public RegisteredPartition getPartition() { + return partition; + } + + public synchronized int getDesiredReadinessFlag() { + return phase.getRequiredSelectionKey(); + } + + public synchronized List<FlowFileRecord> getFlowFilesSent() { + return Collections.unmodifiableList(flowFilesSent); + } + + public synchronized boolean isComplete() { + return complete; + } + + public synchronized boolean communicate() throws IOException { + if (isComplete()) { + return false; + } + + if (isPenalized()) { + logger.debug("Will not communicate with Peer {} for Connection {} because session is penalized", peerDescription, connectionId); + return false; + } + + // If there's already a data frame prepared for writing, just write to the channel. + if (preparedFrame != null && preparedFrame.hasRemaining()) { + logger.trace("Current Frame is already available. Will continue writing current frame to channel"); + final int bytesWritten = channel.write(preparedFrame); + return bytesWritten > 0; + } + + try { + // Check if the phase is one that needs to receive data and if so, call the appropriate method. + switch (phase) { + case RECEIVE_SPACE_RESPONSE: + return receiveSpaceAvailableResponse(); + case VERIFY_CHECKSUM: + return verifyChecksum(); + case CONFIRM_TRANSACTION_COMPLETE: + return confirmTransactionComplete(); + case RECEIVE_PROTOCOL_VERSION_ACKNOWLEDGMENT: + return receiveProtocolVersionAcknowledgment(); + case RECEIVE_RECOMMENDED_PROTOCOL_VERSION: + return receiveRecommendedProtocolVersion(); + } + + // Otherwise, we need to send something so get the data frame that should be sent and write it to the channel + final ByteBuffer byteBuffer = getDataFrame(); + preparedFrame = channel.prepareForWrite(byteBuffer); // Prepare data frame for writing. E.g., encrypt the data, etc. + + final int bytesWritten = channel.write(preparedFrame); + return bytesWritten > 0; + } catch (final Exception e) { + complete = true; + throw e; + } + } + + + private boolean confirmTransactionComplete() throws IOException { + logger.debug("Confirming Transaction Complete for Peer {}", peerDescription); + + final OptionalInt transactionResponse = channel.read(); + if (!transactionResponse.isPresent()) { + if (System.currentTimeMillis() > readTimeout) { + throw new SocketTimeoutException("Timed out waiting for Peer " + peerDescription + " to confirm the transaction is complete"); + } + + return false; + } + + final int response = transactionResponse.getAsInt(); + if (response < 0) { + throw new EOFException("Confirmed checksum when writing data to Peer " + peerDescription + " but encountered End-of-File when expecting a Transaction Complete confirmation"); + } + + if (response == ABORT_TRANSACTION) { + throw new TransactionAbortedException("Confirmed checksum when writing data to Peer " + peerDescription + " but Peer aborted transaction instead of completing it"); + } + if (response != CONFIRM_COMPLETE_TRANSACTION) { + throw new IOException("Expected a CONFIRM_COMPLETE_TRANSACTION response from Peer " + peerDescription + " but received a value of " + response); + } + + complete = true; + logger.debug("Successfully completed Transaction to send {} FlowFiles to Peer {} for Connection {}", flowFilesSent.size(), peerDescription, connectionId); + + return true; + } + + + private boolean verifyChecksum() throws IOException { + logger.debug("Verifying Checksum for Peer {}", peerDescription); + + final OptionalInt checksumResponse = channel.read(); + if (!checksumResponse.isPresent()) { + if (System.currentTimeMillis() > readTimeout) { + throw new SocketTimeoutException("Timed out waiting for Peer " + peerDescription + " to verify the checksum"); + } + + return false; + } + + final int response = checksumResponse.getAsInt(); + if (response < 0) { + throw new EOFException("Encountered End-of-File when trying to verify Checksum with Peer " + peerDescription); + } + + if (response == REJECT_CHECKSUM) { + throw new TransactionAbortedException("After transferring FlowFiles to Peer " + peerDescription + " received a REJECT_CHECKSUM response. Aborting transaction."); + } + if (response != CONFIRM_CHECKSUM) { + throw new TransactionAbortedException("After transferring FlowFiles to Peer " + peerDescription + " received an unexpected response code " + response + + ". Aborting transaction."); + } + + logger.debug("Checksum confirmed. Writing COMPLETE_TRANSACTION flag"); + phase = TransactionPhase.SEND_TRANSACTION_COMPLETE; + + return true; + } + + + + private ByteBuffer getDataFrame() throws IOException { + switch (phase) { + case RECOMMEND_PROTOCOL_VERSION: + return recommendProtocolVersion(); + case ABORT_PROTOCOL_NEGOTIATION: + return abortProtocolNegotiation(); + case SEND_CONNECTION_ID: + return getConnectionId(); + case CHECK_SPACE: + return checkSpace(); + case GET_NEXT_FLOWFILE: + return getNextFlowFile(); + case SEND_FLOWFILE_DEFINITION: + case SEND_FLOWFILE_CONTENTS: + return getFlowFileContent(); + case SEND_CHECKSUM: + return getChecksum(); + case SEND_TRANSACTION_COMPLETE: + return getTransactionComplete(); + default: + logger.debug("Phase of {}, returning null ByteBuffer", phase); + return null; + } + } + + + private ByteBuffer getTransactionComplete() { + logger.debug("Sending Transaction Complete Indicator to Peer {}", peerDescription); + + final ByteBuffer buffer = ByteBuffer.allocate(1); + buffer.put((byte) LoadBalanceProtocolConstants.COMPLETE_TRANSACTION); + buffer.rewind(); + + readTimeout = System.currentTimeMillis() + timeoutMillis; + phase = TransactionPhase.CONFIRM_TRANSACTION_COMPLETE; + return buffer; + } + + private ByteBuffer getChecksum() { + logger.debug("Sending Checksum of {} to Peer {}", checksum.getValue(), peerDescription); + + // No more FlowFiles. + final ByteBuffer buffer = ByteBuffer.allocate(8); + buffer.putLong(checksum.getValue()); + + readTimeout = System.currentTimeMillis() + timeoutMillis; + phase = TransactionPhase.VERIFY_CHECKSUM; + buffer.rewind(); + return buffer; + } + + private ByteBuffer getFlowFileContent() throws IOException { + // This method is fairly inefficient, copying lots of byte[]. Can do better. But keeping it simple for + // now to get this working. Revisit with optimizations later. + try { + if (flowFileInputStream == null) { + flowFileInputStream = flowFileContentAccess.read(currentFlowFile); + } + + final int bytesRead = StreamUtils.fillBuffer(flowFileInputStream, byteBuffer, false); + if (bytesRead < 1) { + // If no data available, close the stream and move on to the next phase, returning a NO_DATA_FRAME buffer. + flowFileInputStream.close(); + flowFileInputStream = null; + phase = TransactionPhase.GET_NEXT_FLOWFILE; + + final ByteBuffer buffer = ByteBuffer.allocate(1); + buffer.put((byte) LoadBalanceProtocolConstants.NO_DATA_FRAME); + buffer.rewind(); + + checksum.update(LoadBalanceProtocolConstants.NO_DATA_FRAME); + + logger.debug("Sending NO_DATA_FRAME indicator to Peer {}", peerDescription); + + return buffer; + } + + logger.trace("Sending Data Frame that is {} bytes long to Peer {}", bytesRead, peerDescription); + final ByteBuffer buffer; + + if (partition.getCompression() == LoadBalanceCompression.COMPRESS_ATTRIBUTES_AND_CONTENT) { + final byte[] compressed = compressDataFrame(byteBuffer, bytesRead); + final int compressedMaxLen = compressed.length; + + buffer = ByteBuffer.allocate(3 + compressedMaxLen); + buffer.put((byte) LoadBalanceProtocolConstants.DATA_FRAME_FOLLOWS); + buffer.putShort((short) compressedMaxLen); + + buffer.put(compressed, 0, compressedMaxLen); + + } else { + buffer = ByteBuffer.allocate(3 + bytesRead); + buffer.put((byte) LoadBalanceProtocolConstants.DATA_FRAME_FOLLOWS); + buffer.putShort((short) bytesRead); + + buffer.put(byteBuffer, 0, bytesRead); + } + + final byte[] frameArray = buffer.array(); + checksum.update(frameArray, 0, frameArray.length); + + phase = TransactionPhase.SEND_FLOWFILE_CONTENTS; + buffer.rewind(); + return buffer; + } catch (final ContentNotFoundException cnfe) { + throw new ContentNotFoundException(currentFlowFile, cnfe.getMissingClaim(), cnfe.getMessage()); + } + } + + private byte[] compressDataFrame(final byte[] uncompressed, final int byteCount) throws IOException { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final OutputStream gzipOut = new GZIPOutputStream(baos, 1)) { + + gzipOut.write(uncompressed, 0, byteCount); + gzipOut.close(); + + return baos.toByteArray(); + } + } + + private ByteBuffer getNextFlowFile() throws IOException { + if (transactionThreshold.isThresholdMet()) { + currentFlowFile = null; + logger.debug("Transaction Threshold reached sending to Peer {}; Transitioning phase to SEND_CHECKSUM", peerDescription); + } else { + currentFlowFile = flowFileSupplier.get(); + + if (currentFlowFile == null) { + logger.debug("No more FlowFiles to send to Peer {}; Transitioning phase to SEND_CHECKSUM", peerDescription); + } + } + + if (currentFlowFile == null) { + phase = TransactionPhase.SEND_CHECKSUM; + return noMoreFlowFiles(); + } + + transactionThreshold.adjust(1, currentFlowFile.getSize()); + logger.debug("Next FlowFile to send to Peer {} is {}", peerDescription, currentFlowFile); + flowFilesSent.add(currentFlowFile); + + final LoadBalanceCompression compression = partition.getCompression(); + final boolean compressAttributes = compression != LoadBalanceCompression.DO_NOT_COMPRESS; + logger.debug("Compression to use for sending to Peer {} is {}", peerDescription, compression); + + final byte[] flowFileEncoded; + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream()) { + if (compressAttributes) { + try (final OutputStream gzipOut = new GZIPOutputStream(baos, 1); + final ByteCountingOutputStream out = new ByteCountingOutputStream(gzipOut)) { + + flowFileCodec.encode(currentFlowFile, out); + } + } else { + flowFileCodec.encode(currentFlowFile, baos); + } + + flowFileEncoded = baos.toByteArray(); + } + + final int metadataLength = flowFileEncoded.length; + final ByteBuffer buffer = ByteBuffer.allocate(flowFileEncoded.length + 5); + buffer.put((byte) LoadBalanceProtocolConstants.MORE_FLOWFILES); + checksum.update(LoadBalanceProtocolConstants.MORE_FLOWFILES); + + buffer.putInt(metadataLength); + checksum.update((metadataLength >> 24) & 0xFF); + checksum.update((metadataLength >> 16) & 0xFF); + checksum.update((metadataLength >> 8) & 0xFF); + checksum.update(metadataLength & 0xFF); + + buffer.put(flowFileEncoded); + checksum.update(flowFileEncoded, 0, flowFileEncoded.length); + + phase = TransactionPhase.SEND_FLOWFILE_DEFINITION; + buffer.rewind(); + return buffer; + } + + + private ByteBuffer recommendProtocolVersion() { + logger.debug("Recommending to Peer {} that Protocol Version {} be used", peerDescription, protocolVersion); + + final ByteBuffer buffer = ByteBuffer.allocate(1); + buffer.put((byte) protocolVersion); + buffer.rewind(); + + readTimeout = System.currentTimeMillis() + timeoutMillis; + phase = TransactionPhase.RECEIVE_PROTOCOL_VERSION_ACKNOWLEDGMENT; + return buffer; + } + + private boolean receiveProtocolVersionAcknowledgment() throws IOException { + logger.debug("Confirming Transaction Complete for Peer {}", peerDescription); + + final OptionalInt ackResponse = channel.read(); + if (!ackResponse.isPresent()) { + if (System.currentTimeMillis() > readTimeout) { + throw new SocketTimeoutException("Timed out waiting for Peer " + peerDescription + " to acknowledge Protocol Version"); + } + + return false; + } + + final int response = ackResponse.getAsInt(); + if (response < 0) { + throw new EOFException("Encounter End-of-File with Peer " + peerDescription + " when expecting a Protocol Version Acknowledgment"); + } + + if (response == VERSION_ACCEPTED) { + logger.debug("Peer {} accepted Protocol Version {}", peerDescription, protocolVersion); + phase = TransactionPhase.SEND_CONNECTION_ID; + return true; + } + + if (response == REQEUST_DIFFERENT_VERSION) { + logger.debug("Recommended using Protocol Version of {} with Peer {} but received REQUEST_DIFFERENT_VERSION response", protocolVersion, peerDescription); + readTimeout = System.currentTimeMillis() + timeoutMillis; + phase = TransactionPhase.RECEIVE_RECOMMENDED_PROTOCOL_VERSION; + return true; + } + + throw new IOException("Failed to negotiate Protocol Version with Peer " + peerDescription + ". Recommended version " + protocolVersion + " but instead of an ACCEPT or REJECT " + + "response got back a response of " + response); + } + + private boolean receiveRecommendedProtocolVersion() throws IOException { + logger.debug("Receiving Protocol Version from Peer {}", peerDescription); + + final OptionalInt recommendationResponse = channel.read(); + if (!recommendationResponse.isPresent()) { + if (System.currentTimeMillis() > readTimeout) { + throw new SocketTimeoutException("Timed out waiting for Peer " + peerDescription + " to recommend Protocol Version"); + } + + return false; + } + + final int requestedVersion = recommendationResponse.getAsInt(); + if (requestedVersion < 0) { + throw new EOFException("Encounter End-of-File with Peer " + peerDescription + " when expecting a Protocol Version Recommendation"); + } + + if (negotiator.isVersionSupported(requestedVersion)) { + protocolVersion = requestedVersion; + phase = TransactionPhase.SEND_CONNECTION_ID; + logger.debug("Peer {} recommended Protocol Version of {}. Accepting version.", peerDescription, requestedVersion); + + return true; + } else { + final Integer preferred = negotiator.getPreferredVersion(requestedVersion); + if (preferred == null) { + logger.debug("Peer {} requested version {} of the Load Balance Protocol. This version is not acceptable. Aborting communications.", peerDescription, requestedVersion); + phase = TransactionPhase.ABORT_PROTOCOL_NEGOTIATION; + return true; + } else { + logger.debug("Peer {} requested version {} of the Protocol. Recommending version {} instead", peerDescription, requestedVersion, preferred); + protocolVersion = preferred; + phase = TransactionPhase.RECOMMEND_PROTOCOL_VERSION; + return true; + } + } + } + + private ByteBuffer noMoreFlowFiles() { + final ByteBuffer buffer = ByteBuffer.allocate(1); + buffer.put((byte) LoadBalanceProtocolConstants.NO_MORE_FLOWFILES); + buffer.rewind(); + + checksum.update(LoadBalanceProtocolConstants.NO_MORE_FLOWFILES); + return buffer; + } + + private ByteBuffer abortProtocolNegotiation() { + final ByteBuffer buffer = ByteBuffer.allocate(1); + buffer.put((byte) ABORT_PROTOCOL_NEGOTIATION); + buffer.rewind(); + + return buffer; + } + + private ByteBuffer getConnectionId() { + logger.debug("Sending Connection ID {} to Peer {}", connectionId, peerDescription); + + final ByteBuffer buffer = ByteBuffer.allocate(connectionId.length() + 2); + buffer.putShort((short) connectionId.length()); + buffer.put(connectionId.getBytes(StandardCharsets.UTF_8)); + buffer.rewind(); + + final byte[] frameBytes = buffer.array(); + checksum.update(frameBytes, 0, frameBytes.length); + + phase = TransactionPhase.CHECK_SPACE; + return buffer; + } + + private ByteBuffer checkSpace() { + logger.debug("Sending a 'Check Space' request to Peer {} to determine if there is space in the queue for more FlowFiles", peerDescription); + + final ByteBuffer buffer = ByteBuffer.allocate(1); + + if (partition.isHonorBackpressure()) { + buffer.put((byte) LoadBalanceProtocolConstants.CHECK_SPACE); + checksum.update(LoadBalanceProtocolConstants.CHECK_SPACE); + + readTimeout = System.currentTimeMillis() + timeoutMillis; + phase = TransactionPhase.RECEIVE_SPACE_RESPONSE; + } else { + buffer.put((byte) LoadBalanceProtocolConstants.SKIP_SPACE_CHECK); + checksum.update(LoadBalanceProtocolConstants.SKIP_SPACE_CHECK); + + phase = TransactionPhase.GET_NEXT_FLOWFILE; + } + + buffer.rewind(); + return buffer; + } + + + private boolean receiveSpaceAvailableResponse() throws IOException { + logger.debug("Receiving response from Peer {} to determine whether or not space is available in queue {}", peerDescription, connectionId); + + final OptionalInt spaceAvailableResponse = channel.read(); + if (!spaceAvailableResponse.isPresent()) { + if (System.currentTimeMillis() > readTimeout) { + throw new SocketTimeoutException("Timed out waiting for Peer " + peerDescription + " to verify whether or not space is available for Connection " + connectionId); + } + + return false; + } + + final int response = spaceAvailableResponse.getAsInt(); + if (response < 0) { + throw new EOFException("Encountered End-of-File when trying to verify with Peer " + peerDescription + " whether or not space is available in Connection " + connectionId); + } + + if (response == SPACE_AVAILABLE) { + logger.debug("Peer {} has confirmed that space is available in Connection {}", peerDescription, connectionId); + phase = TransactionPhase.GET_NEXT_FLOWFILE; + } else if (response == QUEUE_FULL) { + logger.debug("Peer {} has confirmed that the queue is full for Connection {}", peerDescription, connectionId); + phase = TransactionPhase.RECOMMEND_PROTOCOL_VERSION; + checksum.reset(); // We are restarting the session entirely so we need to reset our checksum + penalize(); + } else { + throw new TransactionAbortedException("After requesting to know whether or not Peer " + peerDescription + " has space available in Connection " + connectionId + + ", received unexpected response of " + response + ". Aborting transaction."); + } + + return true; + } + + private void penalize() { + penaltyExpiration = System.currentTimeMillis() + PENALTY_MILLIS; + } + + private boolean isPenalized() { + // check for penaltyExpiration > -1L is not strictly necessary as it's implied by the second check but is still + // here because it's more efficient to check this than to make the system call to System.currentTimeMillis(). + return penaltyExpiration > -1L && System.currentTimeMillis() < penaltyExpiration; + } + + + private enum TransactionPhase { + RECOMMEND_PROTOCOL_VERSION(SelectionKey.OP_WRITE), + + RECEIVE_PROTOCOL_VERSION_ACKNOWLEDGMENT(SelectionKey.OP_READ), + + RECEIVE_RECOMMENDED_PROTOCOL_VERSION(SelectionKey.OP_READ), + + ABORT_PROTOCOL_NEGOTIATION(SelectionKey.OP_WRITE), + + SEND_CONNECTION_ID(SelectionKey.OP_WRITE), + + CHECK_SPACE(SelectionKey.OP_WRITE), + + RECEIVE_SPACE_RESPONSE(SelectionKey.OP_READ), + + SEND_FLOWFILE_DEFINITION(SelectionKey.OP_WRITE), + + SEND_FLOWFILE_CONTENTS(SelectionKey.OP_WRITE), + + GET_NEXT_FLOWFILE(SelectionKey.OP_WRITE), + + SEND_CHECKSUM(SelectionKey.OP_WRITE), + + VERIFY_CHECKSUM(SelectionKey.OP_READ), + + SEND_TRANSACTION_COMPLETE(SelectionKey.OP_WRITE), + + CONFIRM_TRANSACTION_COMPLETE(SelectionKey.OP_READ); + + + private final int requiredSelectionKey; + + TransactionPhase(final int requiredSelectionKey) { + this.requiredSelectionKey = requiredSelectionKey; + } + + public int getRequiredSelectionKey() { + return requiredSelectionKey; + } + } +}
http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClient.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClient.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClient.java new file mode 100644 index 0000000..066b597 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClient.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.apache.nifi.cluster.protocol.NodeIdentifier; +import org.apache.nifi.controller.queue.LoadBalanceCompression; +import org.apache.nifi.controller.queue.clustered.FlowFileContentAccess; +import org.apache.nifi.controller.queue.clustered.SimpleLimitThreshold; +import org.apache.nifi.controller.queue.clustered.TransactionThreshold; +import org.apache.nifi.controller.queue.clustered.client.LoadBalanceFlowFileCodec; +import org.apache.nifi.controller.queue.clustered.client.async.AsyncLoadBalanceClient; +import org.apache.nifi.controller.queue.clustered.client.async.TransactionCompleteCallback; +import org.apache.nifi.controller.queue.clustered.client.async.TransactionFailureCallback; +import org.apache.nifi.controller.repository.FlowFileRecord; +import org.apache.nifi.events.EventReporter; +import org.apache.nifi.reporting.Severity; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BooleanSupplier; +import java.util.function.Predicate; +import java.util.function.Supplier; + + +public class NioAsyncLoadBalanceClient implements AsyncLoadBalanceClient { + private static final Logger logger = LoggerFactory.getLogger(NioAsyncLoadBalanceClient.class); + private static final long PENALIZATION_MILLIS = TimeUnit.SECONDS.toMillis(1L); + + private final NodeIdentifier nodeIdentifier; + private final SSLContext sslContext; + private final int timeoutMillis; + private final FlowFileContentAccess flowFileContentAccess; + private final LoadBalanceFlowFileCodec flowFileCodec; + private final EventReporter eventReporter; + + private volatile boolean running = false; + private final AtomicLong penalizationEnd = new AtomicLong(0L); + + private final Map<String, RegisteredPartition> registeredPartitions = new HashMap<>(); + private final Queue<RegisteredPartition> partitionQueue = new LinkedBlockingQueue<>(); + + // guarded by synchronizing on this + private PeerChannel channel; + private Selector selector; + private SelectionKey selectionKey; + + // While we use synchronization to guard most of the Class's state, we use a separate lock for the LoadBalanceSession. + // We do this because we need to atomically decide whether or not we are able to communicate over the socket with another node and if so, continue on and do so. + // However, we cannot do this within a synchronized block because if we did, then if Thread 1 were communicating with the remote node, and Thread 2 wanted to attempt + // to do so, it would have to wait until Thread 1 released the synchronization. Instead, we want Thread 2 to determine that the resource is not free and move on. + // I.e., we need to use the capability of Lock#tryLock, and the synchronized keyword does not offer this sort of functionality. + private final Lock loadBalanceSessionLock = new ReentrantLock(); + private LoadBalanceSession loadBalanceSession = null; + + + public NioAsyncLoadBalanceClient(final NodeIdentifier nodeIdentifier, final SSLContext sslContext, final int timeoutMillis, final FlowFileContentAccess flowFileContentAccess, + final LoadBalanceFlowFileCodec flowFileCodec, final EventReporter eventReporter) { + this.nodeIdentifier = nodeIdentifier; + this.sslContext = sslContext; + this.timeoutMillis = timeoutMillis; + this.flowFileContentAccess = flowFileContentAccess; + this.flowFileCodec = flowFileCodec; + this.eventReporter = eventReporter; + } + + @Override + public NodeIdentifier getNodeIdentifier() { + return nodeIdentifier; + } + + public synchronized void register(final String connectionId, final BooleanSupplier emptySupplier, final Supplier<FlowFileRecord> flowFileSupplier, + final TransactionFailureCallback failureCallback, final TransactionCompleteCallback successCallback, + final Supplier<LoadBalanceCompression> compressionSupplier, final BooleanSupplier honorBackpressureSupplier) { + + if (registeredPartitions.containsKey(connectionId)) { + throw new IllegalStateException("Connection with ID " + connectionId + " is already registered"); + } + + final RegisteredPartition partition = new RegisteredPartition(connectionId, emptySupplier, flowFileSupplier, failureCallback, successCallback, compressionSupplier, honorBackpressureSupplier); + registeredPartitions.put(connectionId, partition); + partitionQueue.add(partition); + } + + public synchronized void unregister(final String connectionId) { + registeredPartitions.remove(connectionId); + } + + private synchronized Map<String, RegisteredPartition> getRegisteredPartitions() { + return new HashMap<>(registeredPartitions); + } + + public void start() { + running = true; + logger.debug("{} started", this); + } + + public void stop() { + running = false; + logger.debug("{} stopped", this); + close(); + } + + private synchronized void close() { + if (selector != null && selector.isOpen()) { + try { + selector.close(); + } catch (final Exception e) { + logger.warn("Failed to close NIO Selector", e); + } + } + + if (channel != null && channel.isOpen()) { + try { + channel.close(); + } catch (final Exception e) { + logger.warn("Failed to close Socket Channel to {} for Load Balancing", nodeIdentifier, e); + } + } + + channel = null; + selector = null; + } + + public boolean isRunning() { + return running; + } + + public boolean isPenalized() { + final long endTimestamp = penalizationEnd.get(); + if (endTimestamp == 0) { + return false; + } + + if (endTimestamp < System.currentTimeMillis()) { + // set penalization end to 0 so that next time we don't need to check System.currentTimeMillis() because + // systems calls are expensive enough that we'd like to avoid them when we can. + penalizationEnd.compareAndSet(endTimestamp, 0L); + return false; + } + + return true; + } + + private void penalize() { + logger.debug("Penalizing {}", this); + this.penalizationEnd.set(System.currentTimeMillis() + PENALIZATION_MILLIS); + } + + + public boolean communicate() throws IOException { + if (!running) { + return false; + } + + // Use #tryLock here so that if another thread is already communicating with this Client, this thread + // will not block and wait but instead will just return so that the Thread Pool can proceed to the next Client. + if (!loadBalanceSessionLock.tryLock()) { + return false; + } + + try { + RegisteredPartition readyPartition = null; + + if (!isConnectionEstablished()) { + readyPartition = getReadyPartition(); + if (readyPartition == null) { + logger.debug("{} has no connection with data ready to be transmitted so will penalize Client without communicating", this); + penalize(); + return false; + } + + try { + establishConnection(); + } catch (IOException e) { + penalize(); + + partitionQueue.offer(readyPartition); + + for (final RegisteredPartition partition : getRegisteredPartitions().values()) { + logger.debug("Triggering Transaction Failure Callback for {} with Transaction Phase of CONNECTING", partition); + partition.getFailureCallback().onTransactionFailed(Collections.emptyList(), e, TransactionFailureCallback.TransactionPhase.CONNECTING); + } + + return false; + } + } + + final LoadBalanceSession loadBalanceSession = getActiveTransaction(readyPartition); + if (loadBalanceSession == null) { + penalize(); + return false; + } + + selector.selectNow(); + final boolean ready = (loadBalanceSession.getDesiredReadinessFlag() & selectionKey.readyOps()) != 0; + if (!ready) { + return false; + } + + boolean anySuccess = false; + boolean success; + do { + try { + success = loadBalanceSession.communicate(); + } catch (final Exception e) { + logger.error("Failed to communicate with Peer {}", nodeIdentifier.toString(), e); + eventReporter.reportEvent(Severity.ERROR, "Load Balanced Connection", "Failed to communicate with Peer " + nodeIdentifier + " when load balancing data for Connection with ID " + + loadBalanceSession.getPartition().getConnectionId() + " due to " + e); + + penalize(); + loadBalanceSession.getPartition().getFailureCallback().onTransactionFailed(loadBalanceSession.getFlowFilesSent(), e, TransactionFailureCallback.TransactionPhase.SENDING); + close(); + + return false; + } + + anySuccess = anySuccess || success; + } while (success); + + if (loadBalanceSession.isComplete()) { + loadBalanceSession.getPartition().getSuccessCallback().onTransactionComplete(loadBalanceSession.getFlowFilesSent()); + } + + return anySuccess; + } catch (final Exception e) { + close(); + loadBalanceSession = null; + throw e; + } finally { + loadBalanceSessionLock.unlock(); + } + } + + /** + * If any FlowFiles have been transferred in an active session, fail the transaction. Otherwise, gather up to the Transaction Threshold's limits + * worth of FlowFiles and treat them as a failed transaction. In either case, terminate the session. This allows us to transfer FlowFiles from + * queue partitions where the partitioner indicates that the data should be rebalanced, but does so in a way that we don't immediately rebalance + * all FlowFiles. This is desirable in a case such as when we have a lot of data queued up in a connection and then a node temporarily disconnects. + * We don't want to then just push all data to other nodes. We'd rather push the data out to other nodes slowly while waiting for the disconnected + * node to reconnect. And if the node reconnects, we want to keep sending it data. + */ + public void nodeDisconnected() { + if (!loadBalanceSessionLock.tryLock()) { + // If we are not able to obtain the loadBalanceSessionLock, we cannot access the load balance session. + return; + } + + try { + final LoadBalanceSession session = getFailoverSession(); + if (session != null) { + loadBalanceSession = null; + + logger.debug("Node {} disconnected so will terminate the Load Balancing Session", nodeIdentifier); + final List<FlowFileRecord> flowFilesSent = session.getFlowFilesSent(); + + if (!flowFilesSent.isEmpty()) { + session.getPartition().getFailureCallback().onTransactionFailed(session.getFlowFilesSent(), TransactionFailureCallback.TransactionPhase.SENDING); + } + + close(); + penalize(); + return; + } + + // Obtain a partition that needs to be rebalanced on failure + final RegisteredPartition readyPartition = getReadyPartition(partition -> partition.getFailureCallback().isRebalanceOnFailure()); + if (readyPartition == null) { + return; + } + + partitionQueue.offer(readyPartition); // allow partition to be obtained again + final TransactionThreshold threshold = newTransactionThreshold(); + + final List<FlowFileRecord> flowFiles = new ArrayList<>(); + while (!threshold.isThresholdMet()) { + final FlowFileRecord flowFile = readyPartition.getFlowFileRecordSupplier().get(); + if (flowFile == null) { + break; + } + + flowFiles.add(flowFile); + threshold.adjust(1, flowFile.getSize()); + } + + logger.debug("Node {} not connected so failing {} FlowFiles for Load Balancing", nodeIdentifier, flowFiles.size()); + readyPartition.getFailureCallback().onTransactionFailed(flowFiles, TransactionFailureCallback.TransactionPhase.SENDING); + penalize(); // Don't just transfer FlowFiles out of queue's partition as fast as possible, because the node may only be disconnected for a short time. + } finally { + loadBalanceSessionLock.unlock(); + } + } + + private synchronized LoadBalanceSession getFailoverSession() { + if (loadBalanceSession != null && !loadBalanceSession.isComplete()) { + return loadBalanceSession; + } + + return null; + } + + + private RegisteredPartition getReadyPartition() { + return getReadyPartition(partition -> true); + } + + private synchronized RegisteredPartition getReadyPartition(final Predicate<RegisteredPartition> filter) { + final List<RegisteredPartition> polledPartitions = new ArrayList<>(); + + try { + RegisteredPartition partition; + while ((partition = partitionQueue.poll()) != null) { + if (partition.isEmpty() || !filter.test(partition)) { + polledPartitions.add(partition); + continue; + } + + return partition; + } + + return null; + } finally { + polledPartitions.forEach(partitionQueue::offer); + } + } + + private synchronized LoadBalanceSession getActiveTransaction(final RegisteredPartition proposedPartition) { + if (loadBalanceSession != null && !loadBalanceSession.isComplete()) { + return loadBalanceSession; + } + + final RegisteredPartition readyPartition = proposedPartition == null ? getReadyPartition() : proposedPartition; + if (readyPartition == null) { + return null; + } + + loadBalanceSession = new LoadBalanceSession(readyPartition, flowFileContentAccess, flowFileCodec, channel, timeoutMillis, newTransactionThreshold()); + partitionQueue.offer(readyPartition); + + return loadBalanceSession; + } + + private TransactionThreshold newTransactionThreshold() { + return new SimpleLimitThreshold(1000, 10_000_000L); + } + + private synchronized boolean isConnectionEstablished() { + return selector != null && channel != null && channel.isConnected(); + } + + private synchronized void establishConnection() throws IOException { + SocketChannel socketChannel = null; + + try { + selector = Selector.open(); + socketChannel = createChannel(); + + socketChannel.configureBlocking(true); + + channel = createPeerChannel(socketChannel, nodeIdentifier.toString()); + channel.performHandshake(); + + socketChannel.configureBlocking(false); + selectionKey = socketChannel.register(selector, SelectionKey.OP_WRITE | SelectionKey.OP_READ); + } catch (Exception e) { + logger.error("Unable to connect to {} for load balancing", nodeIdentifier, e); + + if (selector != null) { + try { + selector.close(); + } catch (final Exception e1) { + e.addSuppressed(e1); + } + } + + if (channel != null) { + try { + channel.close(); + } catch (final Exception e1) { + e.addSuppressed(e1); + } + } + + if (socketChannel != null) { + try { + socketChannel.close(); + } catch (final Exception e1) { + e.addSuppressed(e1); + } + } + + throw e; + } + } + + + private PeerChannel createPeerChannel(final SocketChannel channel, final String peerDescription) { + if (sslContext == null) { + logger.debug("No SSL Context is available so will not perform SSL Handshake with Peer {}", peerDescription); + return new PeerChannel(channel, null, peerDescription); + } + + logger.debug("Performing SSL Handshake with Peer {}", peerDescription); + + final SSLEngine sslEngine = sslContext.createSSLEngine(); + sslEngine.setUseClientMode(true); + sslEngine.setNeedClientAuth(true); + + return new PeerChannel(channel, sslEngine, peerDescription); + } + + + private SocketChannel createChannel() throws IOException { + final SocketChannel socketChannel = SocketChannel.open(); + try { + socketChannel.configureBlocking(true); + final Socket socket = socketChannel.socket(); + socket.setSoTimeout(timeoutMillis); + + socket.connect(new InetSocketAddress(nodeIdentifier.getLoadBalanceAddress(), nodeIdentifier.getLoadBalancePort())); + socket.setSoTimeout(timeoutMillis); + + return socketChannel; + } catch (final Exception e) { + try { + socketChannel.close(); + } catch (final Exception closeException) { + e.addSuppressed(closeException); + } + + throw e; + } + } + + + @Override + public String toString() { + return "NioAsyncLoadBalanceClient[nodeId=" + nodeIdentifier + "]"; + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientFactory.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientFactory.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientFactory.java new file mode 100644 index 0000000..79fe4be --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.apache.nifi.cluster.protocol.NodeIdentifier; +import org.apache.nifi.controller.queue.clustered.FlowFileContentAccess; +import org.apache.nifi.controller.queue.clustered.client.LoadBalanceFlowFileCodec; +import org.apache.nifi.controller.queue.clustered.client.StandardLoadBalanceFlowFileCodec; +import org.apache.nifi.controller.queue.clustered.client.async.AsyncLoadBalanceClientFactory; +import org.apache.nifi.events.EventReporter; + +import javax.net.ssl.SSLContext; + +public class NioAsyncLoadBalanceClientFactory implements AsyncLoadBalanceClientFactory { + private final SSLContext sslContext; + private final int timeoutMillis; + private final FlowFileContentAccess flowFileContentAccess; + private final EventReporter eventReporter; + private final LoadBalanceFlowFileCodec flowFileCodec; + + public NioAsyncLoadBalanceClientFactory(final SSLContext sslContext, final int timeoutMillis, final FlowFileContentAccess flowFileContentAccess, final EventReporter eventReporter, + final LoadBalanceFlowFileCodec loadBalanceFlowFileCodec) { + this.sslContext = sslContext; + this.timeoutMillis = timeoutMillis; + this.flowFileContentAccess = flowFileContentAccess; + this.eventReporter = eventReporter; + this.flowFileCodec = loadBalanceFlowFileCodec; + } + + + @Override + public NioAsyncLoadBalanceClient createClient(final NodeIdentifier nodeIdentifier) { + return new NioAsyncLoadBalanceClient(nodeIdentifier, sslContext, timeoutMillis, flowFileContentAccess, new StandardLoadBalanceFlowFileCodec(), eventReporter); + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientRegistry.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientRegistry.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientRegistry.java new file mode 100644 index 0000000..514a58c --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientRegistry.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.apache.nifi.cluster.protocol.NodeIdentifier; +import org.apache.nifi.controller.queue.LoadBalanceCompression; +import org.apache.nifi.controller.queue.clustered.client.async.AsyncLoadBalanceClient; +import org.apache.nifi.controller.queue.clustered.client.async.AsyncLoadBalanceClientRegistry; +import org.apache.nifi.controller.queue.clustered.client.async.TransactionCompleteCallback; +import org.apache.nifi.controller.queue.clustered.client.async.TransactionFailureCallback; +import org.apache.nifi.controller.repository.FlowFileRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArraySet; +import java.util.function.BooleanSupplier; +import java.util.function.Supplier; + +public class NioAsyncLoadBalanceClientRegistry implements AsyncLoadBalanceClientRegistry { + private static final Logger logger = LoggerFactory.getLogger(NioAsyncLoadBalanceClientRegistry.class); + + private final NioAsyncLoadBalanceClientFactory clientFactory; + private final int clientsPerNode; + + private Map<NodeIdentifier, Set<AsyncLoadBalanceClient>> clientMap = new HashMap<>(); + private Set<AsyncLoadBalanceClient> allClients = new CopyOnWriteArraySet<>(); + private boolean running = false; + + public NioAsyncLoadBalanceClientRegistry(final NioAsyncLoadBalanceClientFactory clientFactory, final int clientsPerNode) { + this.clientFactory = clientFactory; + this.clientsPerNode = clientsPerNode; + } + + @Override + public synchronized void register(final String connectionId, final NodeIdentifier nodeId, final BooleanSupplier emptySupplier, final Supplier<FlowFileRecord> flowFileSupplier, + final TransactionFailureCallback failureCallback, final TransactionCompleteCallback successCallback, + final Supplier<LoadBalanceCompression> compressionSupplier, final BooleanSupplier honorBackpressureSupplier) { + + Set<AsyncLoadBalanceClient> clients = clientMap.get(nodeId); + if (clients == null) { + clients = registerClients(nodeId); + } + + clients.forEach(client -> client.register(connectionId, emptySupplier, flowFileSupplier, failureCallback, successCallback, compressionSupplier, honorBackpressureSupplier)); + logger.debug("Registered Connection with ID {} to send to Node {}", connectionId, nodeId); + } + + + @Override + public synchronized void unregister(final String connectionId, final NodeIdentifier nodeId) { + final Set<AsyncLoadBalanceClient> clients = clientMap.remove(nodeId); + if (clients == null) { + return; + } + + clients.forEach(client -> client.unregister(connectionId)); + + allClients.removeAll(clients); + logger.debug("Un-registered Connection with ID {} so that it will no longer send data to Node {}", connectionId, nodeId); + } + + private Set<AsyncLoadBalanceClient> registerClients(final NodeIdentifier nodeId) { + final Set<AsyncLoadBalanceClient> clients = new HashSet<>(); + + for (int i=0; i < clientsPerNode; i++) { + final AsyncLoadBalanceClient client = clientFactory.createClient(nodeId); + clients.add(client); + + logger.debug("Added client {} for communicating with Node {}", client, nodeId); + } + + clientMap.put(nodeId, clients); + allClients.addAll(clients); + + if (running) { + clients.forEach(AsyncLoadBalanceClient::start); + } + + return clients; + } + + public synchronized Set<AsyncLoadBalanceClient> getAllClients() { + return allClients; + } + + public synchronized void start() { + if (running) { + return; + } + + running = true; + allClients.forEach(AsyncLoadBalanceClient::start); + } + + public synchronized void stop() { + if (!running) { + return; + } + + running = false; + allClients.forEach(AsyncLoadBalanceClient::stop); + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientTask.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientTask.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientTask.java new file mode 100644 index 0000000..35ea5f9 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/NioAsyncLoadBalanceClientTask.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.apache.nifi.cluster.coordination.ClusterCoordinator; +import org.apache.nifi.cluster.coordination.node.NodeConnectionState; +import org.apache.nifi.cluster.coordination.node.NodeConnectionStatus; +import org.apache.nifi.cluster.protocol.NodeIdentifier; +import org.apache.nifi.controller.queue.clustered.client.async.AsyncLoadBalanceClient; +import org.apache.nifi.events.EventReporter; +import org.apache.nifi.reporting.Severity; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class NioAsyncLoadBalanceClientTask implements Runnable { + private static final Logger logger = LoggerFactory.getLogger(NioAsyncLoadBalanceClientTask.class); + private static final String EVENT_CATEGORY = "Load-Balanced Connection"; + + private final NioAsyncLoadBalanceClientRegistry clientRegistry; + private final ClusterCoordinator clusterCoordinator; + private final EventReporter eventReporter; + private volatile boolean running = true; + + public NioAsyncLoadBalanceClientTask(final NioAsyncLoadBalanceClientRegistry clientRegistry, final ClusterCoordinator clusterCoordinator, final EventReporter eventReporter) { + this.clientRegistry = clientRegistry; + this.clusterCoordinator = clusterCoordinator; + this.eventReporter = eventReporter; + } + + @Override + public void run() { + while (running) { + try { + boolean success = false; + for (final AsyncLoadBalanceClient client : clientRegistry.getAllClients()) { + if (!client.isRunning()) { + logger.trace("Client {} is not running so will not communicate with it", client); + continue; + } + + if (client.isPenalized()) { + logger.trace("Client {} is penalized so will not communicate with it", client); + continue; + } + + final NodeIdentifier clientNodeId = client.getNodeIdentifier(); + final NodeConnectionStatus connectionStatus = clusterCoordinator.getConnectionStatus(clientNodeId); + if (connectionStatus == null) { + logger.debug("Could not determine Connection Status for Node with ID {}; will not communicate with it", clientNodeId); + continue; + } + + final NodeConnectionState connectionState = connectionStatus.getState(); + if (connectionState == NodeConnectionState.DISCONNECTED || connectionState == NodeConnectionState.DISCONNECTING) { + client.nodeDisconnected(); + continue; + } + + if (connectionState != NodeConnectionState.CONNECTED) { + logger.debug("Client {} is for node that is not currently connected (state = {}) so will not communicate with node", client, connectionState); + continue; + } + + try { + while (client.communicate()) { + success = true; + logger.trace("Client {} was able to make progress communicating with peer. Will continue to communicate with peer.", client); + } + } catch (final Exception e) { + eventReporter.reportEvent(Severity.ERROR, EVENT_CATEGORY, "Failed to communicate with Peer " + + client.getNodeIdentifier() + " while trying to load balance data across the cluster due to " + e.toString()); + logger.error("Failed to communicate with Peer {} while trying to load balance data across the cluster.", client.getNodeIdentifier(), e); + } + + logger.trace("Client {} was no longer able to make progress communicating with peer. Will move on to the next client", client); + } + + if (!success) { + logger.trace("Was unable to communicate with any client. Will sleep for 10 milliseconds."); + Thread.sleep(10L); + } + } catch (final Exception e) { + logger.error("Failed to communicate with peer while trying to load balance data across the cluster", e); + eventReporter.reportEvent(Severity.ERROR, EVENT_CATEGORY, "Failed to comunicate with Peer while trying to load balance data across the cluster due to " + e); + } + } + } + + public void stop() { + running = false; + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/PeerChannel.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/PeerChannel.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/PeerChannel.java new file mode 100644 index 0000000..67afb4a --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/PeerChannel.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.OptionalInt; + +public class PeerChannel implements Closeable { + private static final Logger logger = LoggerFactory.getLogger(PeerChannel.class); + + private final SocketChannel socketChannel; + private final SSLEngine sslEngine; + private final String peerDescription; + + private final ByteBuffer singleByteBuffer = ByteBuffer.allocate(1); + private ByteBuffer destinationBuffer = ByteBuffer.allocate(16 * 1024); // buffer that SSLEngine is to write into + private ByteBuffer streamBuffer = ByteBuffer.allocate(16 * 1024); // buffer for data that is read from SocketChannel + private ByteBuffer applicationBuffer = ByteBuffer.allocate(0); // buffer for application-level data that is ready to be served up (i.e., already decrypted if necessary) + + public PeerChannel(final SocketChannel socketChannel, final SSLEngine sslEngine, final String peerDescription) { + this.socketChannel = socketChannel; + this.sslEngine = sslEngine; + this.peerDescription = peerDescription; + } + + + @Override + public void close() throws IOException { + socketChannel.close(); + } + + public boolean isConnected() { + return socketChannel.isConnected(); + } + + public boolean isOpen() { + return socketChannel.isOpen(); + } + + public String getPeerDescription() { + return peerDescription; + } + + public boolean write(final byte b) throws IOException { + singleByteBuffer.clear(); + singleByteBuffer.put(b); + singleByteBuffer.rewind(); + + final ByteBuffer prepared = prepareForWrite(singleByteBuffer); + final int bytesWritten = write(prepared); + return bytesWritten > 0; + } + + public OptionalInt read() throws IOException { + singleByteBuffer.clear(); + final int bytesRead = read(singleByteBuffer); + if (bytesRead < 0) { + return OptionalInt.of(-1); + } + if (bytesRead == 0) { + return OptionalInt.empty(); + } + + singleByteBuffer.flip(); + + final byte read = singleByteBuffer.get(); + return OptionalInt.of(read & 0xFF); + } + + + + + /** + * Reads the given ByteBuffer of data and returns a new ByteBuffer (which is "flipped" / ready to be read). The newly returned + * ByteBuffer will be written to be written via the {@link #write(ByteBuffer)} method. I.e., it will have already been encrypted, if + * necessary, and any other decorations that need to be applied before sending will already have been applied. + * + * @param plaintext the data to be prepped + * @return a ByteBuffer containing the prepared data + * @throws IOException if a failure occurs while encrypting the data + */ + public ByteBuffer prepareForWrite(final ByteBuffer plaintext) throws IOException { + if (sslEngine == null) { + return plaintext; + } + + + ByteBuffer prepared = ByteBuffer.allocate(Math.min(85, plaintext.capacity() - plaintext.position())); + while (plaintext.hasRemaining()) { + encrypt(plaintext); + + final int bytesRemaining = prepared.capacity() - prepared.position(); + if (bytesRemaining < destinationBuffer.remaining()) { + final ByteBuffer temp = ByteBuffer.allocate(prepared.capacity() + sslEngine.getSession().getApplicationBufferSize()); + prepared.flip(); + temp.put(prepared); + prepared = temp; + } + + prepared.put(destinationBuffer); + } + + prepared.flip(); + return prepared; + } + + public int write(final ByteBuffer preparedBuffer) throws IOException { + return socketChannel.write(preparedBuffer); + } + + + public int read(final ByteBuffer dst) throws IOException { + // If we have data ready to go, then go ahead and copy it. + final int bytesCopied = copy(applicationBuffer, dst); + if (bytesCopied != 0) { + return bytesCopied; + } + + final int bytesRead = socketChannel.read(streamBuffer); + if (bytesRead < 1) { + return bytesRead; + } + + if (bytesRead > 0) { + logger.trace("Read {} bytes from SocketChannel", bytesRead); + } + + streamBuffer.flip(); + + try { + if (sslEngine == null) { + cloneToApplicationBuffer(streamBuffer); + return copy(applicationBuffer, dst); + } else { + final boolean decrypted = decrypt(streamBuffer); + logger.trace("Decryption after reading those bytes successful = {}", decrypted); + + if (decrypted) { + cloneToApplicationBuffer(destinationBuffer); + logger.trace("Cloned destination buffer to application buffer"); + + return copy(applicationBuffer, dst); + } else { + // Not enough data to decrypt. Compact the buffer so that we keep the data we have + // but prepare the buffer to be written to again. + logger.debug("Not enough data to decrypt. Will need to consume more data before decrypting"); + streamBuffer.compact(); + return 0; + } + } + } finally { + streamBuffer.compact(); + } + } + + private void cloneToApplicationBuffer(final ByteBuffer buffer) { + if (applicationBuffer.capacity() < buffer.remaining()) { + applicationBuffer = ByteBuffer.allocate(buffer.remaining()); + } else { + applicationBuffer.clear(); + } + + applicationBuffer.put(buffer); + applicationBuffer.flip(); + } + + private int copy(final ByteBuffer src, final ByteBuffer dst) { + if (src != null && src.hasRemaining()) { + final int bytesToCopy = Math.min(dst.remaining(), src.remaining()); + if (bytesToCopy < 1) { + return bytesToCopy; + } + + final byte[] buff = new byte[bytesToCopy]; + src.get(buff); + dst.put(buff); + return bytesToCopy; + } + + return 0; + } + + + /** + * Encrypts the given buffer of data, writing the result into {@link #destinationBuffer}. + * @param plaintext the data to encrypt + * @throws IOException if the Peer closes the connection abruptly or if unable to perform the encryption + */ + private void encrypt(final ByteBuffer plaintext) throws IOException { + if (sslEngine == null) { + throw new SSLException("Unable to encrypt message because no SSLEngine has been configured"); + } + + destinationBuffer.clear(); + + while (true) { + final SSLEngineResult result = sslEngine.wrap(plaintext, destinationBuffer); + + switch (result.getStatus()) { + case OK: + destinationBuffer.flip(); + return; + case CLOSED: + throw new IOException("Failed to encrypt data to write to Peer " + peerDescription + " because Peer unexpectedly closed connection"); + case BUFFER_OVERFLOW: + // destinationBuffer is not large enough. Need to increase the size. + final ByteBuffer tempBuffer = ByteBuffer.allocate(destinationBuffer.capacity() + sslEngine.getSession().getApplicationBufferSize()); + destinationBuffer.flip(); + tempBuffer.put(destinationBuffer); + destinationBuffer = tempBuffer; + break; + case BUFFER_UNDERFLOW: + // We should never get this result on a call to SSLEngine.wrap(), only on a call to unwrap(). + throw new IOException("Received unexpected Buffer Underflow result when encrypting data to write to Peer " + peerDescription); + } + } + } + + + + + /** + * Attempts to decrypt the given buffer of data, writing the result into {@link #destinationBuffer}. If successful, will return <code>true</code>. + * If more data is needed in order to perform the decryption, will return <code>false</code>. + * + * @param encrypted the ByteBuffer containing the data to decrypt + * @return <code>true</code> if decryption was successful, <code>false</code> otherwise + * @throws IOException if the Peer closed the connection or if unable to decrypt the message + */ + private boolean decrypt(final ByteBuffer encrypted) throws IOException { + if (sslEngine == null) { + throw new SSLException("Unable to decrypt message because no SSLEngine has been configured"); + } + + destinationBuffer.clear(); + + while (true) { + final SSLEngineResult result = sslEngine.unwrap(encrypted, destinationBuffer); + + switch (result.getStatus()) { + case OK: + destinationBuffer.flip(); + return true; + case CLOSED: + throw new IOException("Failed to decrypt data from Peer " + peerDescription + " because Peer unexpectedly closed connection"); + case BUFFER_OVERFLOW: + // ecnryptedBuffer is not large enough. Need to increase the size. + final ByteBuffer tempBuffer = ByteBuffer.allocate(encrypted.position() + sslEngine.getSession().getApplicationBufferSize()); + destinationBuffer.flip(); + tempBuffer.put(destinationBuffer); + destinationBuffer = tempBuffer; + + break; + case BUFFER_UNDERFLOW: + // Not enough data to decrypt. Must read more from the channel. + return false; + } + } + } + + + public void performHandshake() throws IOException { + if (sslEngine == null) { + return; + } + + sslEngine.beginHandshake(); + + final ByteBuffer emptyMessage = ByteBuffer.allocate(0); + ByteBuffer unwrapBuffer = ByteBuffer.allocate(0); + + while (true) { + final SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); + + switch (handshakeStatus) { + case FINISHED: + case NOT_HANDSHAKING: + streamBuffer.clear(); + destinationBuffer.clear(); + logger.debug("Completed SSL Handshake with Peer {}", peerDescription); + return; + + case NEED_TASK: + logger.debug("SSL Handshake with Peer {} Needs Task", peerDescription); + + Runnable runnable; + while ((runnable = sslEngine.getDelegatedTask()) != null) { + runnable.run(); + } + break; + + case NEED_WRAP: + logger.trace("SSL Handshake with Peer {} Needs Wrap", peerDescription); + + encrypt(emptyMessage); + final int bytesWritten = write(destinationBuffer); + logger.debug("Wrote {} bytes for NEED_WRAP portion of Handshake", bytesWritten); + break; + + case NEED_UNWRAP: + logger.trace("SSL Handshake with Peer {} Needs Unwrap", peerDescription); + + while (sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { + final boolean decrypted = decrypt(unwrapBuffer); + if (decrypted) { + logger.trace("Decryption was successful for NEED_UNWRAP portion of Handshake"); + break; + } + + if (unwrapBuffer.capacity() - unwrapBuffer.position() < 1) { + logger.trace("Enlarging size of Buffer for NEED_UNWRAP portion of Handshake"); + + // destinationBuffer is not large enough. Need to increase the size. + final ByteBuffer tempBuffer = ByteBuffer.allocate(unwrapBuffer.capacity() + sslEngine.getSession().getApplicationBufferSize()); + tempBuffer.put(unwrapBuffer); + unwrapBuffer = tempBuffer; + unwrapBuffer.flip(); + continue; + } + + logger.trace("Need to read more bytes for NEED_UNWRAP portion of Handshake"); + + // Need to read more data. + unwrapBuffer.compact(); + final int bytesRead = socketChannel.read(unwrapBuffer); + unwrapBuffer.flip(); + logger.debug("Read {} bytes for NEED_UNWRAP portion of Handshake", bytesRead); + } + + break; + } + } + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/RegisteredPartition.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/RegisteredPartition.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/RegisteredPartition.java new file mode 100644 index 0000000..e427b21 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/client/async/nio/RegisteredPartition.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.client.async.nio; + +import org.apache.nifi.controller.queue.LoadBalanceCompression; +import org.apache.nifi.controller.queue.clustered.client.async.TransactionCompleteCallback; +import org.apache.nifi.controller.queue.clustered.client.async.TransactionFailureCallback; +import org.apache.nifi.controller.repository.FlowFileRecord; + +import java.util.function.BooleanSupplier; +import java.util.function.Supplier; + +public class RegisteredPartition { + private final String connectionId; + private final Supplier<FlowFileRecord> flowFileRecordSupplier; + private final TransactionFailureCallback failureCallback; + private final BooleanSupplier emptySupplier; + private final TransactionCompleteCallback successCallback; + private final Supplier<LoadBalanceCompression> compressionSupplier; + private final BooleanSupplier honorBackpressureSupplier; + + public RegisteredPartition(final String connectionId, final BooleanSupplier emptySupplier, final Supplier<FlowFileRecord> flowFileSupplier, final TransactionFailureCallback failureCallback, + final TransactionCompleteCallback successCallback, final Supplier<LoadBalanceCompression> compressionSupplier, final BooleanSupplier honorBackpressureSupplier) { + this.connectionId = connectionId; + this.emptySupplier = emptySupplier; + this.flowFileRecordSupplier = flowFileSupplier; + this.failureCallback = failureCallback; + this.successCallback = successCallback; + this.compressionSupplier = compressionSupplier; + this.honorBackpressureSupplier = honorBackpressureSupplier; + } + + public boolean isEmpty() { + return emptySupplier.getAsBoolean(); + } + + public String getConnectionId() { + return connectionId; + } + + public Supplier<FlowFileRecord> getFlowFileRecordSupplier() { + return flowFileRecordSupplier; + } + + public TransactionFailureCallback getFailureCallback() { + return failureCallback; + } + + public TransactionCompleteCallback getSuccessCallback() { + return successCallback; + } + + public LoadBalanceCompression getCompression() { + return compressionSupplier.get(); + } + + public boolean isHonorBackpressure() { + return honorBackpressureSupplier.getAsBoolean(); + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/619f1ffe/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/partition/CorrelationAttributePartitioner.java ---------------------------------------------------------------------- diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/partition/CorrelationAttributePartitioner.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/partition/CorrelationAttributePartitioner.java new file mode 100644 index 0000000..12560d4 --- /dev/null +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/queue/clustered/partition/CorrelationAttributePartitioner.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.controller.queue.clustered.partition; + +import com.google.common.hash.Hashing; +import org.apache.nifi.controller.repository.FlowFileRecord; + +public class CorrelationAttributePartitioner implements FlowFilePartitioner { + private final String partitioningAttribute; + + public CorrelationAttributePartitioner(final String partitioningAttribute) { + this.partitioningAttribute = partitioningAttribute; + } + + @Override + public QueuePartition getPartition(final FlowFileRecord flowFile, final QueuePartition[] partitions, final QueuePartition localPartition) { + final int hash = hash(flowFile); + + // The consistentHash method appears to always return a bucket of '1' if there are 2 possible buckets, + // so in this case we will just use modulo division to avoid this. I suspect this is a bug with the Guava + // implementation, but it's not clear at this point. + final int index; + if (partitions.length < 3) { + index = hash % partitions.length; + } else { + index = Hashing.consistentHash(hash, partitions.length); + } + + return partitions[index]; + } + + protected int hash(final FlowFileRecord flowFile) { + final String partitionAttributeValue = flowFile.getAttribute(partitioningAttribute); + return (partitionAttributeValue == null) ? 0 : partitionAttributeValue.hashCode(); + } + + @Override + public boolean isRebalanceOnClusterResize() { + return true; + } + + @Override + public boolean isRebalanceOnFailure() { + return false; + } +}
