Repository: kafka
Updated Branches:
  refs/heads/trunk 143a33bc5 -> 86aa0eb0f


http://git-wip-us.apache.org/repos/asf/kafka/blob/86aa0eb0/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java
new file mode 100644
index 0000000..2ff3b89
--- /dev/null
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java
@@ -0,0 +1,434 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.streams.KeyValue;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+public class ThreadCacheTest {
+
+    @Test
+    public void basicPutGet() throws IOException {
+        List<KeyValue<String, String>> toInsert = Arrays.asList(
+                new KeyValue<>("K1", "V1"),
+                new KeyValue<>("K2", "V2"),
+                new KeyValue<>("K3", "V3"),
+                new KeyValue<>("K4", "V4"),
+                new KeyValue<>("K5", "V5"));
+        final KeyValue<String, String> kv = toInsert.get(0);
+        final String name = "name";
+        ThreadCache cache = new ThreadCache(
+                toInsert.size() * memoryCacheEntrySize(kv.key.getBytes(), 
kv.value.getBytes(), ""));
+
+        for (int i = 0; i < toInsert.size(); i++) {
+            byte[] key = toInsert.get(i).key.getBytes();
+            byte[] value = toInsert.get(i).value.getBytes();
+            cache.put(name, key, new LRUCacheEntry(value, true, 1L, 1L, 1, 
""));
+        }
+
+        for (int i = 0; i < toInsert.size(); i++) {
+            byte[] key = toInsert.get(i).key.getBytes();
+            LRUCacheEntry entry = cache.get(name, key);
+            assertEquals(entry.isDirty, true);
+            assertEquals(new String(entry.value), toInsert.get(i).value);
+        }
+        assertEquals(cache.gets(), 5);
+        assertEquals(cache.puts(), 5);
+        assertEquals(cache.evicts(), 0);
+        assertEquals(cache.flushes(), 0);
+    }
+
+    private void checkOverheads(double entryFactor, double systemFactor, long 
desiredCacheSize, int keySizeBytes,
+                            int valueSizeBytes) {
+        Runtime runtime = Runtime.getRuntime();
+        byte[] key = new byte[keySizeBytes];
+        byte[] value = new byte[valueSizeBytes];
+        final String name = "name";
+        long numElements = desiredCacheSize / memoryCacheEntrySize(key, value, 
"");
+
+        System.gc();
+        long prevRuntimeMemory = runtime.totalMemory() - runtime.freeMemory();
+
+        ThreadCache cache = new ThreadCache(desiredCacheSize);
+        long size = cache.sizeBytes();
+        assertEquals(size, 0);
+        for (int i = 0; i < numElements; i++) {
+            String keyStr = "K" + i;
+            key = keyStr.getBytes();
+            value = new byte[valueSizeBytes];
+            cache.put(name, key, new LRUCacheEntry(value, true, 1L, 1L, 1, 
""));
+        }
+
+
+        System.gc();
+        double ceiling = desiredCacheSize + desiredCacheSize * entryFactor;
+        long usedRuntimeMemory = runtime.totalMemory() - runtime.freeMemory() 
- prevRuntimeMemory;
+        assertTrue((double) cache.sizeBytes() <= ceiling);
+
+        assertTrue("Used memory size " + usedRuntimeMemory + " greater than 
expected " + cache.sizeBytes() * systemFactor,
+            cache.sizeBytes() * systemFactor >= usedRuntimeMemory);
+    }
+
+    @Test
+    public void cacheOverheadsSmallValues() {
+        Runtime runtime = Runtime.getRuntime();
+        double factor = 0.05;
+        double systemFactor = 2.5;
+        long desiredCacheSize = Math.min(100 * 1024 * 1024L, 
runtime.maxMemory());
+        int keySizeBytes = 8;
+        int valueSizeBytes = 100;
+
+        checkOverheads(factor, systemFactor, desiredCacheSize, keySizeBytes, 
valueSizeBytes);
+    }
+
+    @Test
+    public void cacheOverheadsLargeValues() {
+        Runtime runtime = Runtime.getRuntime();
+        double factor = 0.05;
+        double systemFactor = 1.5;
+        long desiredCacheSize = Math.min(100 * 1024 * 1024L, 
runtime.maxMemory());
+        int keySizeBytes = 8;
+        int valueSizeBytes = 1000;
+
+        checkOverheads(factor, systemFactor, desiredCacheSize, keySizeBytes, 
valueSizeBytes);
+    }
+
+
+    static int memoryCacheEntrySize(byte[] key, byte[] value, final String 
topic) {
+        return key.length +
+                value.length +
+                1 + // isDirty
+                8 + // timestamp
+                8 + // offset
+                4 +
+                topic.length() +
+                // LRU Node entries
+                key.length +
+                8 + // entry
+                8 + // previous
+                8; // next
+    }
+
+    @Test
+    public void evict() throws IOException {
+        final List<KeyValue<String, String>> received = new ArrayList<>();
+        List<KeyValue<String, String>> expected = Arrays.asList(
+                new KeyValue<>("K1", "V1"));
+
+        List<KeyValue<String, String>> toInsert = Arrays.asList(
+                new KeyValue<>("K1", "V1"),
+                new KeyValue<>("K2", "V2"),
+                new KeyValue<>("K3", "V3"),
+                new KeyValue<>("K4", "V4"),
+                new KeyValue<>("K5", "V5"));
+        final KeyValue<String, String> kv = toInsert.get(0);
+        final String namespace = "kafka";
+        ThreadCache cache = new ThreadCache(
+                memoryCacheEntrySize(kv.key.getBytes(), kv.value.getBytes(), 
""));
+        cache.addDirtyEntryFlushListener(namespace, new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+                for (ThreadCache.DirtyEntry dirtyEntry : dirty) {
+                    received.add(new KeyValue<>(dirtyEntry.key().toString(), 
new String(dirtyEntry.newValue())));
+                }
+            }
+
+        });
+
+
+        for (int i = 0; i < toInsert.size(); i++) {
+            byte[] key = toInsert.get(i).key.getBytes();
+            byte[] value = toInsert.get(i).value.getBytes();
+            cache.put(namespace, key, new LRUCacheEntry(value, true, 1, 1, 1, 
""));
+        }
+
+        for (int i = 0; i < expected.size(); i++) {
+            KeyValue<String, String> expectedRecord = expected.get(i);
+            KeyValue<String, String> actualRecord = received.get(i);
+            assertEquals(expectedRecord, actualRecord);
+        }
+        assertEquals(cache.evicts(), 4);
+    }
+
+    @Test
+    public void shouldDelete() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final byte[] key = new byte[]{0};
+
+        cache.put("name", key, dirtyEntry(key));
+        assertEquals(key, cache.delete("name", key).value);
+        assertNull(cache.get("name", key));
+    }
+
+    @Test
+    public void shouldNotFlushAfterDelete() throws Exception {
+        final byte[] key = new byte[]{0};
+        final ThreadCache cache = new ThreadCache(10000L);
+        final List<ThreadCache.DirtyEntry> received = new ArrayList<>();
+        final String namespace = "namespace";
+        cache.addDirtyEntryFlushListener(namespace, new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+                received.addAll(dirty);
+            }
+        });
+        cache.put(namespace, key, dirtyEntry(key));
+        assertEquals(key, cache.delete(namespace, key).value);
+
+        // flushing should have no further effect
+        cache.flush(namespace);
+        assertEquals(0, received.size());
+        assertEquals(cache.flushes(), 1);
+    }
+
+    @Test
+    public void shouldNotBlowUpOnNonExistentKeyWhenDeleting() throws Exception 
{
+        final ThreadCache cache = new ThreadCache(10000L);
+        final byte[] key = new byte[]{0};
+
+        cache.put("name", key, dirtyEntry(key));
+        assertNull(cache.delete("name", new byte[]{1}));
+    }
+
+    @Test
+    public void shouldNotBlowUpOnNonExistentNamespaceWhenDeleting() throws 
Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        assertNull(cache.delete("name", new byte[]{1}));
+    }
+
+    @Test
+    public void shouldNotClashWithOverlappingNames() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final byte[] nameByte = new byte[]{0};
+        final byte[] name1Byte = new byte[]{1};
+        cache.put("name", nameByte, dirtyEntry(nameByte));
+        cache.put("name1", nameByte, dirtyEntry(name1Byte));
+
+        assertArrayEquals(nameByte, cache.get("name", nameByte).value);
+        assertArrayEquals(name1Byte, cache.get("name1", nameByte).value);
+    }
+
+    @Test
+    public void shouldPeekNextKey() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final byte[] theByte = {0};
+        final String namespace = "streams";
+        cache.put(namespace, theByte, dirtyEntry(theByte));
+        final ThreadCache.MemoryLRUCacheBytesIterator iterator = 
cache.range(namespace, theByte, new byte[]{1});
+        assertArrayEquals(theByte, iterator.peekNextKey());
+        assertArrayEquals(theByte, iterator.peekNextKey());
+    }
+
+    @Test
+    public void shouldGetSameKeyAsPeekNext() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final byte[] theByte = {0};
+        final String namespace = "streams";
+        cache.put(namespace, theByte, dirtyEntry(theByte));
+        final ThreadCache.MemoryLRUCacheBytesIterator iterator = 
cache.range(namespace, theByte, new byte[]{1});
+        assertArrayEquals(iterator.peekNextKey(), iterator.next().key);
+    }
+
+    @Test(expected = NoSuchElementException.class)
+    public void shouldThrowIfNoPeekNextKey() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final ThreadCache.MemoryLRUCacheBytesIterator iterator = 
cache.range("", new byte[]{0}, new byte[]{1});
+        iterator.peekNextKey();
+    }
+
+    @Test
+    public void shouldReturnFalseIfNoNextKey() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final ThreadCache.MemoryLRUCacheBytesIterator iterator = 
cache.range("", new byte[]{0}, new byte[]{1});
+        assertFalse(iterator.hasNext());
+    }
+
+    @Test
+    public void shouldPeekAndIterateOverRange() throws Exception {
+        final ThreadCache cache = new ThreadCache(10000L);
+        final byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, 
{9}, {10}};
+        final String namespace = "streams";
+        for (final byte[] aByte : bytes) {
+            cache.put(namespace, aByte, dirtyEntry(aByte));
+        }
+        final ThreadCache.MemoryLRUCacheBytesIterator iterator = 
cache.range(namespace, new byte[]{1}, new byte[]{4});
+        int bytesIndex = 1;
+        while (iterator.hasNext()) {
+            byte[] peekedKey = iterator.peekNextKey();
+            final KeyValue<byte[], LRUCacheEntry> next = iterator.next();
+            assertArrayEquals(bytes[bytesIndex], peekedKey);
+            assertArrayEquals(bytes[bytesIndex], next.key);
+            bytesIndex++;
+        }
+        assertEquals(5, bytesIndex);
+    }
+
+    @Test
+    public void shouldSkipEntriesWhereValueHasBeenEvictedFromCache() throws 
Exception {
+        final String namespace = "streams";
+        final int entrySize = memoryCacheEntrySize(new byte[1], new byte[1], 
"");
+        final ThreadCache cache = new ThreadCache(entrySize * 5);
+        cache.addDirtyEntryFlushListener(namespace, new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+
+            }
+        });
+        byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}};
+        for (int i = 0; i < 5; i++) {
+            cache.put(namespace, bytes[i], dirtyEntry(bytes[i]));
+        }
+        assertEquals(5, cache.size());
+
+        final ThreadCache.MemoryLRUCacheBytesIterator range = 
cache.range(namespace, new byte[]{0}, new byte[]{5});
+        // should evict byte[] {0}
+        cache.put(namespace, new byte[]{6}, dirtyEntry(new byte[]{6}));
+
+        assertArrayEquals(new byte[]{1}, range.peekNextKey());
+    }
+
+    @Test
+    public void shouldFlushDirtyEntriesForNamespace() throws Exception {
+        final ThreadCache cache = new ThreadCache(100000);
+        final List<byte[]> received = new ArrayList<>();
+        cache.addDirtyEntryFlushListener("1", new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+                for (ThreadCache.DirtyEntry dirtyEntry : dirty) {
+                    received.add(dirtyEntry.key().get());
+                }
+            }
+        });
+        final List<byte[]> expected = Arrays.asList(new byte[]{0}, new 
byte[]{1}, new byte[]{2});
+        for (byte[] bytes : expected) {
+            cache.put("1", bytes, dirtyEntry(bytes));
+        }
+        cache.put("2", new byte[]{4}, dirtyEntry(new byte[]{4}));
+
+        cache.flush("1");
+        assertEquals(expected, received);
+    }
+
+    @Test
+    public void shouldNotFlushCleanEntriesForNamespace() throws Exception {
+        final ThreadCache cache = new ThreadCache(100000);
+        final List<byte[]> received = new ArrayList<>();
+        cache.addDirtyEntryFlushListener("1", new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+                for (ThreadCache.DirtyEntry dirtyEntry : dirty) {
+                    received.add(dirtyEntry.key().get());
+                }
+            }
+        });
+        final List<byte[]> toInsert =  Arrays.asList(new byte[]{0}, new 
byte[]{1}, new byte[]{2});
+        for (byte[] bytes : toInsert) {
+            cache.put("1", bytes, cleanEntry(bytes));
+        }
+        cache.put("2", new byte[]{4}, cleanEntry(new byte[]{4}));
+
+        cache.flush("1");
+        assertEquals(Collections.EMPTY_LIST, received);
+    }
+
+
+    private void shouldEvictImmediatelyIfCacheSizeIsZeroOrVerySmall(final 
ThreadCache cache) {
+        final List<ThreadCache.DirtyEntry> received = new ArrayList<>();
+        final String namespace = "namespace";
+        cache.addDirtyEntryFlushListener(namespace, new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+                received.addAll(dirty);
+            }
+        });
+        cache.put(namespace, new byte[]{0}, dirtyEntry(new byte[]{0}));
+        assertEquals(1, received.size());
+
+        // flushing should have no further effect
+        cache.flush(namespace);
+        assertEquals(1, received.size());
+    }
+
+    @Test
+    public void shouldEvictImmediatelyIfCacheSizeIsVerySmall() throws 
Exception {
+        final ThreadCache cache = new ThreadCache(1);
+        shouldEvictImmediatelyIfCacheSizeIsZeroOrVerySmall(cache);
+    }
+
+    @Test
+    public void shouldEvictImmediatelyIfCacheSizeIsZero() throws Exception {
+        final ThreadCache cache = new ThreadCache(0);
+        shouldEvictImmediatelyIfCacheSizeIsZeroOrVerySmall(cache);
+    }
+
+    @Test
+    public void shouldPutAll() throws Exception {
+        final ThreadCache cache = new ThreadCache(100000);
+
+        cache.putAll("name", Arrays.asList(KeyValue.pair(new byte[]{0}, 
dirtyEntry(new byte[]{5})),
+                                           KeyValue.pair(new byte[]{1}, 
dirtyEntry(new byte[]{6}))));
+
+        assertArrayEquals(new byte[]{5}, cache.get("name", new 
byte[]{0}).value);
+        assertArrayEquals(new byte[]{6}, cache.get("name", new 
byte[]{1}).value);
+    }
+
+    @Test
+    public void shouldNotForwardCleanEntryOnEviction() throws Exception {
+        final ThreadCache cache = new ThreadCache(0);
+        final List<ThreadCache.DirtyEntry> received = new ArrayList<>();
+        cache.addDirtyEntryFlushListener("name", new 
ThreadCache.DirtyEntryFlushListener() {
+            @Override
+            public void apply(final List<ThreadCache.DirtyEntry> dirty) {
+                received.addAll(dirty);
+            }
+        });
+        cache.put("name", new byte[] {1}, cleanEntry(new byte[]{0}));
+        assertEquals(0, received.size());
+    }
+    @Test
+    public void shouldPutIfAbsent() throws Exception {
+        final ThreadCache cache = new ThreadCache(100000);
+        final byte[] key = {10};
+        final byte[] value = {30};
+        assertNull(cache.putIfAbsent("n", key, dirtyEntry(value)));
+        assertArrayEquals(value, cache.putIfAbsent("n", key, dirtyEntry(new 
byte[]{8})).value);
+        assertArrayEquals(value, cache.get("n", key).value);
+    }
+
+    private LRUCacheEntry dirtyEntry(final byte[] key) {
+        return new LRUCacheEntry(key, true, -1, -1, -1, "");
+    }
+
+    private LRUCacheEntry cleanEntry(final byte[] key) {
+        return new LRUCacheEntry(key);
+    }
+
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/86aa0eb0/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java 
b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
index ccc9cb1..ac58f37 100644
--- a/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
+++ b/streams/src/test/java/org/apache/kafka/test/KStreamTestDriver.java
@@ -23,13 +23,14 @@ import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.streams.kstream.KStreamBuilder;
 import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.StateStore;
