This is an automated email from the ASF dual-hosted git repository. aichrist pushed a commit to branch analytics-framework in repository https://gitbox.apache.org/repos/asf/nifi.git
commit dbbd41bb7a36916e5feab7efa360b12c0e29a522 Author: Yolanda Davis <[email protected]> AuthorDate: Fri Aug 23 08:36:19 2019 -0400 NIFI-6585 - Refactored tests to use mocked models and extract functions. Added check in ConnectionStatusAnalytics to confirm expected model by type --- .../analytics/ConnectionStatusAnalytics.java | 16 +- .../analytics/TestConnectionStatusAnalytics.java | 370 ++++++++++++++++----- .../analytics/TestStatusAnalyticsEngine.java | 34 +- 3 files changed, 324 insertions(+), 96 deletions(-) diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java index bc2270b..68c472e 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java @@ -99,6 +99,14 @@ public class ConnectionStatusAnalytics implements StatusAnalytics { }); } + protected StatusAnalyticsModel getModel(String modelType){ + + if(modelMap.containsKey(modelType)){ + return modelMap.get(modelType).getKey(); + }else{ + throw new IllegalArgumentException("Model cannot be found for provided type: " + modelType); + } + } /** * Returns the predicted time (in milliseconds) when backpressure is expected to be applied to this connection, based on the total number of bytes in the queue. * @@ -106,7 +114,7 @@ public class ConnectionStatusAnalytics implements StatusAnalytics { */ public Long getTimeToBytesBackpressureMillis() { - final StatusAnalyticsModel bytesModel = modelMap.get("queuedBytes").getKey(); + final StatusAnalyticsModel bytesModel = getModel("queuedBytes"); FlowFileEvent flowFileEvent = getStatusReport(); final Connection connection = getConnection(); @@ -133,7 +141,7 @@ public class ConnectionStatusAnalytics implements StatusAnalytics { */ public Long getTimeToCountBackpressureMillis() { - final StatusAnalyticsModel countModel = modelMap.get("queuedCount").getKey(); + final StatusAnalyticsModel countModel = getModel("queuedCount"); FlowFileEvent flowFileEvent = getStatusReport(); final Connection connection = getConnection(); @@ -160,7 +168,7 @@ public class ConnectionStatusAnalytics implements StatusAnalytics { */ public Long getNextIntervalBytes() { - final StatusAnalyticsModel bytesModel = modelMap.get("queuedBytes").getKey(); + final StatusAnalyticsModel bytesModel = getModel("queuedBytes"); FlowFileEvent flowFileEvent = getStatusReport(); if (validModel(bytesModel) && flowFileEvent != null) { @@ -182,7 +190,7 @@ public class ConnectionStatusAnalytics implements StatusAnalytics { */ public Long getNextIntervalCount() { - final StatusAnalyticsModel countModel = modelMap.get("queuedCount").getKey(); + final StatusAnalyticsModel countModel = getModel("queuedCount"); FlowFileEvent flowFileEvent = getStatusReport(); if (validModel(countModel) && flowFileEvent != null) { diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java index ff92215..5d9279f 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java @@ -16,7 +16,10 @@ */ package org.apache.nifi.controller.status.analytics; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; @@ -32,7 +35,9 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.lang3.time.DateUtils; import org.apache.nifi.bundle.Bundle; import org.apache.nifi.connectable.Connection; import org.apache.nifi.controller.flow.FlowManager; @@ -43,15 +48,16 @@ import org.apache.nifi.controller.repository.RepositoryStatusReport; import org.apache.nifi.controller.status.history.ComponentStatusRepository; import org.apache.nifi.controller.status.history.ConnectionStatusDescriptor; import org.apache.nifi.controller.status.history.MetricDescriptor; -import org.apache.nifi.controller.status.history.StandardStatusSnapshot; import org.apache.nifi.controller.status.history.StatusHistory; -import org.apache.nifi.controller.status.history.StatusSnapshot; import org.apache.nifi.groups.ProcessGroup; import org.apache.nifi.nar.StandardExtensionDiscoveringManager; import org.apache.nifi.nar.SystemBundle; import org.apache.nifi.util.NiFiProperties; +import org.apache.nifi.util.Tuple; import org.junit.Test; import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; public class TestConnectionStatusAnalytics { @@ -59,8 +65,7 @@ public class TestConnectionStatusAnalytics { .map(ConnectionStatusDescriptor::getDescriptor) .collect(Collectors.toSet()); - protected ConnectionStatusAnalytics getConnectionStatusAnalytics(Long queuedBytes, Long queuedCount, String backPressureDataSizeThreshhold, - Long backPressureObjectThreshold, Boolean isConstantStatus) { + protected ConnectionStatusAnalytics getConnectionStatusAnalytics(Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap) { ComponentStatusRepository statusRepository = Mockito.mock(ComponentStatusRepository.class); FlowManager flowManager; @@ -87,135 +92,342 @@ public class TestConnectionStatusAnalytics { final String connectionIdentifier = "1"; connections.add(connection); - List<StatusSnapshot> snapshotList = new ArrayList<>(); - final long startTime = System.currentTimeMillis(); - int iterations = 10; - - Long inputBytes = queuedBytes * 2; - Long outputBytes = inputBytes - queuedBytes; - Long inputCount = queuedCount * 2; - Long outputCount = inputCount - queuedCount; - - for (int i = 0; i < iterations; i++) { - final StandardStatusSnapshot snapshot = new StandardStatusSnapshot(CONNECTION_METRICS); - snapshot.setTimestamp(new Date(startTime + i * 1000)); - snapshot.addStatusMetric(ConnectionStatusDescriptor.QUEUED_BYTES.getDescriptor(), (isConstantStatus || i < 5) ? queuedBytes : queuedBytes * 2); - snapshot.addStatusMetric(ConnectionStatusDescriptor.QUEUED_COUNT.getDescriptor(), (isConstantStatus || i < 5) ? queuedCount : queuedCount * 2); - snapshot.addStatusMetric(ConnectionStatusDescriptor.INPUT_BYTES.getDescriptor(), (isConstantStatus || i < 5) ? inputBytes : inputBytes * 2); - snapshot.addStatusMetric(ConnectionStatusDescriptor.INPUT_COUNT.getDescriptor(), (isConstantStatus || i < 5) ? inputCount : inputCount * 2); - snapshot.addStatusMetric(ConnectionStatusDescriptor.OUTPUT_BYTES.getDescriptor(), (isConstantStatus || i < 5) ? outputBytes : outputBytes * 2); - snapshot.addStatusMetric(ConnectionStatusDescriptor.OUTPUT_COUNT.getDescriptor(), (isConstantStatus || i < 5) ? outputCount : outputCount * 2); - snapshotList.add(snapshot); - } - - when(flowFileQueue.getBackPressureDataSizeThreshold()).thenReturn(backPressureDataSizeThreshhold); - when(flowFileQueue.getBackPressureObjectThreshold()).thenReturn(backPressureObjectThreshold); + when(flowFileQueue.getBackPressureDataSizeThreshold()).thenReturn("100MB"); + when(flowFileQueue.getBackPressureObjectThreshold()).thenReturn(100L); when(connection.getIdentifier()).thenReturn(connectionIdentifier); when(connection.getFlowFileQueue()).thenReturn(flowFileQueue); when(processGroup.findAllConnections()).thenReturn(connections); - when(statusHistory.getStatusSnapshots()).thenReturn(snapshotList); when(flowManager.getRootGroup()).thenReturn(processGroup); - when(flowFileEvent.getContentSizeIn()).thenReturn(inputBytes); - when(flowFileEvent.getContentSizeOut()).thenReturn(outputBytes); - when(flowFileEvent.getFlowFilesIn()).thenReturn(inputCount.intValue()); - when(flowFileEvent.getFlowFilesOut()).thenReturn(outputCount.intValue()); + when(flowFileEvent.getContentSizeIn()).thenReturn(10L); + when(flowFileEvent.getContentSizeOut()).thenReturn(10L); + when(flowFileEvent.getFlowFilesIn()).thenReturn(10); + when(flowFileEvent.getFlowFilesOut()).thenReturn(10); when(flowFileEventRepository.reportTransferEvents(anyLong())).thenReturn(repositoryStatusReport); when(repositoryStatusReport.getReportEntry(anyString())).thenReturn(flowFileEvent); when(statusRepository.getConnectionStatusHistory(anyString(), any(), any(), anyInt())).thenReturn(statusHistory); ConnectionStatusAnalytics connectionStatusAnalytics = new ConnectionStatusAnalytics(statusRepository, flowManager,flowFileEventRepository, - StatusAnalyticsModelMapFactory.getConnectionStatusModelMap(extensionManager,nifiProperties), connectionIdentifier, false); + modelMap, connectionIdentifier, false); connectionStatusAnalytics.refresh(); return connectionStatusAnalytics; } + public Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> getModelMap( String predictionType, Double score, + Double targetPrediction, Double variablePrediction){ + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = new HashMap<>(); + StatusAnalyticsModel model = Mockito.mock(StatusAnalyticsModel.class); + StatusMetricExtractFunction extractFunction = Mockito.mock(StatusMetricExtractFunction.class); + Tuple<StatusAnalyticsModel,StatusMetricExtractFunction> modelTuple = new Tuple<>(model,extractFunction); + modelMap.put(predictionType,modelTuple); + Map<String,Double> scores = new HashMap<>(); + scores.put("rSquared",score); + + Double[][] features = new Double[1][1]; + Double[] target = new Double[1]; + + when(extractFunction.extractMetric(anyString(),any(StatusHistory.class))).then(new Answer<Tuple<Stream<Double[]>,Stream<Double>>>() { + @Override + public Tuple<Stream<Double[]>, Stream<Double>> answer(InvocationOnMock invocationOnMock) throws Throwable { + return new Tuple<>(Stream.of(features), Stream.of(target)); + } + }); + + when(model.getScores()).thenReturn(scores); + when(model.predict(any(Double[].class))).thenReturn(targetPrediction); + when(model.predictVariable(anyInt(),any(),any())).thenReturn(variablePrediction); + return modelMap; + + } + + @Test + public void testInvalidModelLowScore() { + Date now = new Date(); + Long tomorrowMillis = DateUtils.addDays(now,1).toInstant().toEpochMilli(); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.5,100.0,tomorrowMillis.doubleValue()); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); + } + + @Test + public void testInvalidModelNaNScore() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",Double.NaN,Double.NaN,Double.NaN); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); + } + + @Test + public void testInvalidModelInfiniteScore() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); + } + @Test public void testGetIntervalTimeMillis() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,100.0,100.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); Long interval = connectionStatusAnalytics.getIntervalTimeMillis(); assertNotNull(interval); assert (interval == 180000); } @Test - public void testGetTimeToCountBackpressureMillisConstantStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true); + public void testGetTimeToCountBackpressureMillis() { + Date now = new Date(); + Long tomorrowMillis = DateUtils.addDays(now,1).toInstant().toEpochMilli(); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,100.0,tomorrowMillis.doubleValue()); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); assertNotNull(countTime); assert (countTime > 0); } @Test - public void testGetTimeToCountBackpressureMillisVaryingStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false); + public void testCannotPredictTimeToCountNaN() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,Double.NaN,Double.NaN); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); assertNotNull(countTime); - assert (countTime == -1L); + assert (countTime == -1); } @Test - public void testGetTimeToBytesBackpressureMillisConstantStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true); - Long bytesTime = connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); - assertNotNull(bytesTime); - assert (bytesTime == -1L || bytesTime == 0); + public void testCannotPredictTimeToCountInfinite() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); } @Test - public void testGetTimeToBytesBackpressureMillisVaryingStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false); - Long bytesTime = connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); - assertNotNull(bytesTime); - assert (bytesTime == -1L); + public void testCannotPredictTimeToCountNegative() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,-1.0,-1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToCountBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); } @Test - public void testGetNextIntervalBytesConstantStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true); - Long nextBytes = connectionStatusAnalytics.getNextIntervalBytes(); - assertNotNull(nextBytes); - assert (nextBytes == 5000L); + public void testMissingModelGetTimeToCount() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + try { + connectionStatusAnalytics.getTimeToCountBackpressureMillis(); + fail(); + }catch(IllegalArgumentException iae){ + assertTrue(true); + } } @Test - public void testGetNextIntervalBytesVaryingStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false); - Long nextBytes = connectionStatusAnalytics.getNextIntervalBytes(); - assertNotNull(nextBytes); - assert (nextBytes == -1L); + public void testGetTimeToBytesBackpressureMillis() { + Date now = new Date(); + Long tomorrowMillis = DateUtils.addDays(now,1).toInstant().toEpochMilli(); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,100.0,tomorrowMillis.doubleValue()); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); + assertNotNull(countTime); + assert (countTime > 0); } @Test - public void testGetNextIntervalCountConstantStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true); - Long nextCount = connectionStatusAnalytics.getNextIntervalCount(); - assertNotNull(nextCount); - assert (nextCount == 50L); + public void testCannotPredictTimeToBytesNaN() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,Double.NaN,Double.NaN); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); } @Test - public void testGetNextIntervalCountVaryingStatus() { + public void testCannotPredictTimeToBytesInfinite() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); + } - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false); - Long nextCount = connectionStatusAnalytics.getNextIntervalCount(); - assertNotNull(nextCount); - assert (nextCount == -1L); + @Test + public void testCannotPredictTimeToBytesNegative() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,-1.0,-1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long countTime = connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); + assertNotNull(countTime); + assert (countTime == -1); + } + + @Test + public void testMissingModelGetTimeToBytes() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + try { + connectionStatusAnalytics.getTimeToBytesBackpressureMillis(); + fail(); + }catch(IllegalArgumentException iae){ + assertTrue(true); + } } @Test - public void testGetNextIntervalPercentageUseBytesConstantStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(50000L, 50L, "1MB", 100L, true); - Long nextBytesPercentage = connectionStatusAnalytics.getNextIntervalPercentageUseBytes(); - assertNotNull(nextBytesPercentage); - assert (nextBytesPercentage == 5); + public void testGetNextIntervalBytes() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,1.0,1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalBytes(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes > 0); } @Test - public void testGetNextIntervalPercentageUseCountConstantStatus() { - ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true); - Long nextCountPercentage = connectionStatusAnalytics.getNextIntervalPercentageUseCount(); - assertNotNull(nextCountPercentage); - assert (nextCountPercentage == 50); + public void testCannotPredictNextIntervalBytesNegative() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,-1.0,-1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalBytes(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes == -1); } + @Test + public void testCannotPredictNextIntervalBytesNaN() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,Double.NaN,Double.NaN); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalBytes(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes == -1); + } + + @Test + public void testCannotPredictNextIntervalBytesInfinity() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalBytes(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes == -1); + } + + @Test + public void testMissingModelNextIntervalBytes() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + try { + connectionStatusAnalytics.getNextIntervalBytes(); + fail(); + }catch(IllegalArgumentException iae){ + assertTrue(true); + } + } + + @Test + public void testGetNextIntervalCount() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,1.0,1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalCount(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes > 0); + } + + @Test + public void testCannotPredictNextIntervalCountNegative() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,-1.0,-1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalCount(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes == -1); + } + + @Test + public void testCannotPredictNextIntervalCountNaN() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,Double.NaN,Double.NaN); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalCount(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes == -1); + } + + @Test + public void testCannotPredictNextIntervalCountInfinity() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long nextIntervalBytes = connectionStatusAnalytics.getNextIntervalCount(); + assertNotNull(nextIntervalBytes); + assert (nextIntervalBytes == -1); + } + + @Test + public void testMissingModelGetNextIntervalCount() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + try { + connectionStatusAnalytics.getNextIntervalCount(); + fail(); + }catch(IllegalArgumentException iae){ + assertTrue(true); + } + } + + @Test + public void testGetNextIntervalPercentageUseCount() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedCount",.9,50.0,1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long percentage = connectionStatusAnalytics.getNextIntervalPercentageUseCount(); + assertNotNull(percentage); + assert (percentage == 50); + } + + @Test + public void testGetNextIntervalPercentageUseBytes() { + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> modelMap = getModelMap("queuedBytes",.9,10000000.0,1.0); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(modelMap); + Long percentage = connectionStatusAnalytics.getNextIntervalPercentageUseBytes(); + assertNotNull(percentage); + assert (percentage == 10); + } + + @Test + public void testGetScores() { + Date now = new Date(); + Long tomorrowMillis = DateUtils.addDays(now,1).toInstant().toEpochMilli(); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> bytesModelMap = getModelMap("queuedBytes",.9,10000000.0,tomorrowMillis.doubleValue()); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> countModelMap = getModelMap("queuedCount",.9,50.0,tomorrowMillis.doubleValue()); + countModelMap.putAll(bytesModelMap); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(countModelMap); + Map<String,Long> scores = connectionStatusAnalytics.getPredictions(); + assertNotNull(scores); + assertFalse(scores.isEmpty()); + assertTrue(scores.get("nextIntervalPercentageUseCount").equals(50L)); + assertTrue(scores.get("nextIntervalBytes").equals(10000000L)); + assertTrue(scores.get("timeToBytesBackpressureMillis") > 0); + assertTrue(scores.get("nextIntervalCount").equals(50L)); + assertTrue(scores.get("nextIntervalPercentageUseBytes").equals(10L)); + assertTrue(scores.get("intervalTimeMillis").equals(180000L)); + assertTrue(scores.get("timeToCountBackpressureMillis") > 0); + } + + @Test + public void testGetScoresWithBadModel() { + Date now = new Date(); + Long tomorrowMillis = DateUtils.addDays(now,1).toInstant().toEpochMilli(); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> bytesModelMap = getModelMap("queuedBytes",.9,10000000.0,tomorrowMillis.doubleValue()); + Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> countModelMap = getModelMap("queuedCount",.1,50.0,tomorrowMillis.doubleValue()); + countModelMap.putAll(bytesModelMap); + ConnectionStatusAnalytics connectionStatusAnalytics = getConnectionStatusAnalytics(countModelMap); + Map<String,Long> scores = connectionStatusAnalytics.getPredictions(); + assertNotNull(scores); + assertFalse(scores.isEmpty()); + assertTrue(scores.get("nextIntervalPercentageUseCount").equals(-1L)); + assertTrue(scores.get("nextIntervalBytes").equals(10000000L)); + assertTrue(scores.get("timeToBytesBackpressureMillis") > 0); + assertTrue(scores.get("nextIntervalCount").equals(-1L)); + assertTrue(scores.get("nextIntervalPercentageUseBytes").equals(10L)); + assertTrue(scores.get("intervalTimeMillis").equals(180000L)); + assertTrue(scores.get("timeToCountBackpressureMillis") == -1); + } } diff --git a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java index fce167d..477216c 100644 --- a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java +++ b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java @@ -25,22 +25,20 @@ import static org.mockito.Mockito.when; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.stream.Stream; -import org.apache.nifi.bundle.Bundle; import org.apache.nifi.controller.flow.FlowManager; import org.apache.nifi.controller.repository.FlowFileEventRepository; import org.apache.nifi.controller.status.history.ComponentStatusRepository; import org.apache.nifi.controller.status.history.StatusHistory; import org.apache.nifi.controller.status.history.StatusSnapshot; import org.apache.nifi.groups.ProcessGroup; -import org.apache.nifi.nar.ExtensionManager; -import org.apache.nifi.nar.StandardExtensionDiscoveringManager; -import org.apache.nifi.nar.SystemBundle; -import org.apache.nifi.util.NiFiProperties; import org.apache.nifi.util.Tuple; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; public abstract class TestStatusAnalyticsEngine { @@ -58,21 +56,31 @@ public abstract class TestStatusAnalyticsEngine { statusRepository = Mockito.mock(ComponentStatusRepository.class); flowManager = Mockito.mock(FlowManager.class); + modelMap = new HashMap<>(); - final Map<String, String> otherProps = new HashMap<>(); - final String propsFile = "src/test/resources/conf/nifi.properties"; - NiFiProperties nifiProperties = NiFiProperties.createBasicNiFiProperties(propsFile, otherProps); + StatusAnalyticsModel countModel = Mockito.mock(StatusAnalyticsModel.class); + StatusAnalyticsModel byteModel = Mockito.mock(StatusAnalyticsModel.class); + StatusMetricExtractFunction extractFunction = Mockito.mock(StatusMetricExtractFunction.class); + Tuple<StatusAnalyticsModel,StatusMetricExtractFunction> countTuple = new Tuple<>(countModel,extractFunction); + Tuple<StatusAnalyticsModel,StatusMetricExtractFunction> byteTuple = new Tuple<>(byteModel,extractFunction); + modelMap.put("queuedCount",countTuple); + modelMap.put("queuedBytes",byteTuple); - // use the system bundle - Bundle systemBundle = SystemBundle.create(nifiProperties); - ExtensionManager extensionManager = new StandardExtensionDiscoveringManager(); - ((StandardExtensionDiscoveringManager)extensionManager).discoverExtensions(systemBundle, Collections.emptySet()); + Double[][] features = new Double[1][1]; + Double[] target = new Double[1]; - modelMap = StatusAnalyticsModelMapFactory.getConnectionStatusModelMap(extensionManager,nifiProperties); ProcessGroup processGroup = Mockito.mock(ProcessGroup.class); StatusHistory statusHistory = Mockito.mock(StatusHistory.class); StatusSnapshot statusSnapshot = Mockito.mock(StatusSnapshot.class); + + when(extractFunction.extractMetric(anyString(),any(StatusHistory.class))).then(new Answer<Tuple<Stream<Double[]>,Stream<Double>>>() { + @Override + public Tuple<Stream<Double[]>, Stream<Double>> answer(InvocationOnMock invocationOnMock) throws Throwable { + return new Tuple<>(Stream.of(features), Stream.of(target)); + } + }); + when(statusSnapshot.getMetricDescriptors()).thenReturn(Collections.emptySet()); when(flowManager.getRootGroup()).thenReturn(processGroup); when(statusRepository.getConnectionStatusHistory(anyString(), any(), any(), anyInt())).thenReturn(statusHistory);
