Repository: metron Updated Branches: refs/heads/master 43e529fe0 -> d0e1ba508
METRON-569 Enrichment topology duplicates messages (merrimanr) closes apache/metron#603 Project: http://git-wip-us.apache.org/repos/asf/metron/repo Commit: http://git-wip-us.apache.org/repos/asf/metron/commit/d0e1ba50 Tree: http://git-wip-us.apache.org/repos/asf/metron/tree/d0e1ba50 Diff: http://git-wip-us.apache.org/repos/asf/metron/diff/d0e1ba50 Branch: refs/heads/master Commit: d0e1ba5089eb08187fd36c55fbe40539ef1f7890 Parents: 43e529f Author: merrimanr <[email protected]> Authored: Mon Jun 5 16:51:18 2017 -0500 Committer: merrimanr <[email protected]> Committed: Mon Jun 5 16:51:18 2017 -0500 ---------------------------------------------------------------------- .../src/main/flux/enrichment/remote.yaml | 4 +- .../enrichment/bolt/EnrichmentJoinBolt.java | 7 +- .../apache/metron/enrichment/bolt/JoinBolt.java | 29 +++--- .../enrichment/bolt/ThreatIntelJoinBolt.java | 6 +- .../enrichment/bolt/EnrichmentJoinBoltTest.java | 18 +++- .../metron/enrichment/bolt/JoinBoltTest.java | 98 ++++++++++++++------ .../bolt/ThreatIntelJoinBoltTest.java | 40 ++++---- 7 files changed, 135 insertions(+), 67 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/main/flux/enrichment/remote.yaml ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/main/flux/enrichment/remote.yaml b/metron-platform/metron-enrichment/src/main/flux/enrichment/remote.yaml index e4f119e..0e50f77 100644 --- a/metron-platform/metron-enrichment/src/main/flux/enrichment/remote.yaml +++ b/metron-platform/metron-enrichment/src/main/flux/enrichment/remote.yaml @@ -312,7 +312,7 @@ bolts: - "${kafka.zk}" configMethods: - name: "withMaxCacheSize" - args: [10000] + args: [100000] - name: "withMaxTimeRetain" args: [10] - id: "enrichmentErrorOutputBolt" @@ -366,7 +366,7 @@ bolts: - "${kafka.zk}" configMethods: - name: "withMaxCacheSize" - args: [10000] + args: [100000] - name: "withMaxTimeRetain" args: [10] - id: "threatIntelErrorOutputBolt" http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBolt.java ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBolt.java b/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBolt.java index 2adf430..4b88399 100644 --- a/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBolt.java +++ b/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBolt.java @@ -17,11 +17,13 @@ */ package org.apache.metron.enrichment.bolt; +import org.apache.metron.common.message.MessageGetStrategy; import org.apache.storm.task.TopologyContext; import com.google.common.base.Joiner; import org.apache.metron.common.configuration.enrichment.SensorEnrichmentConfig; import org.apache.metron.common.configuration.enrichment.handler.ConfigHandler; import org.apache.metron.common.utils.MessageUtils; +import org.apache.storm.tuple.Tuple; import org.json.simple.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,10 +68,11 @@ public class EnrichmentJoinBolt extends JoinBolt<JSONObject> { @Override - public JSONObject joinMessages(Map<String, JSONObject> streamMessageMap) { + public JSONObject joinMessages(Map<String, Tuple> streamMessageMap, MessageGetStrategy messageGetStrategy) { JSONObject message = new JSONObject(); for (String key : streamMessageMap.keySet()) { - JSONObject obj = streamMessageMap.get(key); + Tuple tuple = streamMessageMap.get(key); + JSONObject obj = (JSONObject) messageGetStrategy.get(tuple); message.putAll(obj); } List<Object> emptyKeys = new ArrayList<>(); http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/JoinBolt.java ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/JoinBolt.java b/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/JoinBolt.java index a8e793d..f3fe52c 100644 --- a/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/JoinBolt.java +++ b/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/JoinBolt.java @@ -51,11 +51,11 @@ public abstract class JoinBolt<V> extends ConfiguredEnrichmentBolt { .getLogger(JoinBolt.class); protected OutputCollector collector; - protected transient CacheLoader<String, Map<String, V>> loader; - protected transient LoadingCache<String, Map<String, V>> cache; - private transient MessageGetStrategy keyGetStrategy; - private transient MessageGetStrategy subgroupGetStrategy; - private transient MessageGetStrategy messageGetStrategy; + protected transient CacheLoader<String, Map<String, Tuple>> loader; + protected transient LoadingCache<String, Map<String, Tuple>> cache; + protected transient MessageGetStrategy keyGetStrategy; + protected transient MessageGetStrategy subgroupGetStrategy; + protected transient MessageGetStrategy messageGetStrategy; protected Long maxCacheSize; protected Long maxTimeRetain; @@ -86,9 +86,9 @@ public abstract class JoinBolt<V> extends ConfiguredEnrichmentBolt { if (this.maxTimeRetain == null) { throw new IllegalStateException("maxTimeRetain must be specified"); } - loader = new CacheLoader<String, Map<String, V>>() { + loader = new CacheLoader<String, Map<String, Tuple>>() { @Override - public Map<String, V> load(String key) throws Exception { + public Map<String, Tuple> load(String key) throws Exception { return new HashMap<>(); } }; @@ -98,10 +98,10 @@ public abstract class JoinBolt<V> extends ConfiguredEnrichmentBolt { prepare(map, topologyContext); } - class JoinRemoveListener implements RemovalListener<String, Map<String, V>> { + class JoinRemoveListener implements RemovalListener<String, Map<String, Tuple>> { @Override - public void onRemoval(RemovalNotification<String, Map<String, V>> removalNotification) { + public void onRemoval(RemovalNotification<String, Map<String, Tuple>> removalNotification) { if (removalNotification.getCause() == RemovalCause.SIZE) { String errorMessage = "Join cache reached max size limit. Increase the maxCacheSize setting or add more tasks to enrichment/threatintel join bolt."; Exception exception = new Exception(errorMessage); @@ -126,12 +126,12 @@ public abstract class JoinBolt<V> extends ConfiguredEnrichmentBolt { streamId = Joiner.on(":").join("" + streamId, subgroup == null?"":subgroup); V message = (V) messageGetStrategy.get(tuple); try { - Map<String, V> streamMessageMap = cache.get(key); + Map<String, Tuple> streamMessageMap = cache.get(key); if (streamMessageMap.containsKey(streamId)) { LOG.warn(String.format("Received key %s twice for " + "stream %s", key, streamId)); } - streamMessageMap.put(streamId, message); + streamMessageMap.put(streamId, tuple); Set<String> streamIds = getStreamIds(message); Set<String> streamMessageKeys = streamMessageMap.keySet(); if ( streamMessageKeys.size() == streamIds.size() @@ -141,11 +141,12 @@ public abstract class JoinBolt<V> extends ConfiguredEnrichmentBolt { collector.emit( "message" , tuple , new Values( key - , joinMessages(streamMessageMap) + , joinMessages(streamMessageMap, this.messageGetStrategy) ) ); cache.invalidate(key); - collector.ack(tuple); + Tuple messageTuple = streamMessageMap.get("message:"); + collector.ack(messageTuple); LOG.trace("Emitted message for key: {}", key); } else { cache.put(key, streamMessageMap); @@ -177,5 +178,5 @@ public abstract class JoinBolt<V> extends ConfiguredEnrichmentBolt { public abstract Set<String> getStreamIds(V value); - public abstract V joinMessages(Map<String, V> streamMessageMap); + public abstract V joinMessages(Map<String, Tuple> streamMessageMap, MessageGetStrategy messageGetStrategy); } http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBolt.java ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBolt.java b/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBolt.java index 4d924c3..d4865e2 100644 --- a/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBolt.java +++ b/metron-platform/metron-enrichment/src/main/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBolt.java @@ -27,11 +27,13 @@ import org.apache.metron.common.configuration.enrichment.threatintel.ThreatTriag import org.apache.metron.common.dsl.Context; import org.apache.metron.common.dsl.StellarFunctions; import org.apache.metron.common.dsl.functions.resolver.FunctionResolver; +import org.apache.metron.common.message.MessageGetStrategy; import org.apache.metron.common.utils.ConversionUtils; import org.apache.metron.common.utils.MessageUtils; import org.apache.metron.enrichment.adapters.geo.GeoLiteDatabase; import org.apache.metron.threatintel.triage.ThreatTriageProcessor; import org.apache.storm.task.TopologyContext; +import org.apache.storm.tuple.Tuple; import org.json.simple.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -132,8 +134,8 @@ public class ThreatIntelJoinBolt extends EnrichmentJoinBolt { } @Override - public JSONObject joinMessages(Map<String, JSONObject> streamMessageMap) { - JSONObject ret = super.joinMessages(streamMessageMap); + public JSONObject joinMessages(Map<String, Tuple> streamMessageMap, MessageGetStrategy messageGetStrategy) { + JSONObject ret = super.joinMessages(streamMessageMap, messageGetStrategy); LOG.trace("Received joined messages: {}", ret); boolean isAlert = ret.containsKey("is_alert"); if(!isAlert) { http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBoltTest.java ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBoltTest.java b/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBoltTest.java index 56ddf08..77dd4cf 100644 --- a/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBoltTest.java +++ b/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/EnrichmentJoinBoltTest.java @@ -18,7 +18,9 @@ package org.apache.metron.enrichment.bolt; import org.adrianwalker.multilinestring.Multiline; +import org.apache.metron.common.message.MessageGetStrategy; import org.apache.metron.test.bolt.BaseEnrichmentBoltTest; +import org.apache.storm.tuple.Tuple; import org.json.simple.JSONObject; import org.json.simple.parser.JSONParser; import org.json.simple.parser.ParseException; @@ -32,6 +34,9 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class EnrichmentJoinBoltTest extends BaseEnrichmentBoltTest { /** @@ -75,10 +80,15 @@ public class EnrichmentJoinBoltTest extends BaseEnrichmentBoltTest { enrichmentJoinBolt.prepare(new HashMap<>(), topologyContext, outputCollector); Set<String> actualStreamIds = enrichmentJoinBolt.getStreamIds(sampleMessage); Assert.assertEquals(joinStreamIds, actualStreamIds); - Map<String, JSONObject> streamMessageMap = new HashMap<>(); - streamMessageMap.put("message", sampleMessage); - streamMessageMap.put("enriched", enrichedMessage); - JSONObject joinedMessage = enrichmentJoinBolt.joinMessages(streamMessageMap); + Map<String, Tuple> streamMessageMap = new HashMap<>(); + MessageGetStrategy messageGetStrategy = mock(MessageGetStrategy.class); + Tuple sampleTuple = mock(Tuple.class); + when(messageGetStrategy.get(sampleTuple)).thenReturn(sampleMessage); + Tuple enrichedTuple = mock(Tuple.class); + when(messageGetStrategy.get(enrichedTuple)).thenReturn(enrichedMessage); + streamMessageMap.put("message", sampleTuple); + streamMessageMap.put("enriched", enrichedTuple); + JSONObject joinedMessage = enrichmentJoinBolt.joinMessages(streamMessageMap, messageGetStrategy); removeTimingFields(joinedMessage); Assert.assertEquals(expectedJoinedMessage, joinedMessage); } http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/JoinBoltTest.java ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/JoinBoltTest.java b/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/JoinBoltTest.java index 9f12fcd..e03dc71 100644 --- a/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/JoinBoltTest.java +++ b/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/JoinBoltTest.java @@ -21,9 +21,11 @@ import com.google.common.cache.LoadingCache; import org.adrianwalker.multilinestring.Multiline; import org.apache.metron.common.Constants; import org.apache.metron.common.error.MetronError; +import org.apache.metron.common.message.MessageGetStrategy; import org.apache.metron.test.bolt.BaseEnrichmentBoltTest; import org.apache.metron.test.error.MetronErrorJSONMatcher; import org.apache.storm.task.TopologyContext; +import org.apache.storm.tuple.Tuple; import org.apache.storm.tuple.Values; import org.json.simple.JSONObject; import org.json.simple.parser.JSONParser; @@ -44,6 +46,7 @@ import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class JoinBoltTest extends BaseEnrichmentBoltTest { @@ -65,11 +68,12 @@ public class JoinBoltTest extends BaseEnrichmentBoltTest { for(String s : streamIds) { ret.add(s + ":"); } + ret.add("message:"); return ret; } @Override - public JSONObject joinMessages(Map<String, JSONObject> streamMessageMap) { + public JSONObject joinMessages(Map<String, Tuple> streamMessageMap, MessageGetStrategy messageGetStrategy) { return joinedMessage; } } @@ -83,6 +87,7 @@ public class JoinBoltTest extends BaseEnrichmentBoltTest { private String joinedMessageString; private JSONObject joinedMessage; + private JoinBolt<JSONObject> joinBolt; @Before public void parseMessages() { @@ -92,13 +97,13 @@ public class JoinBoltTest extends BaseEnrichmentBoltTest { } catch (ParseException e) { e.printStackTrace(); } + joinBolt = new StandAloneJoinBolt("zookeeperUrl"); + joinBolt.setCuratorFramework(client); + joinBolt.setTreeCache(cache); } @Test - public void test() throws Exception { - StandAloneJoinBolt joinBolt = new StandAloneJoinBolt("zookeeperUrl"); - joinBolt.setCuratorFramework(client); - joinBolt.setTreeCache(cache); + public void testPrepare() { try { joinBolt.prepare(new HashMap(), topologyContext, outputCollector); fail("Should fail if a maxCacheSize property is not set"); @@ -110,38 +115,79 @@ public class JoinBoltTest extends BaseEnrichmentBoltTest { } catch(IllegalStateException e) {} joinBolt.withMaxTimeRetain(10000); joinBolt.prepare(new HashMap(), topologyContext, outputCollector); + } + + @Test + public void testDeclareOutputFields() { joinBolt.declareOutputFields(declarer); verify(declarer, times(1)).declareStream(eq("message"), argThat(new FieldsMatcher("key", "message"))); - when(tuple.getValueByField("key")).thenReturn(key); - when(tuple.getSourceStreamId()).thenReturn("geo"); - when(tuple.getValueByField("message")).thenReturn(geoMessage); - joinBolt.execute(tuple); - verify(outputCollector, times(0)).emit(eq("message"), any(tuple.getClass()), any(Values.class)); - verify(outputCollector, times(0)).ack(tuple); - when(tuple.getSourceStreamId()).thenReturn("host"); - when(tuple.getValueByField("message")).thenReturn(hostMessage); - joinBolt.execute(tuple); - verify(outputCollector, times(0)).emit(eq("message"), any(tuple.getClass()), any(Values.class)); - verify(outputCollector, times(0)).ack(tuple); - when(tuple.getSourceStreamId()).thenReturn("hbaseEnrichment"); - when(tuple.getValueByField("message")).thenReturn(hbaseEnrichmentMessage); - joinBolt.execute(tuple); - when(tuple.getSourceStreamId()).thenReturn("stellar"); - when(tuple.getValueByField("message")).thenReturn(new JSONObject()); - verify(outputCollector, times(0)).emit(eq("message"), any(tuple.getClass()), eq(new Values(key, joinedMessage))); - joinBolt.execute(tuple); + verify(declarer, times(1)).declareStream(eq("error"), argThat(new FieldsMatcher("message"))); + verifyNoMoreInteractions(declarer); + } + + @Test + public void testExecute() { + joinBolt.withMaxCacheSize(100); + joinBolt.withMaxTimeRetain(10000); + joinBolt.prepare(new HashMap(), topologyContext, outputCollector); + + Tuple geoTuple = mock(Tuple.class); + when(geoTuple.getValueByField("key")).thenReturn(key); + when(geoTuple.getSourceStreamId()).thenReturn("geo"); + when(geoTuple.getValueByField("message")).thenReturn(geoMessage); + joinBolt.execute(geoTuple); + + Tuple messageTuple = mock(Tuple.class); + when(messageTuple.getValueByField("key")).thenReturn(key); + when(messageTuple.getSourceStreamId()).thenReturn("message"); + when(messageTuple.getValueByField("message")).thenReturn(sampleMessage); + joinBolt.execute(messageTuple); + + Tuple hostTuple = mock(Tuple.class); + when(hostTuple.getValueByField("key")).thenReturn(key); + when(hostTuple.getSourceStreamId()).thenReturn("host"); + when(hostTuple.getValueByField("message")).thenReturn(hostMessage); + joinBolt.execute(hostTuple); + + Tuple hbaseEnrichmentTuple = mock(Tuple.class); + when(hbaseEnrichmentTuple.getValueByField("key")).thenReturn(key); + when(hbaseEnrichmentTuple.getSourceStreamId()).thenReturn("hbaseEnrichment"); + when(hbaseEnrichmentTuple.getValueByField("message")).thenReturn(hbaseEnrichmentMessage); + joinBolt.execute(hbaseEnrichmentTuple); + + Tuple stellarTuple = mock(Tuple.class); + when(stellarTuple.getValueByField("key")).thenReturn(key); + when(stellarTuple.getSourceStreamId()).thenReturn("stellar"); + when(stellarTuple.getValueByField("message")).thenReturn(new JSONObject()); + joinBolt.execute(stellarTuple); + verify(outputCollector, times(1)).emit(eq("message"), any(tuple.getClass()), eq(new Values(key, joinedMessage))); - verify(outputCollector, times(1)).ack(tuple); + verify(outputCollector, times(1)).ack(messageTuple); + + verifyNoMoreInteractions(outputCollector); + } + @SuppressWarnings("unchecked") + @Test + public void testExecuteShouldReportError() throws ExecutionException { + joinBolt.withMaxCacheSize(100); + joinBolt.withMaxTimeRetain(10000); + joinBolt.prepare(new HashMap(), topologyContext, outputCollector); + when(tuple.getValueByField("key")).thenReturn(key); + when(tuple.getValueByField("message")).thenReturn(new JSONObject()); joinBolt.cache = mock(LoadingCache.class); when(joinBolt.cache.get(key)).thenThrow(new ExecutionException(new Exception("join exception"))); - joinBolt.execute(tuple); + joinBolt.execute(tuple); + ExecutionException expectedExecutionException = new ExecutionException(new Exception("join exception")); MetronError error = new MetronError() .withErrorType(Constants.ErrorType.ENRICHMENT_ERROR) .withMessage("Joining problem: {}") - .withThrowable(new ExecutionException(new Exception("join exception"))) + .withThrowable(expectedExecutionException) .addRawMessage(new JSONObject()); verify(outputCollector, times(1)).emit(eq(Constants.ERROR_STREAM), argThat(new MetronErrorJSONMatcher(error.getJSONObject()))); + verify(outputCollector, times(1)).reportError(any(ExecutionException.class)); + verify(outputCollector, times(1)).ack(eq(tuple)); + verifyNoMoreInteractions(outputCollector); } } http://git-wip-us.apache.org/repos/asf/metron/blob/d0e1ba50/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBoltTest.java ---------------------------------------------------------------------- diff --git a/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBoltTest.java b/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBoltTest.java index 6fe318e..0f3cc8c 100644 --- a/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBoltTest.java +++ b/metron-platform/metron-enrichment/src/test/java/org/apache/metron/enrichment/bolt/ThreatIntelJoinBoltTest.java @@ -19,16 +19,15 @@ package org.apache.metron.enrichment.bolt; import com.fasterxml.jackson.databind.JsonMappingException; import junit.framework.Assert; -import junit.framework.TestCase; import org.adrianwalker.multilinestring.Multiline; -import org.apache.hadoop.fs.Path; import org.apache.metron.common.configuration.enrichment.SensorEnrichmentConfig; -import org.apache.metron.common.configuration.enrichment.threatintel.ThreatScore; import org.apache.metron.common.configuration.enrichment.threatintel.ThreatTriageConfig; +import org.apache.metron.common.message.MessageGetStrategy; import org.apache.metron.common.utils.JSONUtils; import org.apache.metron.enrichment.adapters.geo.GeoLiteDatabase; import org.apache.metron.test.bolt.BaseEnrichmentBoltTest; import org.apache.metron.test.utils.UnitTestHelper; +import org.apache.storm.tuple.Tuple; import org.json.simple.JSONObject; import org.json.simple.parser.JSONParser; import org.json.simple.parser.ParseException; @@ -39,9 +38,13 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.util.HashMap; -import java.util.List; import java.util.Map; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + public class ThreatIntelJoinBoltTest extends BaseEnrichmentBoltTest { /** @@ -187,26 +190,29 @@ public class ThreatIntelJoinBoltTest extends BaseEnrichmentBoltTest { fieldMap = threatIntelJoinBolt.getFieldMap(sensorType); Assert.assertTrue(fieldMap.containsKey("hbaseThreatIntel")); - Map<String, JSONObject> streamMessageMap = new HashMap<>(); - streamMessageMap.put("message", message); - JSONObject joinedMessage = threatIntelJoinBolt.joinMessages(streamMessageMap); - Assert.assertFalse(joinedMessage.containsKey("is_alert")); + MessageGetStrategy messageGetStrategy = mock(MessageGetStrategy.class); + Tuple messageTuple = mock(Tuple.class); + when(messageGetStrategy.get(messageTuple)).thenReturn(message); + Map<String, Tuple> streamMessageMap = new HashMap<>(); + streamMessageMap.put("message", messageTuple); + JSONObject joinedMessage = threatIntelJoinBolt.joinMessages(streamMessageMap, messageGetStrategy); + assertFalse(joinedMessage.containsKey("is_alert")); - streamMessageMap.put("message", messageWithTiming); - joinedMessage = threatIntelJoinBolt.joinMessages(streamMessageMap); - Assert.assertFalse(joinedMessage.containsKey("is_alert")); + when(messageGetStrategy.get(messageTuple)).thenReturn(messageWithTiming); + joinedMessage = threatIntelJoinBolt.joinMessages(streamMessageMap, messageGetStrategy); + assertFalse(joinedMessage.containsKey("is_alert")); - streamMessageMap.put("message", alertMessage); - joinedMessage = threatIntelJoinBolt.joinMessages(streamMessageMap); - Assert.assertTrue(joinedMessage.containsKey("is_alert") && "true".equals(joinedMessage.get("is_alert"))); + when(messageGetStrategy.get(messageTuple)).thenReturn(alertMessage); + joinedMessage = threatIntelJoinBolt.joinMessages(streamMessageMap, messageGetStrategy); + assertTrue(joinedMessage.containsKey("is_alert") && "true".equals(joinedMessage.get("is_alert"))); if(withThreatTriage && !badConfig) { - Assert.assertTrue(joinedMessage.containsKey("threat.triage.score")); + assertTrue(joinedMessage.containsKey("threat.triage.score")); Double score = (Double) joinedMessage.get("threat.triage.score"); - Assert.assertTrue(Math.abs(10d - score) < 1e-10); + assertTrue(Math.abs(10d - score) < 1e-10); } else { - Assert.assertFalse(joinedMessage.containsKey("threat.triage.score")); + assertFalse(joinedMessage.containsKey("threat.triage.score")); } } }
