This is an automated email from the ASF dual-hosted git repository.

aichrist pushed a commit to branch analytics-framework
in repository https://gitbox.apache.org/repos/asf/nifi.git

commit dbbd41bb7a36916e5feab7efa360b12c0e29a522
Author: Yolanda Davis <[email protected]>
AuthorDate: Fri Aug 23 08:36:19 2019 -0400

    NIFI-6585 - Refactored tests to use mocked models and extract functions.  
Added check in ConnectionStatusAnalytics to confirm expected model by type
---
 .../analytics/ConnectionStatusAnalytics.java       |  16 +-
 .../analytics/TestConnectionStatusAnalytics.java   | 370 ++++++++++++++++-----
 .../analytics/TestStatusAnalyticsEngine.java       |  34 +-
 3 files changed, 324 insertions(+), 96 deletions(-)

diff --git 
a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java
 
b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java
index bc2270b..68c472e 100644
--- 
a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java
+++ 
b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/main/java/org/apache/nifi/controller/status/analytics/ConnectionStatusAnalytics.java
@@ -99,6 +99,14 @@ public class ConnectionStatusAnalytics implements 
StatusAnalytics {
         });
     }
 
+    protected StatusAnalyticsModel getModel(String modelType){
+
+        if(modelMap.containsKey(modelType)){
+            return modelMap.get(modelType).getKey();
+        }else{
+            throw new IllegalArgumentException("Model cannot be found for 
provided type: " + modelType);
+        }
+    }
     /**
      * Returns the predicted time (in milliseconds) when backpressure is 
expected to be applied to this connection, based on the total number of bytes 
in the queue.
      *
@@ -106,7 +114,7 @@ public class ConnectionStatusAnalytics implements 
StatusAnalytics {
      */
     public Long getTimeToBytesBackpressureMillis() {
 
-        final StatusAnalyticsModel bytesModel = 
modelMap.get("queuedBytes").getKey();
+        final StatusAnalyticsModel bytesModel = getModel("queuedBytes");
         FlowFileEvent flowFileEvent = getStatusReport();
 
         final Connection connection = getConnection();
@@ -133,7 +141,7 @@ public class ConnectionStatusAnalytics implements 
StatusAnalytics {
      */
     public Long getTimeToCountBackpressureMillis() {
 
-        final StatusAnalyticsModel countModel = 
modelMap.get("queuedCount").getKey();
+        final StatusAnalyticsModel countModel = getModel("queuedCount");
         FlowFileEvent flowFileEvent = getStatusReport();
 
         final Connection connection = getConnection();
@@ -160,7 +168,7 @@ public class ConnectionStatusAnalytics implements 
StatusAnalytics {
      */
 
     public Long getNextIntervalBytes() {
-        final StatusAnalyticsModel bytesModel = 
modelMap.get("queuedBytes").getKey();
+        final StatusAnalyticsModel bytesModel = getModel("queuedBytes");
         FlowFileEvent flowFileEvent = getStatusReport();
 
         if (validModel(bytesModel) && flowFileEvent != null) {
@@ -182,7 +190,7 @@ public class ConnectionStatusAnalytics implements 
StatusAnalytics {
      */
 
     public Long getNextIntervalCount() {
-        final StatusAnalyticsModel countModel = 
modelMap.get("queuedCount").getKey();
+        final StatusAnalyticsModel countModel = getModel("queuedCount");
         FlowFileEvent flowFileEvent = getStatusReport();
 
         if (validModel(countModel) && flowFileEvent != null) {
diff --git 
a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java
 
b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java
index ff92215..5d9279f 100644
--- 
a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java
+++ 
b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestConnectionStatusAnalytics.java
@@ -16,7 +16,10 @@
  */
 package org.apache.nifi.controller.status.analytics;
 
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyLong;
@@ -32,7 +35,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
+import org.apache.commons.lang3.time.DateUtils;
 import org.apache.nifi.bundle.Bundle;
 import org.apache.nifi.connectable.Connection;
 import org.apache.nifi.controller.flow.FlowManager;
@@ -43,15 +48,16 @@ import 
org.apache.nifi.controller.repository.RepositoryStatusReport;
 import org.apache.nifi.controller.status.history.ComponentStatusRepository;
 import org.apache.nifi.controller.status.history.ConnectionStatusDescriptor;
 import org.apache.nifi.controller.status.history.MetricDescriptor;
-import org.apache.nifi.controller.status.history.StandardStatusSnapshot;
 import org.apache.nifi.controller.status.history.StatusHistory;
-import org.apache.nifi.controller.status.history.StatusSnapshot;
 import org.apache.nifi.groups.ProcessGroup;
 import org.apache.nifi.nar.StandardExtensionDiscoveringManager;
 import org.apache.nifi.nar.SystemBundle;
 import org.apache.nifi.util.NiFiProperties;
+import org.apache.nifi.util.Tuple;
 import org.junit.Test;
 import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 public class TestConnectionStatusAnalytics {
 
@@ -59,8 +65,7 @@ public class TestConnectionStatusAnalytics {
             .map(ConnectionStatusDescriptor::getDescriptor)
             .collect(Collectors.toSet());
 
-    protected ConnectionStatusAnalytics getConnectionStatusAnalytics(Long 
queuedBytes, Long queuedCount, String backPressureDataSizeThreshhold,
-                                                                     Long 
backPressureObjectThreshold, Boolean isConstantStatus) {
+    protected ConnectionStatusAnalytics 
getConnectionStatusAnalytics(Map<String, Tuple<StatusAnalyticsModel, 
StatusMetricExtractFunction>> modelMap) {
 
         ComponentStatusRepository statusRepository = 
Mockito.mock(ComponentStatusRepository.class);
         FlowManager flowManager;
@@ -87,135 +92,342 @@ public class TestConnectionStatusAnalytics {
         final String connectionIdentifier = "1";
         connections.add(connection);
 
-        List<StatusSnapshot> snapshotList = new ArrayList<>();
-        final long startTime = System.currentTimeMillis();
-        int iterations = 10;
-
-        Long inputBytes = queuedBytes * 2;
-        Long outputBytes = inputBytes - queuedBytes;
-        Long inputCount = queuedCount * 2;
-        Long outputCount = inputCount - queuedCount;
-
-        for (int i = 0; i < iterations; i++) {
-            final StandardStatusSnapshot snapshot = new 
StandardStatusSnapshot(CONNECTION_METRICS);
-            snapshot.setTimestamp(new Date(startTime + i * 1000));
-            
snapshot.addStatusMetric(ConnectionStatusDescriptor.QUEUED_BYTES.getDescriptor(),
 (isConstantStatus || i < 5) ? queuedBytes : queuedBytes * 2);
-            
snapshot.addStatusMetric(ConnectionStatusDescriptor.QUEUED_COUNT.getDescriptor(),
 (isConstantStatus || i < 5) ? queuedCount : queuedCount * 2);
-            
snapshot.addStatusMetric(ConnectionStatusDescriptor.INPUT_BYTES.getDescriptor(),
 (isConstantStatus || i < 5) ? inputBytes : inputBytes * 2);
-            
snapshot.addStatusMetric(ConnectionStatusDescriptor.INPUT_COUNT.getDescriptor(),
 (isConstantStatus || i < 5) ? inputCount : inputCount * 2);
-            
snapshot.addStatusMetric(ConnectionStatusDescriptor.OUTPUT_BYTES.getDescriptor(),
 (isConstantStatus || i < 5) ? outputBytes : outputBytes * 2);
-            
snapshot.addStatusMetric(ConnectionStatusDescriptor.OUTPUT_COUNT.getDescriptor(),
 (isConstantStatus || i < 5) ? outputCount : outputCount * 2);
-            snapshotList.add(snapshot);
-        }
-
-        
when(flowFileQueue.getBackPressureDataSizeThreshold()).thenReturn(backPressureDataSizeThreshhold);
-        
when(flowFileQueue.getBackPressureObjectThreshold()).thenReturn(backPressureObjectThreshold);
+        
when(flowFileQueue.getBackPressureDataSizeThreshold()).thenReturn("100MB");
+        when(flowFileQueue.getBackPressureObjectThreshold()).thenReturn(100L);
         when(connection.getIdentifier()).thenReturn(connectionIdentifier);
         when(connection.getFlowFileQueue()).thenReturn(flowFileQueue);
         when(processGroup.findAllConnections()).thenReturn(connections);
-        when(statusHistory.getStatusSnapshots()).thenReturn(snapshotList);
         when(flowManager.getRootGroup()).thenReturn(processGroup);
-        when(flowFileEvent.getContentSizeIn()).thenReturn(inputBytes);
-        when(flowFileEvent.getContentSizeOut()).thenReturn(outputBytes);
-        when(flowFileEvent.getFlowFilesIn()).thenReturn(inputCount.intValue());
-        
when(flowFileEvent.getFlowFilesOut()).thenReturn(outputCount.intValue());
+        when(flowFileEvent.getContentSizeIn()).thenReturn(10L);
+        when(flowFileEvent.getContentSizeOut()).thenReturn(10L);
+        when(flowFileEvent.getFlowFilesIn()).thenReturn(10);
+        when(flowFileEvent.getFlowFilesOut()).thenReturn(10);
         
when(flowFileEventRepository.reportTransferEvents(anyLong())).thenReturn(repositoryStatusReport);
         
when(repositoryStatusReport.getReportEntry(anyString())).thenReturn(flowFileEvent);
         when(statusRepository.getConnectionStatusHistory(anyString(), any(), 
any(), anyInt())).thenReturn(statusHistory);
 
         ConnectionStatusAnalytics connectionStatusAnalytics = new 
ConnectionStatusAnalytics(statusRepository, flowManager,flowFileEventRepository,
-                
StatusAnalyticsModelMapFactory.getConnectionStatusModelMap(extensionManager,nifiProperties),
 connectionIdentifier, false);
+                                                                               
             modelMap, connectionIdentifier, false);
         connectionStatusAnalytics.refresh();
         return connectionStatusAnalytics;
     }
 
+    public Map<String, Tuple<StatusAnalyticsModel, 
StatusMetricExtractFunction>> getModelMap( String predictionType, Double score,
+                                                                               
 Double targetPrediction, Double variablePrediction){
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = new HashMap<>();
+        StatusAnalyticsModel model = Mockito.mock(StatusAnalyticsModel.class);
+        StatusMetricExtractFunction extractFunction = 
Mockito.mock(StatusMetricExtractFunction.class);
+        Tuple<StatusAnalyticsModel,StatusMetricExtractFunction> modelTuple = 
new Tuple<>(model,extractFunction);
+        modelMap.put(predictionType,modelTuple);
+        Map<String,Double> scores = new HashMap<>();
+        scores.put("rSquared",score);
+
+        Double[][] features = new Double[1][1];
+        Double[] target = new Double[1];
+
+        
when(extractFunction.extractMetric(anyString(),any(StatusHistory.class))).then(new
 Answer<Tuple<Stream<Double[]>,Stream<Double>>>() {
+            @Override
+            public Tuple<Stream<Double[]>, Stream<Double>> 
answer(InvocationOnMock invocationOnMock) throws Throwable {
+                return new Tuple<>(Stream.of(features), Stream.of(target));
+            }
+        });
+
+        when(model.getScores()).thenReturn(scores);
+        when(model.predict(any(Double[].class))).thenReturn(targetPrediction);
+        
when(model.predictVariable(anyInt(),any(),any())).thenReturn(variablePrediction);
+        return modelMap;
+
+    }
+
+    @Test
+    public void testInvalidModelLowScore() {
+        Date now = new Date();
+        Long tomorrowMillis = 
DateUtils.addDays(now,1).toInstant().toEpochMilli();
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.5,100.0,tomorrowMillis.doubleValue());
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
+    }
+
+      @Test
+    public void testInvalidModelNaNScore() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",Double.NaN,Double.NaN,Double.NaN);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
+    }
+
+    @Test
+    public void testInvalidModelInfiniteScore() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("queuedCount",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
+    }
+
     @Test
     public void testGetIntervalTimeMillis() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true);
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,100.0,100.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
         Long interval = connectionStatusAnalytics.getIntervalTimeMillis();
         assertNotNull(interval);
         assert (interval == 180000);
     }
 
     @Test
-    public void testGetTimeToCountBackpressureMillisConstantStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true);
+    public void testGetTimeToCountBackpressureMillis() {
+        Date now = new Date();
+        Long tomorrowMillis = 
DateUtils.addDays(now,1).toInstant().toEpochMilli();
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,100.0,tomorrowMillis.doubleValue());
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
         Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
         assertNotNull(countTime);
         assert (countTime > 0);
     }
 
     @Test
-    public void testGetTimeToCountBackpressureMillisVaryingStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false);
+    public void testCannotPredictTimeToCountNaN() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,Double.NaN,Double.NaN);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
         Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
         assertNotNull(countTime);
-        assert (countTime == -1L);
+        assert (countTime == -1);
     }
 
     @Test
-    public void testGetTimeToBytesBackpressureMillisConstantStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true);
-        Long bytesTime = 
connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
-        assertNotNull(bytesTime);
-        assert (bytesTime == -1L || bytesTime == 0);
+    public void testCannotPredictTimeToCountInfinite() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("queuedCount",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
     }
 
     @Test
-    public void testGetTimeToBytesBackpressureMillisVaryingStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false);
-        Long bytesTime = 
connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
-        assertNotNull(bytesTime);
-        assert (bytesTime == -1L);
+    public void testCannotPredictTimeToCountNegative() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,-1.0,-1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToCountBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
     }
 
     @Test
-    public void testGetNextIntervalBytesConstantStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true);
-        Long nextBytes = connectionStatusAnalytics.getNextIntervalBytes();
-        assertNotNull(nextBytes);
-        assert (nextBytes == 5000L);
+    public void testMissingModelGetTimeToCount() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        try {
+            connectionStatusAnalytics.getTimeToCountBackpressureMillis();
+            fail();
+        }catch(IllegalArgumentException iae){
+            assertTrue(true);
+        }
     }
 
     @Test
-    public void testGetNextIntervalBytesVaryingStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false);
-        Long nextBytes = connectionStatusAnalytics.getNextIntervalBytes();
-        assertNotNull(nextBytes);
-        assert (nextBytes == -1L);
+    public void testGetTimeToBytesBackpressureMillis() {
+        Date now = new Date();
+        Long tomorrowMillis = 
DateUtils.addDays(now,1).toInstant().toEpochMilli();
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,100.0,tomorrowMillis.doubleValue());
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime > 0);
     }
 
     @Test
-    public void testGetNextIntervalCountConstantStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true);
-        Long nextCount = connectionStatusAnalytics.getNextIntervalCount();
-        assertNotNull(nextCount);
-        assert (nextCount == 50L);
+    public void testCannotPredictTimeToBytesNaN() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,Double.NaN,Double.NaN);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
     }
 
     @Test
-    public void testGetNextIntervalCountVaryingStatus() {
+    public void testCannotPredictTimeToBytesInfinite() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("queuedBytes",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
+    }
 
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, false);
-        Long nextCount = connectionStatusAnalytics.getNextIntervalCount();
-        assertNotNull(nextCount);
-        assert (nextCount == -1L);
+    @Test
+    public void testCannotPredictTimeToBytesNegative() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,-1.0,-1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long countTime = 
connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
+        assertNotNull(countTime);
+        assert (countTime == -1);
+    }
+
+    @Test
+    public void testMissingModelGetTimeToBytes() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        try {
+            connectionStatusAnalytics.getTimeToBytesBackpressureMillis();
+            fail();
+        }catch(IllegalArgumentException iae){
+            assertTrue(true);
+        }
     }
 
     @Test
