This is an automated email from the ASF dual-hosted git repository.

vvcephei pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new c595470  KAFKA-9770: Close underlying state store also when flush 
throws (#8368)
c595470 is described below

commit c595470713be1fd2daf93816a5dbf0e245a707a0
Author: Bruno Cadonna <[email protected]>
AuthorDate: Sat Mar 28 02:36:10 2020 +0100

    KAFKA-9770: Close underlying state store also when flush throws (#8368)
    
    When a caching state store is closed it calls its flush() method.
    If flush() throws an exception the underlying state store is not closed.
    
    This commit ensures that state stores underlying a wrapped state stores
    are closed even when preceding operations in the close method throw.
    
    Co-authored-by: John Roesler <[email protected]>
    Reviewers: John Roesler <[email protected]>, Guozhang Wang 
<[email protected]>, Matthias J. Sax <[email protected]>
---
 .../internals/metrics/StreamsMetricsImpl.java      | 12 +--
 .../state/internals/CachingKeyValueStore.java      | 31 +++++--
 .../state/internals/CachingSessionStore.java       | 23 +++--
 .../state/internals/CachingWindowStore.java        | 23 +++--
 .../streams/state/internals/ExceptionUtils.java    | 46 ++++++++++
 .../state/internals/MeteredKeyValueStore.java      |  7 +-
 .../state/internals/MeteredSessionStore.java       |  7 +-
 .../state/internals/MeteredWindowStore.java        |  7 +-
 .../state/internals/CachingKeyValueStoreTest.java  | 81 ++++++++++++++----
 .../state/internals/CachingSessionStoreTest.java   | 88 ++++++++++++++++++--
 .../state/internals/CachingWindowStoreTest.java    | 97 ++++++++++++++++++----
 .../state/internals/MeteredKeyValueStoreTest.java  | 44 +++++++++-
 .../state/internals/MeteredSessionStoreTest.java   | 37 +++++++++
 .../state/internals/MeteredWindowStoreTest.java    | 71 +++++++++++++---
 14 files changed, 493 insertions(+), 81 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
index ee372f7..3c088f3 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java
@@ -402,11 +402,13 @@ public class StreamsMetricsImpl implements StreamsMetrics 
{
         final String key = storeSensorPrefix(threadId, taskId, storeName);
         synchronized (storeLevelSensors) {
             final String fullSensorName = key + SENSOR_NAME_DELIMITER + 
sensorName;
-            return Optional.ofNullable(metrics.getSensor(fullSensorName))
-                .orElseGet(() -> {
-                    storeLevelSensors.computeIfAbsent(key, ignored -> new 
LinkedList<>()).push(fullSensorName);
-                    return metrics.sensor(fullSensorName, recordingLevel, 
parents);
-                });
+            final Sensor sensor = metrics.getSensor(fullSensorName);
+            if (sensor == null) {
+                storeLevelSensors.computeIfAbsent(key, ignored -> new 
LinkedList<>()).push(fullSensorName);
+                return metrics.sensor(fullSensorName, recordingLevel, parents);
+            } else {
+                return sensor;
+            }
         }
     }
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
index 8aa0ceb..14f4e54 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java
@@ -27,12 +27,16 @@ import org.apache.kafka.streams.state.KeyValueStore;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 
+import static 
org.apache.kafka.streams.state.internals.ExceptionUtils.executeAll;
+import static 
org.apache.kafka.streams.state.internals.ExceptionUtils.throwSuppressed;
+
 public class CachingKeyValueStore
     extends WrappedStateStore<KeyValueStore<Bytes, byte[]>, byte[], byte[]>
     implements KeyValueStore<Bytes, byte[]>, CachedStateStore<byte[], byte[]> {
@@ -119,6 +123,7 @@ public class CachingKeyValueStore
         validateStoreOpen();
         lock.writeLock().lock();
         try {
+            validateStoreOpen();
             // for null bytes, we still put it into cache indicating tombstones
             putInternal(key, value);
         } finally {
@@ -148,6 +153,7 @@ public class CachingKeyValueStore
         validateStoreOpen();
         lock.writeLock().lock();
         try {
+            validateStoreOpen();
             final byte[] v = getInternal(key);
             if (v == null) {
                 putInternal(key, value);
@@ -163,6 +169,7 @@ public class CachingKeyValueStore
         validateStoreOpen();
         lock.writeLock().lock();
         try {
+            validateStoreOpen();
             for (final KeyValue<Bytes, byte[]> entry : entries) {
                 Objects.requireNonNull(entry.key, "key cannot be null");
                 put(entry.key, entry.value);
@@ -178,6 +185,7 @@ public class CachingKeyValueStore
         validateStoreOpen();
         lock.writeLock().lock();
         try {
+            validateStoreOpen();
             return deleteInternal(key);
         } finally {
             lock.writeLock().unlock();
@@ -202,6 +210,7 @@ public class CachingKeyValueStore
         }
         theLock.lock();
         try {
+            validateStoreOpen();
             return getInternal(key);
         } finally {
             theLock.unlock();
@@ -259,6 +268,7 @@ public class CachingKeyValueStore
         validateStoreOpen();
         lock.readLock().lock();
         try {
+            validateStoreOpen();
             return wrapped().approximateNumEntries();
         } finally {
             lock.readLock().unlock();
@@ -267,10 +277,12 @@ public class CachingKeyValueStore
 
     @Override
     public void flush() {
+        validateStoreOpen();
         lock.writeLock().lock();
         try {
+            validateStoreOpen();
             cache.flush(cacheName);
-            super.flush();
+            wrapped().flush();
         } finally {
             lock.writeLock().unlock();
         }
@@ -278,14 +290,19 @@ public class CachingKeyValueStore
 
     @Override
     public void close() {
+        lock.writeLock().lock();
         try {
-            flush();
-        } finally {
-            try {
-                super.close();
-            } finally {
-                cache.close(cacheName);
+            final LinkedList<RuntimeException> suppressed = executeAll(
+                () -> cache.flush(cacheName),
+                () -> cache.close(cacheName),
+                wrapped()::close
+            );
+            if (!suppressed.isEmpty()) {
+                throwSuppressed("Caught an exception while closing caching key 
value store for store " + name(),
+                                suppressed);
             }
+        } finally {
+            lock.writeLock().unlock();
         }
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
index d48f540..f537d4c 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import java.util.NoSuchElementException;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.kstream.Windowed;
@@ -28,10 +27,16 @@ import 
org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import org.apache.kafka.streams.processor.internals.RecordQueue;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.SessionStore;
-import java.util.Objects;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.LinkedList;
+import java.util.NoSuchElementException;
+import java.util.Objects;
+
+import static 
org.apache.kafka.streams.state.internals.ExceptionUtils.executeAll;
+import static 
org.apache.kafka.streams.state.internals.ExceptionUtils.throwSuppressed;
+
 class CachingSessionStore
     extends WrappedStateStore<SessionStore<Bytes, byte[]>, byte[], byte[]>
     implements SessionStore<Bytes, byte[]>, CachedStateStore<byte[], byte[]> {
@@ -228,13 +233,19 @@ class CachingSessionStore
 
     public void flush() {
         cache.flush(cacheName);
-        super.flush();
+        wrapped().flush();
     }
 
     public void close() {
-        flush();
-        cache.close(cacheName);
-        super.close();
+        final LinkedList<RuntimeException> suppressed = executeAll(
+            () -> cache.flush(cacheName),
+            () -> cache.close(cacheName),
+            wrapped()::close
+        );
+        if (!suppressed.isEmpty()) {
+            throwSuppressed("Caught an exception while closing caching session 
store for store " + name(),
+                            suppressed);
+        }
     }
 
     private class CacheIteratorWrapper implements 
PeekingKeyValueIterator<Bytes, LRUCacheEntry> {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
index 4bb8116..d2bd02e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java
@@ -16,7 +16,6 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import java.util.NoSuchElementException;
 import org.apache.kafka.common.serialization.Serdes;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
@@ -34,6 +33,12 @@ import org.apache.kafka.streams.state.WindowStoreIterator;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.LinkedList;
+import java.util.NoSuchElementException;
+
+import static 
org.apache.kafka.streams.state.internals.ExceptionUtils.executeAll;
+import static 
org.apache.kafka.streams.state.internals.ExceptionUtils.throwSuppressed;
+
 class CachingWindowStore
     extends WrappedStateStore<WindowStore<Bytes, byte[]>, byte[], byte[]>
     implements WindowStore<Bytes, byte[]>, CachedStateStore<byte[], byte[]> {
@@ -293,12 +298,20 @@ class CachingWindowStore
     }
 
     @Override
-    public void close() {
-        flush();
-        cache.close(name);
-        wrapped().close();
+    public synchronized void close() {
+        final LinkedList<RuntimeException> suppressed = executeAll(
+            () -> cache.flush(name),
+            () -> cache.close(name),
+            wrapped()::close
+        );
+        if (!suppressed.isEmpty()) {
+            throwSuppressed("Caught an exception while closing caching window 
store for store " + name(),
+                            suppressed);
+        }
     }
 
+
+
     private class CacheIteratorWrapper implements 
PeekingKeyValueIterator<Bytes, LRUCacheEntry> {
 
         private final long segmentInterval;
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/ExceptionUtils.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/ExceptionUtils.java
new file mode 100644
index 0000000..e40b6ad
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/ExceptionUtils.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.LinkedList;
+
+final class ExceptionUtils {
+    private ExceptionUtils() {}
+
+    static LinkedList<RuntimeException> executeAll(final Runnable... actions) {
+        final LinkedList<RuntimeException> suppressed = new LinkedList<>();
+        for (final Runnable action : actions) {
+            try {
+                action.run();
+            } catch (final RuntimeException exception) {
+                suppressed.add(exception);
+            }
+        }
+        return suppressed;
+    }
+
+    static void throwSuppressed(final String message, final 
LinkedList<RuntimeException> suppressed) {
+        if (!suppressed.isEmpty()) {
+            final RuntimeException firstCause = suppressed.pollFirst();
+            final RuntimeException toThrow = new RuntimeException(message, 
firstCause);
+            for (final RuntimeException e : suppressed) {
+                toThrow.addSuppressed(e);
+            }
+            throw toThrow;
+        }
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
index 7a8b973..6076702 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
@@ -199,8 +199,11 @@ public class MeteredKeyValueStore<K, V>
 
     @Override
     public void close() {
-        super.close();
-        streamsMetrics.removeAllStoreLevelSensors(threadId, taskId, name());
+        try {
+            wrapped().close();
+        } finally {
+            streamsMetrics.removeAllStoreLevelSensors(threadId, taskId, 
name());
+        }
     }
 
     private V outerValue(final byte[] value) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
index 3e541e7..c7d4290 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
@@ -228,8 +228,11 @@ public class MeteredSessionStore<K, V>
 
     @Override
     public void close() {
-        super.close();
-        streamsMetrics.removeAllStoreLevelSensors(threadId, taskId, name());
+        try {
+            wrapped().close();
+        } finally {
+            streamsMetrics.removeAllStoreLevelSensors(threadId, taskId, 
name());
+        }
     }
 
     private Bytes keyBytes(final K key) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
index 27683f2..fd39468 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java
@@ -202,8 +202,11 @@ public class MeteredWindowStore<K, V>
 
     @Override
     public void close() {
-        super.close();
-        streamsMetrics.removeAllStoreLevelSensors(threadId, taskId, name());
+        try {
+            wrapped().close();
+        } finally {
+            streamsMetrics.removeAllStoreLevelSensors(threadId, taskId, 
name());
+        }
     }
 
     private Bytes keyBytes(final K key) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java
index f6f00b9..cd4b02c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingKeyValueStoreTest.java
@@ -33,6 +33,7 @@ import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.StoreBuilder;
 import org.apache.kafka.streams.state.Stores;
 import org.apache.kafka.test.InternalMockProcessorContext;
+import org.apache.kafka.test.TestUtils;
 import org.easymock.EasyMock;
 import org.junit.After;
 import org.junit.Before;
@@ -56,13 +57,14 @@ import static org.junit.Assert.fail;
 
 public class CachingKeyValueStoreTest extends AbstractKeyValueStoreTest {
 
+    private final static String TOPIC = "topic";
+    private static final String CACHE_NAMESPACE = "0_0-store-name";
     private final int maxCacheSizeBytes = 150;
     private InternalMockProcessorContext context;
     private CachingKeyValueStore store;
-    private InMemoryKeyValueStore underlyingStore;
+    private KeyValueStore<Bytes, byte[]> underlyingStore;
     private ThreadCache cache;
     private CacheFlushListenerStub<String, String> cacheFlushListener;
-    private String topic;
 
     @Before
     public void setUp() {
@@ -73,8 +75,7 @@ public class CachingKeyValueStoreTest extends 
AbstractKeyValueStoreTest {
         store.setFlushListener(cacheFlushListener, false);
         cache = new ThreadCache(new LogContext("testCache "), 
maxCacheSizeBytes, new MockStreamsMetrics(new Metrics()));
         context = new InternalMockProcessorContext(null, null, null, null, 
cache);
-        topic = "topic";
-        context.setRecordContext(new ProcessorRecordContext(10, 0, 0, topic, 
null));
+        context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, 
null));
         store.init(context, null);
     }
 
@@ -121,22 +122,72 @@ public class CachingKeyValueStoreTest extends 
AbstractKeyValueStoreTest {
     }
 
     @Test
-    public void shouldCloseAfterErrorWithFlush() {
+    public void shouldCloseWrappedStoreAfterErrorDuringCacheFlush() {
+        setUpCloseTests();
+        cache.flush(CACHE_NAMESPACE);
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on flush"));
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.replay(underlyingStore);
+
+        try {
+            store.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(underlyingStore);
+        }
+    }
+
+    @Test
+    public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() {
+        setUpCloseTests();
+        cache.flush(CACHE_NAMESPACE);
+        cache.close(CACHE_NAMESPACE);
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on close"));
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.replay(underlyingStore);
+
         try {
-            cache = EasyMock.niceMock(ThreadCache.class);
-            context = new InternalMockProcessorContext(null, null, null, null, 
cache);
-            context.setRecordContext(new ProcessorRecordContext(10, 0, 0, 
topic, null));
-            store.init(context, null);
-            cache.flush("0_0-store");
-            EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on flush"));
-            EasyMock.replay(cache);
             store.close();
-        } catch (final NullPointerException npe) {
-            assertFalse(underlyingStore.isOpen());
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(underlyingStore);
         }
     }
 
     @Test
+    public void shouldCloseCacheAfterErrorDuringStateStoreClose() {
+        setUpCloseTests();
+        EasyMock.reset(cache);
+        cache.flush(CACHE_NAMESPACE);
+        cache.close(CACHE_NAMESPACE);
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on close"));
+        EasyMock.replay(underlyingStore);
+
+        try {
+            store.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(cache);
+        }
+    }
+
+    private void setUpCloseTests() {
+        underlyingStore = EasyMock.createNiceMock(KeyValueStore.class);
+        EasyMock.expect(underlyingStore.name()).andStubReturn("store-name");
+        EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true);
+        EasyMock.replay(underlyingStore);
+        store = new CachingKeyValueStore(underlyingStore);
+        cache = EasyMock.niceMock(ThreadCache.class);
+        context = new InternalMockProcessorContext(TestUtils.tempDirectory(), 
null, null, null, cache);
+        context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, 
null));
+        store.init(context, store);
+    }
+
+    @Test
     public void shouldPutGetToFromCache() {
         store.put(bytesKey("key"), bytesValue("value"));
         store.put(bytesKey("key2"), bytesValue("value2"));
@@ -374,7 +425,7 @@ public class CachingKeyValueStoreTest extends 
AbstractKeyValueStoreTest {
         while (cachedSize < maxCacheSizeBytes) {
             final String kv = String.valueOf(i++);
             store.put(bytesKey(kv), bytesValue(kv));
-            cachedSize += memoryCacheEntrySize(kv.getBytes(), kv.getBytes(), 
topic);
+            cachedSize += memoryCacheEntrySize(kv.getBytes(), kv.getBytes(), 
TOPIC);
         }
         return i;
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
index 128cdc2..7b1604c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingSessionStoreTest.java
@@ -33,9 +33,12 @@ import 
org.apache.kafka.streams.processor.internals.MockStreamsMetrics;
 import org.apache.kafka.streams.processor.internals.ProcessorRecordContext;
 import 
org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.test.InternalMockProcessorContext;
 import org.apache.kafka.test.TestUtils;
+import org.easymock.EasyMock;
 import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.nio.charset.StandardCharsets;
@@ -67,27 +70,30 @@ public class CachingSessionStoreTest {
     private static final int MAX_CACHE_SIZE_BYTES = 600;
     private static final Long DEFAULT_TIMESTAMP = 10L;
     private static final long SEGMENT_INTERVAL = 100L;
+    private static final String TOPIC = "topic";
+    private static final String CACHE_NAMESPACE = "0_0-store-name";
+
     private final Bytes keyA = Bytes.wrap("a".getBytes());
     private final Bytes keyAA = Bytes.wrap("aa".getBytes());
     private final Bytes keyB = Bytes.wrap("b".getBytes());
 
+    private SessionStore<Bytes, byte[]> underlyingStore =
+        new InMemorySessionStore("store-name", Long.MAX_VALUE, "metric-scope");
+    private InternalMockProcessorContext context;
     private CachingSessionStore cachingStore;
     private ThreadCache cache;
 
-    public CachingSessionStoreTest() {
-        final SessionKeySchema schema = new SessionKeySchema();
-        final RocksDBSegmentedBytesStore root =
-            new RocksDBSegmentedBytesStore("test", "metrics-scope", 0L, 
SEGMENT_INTERVAL, schema);
-        final RocksDBSessionStore sessionStore = new RocksDBSessionStore(root);
-        cachingStore = new CachingSessionStore(sessionStore, SEGMENT_INTERVAL);
+    @Before
+    public void before() {
+        cachingStore = new CachingSessionStore(underlyingStore, 
SEGMENT_INTERVAL);
         cache = new ThreadCache(new LogContext("testCache "), 
MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics()));
         final InternalMockProcessorContext context = new 
InternalMockProcessorContext(TestUtils.tempDirectory(), null, null, null, 
cache);
-        context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 
0, 0, "topic", null));
+        context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 
0, 0, TOPIC, null));
         cachingStore.init(context, cachingStore);
     }
 
     @After
-    public void close() {
+    public void after() {
         cachingStore.close();
     }
 
@@ -124,6 +130,72 @@ public class CachingSessionStoreTest {
     }
 
     @Test
+    public void shouldCloseWrappedStoreAfterErrorDuringCacheFlush() {
+        setUpCloseTests();
+        cache.flush(CACHE_NAMESPACE);
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on flush"));
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.replay(underlyingStore);
+
+        try {
+            cachingStore.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(underlyingStore);
+        }
+    }
+
+    @Test
+    public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() {
+        setUpCloseTests();
+        cache.flush(CACHE_NAMESPACE);
+        cache.close(CACHE_NAMESPACE);
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on close"));
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.replay(underlyingStore);
+
+        try {
+            cachingStore.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(underlyingStore);
+        }
+    }
+
+    @Test
+    public void shouldCloseCacheAfterErrorDuringStateStoreClose() {
+        setUpCloseTests();
+        EasyMock.reset(cache);
+        cache.flush(CACHE_NAMESPACE);
+        cache.close(CACHE_NAMESPACE);
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on close"));
+        EasyMock.replay(underlyingStore);
+
+        try {
+            cachingStore.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(cache);
+        }
+    }
+
+    private void setUpCloseTests() {
+        underlyingStore = EasyMock.createNiceMock(SessionStore.class);
+        EasyMock.expect(underlyingStore.name()).andStubReturn("store-name");
+        EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true);
+        EasyMock.replay(underlyingStore);
+        cachingStore = new CachingSessionStore(underlyingStore, 
SEGMENT_INTERVAL);
+        cache = EasyMock.niceMock(ThreadCache.class);
+        context = new InternalMockProcessorContext(TestUtils.tempDirectory(), 
null, null, null, cache);
+        context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, 
null));
+        cachingStore.init(context, cachingStore);
+    }
+
+    @Test
     public void shouldPutFetchRangeFromCache() {
         cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), 
"1".getBytes());
         cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), 
"1".getBytes());
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java
index 64ab25a..28fa06d 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingWindowStoreTest.java
@@ -44,6 +44,7 @@ import org.apache.kafka.streams.state.WindowStoreIterator;
 import org.apache.kafka.streams.TestInputTopic;
 import org.apache.kafka.test.InternalMockProcessorContext;
 import org.apache.kafka.test.TestUtils;
+import org.easymock.EasyMock;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -78,31 +79,33 @@ public class CachingWindowStoreTest {
     private static final long DEFAULT_TIMESTAMP = 10L;
     private static final Long WINDOW_SIZE = 10L;
     private static final long SEGMENT_INTERVAL = 100L;
+    private final static String TOPIC = "topic";
+    private static final String CACHE_NAMESPACE = "0_0-store-name";
+
     private InternalMockProcessorContext context;
-    private RocksDBSegmentedBytesStore underlying;
+    private RocksDBSegmentedBytesStore bytesStore;
+    private WindowStore<Bytes, byte[]> underlyingStore;
     private CachingWindowStore cachingStore;
     private CachingKeyValueStoreTest.CacheFlushListenerStub<Windowed<String>, 
String> cacheListener;
     private ThreadCache cache;
-    private String topic;
     private WindowKeySchema keySchema;
 
     @Before
     public void setUp() {
         keySchema = new WindowKeySchema();
-        underlying = new RocksDBSegmentedBytesStore("test", "metrics-scope", 
0, SEGMENT_INTERVAL, keySchema);
-        final RocksDBWindowStore windowStore = new RocksDBWindowStore(
-            underlying,
+        bytesStore = new RocksDBSegmentedBytesStore("test", "metrics-scope", 
0, SEGMENT_INTERVAL, keySchema);
+        underlyingStore = new RocksDBWindowStore(
+            bytesStore,
             false,
             WINDOW_SIZE);
         final TimeWindowedDeserializer<String> keyDeserializer = new 
TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE);
         keyDeserializer.setIsChangelogTopic(true);
         cacheListener = new 
CachingKeyValueStoreTest.CacheFlushListenerStub<>(keyDeserializer, new 
StringDeserializer());
-        cachingStore = new CachingWindowStore(windowStore, WINDOW_SIZE, 
SEGMENT_INTERVAL);
+        cachingStore = new CachingWindowStore(underlyingStore, WINDOW_SIZE, 
SEGMENT_INTERVAL);
         cachingStore.setFlushListener(cacheListener, false);
         cache = new ThreadCache(new LogContext("testCache "), 
MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics()));
-        topic = "topic";
         context = new InternalMockProcessorContext(TestUtils.tempDirectory(), 
null, null, null, cache);
-        context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 
0, 0, topic, null));
+        context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 
0, 0, TOPIC, null));
         cachingStore.init(context, cachingStore);
     }
 
@@ -123,7 +126,7 @@ public class CachingWindowStoreTest {
 
         builder.addStateStore(storeBuilder);
 
-        builder.stream(topic,
+        builder.stream(TOPIC,
             Consumed.with(Serdes.String(), Serdes.String()))
             .transform(() -> new Transformer<String, String, KeyValue<String, 
String>>() {
                 private WindowStore<String, String> store;
@@ -181,7 +184,7 @@ public class CachingWindowStoreTest {
         final Instant initialWallClockTime = Instant.ofEpochMilli(0L);
         final TopologyTestDriver driver = new 
TopologyTestDriver(builder.build(), streamsConfiguration, initialWallClockTime);
 
-        final TestInputTopic<String, String> inputTopic = 
driver.createInputTopic(topic,
+        final TestInputTopic<String, String> inputTopic = 
driver.createInputTopic(TOPIC,
             Serdes.String().serializer(),
             Serdes.String().serializer(),
             initialWallClockTime,
@@ -336,7 +339,7 @@ public class CachingWindowStoreTest {
     public void shouldFlushEvictedItemsIntoUnderlyingStore() {
         final int added = addItemsToCache();
         // all dirty entries should have been flushed
-        final KeyValueIterator<Bytes, byte[]> iter = underlying.fetch(
+        final KeyValueIterator<Bytes, byte[]> iter = bytesStore.fetch(
             Bytes.wrap("0".getBytes(StandardCharsets.UTF_8)),
             DEFAULT_TIMESTAMP,
             DEFAULT_TIMESTAMP);
@@ -453,7 +456,7 @@ public class CachingWindowStoreTest {
     @Test
     public void shouldIterateCacheAndStore() {
         final Bytes key = Bytes.wrap("1".getBytes());
-        underlying.put(WindowKeySchema.toStoreKeyBinary(key, 
DEFAULT_TIMESTAMP, 0), "a".getBytes());
+        bytesStore.put(WindowKeySchema.toStoreKeyBinary(key, 
DEFAULT_TIMESTAMP, 0), "a".getBytes());
         cachingStore.put(key, bytesValue("b"), DEFAULT_TIMESTAMP + 
WINDOW_SIZE);
         final WindowStoreIterator<byte[]> fetch =
             cachingStore.fetch(bytesKey("1"), ofEpochMilli(DEFAULT_TIMESTAMP), 
ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE));
@@ -465,7 +468,7 @@ public class CachingWindowStoreTest {
     @Test
     public void shouldIterateCacheAndStoreKeyRange() {
         final Bytes key = Bytes.wrap("1".getBytes());
-        underlying.put(WindowKeySchema.toStoreKeyBinary(key, 
DEFAULT_TIMESTAMP, 0), "a".getBytes());
+        bytesStore.put(WindowKeySchema.toStoreKeyBinary(key, 
DEFAULT_TIMESTAMP, 0), "a".getBytes());
         cachingStore.put(key, bytesValue("b"), DEFAULT_TIMESTAMP + 
WINDOW_SIZE);
 
         final KeyValueIterator<Windowed<Bytes>, byte[]> fetchRange =
@@ -623,6 +626,72 @@ public class CachingWindowStoreTest {
             + "Note that the built-in numerical serdes do not follow this for 
negative numbers"));
     }
 
+    @Test
+    public void shouldCloseWrappedStoreAfterErrorDuringCacheFlush() {
+        setUpCloseTests();
+        cache.flush(CACHE_NAMESPACE);
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on flush"));
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.replay(underlyingStore);
+
+        try {
+            cachingStore.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(underlyingStore);
+        }
+    }
+
+    @Test
+    public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() {
+        setUpCloseTests();
+        cache.flush(CACHE_NAMESPACE);
+        cache.close(CACHE_NAMESPACE);
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on close"));
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.replay(underlyingStore);
+
+        try {
+            cachingStore.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(underlyingStore);
+        }
+    }
+
+    @Test
+    public void shouldCloseCacheAfterErrorDuringStateStoreClose() {
+        setUpCloseTests();
+        EasyMock.reset(cache);
+        cache.flush(CACHE_NAMESPACE);
+        cache.close(CACHE_NAMESPACE);
+        EasyMock.replay(cache);
+        EasyMock.reset(underlyingStore);
+        underlyingStore.close();
+        EasyMock.expectLastCall().andThrow(new 
NullPointerException("Simulating an error on close"));
+        EasyMock.replay(underlyingStore);
+
+        try {
+            cachingStore.close();
+        } catch (final RuntimeException exception) {
+            EasyMock.verify(cache);
+        }
+    }
+
+    private void setUpCloseTests() {
+        underlyingStore = EasyMock.createNiceMock(WindowStore.class);
+        EasyMock.expect(underlyingStore.name()).andStubReturn("store-name");
+        EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true);
+        EasyMock.replay(underlyingStore);
+        cachingStore = new CachingWindowStore(underlyingStore, WINDOW_SIZE, 
SEGMENT_INTERVAL);
+        cache = EasyMock.niceMock(ThreadCache.class);
+        context = new InternalMockProcessorContext(TestUtils.tempDirectory(), 
null, null, null, cache);
+        context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, 
null));
+        cachingStore.init(context, cachingStore);
+    }
+
     private static KeyValue<Windowed<Bytes>, byte[]> windowedPair(final String 
key, final String value, final long timestamp) {
         return KeyValue.pair(
             new Windowed<>(bytesKey(key), new TimeWindow(timestamp, timestamp 
+ WINDOW_SIZE)),
@@ -636,7 +705,7 @@ public class CachingWindowStoreTest {
         while (cachedSize < MAX_CACHE_SIZE_BYTES) {
             final String kv = String.valueOf(i++);
             cachingStore.put(bytesKey(kv), bytesValue(kv));
-            cachedSize += memoryCacheEntrySize(kv.getBytes(), kv.getBytes(), 
topic) +
+            cachedSize += memoryCacheEntrySize(kv.getBytes(), kv.getBytes(), 
TOPIC) +
                 8 + // timestamp
                 4; // sequenceNumber
         }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
index de7c089..9a2b9fd 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
@@ -48,6 +48,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -62,8 +63,11 @@ import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 @RunWith(Parameterized.class)
@@ -202,10 +206,6 @@ public class MeteredKeyValueStoreTest {
         verify(inner);
     }
 
-    private KafkaMetric metric(final String name) {
-        return this.metrics.metric(new MetricName(name, storeLevelGroup, "", 
this.tags));
-    }
-
     @SuppressWarnings("unchecked")
     @Test
     public void shouldPutAllToInnerStoreAndRecordPutAllMetric() {
@@ -311,8 +311,44 @@ public class MeteredKeyValueStoreTest {
         assertFalse(metered.setFlushListener(null, false));
     }
 
+    @Test
+    public void shouldRemoveMetricsOnClose() {
+        inner.close();
+        expectLastCall();
+        init(); // replays "inner"
+
+        // There's always a "count" metric registered
+        assertThat(storeMetrics(), not(empty()));
+        metered.close();
+        assertThat(storeMetrics(), empty());
+        verify(inner);
+    }
+
+    @Test
+    public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
+        inner.close();
+        expectLastCall().andThrow(new RuntimeException("Oops!"));
+        init(); // replays "inner"
+
+        assertThat(storeMetrics(), not(empty()));
+        assertThrows(RuntimeException.class, metered::close);
+        assertThat(storeMetrics(), empty());
+        verify(inner);
+    }
+
     private KafkaMetric metric(final MetricName metricName) {
         return this.metrics.metric(metricName);
     }
 
+    private KafkaMetric metric(final String name) {
+        return metrics.metric(new MetricName(name, storeLevelGroup, "", tags));
+    }
+
+    private List<MetricName> storeMetrics() {
+        return metrics.metrics()
+                      .keySet()
+                      .stream()
+                      .filter(name -> name.group().equals(storeLevelGroup) && 
name.tags().equals(tags))
+                      .collect(Collectors.toList());
+    }
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
index e589c5b..d1f805c 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
@@ -49,7 +49,9 @@ import org.junit.runners.Parameterized.Parameters;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
@@ -64,8 +66,11 @@ import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.verify;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 @RunWith(Parameterized.class)
@@ -339,8 +344,40 @@ public class MeteredSessionStoreTest {
         assertFalse(metered.setFlushListener(null, false));
     }
 
+    @Test
+    public void shouldRemoveMetricsOnClose() {
+        inner.close();
+        expectLastCall();
+        init(); // replays "inner"
+
+        // There's always a "count" metric registered
+        assertThat(storeMetrics(), not(empty()));
+        metered.close();
+        assertThat(storeMetrics(), empty());
+        verify(inner);
+    }
+
+    @Test
+    public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
+        inner.close();
+        expectLastCall().andThrow(new RuntimeException("Oops!"));
+        init(); // replays "inner"
+
+        assertThat(storeMetrics(), not(empty()));
+        assertThrows(RuntimeException.class, metered::close);
+        assertThat(storeMetrics(), empty());
+        verify(inner);
+    }
+
     private KafkaMetric metric(final String name) {
         return this.metrics.metric(new MetricName(name, storeLevelGroup, "", 
this.tags));
     }
 
+    private List<MetricName> storeMetrics() {
+        return metrics.metrics()
+                      .keySet()
+                      .stream()
+                      .filter(name -> name.group().equals(storeLevelGroup) && 
name.tags().equals(tags))
+                      .collect(Collectors.toList());
+    }
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java
index 7d05262..569203e 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java
@@ -43,10 +43,14 @@ import org.junit.runners.Parameterized.Parameters;
 
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 import static java.time.Instant.ofEpochMilli;
 import static java.util.Collections.singletonMap;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE;
 import static 
org.apache.kafka.test.StreamsTestUtils.getMetricByNameFilterByTags;
 import static org.easymock.EasyMock.anyObject;
@@ -57,9 +61,13 @@ import static org.easymock.EasyMock.expectLastCall;
 import static org.easymock.EasyMock.mock;
 import static org.easymock.EasyMock.replay;
 import static org.easymock.EasyMock.verify;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 @RunWith(Parameterized.class)
@@ -87,6 +95,7 @@ public class MeteredWindowStoreTest {
     private final Metrics metrics = new Metrics(new 
MetricConfig().recordLevel(Sensor.RecordingLevel.DEBUG));
     private String storeLevelGroup;
     private String threadIdTagKey;
+    private Map<String, String> tags;
 
     {
         expect(innerStoreMock.name()).andReturn(STORE_NAME).anyTimes();
@@ -121,6 +130,11 @@ public class MeteredWindowStoreTest {
             StreamsConfig.METRICS_0100_TO_24.equals(builtInMetricsVersion) ? 
STORE_LEVEL_GROUP_FROM_0100_TO_24 : STORE_LEVEL_GROUP;
         threadIdTagKey =
             StreamsConfig.METRICS_0100_TO_24.equals(builtInMetricsVersion) ? 
THREAD_ID_TAG_KEY_FROM_0100_TO_24 : THREAD_ID_TAG_KEY;
+        tags = mkMap(
+            mkEntry(threadIdTagKey, threadId),
+            mkEntry("task-id", context.taskId().toString()),
+            mkEntry(STORE_TYPE + "-state-id", STORE_NAME)
+        );
     }
 
     @Test
@@ -279,17 +293,6 @@ public class MeteredWindowStoreTest {
     }
 
     @Test
-    public void shouldCloseUnderlyingStore() {
-        innerStoreMock.close();
-        expectLastCall();
-        replay(innerStoreMock);
-
-        store.init(context, store);
-        store.close();
-        verify(innerStoreMock);
-    }
-
-    @Test
     public void shouldNotThrowNullPointerExceptionIfFetchReturnsNull() {
         expect(innerStoreMock.fetch(Bytes.wrap("a".getBytes()), 
0)).andReturn(null);
         replay(innerStoreMock);
@@ -325,4 +328,50 @@ public class MeteredWindowStoreTest {
     public void shouldNotSetFlushListenerOnWrappedNoneCachingStore() {
         assertFalse(store.setFlushListener(null, false));
     }
+
+    @Test
+    public void shouldCloseUnderlyingStore() {
+        innerStoreMock.close();
+        expectLastCall();
+        replay(innerStoreMock);
+        store.init(context, store);
+
+        store.close();
+        verify(innerStoreMock);
+    }
+
+    @Test
+    public void shouldRemoveMetricsOnClose() {
+        innerStoreMock.close();
+        expectLastCall();
+        replay(innerStoreMock);
+        store.init(context, store);
+
+        assertThat(storeMetrics(), not(empty()));
+        store.close();
+        assertThat(storeMetrics(), empty());
+        verify(innerStoreMock);
+    }
+
+    @Test
+    public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() {
+        innerStoreMock.close();
+        expectLastCall().andThrow(new RuntimeException("Oops!"));
+        replay(innerStoreMock);
+        store.init(context, store);
+
+        // There's always a "count" metric registered
+        assertThat(storeMetrics(), not(empty()));
+        assertThrows(RuntimeException.class, store::close);
+        assertThat(storeMetrics(), empty());
+        verify(innerStoreMock);
+    }
+
+    private List<MetricName> storeMetrics() {
+        return metrics.metrics()
+                      .keySet()
+                      .stream()
+                      .filter(name -> name.group().equals(storeLevelGroup) && 
name.tags().equals(tags))
+                      .collect(Collectors.toList());
+    }
 }

Reply via email to