http://git-wip-us.apache.org/repos/asf/nifi/blob/c120c498/nifi-commons/nifi-site-to-site-client/src/main/java/org/apache/nifi/remote/util/SiteToSiteRestApiClient.java ---------------------------------------------------------------------- diff --git a/nifi-commons/nifi-site-to-site-client/src/main/java/org/apache/nifi/remote/util/SiteToSiteRestApiClient.java b/nifi-commons/nifi-site-to-site-client/src/main/java/org/apache/nifi/remote/util/SiteToSiteRestApiClient.java new file mode 100644 index 0000000..4195ae9 --- /dev/null +++ b/nifi-commons/nifi-site-to-site-client/src/main/java/org/apache/nifi/remote/util/SiteToSiteRestApiClient.java @@ -0,0 +1,992 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.remote.util; + +import org.apache.commons.lang3.StringUtils; +import org.apache.http.Header; +import org.apache.http.HttpEntity; +import org.apache.http.HttpException; +import org.apache.http.HttpHost; +import org.apache.http.HttpInetConnection; +import org.apache.http.HttpRequest; +import org.apache.http.HttpResponse; +import org.apache.http.HttpResponseInterceptor; +import org.apache.http.StatusLine; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpDelete; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.client.utils.URIUtils; +import org.apache.http.conn.ManagedHttpClientConnection; +import org.apache.http.entity.BasicHttpEntity; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.nio.client.CloseableHttpAsyncClient; +import org.apache.http.impl.nio.client.HttpAsyncClientBuilder; +import org.apache.http.impl.nio.client.HttpAsyncClients; +import org.apache.http.nio.ContentEncoder; +import org.apache.http.nio.IOControl; +import org.apache.http.nio.conn.ManagedNHttpClientConnection; +import org.apache.http.nio.protocol.BasicAsyncResponseConsumer; +import org.apache.http.nio.protocol.HttpAsyncRequestProducer; +import org.apache.http.protocol.HttpContext; +import org.apache.http.protocol.HttpCoreContext; +import org.apache.http.util.EntityUtils; +import org.apache.nifi.remote.TransferDirection; +import org.apache.nifi.remote.client.http.TransportProtocolVersionNegotiator; +import org.apache.nifi.remote.exception.PortNotRunningException; +import org.apache.nifi.remote.exception.ProtocolException; +import org.apache.nifi.remote.exception.UnknownPortException; +import org.apache.nifi.remote.io.http.HttpCommunicationsSession; +import org.apache.nifi.remote.io.http.HttpInput; +import org.apache.nifi.remote.io.http.HttpOutput; +import org.apache.nifi.remote.protocol.CommunicationsSession; +import org.apache.nifi.remote.protocol.ResponseCode; +import org.apache.nifi.remote.protocol.http.HttpHeaders; +import org.apache.nifi.remote.protocol.http.HttpProxy; +import org.apache.nifi.security.util.CertificateUtils; +import org.apache.nifi.stream.io.ByteArrayInputStream; +import org.apache.nifi.stream.io.ByteArrayOutputStream; +import org.apache.nifi.stream.io.StreamUtils; +import org.apache.nifi.web.api.dto.ControllerDTO; +import org.apache.nifi.web.api.dto.remote.PeerDTO; +import org.apache.nifi.web.api.entity.ControllerEntity; +import org.apache.nifi.web.api.entity.PeersEntity; +import org.apache.nifi.web.api.entity.TransactionResultEntity; +import org.codehaus.jackson.JsonParseException; +import org.codehaus.jackson.map.DeserializationConfig; +import org.codehaus.jackson.map.JsonMappingException; +import org.codehaus.jackson.map.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import java.io.Closeable; +import java.io.IOException; +import java.io.InputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.regex.Pattern; + +import static org.apache.commons.lang3.StringUtils.isEmpty; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.HANDSHAKE_PROPERTY_BATCH_COUNT; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.HANDSHAKE_PROPERTY_BATCH_DURATION; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.HANDSHAKE_PROPERTY_BATCH_SIZE; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.HANDSHAKE_PROPERTY_REQUEST_EXPIRATION; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.HANDSHAKE_PROPERTY_USE_COMPRESSION; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_HEADER_NAME; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE; + +public class SiteToSiteRestApiClient implements Closeable { + + private static final int RESPONSE_CODE_OK = 200; + private static final int RESPONSE_CODE_CREATED = 201; + private static final int RESPONSE_CODE_ACCEPTED = 202; + private static final int RESPONSE_CODE_SEE_OTHER = 303; + private static final int RESPONSE_CODE_BAD_REQUEST = 400; + private static final int RESPONSE_CODE_UNAUTHORIZED = 401; + private static final int RESPONSE_CODE_NOT_FOUND = 404; + private static final int RESPONSE_CODE_SERVICE_UNAVAILABLE = 503; + + private static final Logger logger = LoggerFactory.getLogger(SiteToSiteRestApiClient.class); + + private String baseUrl; + protected final SSLContext sslContext; + protected final HttpProxy proxy; + private RequestConfig requestConfig; + private CredentialsProvider credentialsProvider; + private CloseableHttpClient httpClient; + private CloseableHttpAsyncClient httpAsyncClient; + + private boolean compress = false; + private long requestExpirationMillis = 0; + private int serverTransactionTtl = 0; + private int batchCount = 0; + private long batchSize = 0; + private long batchDurationMillis = 0; + private TransportProtocolVersionNegotiator transportProtocolVersionNegotiator = new TransportProtocolVersionNegotiator(1); + + private String trustedPeerDn; + private final ScheduledExecutorService ttlExtendTaskExecutor; + private ScheduledFuture<?> ttlExtendingThread; + private SiteToSiteRestApiClient extendingApiClient; + + private int connectTimeoutMillis; + private int readTimeoutMillis; + private static final Pattern HTTP_ABS_URL = Pattern.compile("^https?://.+$"); + + public SiteToSiteRestApiClient(final SSLContext sslContext, final HttpProxy proxy) { + this.sslContext = sslContext; + this.proxy = proxy; + ttlExtendTaskExecutor = Executors.newScheduledThreadPool(1, new ThreadFactory() { + private final ThreadFactory defaultFactory = Executors.defaultThreadFactory(); + + @Override + public Thread newThread(final Runnable r) { + final Thread thread = defaultFactory.newThread(r); + thread.setName(Thread.currentThread().getName() + " TTLExtend"); + return thread; + } + }); + } + + @Override + public void close() throws IOException { + stopExtendingTtl(); + closeSilently(httpClient); + closeSilently(httpAsyncClient); + } + + private CloseableHttpClient getHttpClient() { + if (httpClient == null) { + setupClient(); + } + return httpClient; + } + + private CloseableHttpAsyncClient getHttpAsyncClient() { + if (httpAsyncClient == null) { + setupAsyncClient(); + } + return httpAsyncClient; + } + + private RequestConfig getRequestConfig() { + if (requestConfig == null) { + setupRequestConfig(); + } + return requestConfig; + } + + private CredentialsProvider getCredentialsProvider() { + if (credentialsProvider == null) { + setupCredentialsProvider(); + } + return credentialsProvider; + } + + private void setupRequestConfig() { + final RequestConfig.Builder requestConfigBuilder = RequestConfig.custom() + .setConnectionRequestTimeout(connectTimeoutMillis) + .setConnectTimeout(connectTimeoutMillis) + .setSocketTimeout(readTimeoutMillis); + + if (proxy != null) { + requestConfigBuilder.setProxy(proxy.getHttpHost()); + } + + requestConfig = requestConfigBuilder.build(); + } + + private void setupCredentialsProvider() { + credentialsProvider = new BasicCredentialsProvider(); + if (proxy != null) { + if (!isEmpty(proxy.getUsername()) && !isEmpty(proxy.getPassword())) { + credentialsProvider.setCredentials( + new AuthScope(proxy.getHttpHost()), + new UsernamePasswordCredentials(proxy.getUsername(), proxy.getPassword())); + } + + } + } + + private void setupClient() { + final HttpClientBuilder clientBuilder = HttpClients.custom(); + + if (sslContext != null) { + clientBuilder.setSslcontext(sslContext); + clientBuilder.addInterceptorFirst(new HttpsResponseInterceptor()); + } + + httpClient = clientBuilder + .setDefaultCredentialsProvider(getCredentialsProvider()).build(); + } + + private void setupAsyncClient() { + final HttpAsyncClientBuilder clientBuilder = HttpAsyncClients.custom(); + + if (sslContext != null) { + clientBuilder.setSSLContext(sslContext); + clientBuilder.addInterceptorFirst(new HttpsResponseInterceptor()); + } + + httpAsyncClient = clientBuilder.setDefaultCredentialsProvider(getCredentialsProvider()).build(); + httpAsyncClient.start(); + } + + private class HttpsResponseInterceptor implements HttpResponseInterceptor { + @Override + public void process(final HttpResponse response, final HttpContext httpContext) throws HttpException, IOException { + final HttpCoreContext coreContext = HttpCoreContext.adapt(httpContext); + final HttpInetConnection conn = coreContext.getConnection(HttpInetConnection.class); + if (!conn.isOpen()) { + return; + } + + final SSLSession sslSession; + if (conn instanceof ManagedHttpClientConnection) { + sslSession = ((ManagedHttpClientConnection)conn).getSSLSession(); + } else if (conn instanceof ManagedNHttpClientConnection) { + sslSession = ((ManagedNHttpClientConnection)conn).getSSLSession(); + } else { + throw new RuntimeException("Unexpected connection type was used, " + conn); + } + + + if (sslSession != null) { + final Certificate[] certChain = sslSession.getPeerCertificates(); + if (certChain == null || certChain.length == 0) { + throw new SSLPeerUnverifiedException("No certificates found"); + } + + try { + final X509Certificate cert = CertificateUtils.convertAbstractX509Certificate(certChain[0]); + trustedPeerDn = cert.getSubjectDN().getName().trim(); + } catch (CertificateException e) { + final String msg = "Could not extract subject DN from SSL session peer certificate"; + logger.warn(msg); + throw new SSLPeerUnverifiedException(msg); + } + } + } + } + + public ControllerDTO getController() throws IOException { + try { + HttpGet get = createGet("/site-to-site"); + get.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + return execute(get, ControllerEntity.class).getController(); + + } catch (HttpGetFailedException e) { + if (RESPONSE_CODE_NOT_FOUND == e.getResponseCode()) { + logger.debug("getController received NOT_FOUND, trying to access the old NiFi version resource url..."); + HttpGet get = createGet("/controller"); + return execute(get, ControllerEntity.class).getController(); + } + throw e; + } + } + + public Collection<PeerDTO> getPeers() throws IOException { + HttpGet get = createGet("/site-to-site/peers"); + get.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + return execute(get, PeersEntity.class).getPeers(); + } + + public String initiateTransaction(TransferDirection direction, String portId) throws IOException { + if (TransferDirection.RECEIVE.equals(direction)) { + return initiateTransaction("output-ports", portId); + } else { + return initiateTransaction("input-ports", portId); + } + } + + private String initiateTransaction(String portType, String portId) throws IOException { + logger.debug("initiateTransaction handshaking portType={}, portId={}", portType, portId); + HttpPost post = createPost("/site-to-site/" + portType + "/" + portId + "/transactions"); + + + post.setHeader("Accept", "application/json"); + post.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + + setHandshakeProperties(post); + + try (CloseableHttpResponse response = getHttpClient().execute(post)) { + int responseCode = response.getStatusLine().getStatusCode(); + logger.debug("initiateTransaction responseCode={}", responseCode); + + String transactionUrl; + switch (responseCode) { + case RESPONSE_CODE_CREATED : + EntityUtils.consume(response.getEntity()); + + transactionUrl = readTransactionUrl(response); + if (isEmpty(transactionUrl)) { + throw new ProtocolException("Server returned RESPONSE_CODE_CREATED without Location header"); + } + Header transportProtocolVersionHeader = response.getFirstHeader(HttpHeaders.PROTOCOL_VERSION); + if (transportProtocolVersionHeader == null) { + throw new ProtocolException("Server didn't return confirmed protocol version"); + } + Integer protocolVersionConfirmedByServer = Integer.valueOf(transportProtocolVersionHeader.getValue()); + logger.debug("Finished version negotiation, protocolVersionConfirmedByServer={}", protocolVersionConfirmedByServer); + transportProtocolVersionNegotiator.setVersion(protocolVersionConfirmedByServer); + + Header serverTransactionTtlHeader = response.getFirstHeader(HttpHeaders.SERVER_SIDE_TRANSACTION_TTL); + if (serverTransactionTtlHeader == null) { + throw new ProtocolException("Server didn't return " + HttpHeaders.SERVER_SIDE_TRANSACTION_TTL); + } + serverTransactionTtl = Integer.parseInt(serverTransactionTtlHeader.getValue()); + break; + + default: + try (InputStream content = response.getEntity().getContent()) { + throw handleErrResponse(responseCode, content); + } + } + logger.debug("initiateTransaction handshaking finished, transactionUrl={}", transactionUrl); + return transactionUrl; + } + + } + + public boolean openConnectionForReceive(String transactionUrl, CommunicationsSession commSession) throws IOException { + + HttpGet get = createGet(transactionUrl + "/flow-files"); + get.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + + setHandshakeProperties(get); + + CloseableHttpResponse response = getHttpClient().execute(get); + int responseCode = response.getStatusLine().getStatusCode(); + logger.debug("responseCode={}", responseCode); + + boolean keepItOpen = false; + try { + switch (responseCode) { + case RESPONSE_CODE_OK : + logger.debug("Server returned RESPONSE_CODE_OK, indicating there was no data."); + EntityUtils.consume(response.getEntity()); + return false; + + case RESPONSE_CODE_ACCEPTED : + InputStream httpIn = response.getEntity().getContent(); + InputStream streamCapture = new InputStream() { + boolean closed = false; + @Override + public int read() throws IOException { + if(closed) return -1; + int r = httpIn.read(); + if (r < 0) { + closed = true; + logger.debug("Reached to end of input stream. Closing resources..."); + stopExtendingTtl(); + closeSilently(httpIn); + closeSilently(response); + } + return r; + } + }; + ((HttpInput)commSession.getInput()).setInputStream(streamCapture); + + startExtendingTtl(transactionUrl, httpIn, response); + keepItOpen = true; + return true; + + default: + try (InputStream content = response.getEntity().getContent()) { + throw handleErrResponse(responseCode, content); + } + } + } finally { + if (!keepItOpen) { + response.close(); + } + } + } + + private final int DATA_PACKET_CHANNEL_READ_BUFFER_SIZE = 16384; + private Future<HttpResponse> postResult; + private CountDownLatch transferDataLatch = new CountDownLatch(1); + public void openConnectionForSend(String transactionUrl, CommunicationsSession commSession) throws IOException { + + final String flowFilesPath = transactionUrl + "/flow-files"; + HttpPost post = createPost(flowFilesPath); + + post.setHeader("Content-Type", "application/octet-stream"); + post.setHeader("Accept", "text/plain"); + post.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + + setHandshakeProperties(post); + + CountDownLatch initConnectionLatch = new CountDownLatch(1); + + final URI requestUri = post.getURI(); + final PipedOutputStream outputStream = new PipedOutputStream(); + final PipedInputStream inputStream = new PipedInputStream(outputStream, DATA_PACKET_CHANNEL_READ_BUFFER_SIZE); + final ReadableByteChannel dataPacketChannel = Channels.newChannel(inputStream); + final HttpAsyncRequestProducer asyncRequestProducer = new HttpAsyncRequestProducer() { + + private final ByteBuffer buffer = ByteBuffer.allocate(DATA_PACKET_CHANNEL_READ_BUFFER_SIZE); + + @Override + public HttpHost getTarget() { + return URIUtils.extractHost(requestUri); + } + + @Override + public HttpRequest generateRequest() throws IOException, HttpException { + + // Pass the output stream so that Site-to-Site client thread can send + // data packet through this connection. + logger.debug("sending data to {} has started...", flowFilesPath); + ((HttpOutput)commSession.getOutput()).setOutputStream(outputStream); + initConnectionLatch.countDown(); + + final BasicHttpEntity entity = new BasicHttpEntity(); + entity.setChunked(true); + entity.setContentType("application/octet-stream"); + post.setEntity(entity); + return post; + } + + @Override + public void produceContent(ContentEncoder encoder, IOControl ioControl) throws IOException { + + int totalRead = 0; + int totalProduced = 0; + int read; + // This read() blocks until data becomes available, + // or corresponding outputStream is closed. + while ((read = dataPacketChannel.read(buffer)) > -1) { + + buffer.flip(); + while (buffer.hasRemaining()) { + totalProduced += encoder.write(buffer); + } + buffer.clear(); + logger.trace("Read {} bytes from dataPacketChannel. {}", read, flowFilesPath); + totalRead += read; + + } + + // There might be remaining bytes in buffer. Make sure it's fully drained. + buffer.flip(); + while (buffer.hasRemaining()) { + totalProduced += encoder.write(buffer); + } + + final long totalWritten = commSession.getOutput().getBytesWritten(); + logger.debug("sending data to {} has reached to its end. produced {} bytes by reading {} bytes from channel. {} bytes written in this transaction.", + flowFilesPath, totalProduced, totalRead, totalWritten); + if (totalRead != totalWritten || totalProduced != totalWritten) { + final String msg = "Sending data to %s has reached to its end, but produced : read : wrote byte sizes (%d : $d : %d) were not equal. Something went wrong."; + throw new RuntimeException(String.format(msg, flowFilesPath, totalProduced, totalRead, totalWritten)); + } + transferDataLatch.countDown(); + encoder.complete(); + dataPacketChannel.close(); + + } + + @Override + public void requestCompleted(HttpContext context) { + logger.debug("Sending data to {} completed.", flowFilesPath); + } + + @Override + public void failed(Exception ex) { + logger.error("Sending data to {} has failed", flowFilesPath, ex); + } + + @Override + public boolean isRepeatable() { + // In order to pass authentication, request has to be repeatable. + return true; + } + + @Override + public void resetRequest() throws IOException { + logger.debug("Sending data request to {} has been reset...", flowFilesPath); + } + + @Override + public void close() throws IOException { + logger.debug("Closing sending data request to {}", flowFilesPath); + closeSilently(outputStream); + closeSilently(dataPacketChannel); + stopExtendingTtl(); + } + }; + + postResult = getHttpAsyncClient().execute(asyncRequestProducer, new BasicAsyncResponseConsumer(), null); + + try { + // Need to wait the post request actually started so that we can write to its output stream. + if (!initConnectionLatch.await(connectTimeoutMillis, TimeUnit.MILLISECONDS)) { + throw new IOException("Awaiting initConnectionLatch has been timeout."); + } + + // Started. + transferDataLatch = new CountDownLatch(1); + startExtendingTtl(transactionUrl, dataPacketChannel, null); + + } catch (InterruptedException e) { + throw new IOException("Awaiting initConnectionLatch has been interrupted.", e); + } + + } + + public void finishTransferFlowFiles(CommunicationsSession commSession) throws IOException { + + if (postResult == null) { + new IllegalStateException("Data transfer has not started yet."); + } + + // No more data can be sent. + // Close PipedOutputStream so that dataPacketChannel doesn't blocked. + // If we don't close this output stream, then PipedInputStream loops infinitely at read(). + commSession.getOutput().getOutputStream().close(); + logger.debug("{} FinishTransferFlowFiles no more data can be sent", this); + + try { + if (!transferDataLatch.await(requestExpirationMillis, TimeUnit.MILLISECONDS)) { + throw new IOException("Awaiting transferDataLatch has been timeout."); + } + } catch (InterruptedException e) { + throw new IOException("Awaiting transferDataLatch has been interrupted.", e); + } + + stopExtendingTtl(); + + final HttpResponse response; + try { + response = postResult.get(readTimeoutMillis, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + logger.debug("Something has happened at sending thread. {}", e.getMessage()); + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + throw (IOException) cause; + } else { + throw new IOException(cause); + } + } catch (TimeoutException|InterruptedException e) { + throw new IOException(e); + } + + int responseCode = response.getStatusLine().getStatusCode(); + switch (responseCode) { + case RESPONSE_CODE_ACCEPTED : + String receivedChecksum = EntityUtils.toString(response.getEntity()); + ((HttpInput)commSession.getInput()).setInputStream(new ByteArrayInputStream(receivedChecksum.getBytes())); + ((HttpCommunicationsSession)commSession).setChecksum(receivedChecksum); + logger.debug("receivedChecksum={}", receivedChecksum); + break; + + default: + try (InputStream content = response.getEntity().getContent()) { + throw handleErrResponse(responseCode, content); + } + } + } + + private void startExtendingTtl(final String transactionUrl, final Closeable stream, final CloseableHttpResponse response) { + if (ttlExtendingThread != null) { + // Already started. + return; + } + logger.debug("Starting extending TTL thread..."); + extendingApiClient = new SiteToSiteRestApiClient(sslContext, proxy); + extendingApiClient.transportProtocolVersionNegotiator = this.transportProtocolVersionNegotiator; + extendingApiClient.connectTimeoutMillis = this.connectTimeoutMillis; + extendingApiClient.readTimeoutMillis = this.readTimeoutMillis; + int extendFrequency = serverTransactionTtl / 2; + ttlExtendingThread = ttlExtendTaskExecutor.scheduleWithFixedDelay(() -> { + try { + extendingApiClient.extendTransaction(transactionUrl); + } catch (Exception e) { + logger.warn("Failed to extend transaction ttl", e); + try { + // Without disconnecting, Site-to-Site client keep reading data packet, + // while server has already rollback. + this.close(); + } catch (IOException ec) { + logger.warn("Failed to close", e); + } + } + }, extendFrequency, extendFrequency, TimeUnit.SECONDS); + } + + private void closeSilently(final Closeable closeable) { + try { + if (closeable != null) { + closeable.close(); + } + } catch (IOException e) { + logger.warn("Got an exception during closing {}: {}", closeable, e.getMessage()); + if (logger.isDebugEnabled()) { + logger.warn("", e); + } + } + } + + public TransactionResultEntity extendTransaction(String transactionUrl) throws IOException { + logger.debug("Sending extendTransaction request to transactionUrl: {}", transactionUrl); + + final HttpPut put = createPut(transactionUrl); + + put.setHeader("Accept", "application/json"); + put.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + + setHandshakeProperties(put); + + try (CloseableHttpResponse response = getHttpClient().execute(put)) { + int responseCode = response.getStatusLine().getStatusCode(); + logger.debug("extendTransaction responseCode={}", responseCode); + + try (InputStream content = response.getEntity().getContent()) { + switch (responseCode) { + case RESPONSE_CODE_OK : + return readResponse(content); + + default: + throw handleErrResponse(responseCode, content); + } + } + } + + } + + private void stopExtendingTtl() { + if (!ttlExtendTaskExecutor.isShutdown()) { + ttlExtendTaskExecutor.shutdown(); + } + + if (ttlExtendingThread != null && !ttlExtendingThread.isCancelled()) { + logger.debug("Cancelling extending ttl..."); + ttlExtendingThread.cancel(true); + } + + closeSilently(extendingApiClient); + } + + private IOException handleErrResponse(final int responseCode, final InputStream in) throws IOException { + if(in == null) { + return new IOException("Unexpected response code: " + responseCode); + } + TransactionResultEntity errEntity = readResponse(in); + ResponseCode errCode = ResponseCode.fromCode(errEntity.getResponseCode()); + switch (errCode) { + case UNKNOWN_PORT: + return new UnknownPortException(errEntity.getMessage()); + case PORT_NOT_IN_VALID_STATE: + return new PortNotRunningException(errEntity.getMessage()); + default: + return new IOException("Unexpected response code: " + responseCode + + " errCode:" + errCode + " errMessage:" + errEntity.getMessage()); + } + } + + private TransactionResultEntity readResponse(InputStream inputStream) throws IOException { + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + StreamUtils.copy(inputStream, bos); + String responseMessage = null; + try { + responseMessage = new String(bos.toByteArray(), "UTF-8"); + logger.debug("readResponse responseMessage={}", responseMessage); + + final ObjectMapper mapper = new ObjectMapper(); + return mapper.readValue(responseMessage, TransactionResultEntity.class); + + } catch (JsonParseException | JsonMappingException e) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to parse JSON.", e); + } + TransactionResultEntity entity = new TransactionResultEntity(); + entity.setResponseCode(ResponseCode.ABORT.getCode()); + entity.setMessage(responseMessage); + return entity; + } + } + + private String readTransactionUrl(final CloseableHttpResponse response) { + final Header locationUriIntentHeader = response.getFirstHeader(LOCATION_URI_INTENT_NAME); + logger.debug("locationUriIntentHeader={}", locationUriIntentHeader); + if (locationUriIntentHeader != null) { + if (LOCATION_URI_INTENT_VALUE.equals(locationUriIntentHeader.getValue())) { + Header transactionUrl = response.getFirstHeader(LOCATION_HEADER_NAME); + logger.debug("transactionUrl={}", transactionUrl); + if (transactionUrl != null) { + return transactionUrl.getValue(); + } + } + } + return null; + } + + private void setHandshakeProperties(final HttpRequestBase httpRequest) { + if(compress) httpRequest.setHeader(HANDSHAKE_PROPERTY_USE_COMPRESSION, "true"); + if(requestExpirationMillis > 0) httpRequest.setHeader(HANDSHAKE_PROPERTY_REQUEST_EXPIRATION, String.valueOf(requestExpirationMillis)); + if(batchCount > 0) httpRequest.setHeader(HANDSHAKE_PROPERTY_BATCH_COUNT, String.valueOf(batchCount)); + if(batchSize > 0) httpRequest.setHeader(HANDSHAKE_PROPERTY_BATCH_SIZE, String.valueOf(batchSize)); + if(batchDurationMillis > 0) httpRequest.setHeader(HANDSHAKE_PROPERTY_BATCH_DURATION, String.valueOf(batchDurationMillis)); + } + + private HttpGet createGet(final String path) { + final URI url = getUri(path); + HttpGet get = new HttpGet(url); + get.setConfig(getRequestConfig()); + return get; + } + + private URI getUri(String path) { + final URI url; + try { + if(HTTP_ABS_URL.matcher(path).find()){ + url = new URI(path); + } else { + if(StringUtils.isEmpty(getBaseUrl())){ + throw new IllegalStateException("API baseUrl is not resolved yet, call setBaseUrl or resolveBaseUrl before sending requests with relative path."); + } + url = new URI(baseUrl + path); + } + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e.getMessage()); + } + return url; + } + + private HttpPost createPost(final String path) { + final URI url = getUri(path); + HttpPost post = new HttpPost(url); + post.setConfig(getRequestConfig()); + return post; + } + + private HttpPut createPut(final String path) { + final URI url = getUri(path); + HttpPut put = new HttpPut(url); + put.setConfig(getRequestConfig()); + return put; + } + + private HttpDelete createDelete(final String path) { + final URI url = getUri(path); + HttpDelete delete = new HttpDelete(url); + delete.setConfig(getRequestConfig()); + return delete; + } + + private String execute(final HttpGet get) throws IOException { + + CloseableHttpClient httpClient = getHttpClient(); + try (CloseableHttpResponse response = httpClient.execute(get)) { + StatusLine statusLine = response.getStatusLine(); + int statusCode = statusLine.getStatusCode(); + if (RESPONSE_CODE_OK != statusCode) { + throw new HttpGetFailedException(statusCode, statusLine.getReasonPhrase(), null); + } + HttpEntity entity = response.getEntity(); + String responseMessage = EntityUtils.toString(entity); + return responseMessage; + } + } + + public class HttpGetFailedException extends IOException { + private final int responseCode; + private final String responseMessage; + private final String explanation; + public HttpGetFailedException(final int responseCode, final String responseMessage, final String explanation) { + super("response code " + responseCode + ":" + responseMessage + " with explanation: " + explanation); + this.responseCode = responseCode; + this.responseMessage = responseMessage; + this.explanation = explanation; + } + + public int getResponseCode() { + return responseCode; + } + + public String getDescription() { + return !isEmpty(explanation) ? explanation : responseMessage; + } + } + + + private <T> T execute(final HttpGet get, final Class<T> entityClass) throws IOException { + get.setHeader("Accept", "application/json"); + final String responseMessage = execute(get); + + final ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false); + return mapper.readValue(responseMessage, entityClass); + } + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public void setConnectTimeoutMillis(int connectTimeoutMillis) { + this.connectTimeoutMillis = connectTimeoutMillis; + } + + public void setReadTimeoutMillis(int readTimeoutMillis) { + this.readTimeoutMillis = readTimeoutMillis; + } + + public String resolveBaseUrl(String clusterUrl) { + URI clusterUri; + try { + clusterUri = new URI(clusterUrl); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Specified clusterUrl was: " + clusterUrl, e); + } + return this.resolveBaseUrl(clusterUri); + } + + public String resolveBaseUrl(URI clusterUrl) { + String urlPath = clusterUrl.getPath(); + if (urlPath.endsWith("/")) { + urlPath = urlPath.substring(0, urlPath.length() - 1); + } + return resolveBaseUrl(clusterUrl.getScheme(), clusterUrl.getHost(), clusterUrl.getPort(), urlPath + "-api"); + } + + public String resolveBaseUrl(final String scheme, final String host, final int port) { + return resolveBaseUrl(scheme, host, port, "/nifi-api"); + } + + public String resolveBaseUrl(final String scheme, final String host, final int port, String path) { + String baseUri = scheme + "://" + host + ":" + port + path; + this.setBaseUrl(baseUri); + return baseUri; + } + + public void setCompress(boolean compress) { + this.compress = compress; + } + + public void setRequestExpirationMillis(long requestExpirationMillis) { + if(requestExpirationMillis < 0) throw new IllegalArgumentException("requestExpirationMillis can't be a negative value."); + this.requestExpirationMillis = requestExpirationMillis; + } + + public void setBatchCount(int batchCount) { + if(batchCount < 0) throw new IllegalArgumentException("batchCount can't be a negative value."); + this.batchCount = batchCount; + } + + public void setBatchSize(long batchSize) { + if(batchSize < 0) throw new IllegalArgumentException("batchSize can't be a negative value."); + this.batchSize = batchSize; + } + + public void setBatchDurationMillis(long batchDurationMillis) { + if(batchDurationMillis < 0) throw new IllegalArgumentException("batchDurationMillis can't be a negative value."); + this.batchDurationMillis = batchDurationMillis; + } + + public Integer getTransactionProtocolVersion() { + return transportProtocolVersionNegotiator.getTransactionProtocolVersion(); + } + + public String getTrustedPeerDn() { + return this.trustedPeerDn; + } + + public TransactionResultEntity commitReceivingFlowFiles(String transactionUrl, ResponseCode clientResponse, String checksum) throws IOException { + logger.debug("Sending commitReceivingFlowFiles request to transactionUrl: {}, clientResponse={}, checksum={}", + transactionUrl, clientResponse, checksum); + + stopExtendingTtl(); + + StringBuilder urlBuilder = new StringBuilder(transactionUrl).append("?responseCode=").append(clientResponse.getCode()); + if (ResponseCode.CONFIRM_TRANSACTION.equals(clientResponse)) { + urlBuilder.append("&checksum=").append(checksum); + } + + HttpDelete delete = createDelete(urlBuilder.toString()); + delete.setHeader("Accept", "application/json"); + delete.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + + setHandshakeProperties(delete); + + try (CloseableHttpResponse response = getHttpClient().execute(delete)) { + int responseCode = response.getStatusLine().getStatusCode(); + logger.debug("commitReceivingFlowFiles responseCode={}", responseCode); + + try (InputStream content = response.getEntity().getContent()) { + switch (responseCode) { + case RESPONSE_CODE_OK : + return readResponse(content); + + case RESPONSE_CODE_BAD_REQUEST : + return readResponse(content); + + default: + throw handleErrResponse(responseCode, content); + } + } + } + + } + + public TransactionResultEntity commitTransferFlowFiles(String transactionUrl, ResponseCode clientResponse) throws IOException { + String requestUrl = transactionUrl + "?responseCode=" + clientResponse.getCode(); + logger.debug("Sending commitTransferFlowFiles request to transactionUrl: {}", requestUrl); + + HttpDelete delete = createDelete(requestUrl); + delete.setHeader("Accept", "application/json"); + delete.setHeader(HttpHeaders.PROTOCOL_VERSION, String.valueOf(transportProtocolVersionNegotiator.getVersion())); + + setHandshakeProperties(delete); + + try (CloseableHttpResponse response = getHttpClient().execute(delete)) { + int responseCode = response.getStatusLine().getStatusCode(); + logger.debug("commitTransferFlowFiles responseCode={}", responseCode); + + try (InputStream content = response.getEntity().getContent()) { + switch (responseCode) { + case RESPONSE_CODE_OK : + return readResponse(content); + + case RESPONSE_CODE_BAD_REQUEST : + return readResponse(content); + + default: + throw handleErrResponse(responseCode, content); + } + } + } + + } + +}
http://git-wip-us.apache.org/repos/asf/nifi/blob/c120c498/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/TestPeerSelector.java ---------------------------------------------------------------------- diff --git a/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/TestPeerSelector.java b/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/TestPeerSelector.java new file mode 100644 index 0000000..ca820f8 --- /dev/null +++ b/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/TestPeerSelector.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.remote.client; + +import org.apache.nifi.remote.PeerDescription; +import org.apache.nifi.remote.PeerStatus; +import org.apache.nifi.remote.TransferDirection; +import org.junit.Test; +import org.mockito.Mockito; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.reducing; +import static java.util.stream.Collectors.toMap; +import static org.junit.Assert.assertTrue; + +public class TestPeerSelector { + + private static final Logger logger = LoggerFactory.getLogger(TestPeerSelector.class); + + private Map<String, Integer> calculateAverageSelectedCount(Set<PeerStatus> collection, List<PeerStatus> destinations) { + // Calculate hostname entry, for average calculation. Because there're multiple entry with same host name, different port. + final Map<String, Integer> hostNameCounts + = collection.stream().collect(groupingBy(p -> p.getPeerDescription().getHostname(), reducing(0, p -> 1, Integer::sum))); + + // Calculate how many times each hostname is selected. + return destinations.stream().collect(groupingBy(p -> p.getPeerDescription().getHostname(), reducing(0, p -> 1, Integer::sum))) + .entrySet().stream().collect(toMap(Map.Entry::getKey, e -> { + return e.getValue() / hostNameCounts.get(e.getKey()); + })); + } + + @Test + public void testFormulateDestinationListForOutput() throws IOException { + final Set<PeerStatus> collection = new HashSet<>(); + collection.add(new PeerStatus(new PeerDescription("HasMedium", 1111, true), 4096)); + collection.add(new PeerStatus(new PeerDescription("HasLots", 2222, true), 10240)); + collection.add(new PeerStatus(new PeerDescription("HasLittle", 3333, true), 1024)); + collection.add(new PeerStatus(new PeerDescription("HasMedium", 4444, true), 4096)); + collection.add(new PeerStatus(new PeerDescription("HasMedium", 5555, true), 4096)); + + PeerStatusProvider peerStatusProvider = Mockito.mock(PeerStatusProvider.class); + PeerSelector peerSelector = new PeerSelector(peerStatusProvider, null); + + final List<PeerStatus> destinations = peerSelector.formulateDestinationList(collection, TransferDirection.RECEIVE); + final Map<String, Integer> selectedCounts = calculateAverageSelectedCount(collection, destinations); + + logger.info("selectedCounts={}", selectedCounts); + assertTrue("HasLots should send lots", selectedCounts.get("HasLots") > selectedCounts.get("HasMedium")); + assertTrue("HasMedium should send medium", selectedCounts.get("HasMedium") > selectedCounts.get("HasLittle")); + } + + @Test + public void testFormulateDestinationListForOutputHugeDifference() throws IOException { + final Set<PeerStatus> collection = new HashSet<>(); + collection.add(new PeerStatus(new PeerDescription("HasLittle", 1111, true), 500)); + collection.add(new PeerStatus(new PeerDescription("HasLots", 2222, true), 50000)); + + PeerStatusProvider peerStatusProvider = Mockito.mock(PeerStatusProvider.class); + PeerSelector peerSelector = new PeerSelector(peerStatusProvider, null); + + final List<PeerStatus> destinations = peerSelector.formulateDestinationList(collection, TransferDirection.RECEIVE); + final Map<String, Integer> selectedCounts = calculateAverageSelectedCount(collection, destinations); + + logger.info("selectedCounts={}", selectedCounts); + assertTrue("HasLots should send lots", selectedCounts.get("HasLots") > selectedCounts.get("HasLittle")); + } + + @Test + public void testFormulateDestinationListForInputPorts() throws IOException { + final Set<PeerStatus> collection = new HashSet<>(); + collection.add(new PeerStatus(new PeerDescription("HasMedium", 1111, true), 4096)); + collection.add(new PeerStatus(new PeerDescription("HasLittle", 2222, true), 10240)); + collection.add(new PeerStatus(new PeerDescription("HasLots", 3333, true), 1024)); + collection.add(new PeerStatus(new PeerDescription("HasMedium", 4444, true), 4096)); + collection.add(new PeerStatus(new PeerDescription("HasMedium", 5555, true), 4096)); + + PeerStatusProvider peerStatusProvider = Mockito.mock(PeerStatusProvider.class); + PeerSelector peerSelector = new PeerSelector(peerStatusProvider, null); + + final List<PeerStatus> destinations = peerSelector.formulateDestinationList(collection, TransferDirection.RECEIVE); + final Map<String, Integer> selectedCounts = calculateAverageSelectedCount(collection, destinations); + + logger.info("selectedCounts={}", selectedCounts); + assertTrue("HasLots should get little", selectedCounts.get("HasLots") < selectedCounts.get("HasMedium")); + assertTrue("HasMedium should get medium", selectedCounts.get("HasMedium") < selectedCounts.get("HasLittle")); + } + + @Test + public void testFormulateDestinationListForInputPortsHugeDifference() throws IOException { + final Set<PeerStatus> collection = new HashSet<>(); + collection.add(new PeerStatus(new PeerDescription("HasLots", 1111, true), 500)); + collection.add(new PeerStatus(new PeerDescription("HasLittle", 2222, true), 50000)); + + PeerStatusProvider peerStatusProvider = Mockito.mock(PeerStatusProvider.class); + PeerSelector peerSelector = new PeerSelector(peerStatusProvider, null); + + final List<PeerStatus> destinations = peerSelector.formulateDestinationList(collection, TransferDirection.RECEIVE); + final Map<String, Integer> selectedCounts = calculateAverageSelectedCount(collection, destinations); + + logger.info("selectedCounts={}", selectedCounts); + assertTrue("HasLots should get little", selectedCounts.get("HasLots") < selectedCounts.get("HasLittle")); + } +} http://git-wip-us.apache.org/repos/asf/nifi/blob/c120c498/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/http/TestHttpClient.java ---------------------------------------------------------------------- diff --git a/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/http/TestHttpClient.java b/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/http/TestHttpClient.java new file mode 100644 index 0000000..7240c7a --- /dev/null +++ b/nifi-commons/nifi-site-to-site-client/src/test/java/org/apache/nifi/remote/client/http/TestHttpClient.java @@ -0,0 +1,950 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.remote.client.http; + +import org.apache.nifi.controller.ScheduledState; +import org.apache.nifi.remote.Peer; +import org.apache.nifi.remote.Transaction; +import org.apache.nifi.remote.TransferDirection; +import org.apache.nifi.remote.client.SiteToSiteClient; +import org.apache.nifi.remote.codec.StandardFlowFileCodec; +import org.apache.nifi.remote.io.CompressionInputStream; +import org.apache.nifi.remote.io.CompressionOutputStream; +import org.apache.nifi.remote.protocol.DataPacket; +import org.apache.nifi.remote.protocol.ResponseCode; +import org.apache.nifi.remote.protocol.SiteToSiteTransportProtocol; +import org.apache.nifi.remote.protocol.http.HttpHeaders; +import org.apache.nifi.remote.util.StandardDataPacket; +import org.apache.nifi.stream.io.ByteArrayInputStream; +import org.apache.nifi.stream.io.ByteArrayOutputStream; +import org.apache.nifi.stream.io.StreamUtils; +import org.apache.nifi.web.api.dto.ControllerDTO; +import org.apache.nifi.web.api.dto.PortDTO; +import org.apache.nifi.web.api.dto.remote.PeerDTO; +import org.apache.nifi.web.api.entity.ControllerEntity; +import org.apache.nifi.web.api.entity.PeersEntity; +import org.apache.nifi.web.api.entity.TransactionResultEntity; +import org.codehaus.jackson.map.ObjectMapper; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHandler; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.ServletException; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.SocketTimeoutException; +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.commons.lang3.StringUtils.isEmpty; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_HEADER_NAME; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_NAME; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.LOCATION_URI_INTENT_VALUE; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.PROTOCOL_VERSION; +import static org.apache.nifi.remote.protocol.http.HttpHeaders.SERVER_SIDE_TRANSACTION_TTL; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class TestHttpClient { + + private static Logger logger = LoggerFactory.getLogger(TestHttpClient.class); + + private static Server server; + final private static AtomicBoolean isTestCaseFinished = new AtomicBoolean(false); + + private static Set<PortDTO> inputPorts; + private static Set<PortDTO> outputPorts; + private static Set<PeerDTO> peers; + private static String serverChecksum; + + public static class SiteInfoServlet extends HttpServlet { + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final ControllerDTO controller = new ControllerDTO(); + controller.setRemoteSiteHttpListeningPort(server.getURI().getPort()); + controller.setId("remote-controller-id"); + controller.setInstanceId("remote-instance-id"); + controller.setName("Remote NiFi Flow"); + controller.setSiteToSiteSecure(false); + + assertNotNull("Test case should set <inputPorts> depending on the test scenario.", inputPorts); + controller.setInputPorts(inputPorts); + controller.setInputPortCount(inputPorts.size()); + + assertNotNull("Test case should set <outputPorts> depending on the test scenario.", outputPorts); + controller.setOutputPorts(outputPorts); + controller.setOutputPortCount(outputPorts.size()); + + final ControllerEntity controllerEntity = new ControllerEntity(); + controllerEntity.setController(controller); + + respondWithJson(resp, controllerEntity); + } + } + + public static class PeersServlet extends HttpServlet { + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final PeersEntity peersEntity = new PeersEntity(); + + assertNotNull("Test case should set <peers> depending on the test scenario.", peers); + peersEntity.setPeers(peers); + + respondWithJson(resp, peersEntity); + } + } + + public static class PortTransactionsServlet extends HttpServlet { + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + TransactionResultEntity entity = new TransactionResultEntity(); + entity.setResponseCode(ResponseCode.PROPERTIES_OK.getCode()); + entity.setMessage("A transaction is created."); + + resp.setHeader(LOCATION_URI_INTENT_NAME, LOCATION_URI_INTENT_VALUE); + resp.setHeader(LOCATION_HEADER_NAME, req.getRequestURL() + "/transaction-id"); + setCommonResponseHeaders(resp, reqProtocolVersion); + + respondWithJson(resp, entity, HttpServletResponse.SC_CREATED); + } + + } + + public static class InputPortTransactionServlet extends HttpServlet { + + @Override + protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + final int reqProtocolVersion = getReqProtocolVersion(req); + + final TransactionResultEntity entity = new TransactionResultEntity(); + entity.setResponseCode(ResponseCode.CONTINUE_TRANSACTION.getCode()); + entity.setMessage("Extended TTL."); + + setCommonResponseHeaders(resp, reqProtocolVersion); + + respondWithJson(resp, entity, HttpServletResponse.SC_OK); + } + + @Override + protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + TransactionResultEntity entity = new TransactionResultEntity(); + entity.setResponseCode(ResponseCode.TRANSACTION_FINISHED.getCode()); + entity.setMessage("The transaction is finished."); + + setCommonResponseHeaders(resp, reqProtocolVersion); + + respondWithJson(resp, entity, HttpServletResponse.SC_OK); + } + + } + + public static class OutputPortTransactionServlet extends HttpServlet { + + @Override + protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + final int reqProtocolVersion = getReqProtocolVersion(req); + + final TransactionResultEntity entity = new TransactionResultEntity(); + entity.setResponseCode(ResponseCode.CONTINUE_TRANSACTION.getCode()); + entity.setMessage("Extended TTL."); + + setCommonResponseHeaders(resp, reqProtocolVersion); + + respondWithJson(resp, entity, HttpServletResponse.SC_OK); + } + + @Override + protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + TransactionResultEntity entity = new TransactionResultEntity(); + entity.setResponseCode(ResponseCode.CONFIRM_TRANSACTION.getCode()); + entity.setMessage("The transaction is confirmed."); + + setCommonResponseHeaders(resp, reqProtocolVersion); + + respondWithJson(resp, entity, HttpServletResponse.SC_OK); + } + + } + + public static class FlowFilesServlet extends HttpServlet { + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + setCommonResponseHeaders(resp, reqProtocolVersion); + + DataPacket dataPacket; + while ((dataPacket = readIncomingPacket(req)) != null) { + logger.info("received {}", dataPacket); + consumeDataPacket(dataPacket); + } + logger.info("finish receiving data packets."); + + assertNotNull("Test case should set <serverChecksum> depending on the test scenario.", serverChecksum); + respondWithText(resp, serverChecksum, HttpServletResponse.SC_ACCEPTED); + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + resp.setStatus(HttpServletResponse.SC_ACCEPTED); + resp.setContentType("application/octet-stream"); + setCommonResponseHeaders(resp, reqProtocolVersion); + + final OutputStream outputStream = getOutputStream(req, resp); + writeOutgoingPacket(outputStream); + writeOutgoingPacket(outputStream); + writeOutgoingPacket(outputStream); + resp.flushBuffer(); + } + } + + public static class FlowFilesTimeoutServlet extends FlowFilesServlet { + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + sleepUntilTestCaseFinish(); + + super.doPost(req, resp); + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + sleepUntilTestCaseFinish(); + + super.doGet(req, resp); + } + + } + + public static class FlowFilesTimeoutAfterDataExchangeServlet extends HttpServlet { + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + setCommonResponseHeaders(resp, reqProtocolVersion); + + consumeDataPacket(readIncomingPacket(req)); + + sleepUntilTestCaseFinish(); + + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + + final int reqProtocolVersion = getReqProtocolVersion(req); + + resp.setStatus(HttpServletResponse.SC_ACCEPTED); + resp.setContentType("application/octet-stream"); + setCommonResponseHeaders(resp, reqProtocolVersion); + + writeOutgoingPacket(getOutputStream(req, resp)); + + sleepUntilTestCaseFinish(); + } + } + + private static void sleepUntilTestCaseFinish() { + while (!isTestCaseFinished.get()) { + try { + logger.info("Sleeping..."); + Thread.sleep(1000); + } catch (InterruptedException e) { + logger.info("Got an exception while sleeping.", e); + break; + } + } + } + + private static void writeOutgoingPacket(OutputStream outputStream) throws IOException { + final DataPacket packet = new DataPacketBuilder() + .contents("Example contents from server.") + .attr("Server attr 1", "Server attr 1 value") + .attr("Server attr 2", "Server attr 2 value") + .build(); + new StandardFlowFileCodec().encode(packet, outputStream); + outputStream.flush(); + } + + private static OutputStream getOutputStream(HttpServletRequest req, HttpServletResponse resp) throws IOException { + OutputStream outputStream = resp.getOutputStream(); + if (Boolean.valueOf(req.getHeader(HttpHeaders.HANDSHAKE_PROPERTY_USE_COMPRESSION))){ + outputStream = new CompressionOutputStream(outputStream); + } + return outputStream; + } + + private static DataPacket readIncomingPacket(HttpServletRequest req) throws IOException { + final StandardFlowFileCodec codec = new StandardFlowFileCodec(); + InputStream inputStream = req.getInputStream(); + if (Boolean.valueOf(req.getHeader(HttpHeaders.HANDSHAKE_PROPERTY_USE_COMPRESSION))){ + inputStream = new CompressionInputStream(inputStream); + } + + return codec.decode(inputStream); + } + + private static int getReqProtocolVersion(HttpServletRequest req) { + final String reqProtocolVersionStr = req.getHeader(PROTOCOL_VERSION); + assertTrue(!isEmpty(reqProtocolVersionStr)); + return Integer.parseInt(reqProtocolVersionStr); + } + + private static void setCommonResponseHeaders(HttpServletResponse resp, int reqProtocolVersion) { + resp.setHeader(PROTOCOL_VERSION, String.valueOf(reqProtocolVersion)); + resp.setHeader(SERVER_SIDE_TRANSACTION_TTL, "3"); + } + + private static void respondWithJson(HttpServletResponse resp, Object entity) throws IOException { + respondWithJson(resp, entity, HttpServletResponse.SC_OK); + } + + private static void respondWithJson(HttpServletResponse resp, Object entity, int statusCode) throws IOException { + resp.setContentType("application/json"); + resp.setStatus(statusCode); + final ServletOutputStream out = resp.getOutputStream(); + new ObjectMapper().writer().writeValue(out, entity); + out.flush(); + } + + private static void respondWithText(HttpServletResponse resp, String result, int statusCode) throws IOException { + resp.setContentType("text/plain"); + resp.setStatus(statusCode); + final ServletOutputStream out = resp.getOutputStream(); + out.write(result.getBytes()); + out.flush(); + } + + @BeforeClass + public static void setup() throws Exception { + // Create embedded Jetty server + server = new Server(0); + + ServletContextHandler contextHandler = new ServletContextHandler(); + contextHandler.setContextPath("/nifi-api"); + server.setHandler(contextHandler); + + ServletHandler servletHandler = new ServletHandler(); + contextHandler.insertHandler(servletHandler); + + servletHandler.addServletWithMapping(SiteInfoServlet.class, "/site-to-site"); + servletHandler.addServletWithMapping(PeersServlet.class, "/site-to-site/peers"); + + servletHandler.addServletWithMapping(PortTransactionsServlet.class, "/site-to-site/input-ports/input-running-id/transactions"); + servletHandler.addServletWithMapping(InputPortTransactionServlet.class, "/site-to-site/input-ports/input-running-id/transactions/transaction-id"); + servletHandler.addServletWithMapping(FlowFilesServlet.class, "/site-to-site/input-ports/input-running-id/transactions/transaction-id/flow-files"); + + servletHandler.addServletWithMapping(PortTransactionsServlet.class, "/site-to-site/input-ports/input-timeout-id/transactions"); + servletHandler.addServletWithMapping(InputPortTransactionServlet.class, "/site-to-site/input-ports/input-timeout-id/transactions/transaction-id"); + servletHandler.addServletWithMapping(FlowFilesTimeoutServlet.class, "/site-to-site/input-ports/input-timeout-id/transactions/transaction-id/flow-files"); + + servletHandler.addServletWithMapping(PortTransactionsServlet.class, "/site-to-site/input-ports/input-timeout-data-ex-id/transactions"); + servletHandler.addServletWithMapping(InputPortTransactionServlet.class, "/site-to-site/input-ports/input-timeout-data-ex-id/transactions/transaction-id"); + servletHandler.addServletWithMapping(FlowFilesTimeoutAfterDataExchangeServlet.class, "/site-to-site/input-ports/input-timeout-data-ex-id/transactions/transaction-id/flow-files"); + + servletHandler.addServletWithMapping(PortTransactionsServlet.class, "/site-to-site/output-ports/output-running-id/transactions"); + servletHandler.addServletWithMapping(OutputPortTransactionServlet.class, "/site-to-site/output-ports/output-running-id/transactions/transaction-id"); + servletHandler.addServletWithMapping(FlowFilesServlet.class, "/site-to-site/output-ports/output-running-id/transactions/transaction-id/flow-files"); + + servletHandler.addServletWithMapping(PortTransactionsServlet.class, "/site-to-site/output-ports/output-timeout-id/transactions"); + servletHandler.addServletWithMapping(OutputPortTransactionServlet.class, "/site-to-site/output-ports/output-timeout-id/transactions/transaction-id"); + servletHandler.addServletWithMapping(FlowFilesTimeoutServlet.class, "/site-to-site/output-ports/output-timeout-id/transactions/transaction-id/flow-files"); + + servletHandler.addServletWithMapping(PortTransactionsServlet.class, "/site-to-site/output-ports/output-timeout-data-ex-id/transactions"); + servletHandler.addServletWithMapping(OutputPortTransactionServlet.class, "/site-to-site/output-ports/output-timeout-data-ex-id/transactions/transaction-id"); + servletHandler.addServletWithMapping(FlowFilesTimeoutAfterDataExchangeServlet.class, "/site-to-site/output-ports/output-timeout-data-ex-id/transactions/transaction-id/flow-files"); + + server.start(); + + int serverPort = server.getURI().getPort(); + logger.info("Starting server on port {}", serverPort); + } + + @AfterClass + public static void teardown() throws Exception { + logger.info("Stopping server."); + server.stop(); + } + + private static class DataPacketBuilder { + private final Map<String, String> attributes = new HashMap<>(); + private String contents; + + private DataPacketBuilder attr(final String k, final String v) { + attributes.put(k, v); + return this; + } + + private DataPacketBuilder contents(final String contents) { + this.contents = contents; + return this; + } + + private DataPacket build() { + byte[] bytes = contents.getBytes(); + return new StandardDataPacket(attributes, new ByteArrayInputStream(bytes), bytes.length); + } + + } + + @Before + public void before() throws Exception { + + System.setProperty("org.slf4j.simpleLogger.log.org.apache.nifi.remote", "TRACE"); + System.setProperty("org.slf4j.simpleLogger.log.org.apache.nifi.remote.protocol.http.HttpClientTransaction", "DEBUG"); + + final URI uri = server.getURI(); + final PeerDTO peer = new PeerDTO(); + peer.setHostname(uri.getHost()); + peer.setPort(uri.getPort()); + peer.setFlowFileCount(10); + peer.setSecure(false); + + isTestCaseFinished.set(false); + + peers = new HashSet<>(); + peers.add(peer); + + inputPorts = new HashSet<>(); + + final PortDTO runningInputPort = new PortDTO(); + runningInputPort.setId("running-input-port"); + inputPorts.add(runningInputPort); + runningInputPort.setName("input-running"); + runningInputPort.setId("input-running-id"); + runningInputPort.setType("INPUT_PORT"); + runningInputPort.setState(ScheduledState.RUNNING.name()); + + final PortDTO timeoutInputPort = new PortDTO(); + timeoutInputPort.setId("timeout-input-port"); + inputPorts.add(timeoutInputPort); + timeoutInputPort.setName("input-timeout"); + timeoutInputPort.setId("input-timeout-id"); + timeoutInputPort.setType("INPUT_PORT"); + timeoutInputPort.setState(ScheduledState.RUNNING.name()); + + final PortDTO timeoutDataExInputPort = new PortDTO(); + timeoutDataExInputPort.setId("timeout-dataex-input-port"); + inputPorts.add(timeoutDataExInputPort); + timeoutDataExInputPort.setName("input-timeout-data-ex"); + timeoutDataExInputPort.setId("input-timeout-data-ex-id"); + timeoutDataExInputPort.setType("INPUT_PORT"); + timeoutDataExInputPort.setState(ScheduledState.RUNNING.name()); + + outputPorts = new HashSet<>(); + + final PortDTO runningOutputPort = new PortDTO(); + runningOutputPort.setId("running-output-port"); + outputPorts.add(runningOutputPort); + runningOutputPort.setName("output-running"); + runningOutputPort.setId("output-running-id"); + runningOutputPort.setType("OUTPUT_PORT"); + runningOutputPort.setState(ScheduledState.RUNNING.name()); + + final PortDTO timeoutOutputPort = new PortDTO(); + timeoutOutputPort.setId("timeout-output-port"); + outputPorts.add(timeoutOutputPort); + timeoutOutputPort.setName("output-timeout"); + timeoutOutputPort.setId("output-timeout-id"); + timeoutOutputPort.setType("OUTPUT_PORT"); + timeoutOutputPort.setState(ScheduledState.RUNNING.name()); + + final PortDTO timeoutDataExOutputPort = new PortDTO(); + timeoutDataExOutputPort.setId("timeout-dataex-output-port"); + outputPorts.add(timeoutDataExOutputPort); + timeoutDataExOutputPort.setName("output-timeout-data-ex"); + timeoutDataExOutputPort.setId("output-timeout-data-ex-id"); + timeoutDataExOutputPort.setType("OUTPUT_PORT"); + timeoutDataExOutputPort.setState(ScheduledState.RUNNING.name()); + + + } + + @After + public void after() throws Exception { + isTestCaseFinished.set(true); + } + + private SiteToSiteClient.Builder getDefaultBuilder() { + final URI uri = server.getURI(); + return new SiteToSiteClient.Builder().transportProtocol(SiteToSiteTransportProtocol.HTTP) + .url("http://" + uri.getHost() + ":" + uri.getPort() + "/nifi") + ; + } + + private static void consumeDataPacket(DataPacket packet) throws IOException { + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + StreamUtils.copy(packet.getData(), bos); + String contents = new String(bos.toByteArray()); + logger.info("received: {}, {}", contents, packet.getAttributes()); + } + + + @Test + public void testUnkownClusterUrl() throws Exception { + + final URI uri = server.getURI(); + + try ( + SiteToSiteClient client = getDefaultBuilder() + .url("http://" + uri.getHost() + ":" + uri.getPort() + "/unkown") + .portName("input-running") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNull(transaction); + + } + + } + + @Test + public void testNoAvailablePeer() throws Exception { + + peers = new HashSet<>(); + + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("input-running") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNull(transaction); + + } + + } + + @Test + public void testSendUnknownPort() throws Exception { + + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("input-unknown") + .build() + ) { + try { + client.createTransaction(TransferDirection.SEND); + fail(); + } catch (IOException e) { + logger.info("Exception message: {}", e.getMessage()); + assertTrue(e.getMessage().contains("Failed to determine the identifier of port")); + } + } + } + + @Test + public void testSendSuccess() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("input-running") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNotNull(transaction); + + serverChecksum = "1071206772"; + + + for (int i = 0; i < 20; i++) { + DataPacket packet = new DataPacketBuilder() + .contents("Example contents from client.") + .attr("Client attr 1", "Client attr 1 value") + .attr("Client attr 2", "Client attr 2 value") + .build(); + transaction.send(packet); + long written = ((Peer)transaction.getCommunicant()).getCommunicationsSession().getBytesWritten(); + logger.info("{}: {} bytes have been written.", i, written); + } + + transaction.confirm(); + + transaction.complete(); + } + + } + + @Test + public void testSendSuccessCompressed() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("input-running") + .useCompression(true) + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNotNull(transaction); + + serverChecksum = "1071206772"; + + + for (int i = 0; i < 20; i++) { + DataPacket packet = new DataPacketBuilder() + .contents("Example contents from client.") + .attr("Client attr 1", "Client attr 1 value") + .attr("Client attr 2", "Client attr 2 value") + .build(); + transaction.send(packet); + long written = ((Peer)transaction.getCommunicant()).getCommunicationsSession().getBytesWritten(); + logger.info("{}: {} bytes have been written.", i, written); + } + + transaction.confirm(); + + transaction.complete(); + } + + } + + @Test + public void testSendSlowClientSuccess() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .idleExpiration(1000, TimeUnit.MILLISECONDS) + .portName("input-running") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNotNull(transaction); + + serverChecksum = "3882825556"; + + + for (int i = 0; i < 3; i++) { + DataPacket packet = new DataPacketBuilder() + .contents("Example contents from client.") + .attr("Client attr 1", "Client attr 1 value") + .attr("Client attr 2", "Client attr 2 value") + .build(); + transaction.send(packet); + long written = ((Peer)transaction.getCommunicant()).getCommunicationsSession().getBytesWritten(); + logger.info("{} bytes have been written.", written); + Thread.sleep(50); + } + + transaction.confirm(); + transaction.complete(); + } + + } + + private void completeShouldFail(Transaction transaction) throws IOException { + try { + transaction.complete(); + fail("Complete operation should fail since transaction has already failed."); + } catch (IllegalStateException e) { + logger.info("An exception was thrown as expected.", e); + } + } + + private void confirmShouldFail(Transaction transaction) throws IOException { + try { + transaction.confirm(); + fail("Confirm operation should fail since transaction has already failed."); + } catch (IllegalStateException e) { + logger.info("An exception was thrown as expected.", e); + } + } + + @Test + public void testSendTimeout() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .timeout(1, TimeUnit.SECONDS) + .portName("input-timeout") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNotNull(transaction); + + DataPacket packet = new DataPacketBuilder() + .contents("Example contents from client.") + .attr("Client attr 1", "Client attr 1 value") + .attr("Client attr 2", "Client attr 2 value") + .build(); + serverChecksum = "1345413116"; + + transaction.send(packet); + try { + transaction.confirm(); + fail(); + } catch (IOException e) { + logger.info("An exception was thrown as expected.", e); + assertTrue(e.getMessage().contains("TimeoutException")); + } + + completeShouldFail(transaction); + } + + } + + @Test + public void testSendTimeoutAfterDataExchange() throws Exception { + + System.setProperty("org.slf4j.simpleLogger.log.org.apache.nifi.remote.protocol.http.HttpClientTransaction", "INFO"); + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .idleExpiration(500, TimeUnit.MILLISECONDS) + .timeout(500, TimeUnit.MILLISECONDS) + .portName("input-timeout-data-ex") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.SEND); + + assertNotNull(transaction); + + DataPacket packet = new DataPacketBuilder() + .contents("Example contents from client.") + .attr("Client attr 1", "Client attr 1 value") + .attr("Client attr 2", "Client attr 2 value") + .build(); + + for(int i = 0; i < 100; i++) { + transaction.send(packet); + if (i % 10 == 0) { + logger.info("Sent {} packets...", i); + } + } + + try { + confirmShouldFail(transaction); + fail("Should be timeout."); + } catch (IOException e) { + logger.info("Exception message: {}", e.getMessage()); + assertTrue(e.getMessage().contains("TimeoutException")); + } + + completeShouldFail(transaction); + } + + } + + @Test + public void testReceiveUnknownPort() throws Exception { + + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("output-unknown") + .build() + ) { + try { + client.createTransaction(TransferDirection.RECEIVE); + fail(); + } catch (IOException e) { + logger.info("Exception message: {}", e.getMessage()); + assertTrue(e.getMessage().contains("Failed to determine the identifier of port")); + } + } + } + + @Test + public void testReceiveSuccess() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("output-running") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.RECEIVE); + + assertNotNull(transaction); + + DataPacket packet; + while ((packet = transaction.receive()) != null) { + consumeDataPacket(packet); + } + transaction.confirm(); + transaction.complete(); + } + } + + @Test + public void testReceiveSuccessCompressed() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("output-running") + .useCompression(true) + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.RECEIVE); + + assertNotNull(transaction); + + DataPacket packet; + while ((packet = transaction.receive()) != null) { + consumeDataPacket(packet); + } + transaction.confirm(); + transaction.complete(); + } + } + + @Test + public void testReceiveSlowClientSuccess() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .portName("output-running") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.RECEIVE); + + assertNotNull(transaction); + + DataPacket packet; + while ((packet = transaction.receive()) != null) { + consumeDataPacket(packet); + Thread.sleep(500); + } + transaction.confirm(); + transaction.complete(); + } + } + + @Test + public void testReceiveTimeout() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .timeout(1, TimeUnit.SECONDS) + .portName("output-timeout") + .build() + ) { + try { + client.createTransaction(TransferDirection.RECEIVE); + fail(); + } catch (IOException e) { + logger.info("An exception was thrown as expected.", e); + assertTrue(e instanceof SocketTimeoutException); + } + } + } + + @Test + public void testReceiveTimeoutAfterDataExchange() throws Exception { + + final URI uri = server.getURI(); + + logger.info("uri={}", uri); + try ( + SiteToSiteClient client = getDefaultBuilder() + .timeout(1, TimeUnit.SECONDS) + .portName("output-timeout-data-ex") + .build() + ) { + final Transaction transaction = client.createTransaction(TransferDirection.RECEIVE); + assertNotNull(transaction); + + DataPacket packet = transaction.receive(); + assertNotNull(packet); + consumeDataPacket(packet); + + try { + transaction.receive(); + fail(); + } catch (IOException e) { + logger.info("An exception was thrown as expected.", e); + assertTrue(e.getCause() instanceof SocketTimeoutException); + } + + confirmShouldFail(transaction); + completeShouldFail(transaction); + } + } + +}