-    public void testGetNextIntervalPercentageUseBytesConstantStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(50000L, 50L, "1MB", 100L, true);
-        Long nextBytesPercentage = 
connectionStatusAnalytics.getNextIntervalPercentageUseBytes();
-        assertNotNull(nextBytesPercentage);
-        assert (nextBytesPercentage == 5);
+    public void testGetNextIntervalBytes() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,1.0,1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalBytes();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes > 0);
     }
 
     @Test
-    public void testGetNextIntervalPercentageUseCountConstantStatus() {
-        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(5000L, 50L, "100MB", 100L, true);
-        Long nextCountPercentage = 
connectionStatusAnalytics.getNextIntervalPercentageUseCount();
-        assertNotNull(nextCountPercentage);
-        assert (nextCountPercentage == 50);
+    public void testCannotPredictNextIntervalBytesNegative() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,-1.0,-1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalBytes();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes == -1);
     }
 
+    @Test
+    public void testCannotPredictNextIntervalBytesNaN() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,Double.NaN,Double.NaN);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalBytes();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes == -1);
+    }
+
+    @Test
+    public void testCannotPredictNextIntervalBytesInfinity() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("queuedBytes",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalBytes();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes == -1);
+    }
+
+    @Test
+    public void testMissingModelNextIntervalBytes() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        try {
+            connectionStatusAnalytics.getNextIntervalBytes();
+            fail();
+        }catch(IllegalArgumentException iae){
+            assertTrue(true);
+        }
+    }
+
+    @Test
+    public void testGetNextIntervalCount() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,1.0,1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalCount();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes > 0);
+    }
+
+    @Test
+    public void testCannotPredictNextIntervalCountNegative() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,-1.0,-1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalCount();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes == -1);
+    }
+
+    @Test
+    public void testCannotPredictNextIntervalCountNaN() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,Double.NaN,Double.NaN);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalCount();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes == -1);
+    }
+
+    @Test
+    public void testCannotPredictNextIntervalCountInfinity() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("queuedCount",.9,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long nextIntervalBytes = 
connectionStatusAnalytics.getNextIntervalCount();
+        assertNotNull(nextIntervalBytes);
+        assert (nextIntervalBytes == -1);
+    }
+
+    @Test
+    public void testMissingModelGetNextIntervalCount() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = 
getModelMap("fakeModel",Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY,Double.POSITIVE_INFINITY);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        try {
+            connectionStatusAnalytics.getNextIntervalCount();
+            fail();
+        }catch(IllegalArgumentException iae){
+            assertTrue(true);
+        }
+    }
+
+    @Test
+    public void testGetNextIntervalPercentageUseCount() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedCount",.9,50.0,1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long percentage = 
connectionStatusAnalytics.getNextIntervalPercentageUseCount();
+        assertNotNull(percentage);
+        assert (percentage == 50);
+    }
+
+    @Test
+    public void testGetNextIntervalPercentageUseBytes() {
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
modelMap = getModelMap("queuedBytes",.9,10000000.0,1.0);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(modelMap);
+        Long percentage = 
connectionStatusAnalytics.getNextIntervalPercentageUseBytes();
+        assertNotNull(percentage);
+        assert (percentage == 10);
+    }
+
+    @Test
+    public void testGetScores() {
+        Date now = new Date();
+        Long tomorrowMillis = 
DateUtils.addDays(now,1).toInstant().toEpochMilli();
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
bytesModelMap = 
getModelMap("queuedBytes",.9,10000000.0,tomorrowMillis.doubleValue());
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
countModelMap = getModelMap("queuedCount",.9,50.0,tomorrowMillis.doubleValue());
+        countModelMap.putAll(bytesModelMap);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(countModelMap);
+        Map<String,Long> scores = connectionStatusAnalytics.getPredictions();
+        assertNotNull(scores);
+        assertFalse(scores.isEmpty());
+        assertTrue(scores.get("nextIntervalPercentageUseCount").equals(50L));
+        assertTrue(scores.get("nextIntervalBytes").equals(10000000L));
+        assertTrue(scores.get("timeToBytesBackpressureMillis") > 0);
+        assertTrue(scores.get("nextIntervalCount").equals(50L));
+        assertTrue(scores.get("nextIntervalPercentageUseBytes").equals(10L));
+        assertTrue(scores.get("intervalTimeMillis").equals(180000L));
+        assertTrue(scores.get("timeToCountBackpressureMillis") > 0);
+    }
+
+    @Test
+    public void testGetScoresWithBadModel() {
+        Date now = new Date();
+        Long tomorrowMillis = 
DateUtils.addDays(now,1).toInstant().toEpochMilli();
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
bytesModelMap = 
getModelMap("queuedBytes",.9,10000000.0,tomorrowMillis.doubleValue());
+        Map<String, Tuple<StatusAnalyticsModel, StatusMetricExtractFunction>> 
countModelMap = getModelMap("queuedCount",.1,50.0,tomorrowMillis.doubleValue());
+        countModelMap.putAll(bytesModelMap);
+        ConnectionStatusAnalytics connectionStatusAnalytics = 
getConnectionStatusAnalytics(countModelMap);
+        Map<String,Long> scores = connectionStatusAnalytics.getPredictions();
+        assertNotNull(scores);
+        assertFalse(scores.isEmpty());
+        assertTrue(scores.get("nextIntervalPercentageUseCount").equals(-1L));
+        assertTrue(scores.get("nextIntervalBytes").equals(10000000L));
+        assertTrue(scores.get("timeToBytesBackpressureMillis") > 0);
+        assertTrue(scores.get("nextIntervalCount").equals(-1L));
+        assertTrue(scores.get("nextIntervalPercentageUseBytes").equals(10L));
+        assertTrue(scores.get("intervalTimeMillis").equals(180000L));
+        assertTrue(scores.get("timeToCountBackpressureMillis") == -1);
+    }
 }