-import org.apache.kafka.streams.processor.StateStoreSupplier;
 import org.apache.kafka.streams.processor.StreamPartitioner;
 import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.ProcessorStateManager;
 import org.apache.kafka.streams.processor.internals.ProcessorTopology;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import java.io.File;
 import java.util.HashSet;
@@ -41,6 +42,8 @@ public class KStreamTestDriver {
 
     private final ProcessorTopology topology;
     private final MockProcessorContext context;
+    private ThreadCache cache;
+    private static final long DEFAULT_CACHE_SIZE_BYTES = 1 * 1024 * 1024L;
     public final File stateDir;
 
     private ProcessorNode currNode;
@@ -60,11 +63,12 @@ public class KStreamTestDriver {
         builder.setApplicationId("TestDriver");
         this.topology = builder.build(null);
         this.stateDir = stateDir;
-        this.context = new MockProcessorContext(this, stateDir, keySerde, 
valSerde, new MockRecordCollector());
-        this.context.setTime(0L);
+        this.cache = new ThreadCache(DEFAULT_CACHE_SIZE_BYTES);
+        this.context = new MockProcessorContext(this, stateDir, keySerde, 
valSerde, new MockRecordCollector(), cache);
+        this.context.setRecordContext(new ProcessorRecordContext(0, 0, 0, 
"topic"));
 
-        for (StateStoreSupplier stateStoreSupplier : 
topology.stateStoreSuppliers()) {
-            StateStore store = stateStoreSupplier.get();
+
+        for (StateStore store : topology.stateStores()) {
             store.init(context, store);
         }
 
@@ -89,21 +93,27 @@ public class KStreamTestDriver {
         // if yes, skip
         if 
(topicName.endsWith(ProcessorStateManager.STATE_CHANGELOG_TOPIC_SUFFIX))
             return;
-
+        context.setRecordContext(createRecordContext(context.timestamp()));
+        context.setCurrentNode(currNode);
         try {
             forward(key, value);
         } finally {
             currNode = null;
+            context.setCurrentNode(null);
         }
     }
 
-    public void punctuate(long timestamp) {
-        setTime(timestamp);
+    private ProcessorRecordContext createRecordContext(long timestamp) {
+        return new ProcessorRecordContext(timestamp, -1, -1, "topic");
+    }
+
 
+    public void punctuate(long timestamp) {
         for (ProcessorNode processor : topology.processors()) {
             if (processor.processor() != null) {
                 currNode = processor;
                 try {
+                    context.setRecordContext(createRecordContext(timestamp));
                     processor.processor().punctuate(timestamp);
                 } finally {
                     currNode = null;
@@ -119,7 +129,7 @@ public class KStreamTestDriver {
     @SuppressWarnings("unchecked")
     public <K, V> void forward(K key, V value) {
         ProcessorNode thisNode = currNode;
-        for (ProcessorNode childNode : (List<ProcessorNode<K, V>>) 
thisNode.children()) {
+        for (ProcessorNode childNode : (List<ProcessorNode<K, V>>) 
currNode.children()) {
             currNode = childNode;
             try {
                 childNode.process(key, value);
@@ -168,10 +178,7 @@ public class KStreamTestDriver {
             }
         }
 
-        // close all state stores
-        for (StateStore store : context.allStateStores().values()) {
-            store.close();
-        }
+        flushState();
     }
 
     public Set<String> allProcessorNames() {
@@ -201,6 +208,24 @@ public class KStreamTestDriver {
         return context.allStateStores();
     }
 
+    public void flushState() {
+        final ProcessorNode current = currNode;
+        try {
+            for (StateStore stateStore : context.allStateStores().values()) {
+                final ProcessorNode processorNode = 
topology.storeToProcessorNodeMap().get(stateStore);
+                if (processorNode != null) {
+                    currNode = processorNode;
+                }
+                stateStore.flush();
+            }
+        } finally {
+            currNode = current;
+
+        }
+
+    }
+
+
     private class MockRecordCollector extends RecordCollector {
         public MockRecordCollector() {
             super(null, "KStreamTestDriver");

http://git-wip-us.apache.org/repos/asf/kafka/blob/86aa0eb0/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java 
b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
index d82580d..8ad2fa9 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorContext.java
@@ -21,58 +21,66 @@ import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.StreamsMetrics;
-import org.apache.kafka.streams.processor.ProcessorContext;
+import org.apache.kafka.streams.processor.internals.RecordContext;
 import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
+import org.apache.kafka.streams.processor.internals.ProcessorNode;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
 import org.apache.kafka.streams.state.StateSerdes;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import java.io.File;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
-public class MockProcessorContext implements ProcessorContext, 
RecordCollector.Supplier {
+public class MockProcessorContext implements InternalProcessorContext, 
RecordCollector.Supplier {
 
     private final KStreamTestDriver driver;
     private final Serde<?> keySerde;
     private final Serde<?> valSerde;
     private final RecordCollector.Supplier recordCollectorSupplier;
     private final File stateDir;
-
-    private Map<String, StateStore> storeMap = new HashMap<>();
+    private final ThreadCache cache;
+    private Map<String, StateStore> storeMap = new LinkedHashMap<>();
     private Map<String, StateRestoreCallback> restoreFuncs = new HashMap<>();
 
     long timestamp = -1L;
+    private RecordContext recordContext;
 
     public MockProcessorContext(StateSerdes<?, ?> serdes, RecordCollector 
collector) {
-        this(null, null, serdes.keySerde(), serdes.valueSerde(), collector);
+        this(null, null, serdes.keySerde(), serdes.valueSerde(), collector, 
null);
     }
 
     public MockProcessorContext(KStreamTestDriver driver, File stateDir,
                                 Serde<?> keySerde,
                                 Serde<?> valSerde,
-                                final RecordCollector collector) {
+                                final RecordCollector collector,
+                                final ThreadCache cache) {
         this(driver, stateDir, keySerde, valSerde,
                 new RecordCollector.Supplier() {
                     @Override
                     public RecordCollector recordCollector() {
                         return collector;
                     }
-                });
+                }, cache);
     }
 
     public MockProcessorContext(KStreamTestDriver driver, File stateDir,
                                 Serde<?> keySerde,
                                 Serde<?> valSerde,
-                                RecordCollector.Supplier collectorSupplier) {
+                                RecordCollector.Supplier collectorSupplier,
+                                final ThreadCache cache) {
         this.driver = driver;
         this.stateDir = stateDir;
         this.keySerde = keySerde;
         this.valSerde = valSerde;
         this.recordCollectorSupplier = collectorSupplier;
+        this.cache = cache;
     }
 
     @Override
@@ -110,6 +118,11 @@ public class MockProcessorContext implements 
ProcessorContext, RecordCollector.S
     }
 
     @Override
+    public ThreadCache getCache() {
+        return cache;
+    }
+
+    @Override
     public File stateDir() {
         if (stateDir == null)
             throw new UnsupportedOperationException("State directory not 
specified");
@@ -164,6 +177,7 @@ public class MockProcessorContext implements 
ProcessorContext, RecordCollector.S
         driver.forward(key, value, childName);
     }
 
+
     @Override
     public void commit() {
         throw new UnsupportedOperationException("commit() not supported.");
@@ -171,22 +185,34 @@ public class MockProcessorContext implements 
ProcessorContext, RecordCollector.S
 
     @Override
     public String topic() {
-        return null;
+        if (recordContext == null) {
+            return null;
+        }
+        return recordContext.topic();
     }
 
     @Override
     public int partition() {
-        return -1;
+        if (recordContext == null) {
+            return -1;
+        }
+        return recordContext.partition();
     }
 
     @Override
     public long offset() {
-        return -1L;
+        if (recordContext == null) {
+            return -1L;
+        }
+        return recordContext.offset();
     }
 
     @Override
     public long timestamp() {
-        return this.timestamp;
+        if (recordContext == null) {
+            return timestamp;
+        }
+        return recordContext.timestamp();
     }
 
     @Override
@@ -199,6 +225,11 @@ public class MockProcessorContext implements 
ProcessorContext, RecordCollector.S
         return Collections.emptyMap();
     }
 
+    @Override
+    public RecordContext recordContext() {
+        return recordContext;
+    }
+
     public Map<String, StateStore> allStateStores() {
         return Collections.unmodifiableMap(storeMap);
     }
@@ -209,4 +240,15 @@ public class MockProcessorContext implements 
ProcessorContext, RecordCollector.S
             restoreCallback.restore(entry.key, entry.value);
         }
     }
+
+    @Override
+    public void setRecordContext(final RecordContext recordContext) {
+        this.recordContext = recordContext;
+    }
+
+    @Override
+    public void setCurrentNode(final ProcessorNode currentNode) {
+
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/86aa0eb0/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java 
b/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java
index 67d25f5..a0ffd49 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java
@@ -59,12 +59,12 @@ public class MockProcessorSupplier<K, V> implements 
ProcessorSupplier<K, V> {
         public void process(K key, V value) {
             processed.add((key == null ? "null" : key) + ":" +
                     (value == null ? "null" : value));
+
         }
 
         @Override
         public void punctuate(long streamTime) {
             assertEquals(streamTime, context().timestamp());
-            assertEquals(null, context().topic());
             assertEquals(-1, context().partition());
             assertEquals(-1L, context().offset());
 
@@ -73,8 +73,7 @@ public class MockProcessorSupplier<K, V> implements 
ProcessorSupplier<K, V> {
     }
 
     public void checkAndClearProcessResult(String... expected) {
-        assertEquals("the number of outputs:", expected.length, 
processed.size());
-
+        assertEquals("the number of outputs:" + processed, expected.length, 
processed.size());
         for (int i = 0; i < expected.length; i++) {
             assertEquals("output[" + i + "]:", expected[i], processed.get(i));
         }

http://git-wip-us.apache.org/repos/asf/kafka/blob/86aa0eb0/streams/src/test/java/org/apache/kafka/test/MockReducer.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/test/MockReducer.java 
b/streams/src/test/java/org/apache/kafka/test/MockReducer.java
index 24a8fea..fc71ada 100644
--- a/streams/src/test/java/org/apache/kafka/test/MockReducer.java
+++ b/streams/src/test/java/org/apache/kafka/test/MockReducer.java
@@ -37,7 +37,28 @@ public class MockReducer {
         }
     }
 
+
+    private static class IntegerAdd implements Reducer<Integer> {
+
+        @Override
+        public Integer apply(final Integer value1, final Integer value2) {
+            return value1 + value2;
+        }
+    }
+
+    private static class IntegerSubtract implements Reducer<Integer> {
+
+        @Override
+        public Integer apply(final Integer value1, final Integer value2) {
+            return value1 - value2;
+        }
+    }
+
     public final static Reducer<String> STRING_ADDER = new StringAdd();
 
     public final static Reducer<String> STRING_REMOVER = new StringRemove();
+
+    public final static Reducer<Integer> INTEGER_ADDER = new IntegerAdd();
+
+    public final static Reducer<Integer> INTEGER_SUBTRACTOR = new 
IntegerSubtract();
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/kafka/blob/86aa0eb0/streams/src/test/java/org/apache/kafka/test/ProcessorTopologyTestDriver.java
----------------------------------------------------------------------
diff --git 
a/streams/src/test/java/org/apache/kafka/test/ProcessorTopologyTestDriver.java 
b/streams/src/test/java/org/apache/kafka/test/ProcessorTopologyTestDriver.java
index 8d2ad08..0db69be 100644
--- 
a/streams/src/test/java/org/apache/kafka/test/ProcessorTopologyTestDriver.java
+++ 
b/streams/src/test/java/org/apache/kafka/test/ProcessorTopologyTestDriver.java
@@ -33,12 +33,15 @@ import org.apache.kafka.streams.StreamsMetrics;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.TopologyBuilder;
+import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
 import org.apache.kafka.streams.processor.internals.ProcessorContextImpl;
+import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.ProcessorStateManager;
 import org.apache.kafka.streams.processor.internals.ProcessorTopology;
 import org.apache.kafka.streams.processor.internals.StateDirectory;
 import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -186,7 +189,7 @@ public class ProcessorTopologyTestDriver {
                 public void recordLatency(Sensor sensor, long startNs, long 
endNs) {
                     // do nothing
                 }
-            }, new StateDirectory(applicationId, 
TestUtils.tempDirectory().getPath()));
+            }, new StateDirectory(applicationId, 
TestUtils.tempDirectory().getPath()), new ThreadCache(1024 * 1024));
     }
 
     /**
@@ -207,6 +210,7 @@ public class ProcessorTopologyTestDriver {
         producer.clear();
         // Process the record ...
         task.process();
+        ((InternalProcessorContext) task.context()).setRecordContext(new 
ProcessorRecordContext(0L, offset, tp.partition(), topicName));
         task.commit();
         // Capture all the records sent to the producer ...
         for (ProducerRecord<byte[], byte[]> record : producer.history()) {

Reply via email to