diff --git 
a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java
 
b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java
index fce167d..477216c 100644
--- 
a/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java
+++ 
b/nifi-nar-bundles/nifi-framework-bundle/nifi-framework/nifi-framework-core/src/test/java/org/apache/nifi/controller/status/analytics/TestStatusAnalyticsEngine.java
@@ -25,22 +25,20 @@ import static org.mockito.Mockito.when;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.stream.Stream;
 
-import org.apache.nifi.bundle.Bundle;
 import org.apache.nifi.controller.flow.FlowManager;
 import org.apache.nifi.controller.repository.FlowFileEventRepository;
 import org.apache.nifi.controller.status.history.ComponentStatusRepository;
 import org.apache.nifi.controller.status.history.StatusHistory;
 import org.apache.nifi.controller.status.history.StatusSnapshot;
 import org.apache.nifi.groups.ProcessGroup;
-import org.apache.nifi.nar.ExtensionManager;
-import org.apache.nifi.nar.StandardExtensionDiscoveringManager;
-import org.apache.nifi.nar.SystemBundle;
-import org.apache.nifi.util.NiFiProperties;
 import org.apache.nifi.util.Tuple;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 public abstract class TestStatusAnalyticsEngine {
 
@@ -58,21 +56,31 @@ public abstract class TestStatusAnalyticsEngine {
 
         statusRepository = Mockito.mock(ComponentStatusRepository.class);
         flowManager = Mockito.mock(FlowManager.class);
+        modelMap = new HashMap<>();
 
-        final Map<String, String> otherProps = new HashMap<>();
-        final String propsFile = "src/test/resources/conf/nifi.properties";
-        NiFiProperties nifiProperties = 
NiFiProperties.createBasicNiFiProperties(propsFile, otherProps);
+        StatusAnalyticsModel countModel = 
Mockito.mock(StatusAnalyticsModel.class);
+        StatusAnalyticsModel byteModel = 
Mockito.mock(StatusAnalyticsModel.class);
+        StatusMetricExtractFunction extractFunction = 
Mockito.mock(StatusMetricExtractFunction.class);
+        Tuple<StatusAnalyticsModel,StatusMetricExtractFunction> countTuple = 
new Tuple<>(countModel,extractFunction);
+        Tuple<StatusAnalyticsModel,StatusMetricExtractFunction> byteTuple = 
new Tuple<>(byteModel,extractFunction);
+        modelMap.put("queuedCount",countTuple);
+        modelMap.put("queuedBytes",byteTuple);
 
-        // use the system bundle
-        Bundle systemBundle = SystemBundle.create(nifiProperties);
-        ExtensionManager extensionManager = new 
StandardExtensionDiscoveringManager();
-        
((StandardExtensionDiscoveringManager)extensionManager).discoverExtensions(systemBundle,
 Collections.emptySet());
+        Double[][] features = new Double[1][1];
+        Double[] target = new Double[1];
 
-        modelMap = 
StatusAnalyticsModelMapFactory.getConnectionStatusModelMap(extensionManager,nifiProperties);
 
         ProcessGroup processGroup = Mockito.mock(ProcessGroup.class);
         StatusHistory statusHistory = Mockito.mock(StatusHistory.class);
         StatusSnapshot statusSnapshot = Mockito.mock(StatusSnapshot.class);
+
+        
when(extractFunction.extractMetric(anyString(),any(StatusHistory.class))).then(new
 Answer<Tuple<Stream<Double[]>,Stream<Double>>>() {
+            @Override
+            public Tuple<Stream<Double[]>, Stream<Double>> 
answer(InvocationOnMock invocationOnMock) throws Throwable {
+                return new Tuple<>(Stream.of(features), Stream.of(target));
+            }
+        });
+
         
when(statusSnapshot.getMetricDescriptors()).thenReturn(Collections.emptySet());
         when(flowManager.getRootGroup()).thenReturn(processGroup);
         when(statusRepository.getConnectionStatusHistory(anyString(), any(), 
any(), anyInt())).thenReturn(statusHistory);

Reply via email to