This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new cc90f0bc3b7 KAFKA-20173: Metered layer of KV-stores needs to pass 
Headers (#21684)
cc90f0bc3b7 is described below

commit cc90f0bc3b7cad9f62084e7fda2223da3ab2fd01
Author: Matthias J. Sax <[email protected]>
AuthorDate: Wed Mar 11 15:53:54 2026 -0700

    KAFKA-20173: Metered layer of KV-stores needs to pass Headers (#21684)
    
    Updates the metered ks-stores layer to pass the context headers into
    serdes. Simplifies the code with some refactoring.
    
    Reviewers: Uladzislau Blok <[email protected]>, Alieh Saeedi
    <[email protected]>, Bill Bejeck <[email protected]>, TengYao Chi
    <[email protected]>
---
 .../state/internals/MeteredKeyValueStore.java      | 125 ++++++++-------
 .../internals/MeteredTimestampedKeyValueStore.java |  78 +++++-----
 ...MeteredTimestampedKeyValueStoreWithHeaders.java |   2 +-
 .../internals/MeteredVersionedKeyValueStore.java   |  11 +-
 .../streams/state/KeyValueStoreTestDriver.java     |  11 +-
 .../kafka/streams/state/StateSerdesTest.java       |  11 +-
 .../state/internals/MeteredKeyValueStoreTest.java  |   1 +
 .../MeteredTimestampedKeyValueStoreTest.java       |   3 +-
 ...redTimestampedKeyValueStoreWithHeadersTest.java |   1 +
 .../MeteredVersionedKeyValueStoreTest.java         |   1 +
 .../processor/api/MockProcessorContext.java        | 170 +++++++++++++++------
 .../streams/test/MockProcessorContextAPITest.java  |  76 ++++++++-
 .../test/MockProcessorContextStateStoreTest.java   |  33 ++--
 13 files changed, 343 insertions(+), 180 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
index 6ad53f0974d..0535e2f89e9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java
@@ -17,7 +17,6 @@
 package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serializer;
@@ -64,7 +63,6 @@ import java.util.function.Function;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
-import static org.apache.kafka.streams.state.internals.Utils.keyBytes;
 
 /**
  * A Metered {@link KeyValueStore} wrapper that is used for recording 
operation metrics, and hence its
@@ -75,8 +73,6 @@ import static 
org.apache.kafka.streams.state.internals.Utils.keyBytes;
  * @param <K>
  * @param <V>
  */
-// TODO: replace with new method in follow-up PR of KIP-1271
-@SuppressWarnings("deprecation")
 public class MeteredKeyValueStore<K, V>
     extends WrappedStateStore<KeyValueStore<Bytes, byte[]>, K, V>
     implements KeyValueStore<K, V>, MeteredStateStore {
@@ -120,11 +116,13 @@ public class MeteredKeyValueStore<K, V>
             )
         );
 
-    MeteredKeyValueStore(final KeyValueStore<Bytes, byte[]> inner,
-                         final String metricsScope,
-                         final Time time,
-                         final Serde<K> keySerde,
-                         final Serde<V> valueSerde) {
+    MeteredKeyValueStore(
+        final KeyValueStore<Bytes, byte[]> inner,
+        final String metricsScope,
+        final Time time,
+        final Serde<K> keySerde,
+        final Serde<V> valueSerde
+    ) {
         super(inner);
         this.metricsScope = metricsScope;
         this.time = time != null ? time : Time.SYSTEM;
@@ -133,9 +131,8 @@ public class MeteredKeyValueStore<K, V>
     }
 
     @Override
-    public void init(final StateStoreContext stateStoreContext,
-                     final StateStore root) {
-        internalContext = stateStoreContext instanceof 
InternalProcessorContext ? (InternalProcessorContext<?, ?>) stateStoreContext : 
null;
+    public void init(final StateStoreContext stateStoreContext, final 
StateStore root) {
+        internalContext = (InternalProcessorContext<?, ?>) stateStoreContext;
         taskId = stateStoreContext.taskId();
         initStoreSerde(stateStoreContext);
         streamsMetrics = (StreamsMetricsImpl) stateStoreContext.metrics();
@@ -159,9 +156,18 @@ public class MeteredKeyValueStore<K, V>
         deleteSensor = StateStoreMetrics.deleteSensor(taskId.toString(), 
metricsScope, name(), streamsMetrics);
         e2eLatencySensor = 
StateStoreMetrics.e2ELatencySensor(taskId.toString(), metricsScope, name(), 
streamsMetrics);
         iteratorDurationSensor = 
StateStoreMetrics.iteratorDurationSensor(taskId.toString(), metricsScope, 
name(), streamsMetrics);
-        StateStoreMetrics.addNumOpenIteratorsGauge(taskId.toString(), 
metricsScope, name(), streamsMetrics,
-                (config, now) -> numOpenIterators.sum());
-        StateStoreMetrics.addOldestOpenIteratorGauge(taskId.toString(), 
metricsScope, name(), streamsMetrics,
+        StateStoreMetrics.addNumOpenIteratorsGauge(
+            taskId.toString(),
+            metricsScope,
+            name(),
+            streamsMetrics,
+            (config, now) -> numOpenIterators.sum()
+        );
+        StateStoreMetrics.addOldestOpenIteratorGauge(
+            taskId.toString(),
+            metricsScope,
+            name(),
+            streamsMetrics,
             (config, now) -> {
                 try {
                     final Iterator<MeteredIterator> iter = 
openIterators.iterator();
@@ -169,7 +175,8 @@ public class MeteredKeyValueStore<K, V>
                 } catch (final NoSuchElementException e) {
                     return 0L;
                 }
-            });
+            }
+        );
     }
 
     @Override
@@ -185,24 +192,33 @@ public class MeteredKeyValueStore<K, V>
         final String storeName = name();
         final String changelogTopic = 
ProcessorContextUtils.changelogFor(context, storeName, Boolean.FALSE);
         serdes = StoreSerdeInitializer.prepareStoreSerde(
-            context, storeName, changelogTopic, keySerde, valueSerde, 
this::prepareValueSerdeForStore);
+            context,
+            storeName,
+            changelogTopic,
+            keySerde,
+            valueSerde,
+            this::prepareValueSerdeForStore
+        );
     }
 
     @SuppressWarnings("unchecked")
     @Override
-    public boolean setFlushListener(final CacheFlushListener<K, V> listener,
-                                    final boolean sendOldValues) {
+    public boolean setFlushListener(final CacheFlushListener<K, V> listener, 
final boolean sendOldValues) {
         final KeyValueStore<Bytes, byte[]> wrapped = wrapped();
         if (wrapped instanceof CachedStateStore) {
             return ((CachedStateStore<byte[], byte[]>) 
wrapped).setFlushListener(
-                record -> listener.apply(
-                    record.withKey(serdes.keyFrom(record.key()))
-                        .withValue(new Change<>(
-                            record.value().newValue != null ? 
serdes.valueFrom(record.value().newValue) : null,
-                            record.value().oldValue != null ? 
serdes.valueFrom(record.value().oldValue) : null,
-                            record.value().isLatest
-                        ))
-                ),
+                record -> {
+                    final Change<byte[]> change = record.value();
+                    listener.apply(
+                        record
+                            .withKey(serdes.keyFrom(record.key(), 
record.headers()))
+                            .withValue(new Change<>(
+                                change.newValue != null ? 
serdes.valueFrom(change.newValue, record.headers()) : null,
+                                change.oldValue != null ? 
serdes.valueFrom(change.oldValue, record.headers()) : null,
+                                change.isLatest
+                            ))
+                    );
+                },
                 sendOldValues);
         }
         return false;
@@ -255,8 +271,8 @@ public class MeteredKeyValueStore<K, V>
         RangeQuery<Bytes, byte[]> rawRangeQuery;
         final ResultOrder order = typedQuery.resultOrder();
         rawRangeQuery = RangeQuery.withRange(
-                keyBytes(typedQuery.getLowerBound().orElse(null), serdes),
-                keyBytes(typedQuery.getUpperBound().orElse(null), serdes)
+            serializeKey(typedQuery.getLowerBound().orElse(null)),
+            serializeKey(typedQuery.getUpperBound().orElse(null))
         );
         if (order.equals(ResultOrder.DESCENDING)) {
             rawRangeQuery = rawRangeQuery.withDescendingKeys();
@@ -293,7 +309,7 @@ public class MeteredKeyValueStore<K, V>
         final QueryResult<R> result;
         final KeyQuery<K, V> typedKeyQuery = (KeyQuery<K, V>) query;
         final KeyQuery<Bytes, byte[]> rawKeyQuery =
-            KeyQuery.withKey(keyBytes(typedKeyQuery.getKey(), serdes));
+            KeyQuery.withKey(serializeKey(typedKeyQuery.getKey()));
         final QueryResult<byte[]> rawResult =
             wrapped().query(rawKeyQuery, positionBound, config);
         if (rawResult.isSuccess()) {
@@ -313,7 +329,7 @@ public class MeteredKeyValueStore<K, V>
     public V get(final K key) {
         Objects.requireNonNull(key, "key cannot be null");
         try {
-            return maybeMeasureLatency(() -> 
outerValue(wrapped().get(keyBytes(key, serdes))), time, getSensor);
+            return maybeMeasureLatency(() -> 
deserializeValue(wrapped().get(serializeKey(key))), time, getSensor);
         } catch (final ProcessorStateException e) {
             final String message = String.format(e.getMessage(), key);
             throw new ProcessorStateException(message, e);
@@ -325,7 +341,7 @@ public class MeteredKeyValueStore<K, V>
                     final V value) {
         Objects.requireNonNull(key, "key cannot be null");
         try {
-            maybeMeasureLatency(() -> wrapped().put(keyBytes(key, serdes), 
serdes.rawValue(value, new RecordHeaders())), time, putSensor);
+            maybeMeasureLatency(() -> wrapped().put(serializeKey(key), 
serializeValue(value)), time, putSensor);
             maybeRecordE2ELatency();
         } catch (final ProcessorStateException e) {
             final String message = String.format(e.getMessage(), key, value);
@@ -338,7 +354,7 @@ public class MeteredKeyValueStore<K, V>
                          final V value) {
         Objects.requireNonNull(key, "key cannot be null");
         final V currentValue = maybeMeasureLatency(
-            () -> outerValue(wrapped().putIfAbsent(keyBytes(key, serdes), 
serdes.rawValue(value))),
+            () -> deserializeValue(wrapped().putIfAbsent(serializeKey(key), 
serializeValue(value))),
             time,
             putIfAbsentSensor
         );
@@ -356,7 +372,7 @@ public class MeteredKeyValueStore<K, V>
     public V delete(final K key) {
         Objects.requireNonNull(key, "key cannot be null");
         try {
-            return maybeMeasureLatency(() -> 
outerValue(wrapped().delete(keyBytes(key, serdes))), time, deleteSensor);
+            return maybeMeasureLatency(() -> 
deserializeValue(wrapped().delete(serializeKey(key))), time, deleteSensor);
         } catch (final ProcessorStateException e) {
             final String message = String.format(e.getMessage(), key);
             throw new ProcessorStateException(message, e);
@@ -373,10 +389,8 @@ public class MeteredKeyValueStore<K, V>
     @Override
     public KeyValueIterator<K, V> range(final K from,
                                         final K to) {
-        final byte[] serFrom = from == null ? null : serdes.rawKey(from);
-        final byte[] serTo = to == null ? null : serdes.rawKey(to);
         return new MeteredKeyValueIterator(
-            wrapped().range(Bytes.wrap(serFrom), Bytes.wrap(serTo)),
+            wrapped().range(serializeKey(from), serializeKey(to)),
             rangeSensor
         );
     }
@@ -384,10 +398,8 @@ public class MeteredKeyValueStore<K, V>
     @Override
     public KeyValueIterator<K, V> reverseRange(final K from,
                                                final K to) {
-        final byte[] serFrom = from == null ? null : serdes.rawKey(from);
-        final byte[] serTo = to == null ? null : serdes.rawKey(to);
         return new MeteredKeyValueIterator(
-            wrapped().reverseRange(Bytes.wrap(serFrom), Bytes.wrap(serTo)),
+            wrapped().reverseRange(serializeKey(from), serializeKey(to)),
             rangeSensor
         );
     }
@@ -421,21 +433,31 @@ public class MeteredKeyValueStore<K, V>
         }
     }
 
-    protected V outerValue(final byte[] value) {
-        return value != null ? serdes.valueFrom(value, new RecordHeaders()) : 
null;
+    protected byte[] serializeValue(final V value) {
+        return value != null ? serdes.rawValue(value, 
internalContext.headers()) : null;
+    }
+
+    protected V deserializeValue(final byte[] rawValue) {
+        return rawValue != null ? serdes.valueFrom(rawValue, 
internalContext.headers()) : null;
+    }
+
+    protected Bytes serializeKey(final K key) {
+        return Bytes.wrap(serdes.rawKey(key, internalContext.headers()));
+    }
+
+    protected K deserializeKey(final byte[] rawKey) {
+        return rawKey != null ? serdes.keyFrom(rawKey, 
internalContext.headers()) : null;
     }
 
     private List<KeyValue<Bytes, byte[]>> innerEntries(final List<KeyValue<K, 
V>> from) {
         final List<KeyValue<Bytes, byte[]>> byteEntries = new ArrayList<>();
         for (final KeyValue<K, V> entry : from) {
-            
byteEntries.add(KeyValue.pair(Bytes.wrap(serdes.rawKey(entry.key)), 
serdes.rawValue(entry.value)));
+            byteEntries.add(KeyValue.pair(serializeKey(entry.key), 
serializeValue(entry.value)));
         }
         return byteEntries;
     }
 
     protected void maybeRecordE2ELatency() {
-        // Context is null if the provided context isn't an implementation of 
InternalProcessorContext.
-        // In that case, we _can't_ get the current timestamp, so we don't 
record anything.
         if (e2eLatencySensor.shouldRecord() && internalContext != null) {
             final long currentTime = time.milliseconds();
             final long e2eLatency =  currentTime - 
internalContext.recordContext().timestamp();
@@ -474,8 +496,8 @@ public class MeteredKeyValueStore<K, V>
         public KeyValue<K, V> next() {
             final KeyValue<Bytes, byte[]> keyValue = iter.next();
             return KeyValue.pair(
-                serdes.keyFrom(keyValue.key.get()),
-                outerValue(keyValue.value));
+                deserializeKey(keyValue.key.get()),
+                deserializeValue(keyValue.value));
         }
 
         @Override
@@ -493,7 +515,7 @@ public class MeteredKeyValueStore<K, V>
 
         @Override
         public K peekNextKey() {
-            return serdes.keyFrom(iter.peekNextKey().get());
+            return deserializeKey(iter.peekNextKey().get());
         }
     }
 
@@ -533,8 +555,9 @@ public class MeteredKeyValueStore<K, V>
         public KeyValue<K, V> next() {
             final KeyValue<Bytes, byte[]> keyValue = iter.next();
             return KeyValue.pair(
-                    serdes.keyFrom(keyValue.key.get()),
-                    valueDeserializer.apply(keyValue.value));
+                deserializeKey(keyValue.key.get()),
+                valueDeserializer.apply(keyValue.value)
+            );
         }
 
         @Override
@@ -552,7 +575,7 @@ public class MeteredKeyValueStore<K, V>
 
         @Override
         public K peekNextKey() {
-            return serdes.keyFrom(iter.peekNextKey().get());
+            return deserializeKey(iter.peekNextKey().get());
         }
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java
index 6c7f62d93de..4a4533cd804 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java
@@ -44,7 +44,6 @@ import java.util.function.Function;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
-import static org.apache.kafka.streams.state.internals.Utils.keyBytes;
 
 /**
  * A Metered {@link TimestampedKeyValueStore} wrapper that is used for 
recording operation metrics, and hence its
@@ -54,17 +53,17 @@ import static 
org.apache.kafka.streams.state.internals.Utils.keyBytes;
  * @param <K>
  * @param <V>
  */
-// TODO: replace with new method in follow-up PR of KIP-1271
-@SuppressWarnings("deprecation")
 public class MeteredTimestampedKeyValueStore<K, V>
     extends MeteredKeyValueStore<K, ValueAndTimestamp<V>> 
     implements TimestampedKeyValueStore<K, V> {
 
-    MeteredTimestampedKeyValueStore(final KeyValueStore<Bytes, byte[]> inner,
-                                    final String metricScope,
-                                    final Time time,
-                                    final Serde<K> keySerde,
-                                    final Serde<ValueAndTimestamp<V>> 
valueSerde) {
+    MeteredTimestampedKeyValueStore(
+        final KeyValueStore<Bytes, byte[]> inner,
+        final String metricScope,
+        final Time time,
+        final Serde<K> keySerde,
+        final Serde<ValueAndTimestamp<V>> valueSerde
+    ) {
         super(inner, metricScope, time, keySerde, valueSerde);
     }
 
@@ -99,29 +98,35 @@ public class MeteredTimestampedKeyValueStore<K, V>
         }
     }
 
-    public RawAndDeserializedValue<V> getWithBinary(final K key) {
+    RawAndDeserializedValue<V> getWithBinary(final K key) {
         try {
-            return maybeMeasureLatency(() -> { 
-                final byte[] serializedValue = wrapped().get(keyBytes(key, 
serdes));
-                return new RawAndDeserializedValue<>(serializedValue, 
outerValue(serializedValue));
-            }, time, getSensor);
+            return maybeMeasureLatency(
+                () -> {
+                    final byte[] rawValue = wrapped().get(serializeKey(key));
+                    return new RawAndDeserializedValue<>(rawValue, 
deserializeValue(rawValue));
+                },
+                time,
+                getSensor
+            );
         } catch (final ProcessorStateException e) {
             final String message = String.format(e.getMessage(), key);
             throw new ProcessorStateException(message, e);
         }
     }
 
-    public boolean putIfDifferentValues(final K key,
-                                        final ValueAndTimestamp<V> newValue,
-                                        final byte[] oldSerializedValue) {
+    public boolean putIfDifferentValues(
+        final K key,
+        final ValueAndTimestamp<V> newValue,
+        final byte[] oldSerializedValue
+    ) {
         try {
             return maybeMeasureLatency(
                 () -> {
-                    final byte[] newSerializedValue = 
serdes.rawValue(newValue);
-                    if 
(ValueAndTimestampSerializer.valuesAreSameAndTimeIsIncreasing(oldSerializedValue,
 newSerializedValue)) {
+                    final byte[] rawNewValue = serializeValue(newValue);
+                    if 
(ValueAndTimestampSerializer.valuesAreSameAndTimeIsIncreasing(oldSerializedValue,
 rawNewValue)) {
                         return false;
                     } else {
-                        wrapped().put(keyBytes(key, serdes), 
newSerializedValue);
+                        wrapped().put(serializeKey(key), rawNewValue);
                         return true;
                     }
                 },
@@ -135,10 +140,10 @@ public class MeteredTimestampedKeyValueStore<K, V>
     }
 
     static class RawAndDeserializedValue<ValueType> {
-        final byte[] serializedValue;
+        final byte[] rawValue;
         final ValueAndTimestamp<ValueType> value;
-        RawAndDeserializedValue(final byte[] serializedValue, final 
ValueAndTimestamp<ValueType> value) {
-            this.serializedValue = serializedValue;
+        RawAndDeserializedValue(final byte[] rawValue, final 
ValueAndTimestamp<ValueType> value) {
+            this.rawValue = rawValue;
             this.value = value;
         }
     }
@@ -175,18 +180,14 @@ public class MeteredTimestampedKeyValueStore<K, V>
         return result;
     }
 
-
-
     @SuppressWarnings("unchecked")
     private <R> QueryResult<R> runTimestampedKeyQuery(final Query<R> query,
                                                       final PositionBound 
positionBound,
                                                       final QueryConfig 
config) {
         final QueryResult<R> result;
         final TimestampedKeyQuery<K, V> typedKeyQuery = 
(TimestampedKeyQuery<K, V>) query;
-        final KeyQuery<Bytes, byte[]> rawKeyQuery =
-                KeyQuery.withKey(keyBytes(typedKeyQuery.key(), serdes));
-        final QueryResult<byte[]> rawResult =
-                wrapped().query(rawKeyQuery, positionBound, config);
+        final KeyQuery<Bytes, byte[]> rawKeyQuery = 
KeyQuery.withKey(serializeKey(typedKeyQuery.key()));
+        final QueryResult<byte[]> rawResult = wrapped().query(rawKeyQuery, 
positionBound, config);
         if (rawResult.isSuccess()) {
             final Function<byte[], ValueAndTimestamp<V>> deserializer = 
StoreQueryUtils.deserializeValue(serdes, wrapped());
             final ValueAndTimestamp<V> valueAndTimestamp = 
deserializer.apply(rawResult.getResult());
@@ -210,8 +211,8 @@ public class MeteredTimestampedKeyValueStore<K, V>
         RangeQuery<Bytes, byte[]> rawRangeQuery;
         final ResultOrder order = typedQuery.resultOrder();
         rawRangeQuery = RangeQuery.withRange(
-                keyBytes(typedQuery.lowerBound().orElse(null), serdes),
-                keyBytes(typedQuery.upperBound().orElse(null), serdes)
+            serializeKey(typedQuery.lowerBound().orElse(null)),
+            serializeKey(typedQuery.upperBound().orElse(null))
         );
         if (order.equals(ResultOrder.DESCENDING)) {
             rawRangeQuery = rawRangeQuery.withDescendingKeys();
@@ -248,8 +249,7 @@ public class MeteredTimestampedKeyValueStore<K, V>
                                              final QueryConfig config) {
         final QueryResult<R> result;
         final KeyQuery<K, V> typedKeyQuery = (KeyQuery<K, V>) query;
-        final KeyQuery<Bytes, byte[]> rawKeyQuery =
-                KeyQuery.withKey(keyBytes(typedKeyQuery.getKey(), serdes));
+        final KeyQuery<Bytes, byte[]> rawKeyQuery = 
KeyQuery.withKey(serializeKey(typedKeyQuery.getKey()));
         final QueryResult<byte[]> rawResult =
                 wrapped().query(rawKeyQuery, positionBound, config);
         if (rawResult.isSuccess()) {
@@ -276,8 +276,8 @@ public class MeteredTimestampedKeyValueStore<K, V>
         RangeQuery<Bytes, byte[]> rawRangeQuery;
         final ResultOrder order = typedQuery.resultOrder();
         rawRangeQuery = RangeQuery.withRange(
-                keyBytes(typedQuery.getLowerBound().orElse(null), serdes),
-                keyBytes(typedQuery.getUpperBound().orElse(null), serdes)
+            serializeKey(typedQuery.getLowerBound().orElse(null)),
+            serializeKey(typedQuery.getUpperBound().orElse(null))
         );
         if (order.equals(ResultOrder.DESCENDING)) {
             rawRangeQuery = rawRangeQuery.withDescendingKeys();
@@ -347,12 +347,12 @@ public class MeteredTimestampedKeyValueStore<K, V>
             final KeyValue<Bytes, byte[]> keyValue = iter.next();
             if (returnPlainValue) {
                 final V plainValue = 
valueAndTimestampDeserializer.apply(keyValue.value).value();
-                return KeyValue.pair(
-                        serdes.keyFrom(keyValue.key.get()), plainValue);
+                return KeyValue.pair(deserializeKey(keyValue.key.get()), 
plainValue);
             }
             return (KeyValue<K, V>) KeyValue.pair(
-                    serdes.keyFrom(keyValue.key.get()),
-                    valueAndTimestampDeserializer.apply(keyValue.value));
+                deserializeKey(keyValue.key.get()),
+                valueAndTimestampDeserializer.apply(keyValue.value)
+            );
         }
         @Override
         public void close() {
@@ -368,7 +368,7 @@ public class MeteredTimestampedKeyValueStore<K, V>
 
         @Override
         public K peekNextKey() {
-            return serdes.keyFrom(iter.peekNextKey().get());
+            return deserializeKey(iter.peekNextKey().get());
         }
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeaders.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeaders.java
index 12bb9501700..8730aadcc11 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeaders.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeaders.java
@@ -124,7 +124,7 @@ public class MeteredTimestampedKeyValueStoreWithHeaders<K, 
V>
         Objects.requireNonNull(key, "key cannot be null");
         final Headers headers = value != null ? value.headers() : new 
RecordHeaders();
         final ValueTimestampHeaders<V> currentValue = maybeMeasureLatency(
-            () -> outerValue(wrapped().putIfAbsent(keyBytes(key, headers), 
serdes.rawValue(value, headers))),
+            () -> deserializeValue(wrapped().putIfAbsent(keyBytes(key, 
headers), serdes.rawValue(value, headers))),
             time,
             putIfAbsentSensor
         );
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStore.java
index 6846f017dd9..c98c5590d7a 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStore.java
@@ -54,7 +54,6 @@ import java.util.Objects;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
-import static org.apache.kafka.streams.state.internals.Utils.keyBytes;
 
 /**
  * A metered {@link VersionedKeyValueStore} wrapper that is used for recording 
operation
@@ -146,7 +145,7 @@ public class MeteredVersionedKeyValueStore<K, V>
         public long put(final K key, final V value, final long timestamp) {
             Objects.requireNonNull(key, "key cannot be null");
             try {
-                final long validTo = maybeMeasureLatency(() -> 
inner.put(keyBytes(key, serdes), plainValueSerdes.rawValue(value), timestamp), 
time, putSensor);
+                final long validTo = maybeMeasureLatency(() -> 
inner.put(serializeKey(key), plainValueSerdes.rawValue(value), timestamp), 
time, putSensor);
                 maybeRecordE2ELatency();
                 return validTo;
             } catch (final ProcessorStateException e) {
@@ -158,7 +157,7 @@ public class MeteredVersionedKeyValueStore<K, V>
         public ValueAndTimestamp<V> get(final K key, final long asOfTimestamp) 
{
             Objects.requireNonNull(key, "key cannot be null");
             try {
-                return maybeMeasureLatency(() -> 
outerValue(inner.get(keyBytes(key, serdes), asOfTimestamp)), time, getSensor);
+                return maybeMeasureLatency(() -> 
deserializeValue(inner.get(serializeKey(key), asOfTimestamp)), time, getSensor);
             } catch (final ProcessorStateException e) {
                 final String message = String.format(e.getMessage(), key);
                 throw new ProcessorStateException(message, e);
@@ -168,7 +167,7 @@ public class MeteredVersionedKeyValueStore<K, V>
         public ValueAndTimestamp<V> delete(final K key, final long timestamp) {
             Objects.requireNonNull(key, "key cannot be null");
             try {
-                return maybeMeasureLatency(() -> 
outerValue(inner.delete(keyBytes(key, serdes), timestamp)), time, deleteSensor);
+                return maybeMeasureLatency(() -> 
deserializeValue(inner.delete(serializeKey(key), timestamp)), time, 
deleteSensor);
             } catch (final ProcessorStateException e) {
                 final String message = String.format(e.getMessage(), key);
                 throw new ProcessorStateException(message, e);
@@ -229,7 +228,7 @@ public class MeteredVersionedKeyValueStore<K, V>
                                                           final QueryConfig 
config) {
             final QueryResult<R> result;
             final VersionedKeyQuery<K, V> typedKeyQuery = 
(VersionedKeyQuery<K, V>) query;
-            VersionedKeyQuery<Bytes, byte[]> rawKeyQuery = 
VersionedKeyQuery.withKey(keyBytes(typedKeyQuery.key(), serdes));
+            VersionedKeyQuery<Bytes, byte[]> rawKeyQuery = 
VersionedKeyQuery.withKey(serializeKey(typedKeyQuery.key()));
             if (typedKeyQuery.asOfTimestamp().isPresent()) {
                 rawKeyQuery = 
rawKeyQuery.asOf(typedKeyQuery.asOfTimestamp().get());
             }
@@ -257,7 +256,7 @@ public class MeteredVersionedKeyValueStore<K, V>
             if (fromTime.compareTo(toTime) > 0) {
                 throw new IllegalArgumentException("The `fromTime` timestamp 
must be smaller than the `toTime` timestamp.");
             }
-            MultiVersionedKeyQuery<Bytes, byte[]> rawKeyQuery = 
MultiVersionedKeyQuery.withKey(keyBytes(typedKeyQuery.key(), serdes));
+            MultiVersionedKeyQuery<Bytes, byte[]> rawKeyQuery = 
MultiVersionedKeyQuery.withKey(serializeKey(typedKeyQuery.key()));
             rawKeyQuery = rawKeyQuery.fromTime(fromTime).toTime(toTime);
             if (typedKeyQuery.resultOrder().equals(ResultOrder.DESCENDING)) {
                 rawKeyQuery = rawKeyQuery.withDescendingTimestamps();
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
index 51cbe8f940d..8d2a330f7b1 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.state;
 
 import org.apache.kafka.clients.producer.MockProducer;
 import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serdes;
@@ -240,8 +241,8 @@ public class KeyValueStoreTestDriver<K, V> {
                 final byte[] keyBytes = keySerializer.serialize(topic, 
headers, key);
                 final byte[] valueBytes = valueSerializer.serialize(topic, 
headers, value);
 
-                final K keyTest = serdes.keyFrom(keyBytes);
-                final V valueTest = serdes.valueFrom(valueBytes);
+                final K keyTest = serdes.keyFrom(keyBytes, headers);
+                final V valueTest = serdes.valueFrom(valueBytes, headers);
 
                 recordCommitted(keyTest, valueTest);
             }
@@ -338,7 +339,7 @@ public class KeyValueStoreTestDriver<K, V> {
      * @see #checkForRestoredEntries(KeyValueStore)
      */
     public void addEntryToRestoreLog(final K key, final V value) {
-        restorableEntries.add(new KeyValue<>(stateSerdes.rawKey(key), 
stateSerdes.rawValue(value)));
+        restorableEntries.add(new KeyValue<>(stateSerdes.rawKey(key, new 
RecordHeaders()), stateSerdes.rawValue(value, new RecordHeaders())));
     }
 
     /**
@@ -368,8 +369,8 @@ public class KeyValueStoreTestDriver<K, V> {
         int missing = 0;
         for (final KeyValue<byte[], byte[]> kv : restorableEntries) {
             if (kv != null) {
-                final V value = store.get(stateSerdes.keyFrom(kv.key));
-                if (!Objects.equals(value, stateSerdes.valueFrom(kv.value))) {
+                final V value = store.get(stateSerdes.keyFrom(kv.key, new 
RecordHeaders()));
+                if (!Objects.equals(value, stateSerdes.valueFrom(kv.value, new 
RecordHeaders()))) {
                     ++missing;
                 }
             }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java 
b/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java
index 8af9e9c90dc..4b1bf468fc3 100644
--- a/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java
@@ -58,6 +58,7 @@ public class StateSerdesTest {
         assertThrows(NullPointerException.class, () -> 
StateSerdes.withBuiltinTypes("anyName", byte[].class, null));
     }
 
+    @SuppressWarnings("rawtypes")
     @Test
     public void shouldReturnSerdesForBuiltInKeyAndValueTypesForBuiltinTypes() {
         final Class[] supportedBuildInTypes = new Class[] {
@@ -106,10 +107,11 @@ public class StateSerdesTest {
 
     @Test
     public void shouldThrowIfIncompatibleSerdeForValue() throws 
ClassNotFoundException {
+        @SuppressWarnings("rawtypes")
         final Class myClass = Class.forName("java.lang.String");
         final StateSerdes<Object, Object> stateSerdes = new 
StateSerdes<Object, Object>("anyName", Serdes.serdeFrom(myClass), 
Serdes.serdeFrom(myClass));
         final Integer myInt = 123;
-        final Exception e = assertThrows(StreamsException.class, () -> 
stateSerdes.rawValue(myInt));
+        final Exception e = assertThrows(StreamsException.class, () -> 
stateSerdes.rawValue(myInt, new RecordHeaders()));
         assertThat(
             e.getMessage(),
             equalTo(
@@ -120,11 +122,13 @@ public class StateSerdesTest {
 
     @Test
     public void 
shouldSkipValueAndTimestampeInformationForErrorOnTimestampAndValueSerialization()
 throws ClassNotFoundException {
+        @SuppressWarnings("rawtypes")
         final Class myClass = Class.forName("java.lang.String");
+        @SuppressWarnings("rawtypes")
         final StateSerdes<Object, Object> stateSerdes =
             new StateSerdes<Object, Object>("anyName", 
Serdes.serdeFrom(myClass), new 
ValueAndTimestampSerde(Serdes.serdeFrom(myClass)));
         final Integer myInt = 123;
-        final Exception e = assertThrows(StreamsException.class, () -> 
stateSerdes.rawValue(ValueAndTimestamp.make(myInt, 0L)));
+        final Exception e = assertThrows(StreamsException.class, () -> 
stateSerdes.rawValue(ValueAndTimestamp.make(myInt, 0L), new RecordHeaders()));
         assertThat(
             e.getMessage(),
             equalTo(
@@ -135,10 +139,11 @@ public class StateSerdesTest {
 
     @Test
     public void shouldThrowIfIncompatibleSerdeForKey() throws 
ClassNotFoundException {
+        @SuppressWarnings("rawtypes")
         final Class myClass = Class.forName("java.lang.String");
         final StateSerdes<Object, Object> stateSerdes = new 
StateSerdes<Object, Object>("anyName", Serdes.serdeFrom(myClass), 
Serdes.serdeFrom(myClass));
         final Integer myInt = 123;
-        final Exception e = assertThrows(StreamsException.class, () -> 
stateSerdes.rawKey(myInt));
+        final Exception e = assertThrows(StreamsException.class, () -> 
stateSerdes.rawKey(myInt, new RecordHeaders()));
         assertThat(
             e.getMessage(),
             equalTo(
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
index cd6c77b5db6..f4a9efa0800 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java
@@ -175,6 +175,7 @@ public class MeteredKeyValueStoreTest {
         when(valueDeserializer.deserialize(topic, new RecordHeaders(), 
VALUE_BYTES)).thenReturn(VALUE);
         when(valueSerde.serializer()).thenReturn(valueSerializer);
         when(valueSerializer.serialize(topic, new RecordHeaders(), 
VALUE)).thenReturn(VALUE_BYTES);
+        when(context.headers()).thenReturn(new RecordHeaders());
         when(inner.get(KEY_BYTES)).thenReturn(VALUE_BYTES);
         metered = new MeteredKeyValueStore<>(
             inner,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java
index 12d5d8ea3ad..7583d452d33 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java
@@ -187,6 +187,7 @@ public class MeteredTimestampedKeyValueStoreTest {
         when(valueDeserializer.deserialize(topic, new RecordHeaders(), 
VALUE_AND_TIMESTAMP_BYTES)).thenReturn(VALUE_AND_TIMESTAMP);
         when(valueSerde.serializer()).thenReturn(valueSerializer);
         when(valueSerializer.serialize(topic, new RecordHeaders(), 
VALUE_AND_TIMESTAMP)).thenReturn(VALUE_AND_TIMESTAMP_BYTES);
+        when(context.headers()).thenReturn(new RecordHeaders());
         when(inner.get(KEY_BYTES)).thenReturn(VALUE_AND_TIMESTAMP_BYTES);
         metered = new MeteredTimestampedKeyValueStore<>(
             inner,
@@ -241,7 +242,7 @@ public class MeteredTimestampedKeyValueStoreTest {
 
         final RawAndDeserializedValue<String> valueWithBinary = 
metered.getWithBinary(KEY);
         assertEquals(VALUE_AND_TIMESTAMP, valueWithBinary.value);
-        assertArrayEquals(VALUE_AND_TIMESTAMP_BYTES, 
valueWithBinary.serializedValue);
+        assertArrayEquals(VALUE_AND_TIMESTAMP_BYTES, valueWithBinary.rawValue);
     }
 
     @Test
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeadersTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeadersTest.java
index ddaa24ce184..427b495e157 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeadersTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreWithHeadersTest.java
@@ -475,6 +475,7 @@ public class MeteredTimestampedKeyValueStoreWithHeadersTest 
{
         lenient().when(valueDeserializer.deserialize(eq(topic), 
any(RecordHeaders.class), 
any(byte[].class))).thenReturn(VALUE_TIMESTAMP_HEADERS);
         when(valueSerde.serializer()).thenReturn(valueSerializer);
         lenient().when(valueSerializer.serialize(eq(topic), 
any(RecordHeaders.class), 
eq(VALUE_TIMESTAMP_HEADERS))).thenReturn(VALUE_TIMESTAMP_HEADERS_BYTES);
+        when(context.headers()).thenReturn(new RecordHeaders());
         
when(inner.get(any(Bytes.class))).thenReturn(VALUE_TIMESTAMP_HEADERS_BYTES);
         metered = new MeteredTimestampedKeyValueStoreWithHeaders<>(
             inner,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStoreTest.java
index 9e047046345..4fcf0010d21 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredVersionedKeyValueStoreTest.java
@@ -168,6 +168,7 @@ public class MeteredVersionedKeyValueStoreTest {
         when(keySerde.serializer()).thenReturn(keySerializer);
         when(valueSerde.serializer()).thenReturn(valueSerializer);
         when(valueSerde.deserializer()).thenReturn(valueDeserializer);
+        when(context.headers()).thenReturn(new RecordHeaders());
 
         store.close();
         store = new MeteredVersionedKeyValueStore<>(
diff --git 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
index 8cea2fe3290..d44647526a8 100644
--- 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
+++ 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java
@@ -16,10 +16,13 @@
  */
 package org.apache.kafka.streams.processor.api;
 
+import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.serialization.Serde;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsMetrics;
@@ -33,11 +36,17 @@ import 
org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.To;
+import org.apache.kafka.streams.processor.internals.AbstractProcessorContext;
 import org.apache.kafka.streams.processor.internals.ClientUtils;
 import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.StateManager;
+import org.apache.kafka.streams.processor.internals.StreamTask;
 import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import java.io.File;
 import java.time.Duration;
@@ -493,64 +502,127 @@ public class MockProcessorContext<KForward, VForward> 
implements ProcessorContex
      * @return a {@link StateStoreContext} that delegates to this 
ProcessorContext.
      */
     public StateStoreContext getStateStoreContext() {
-        return new StateStoreContext() {
-            @Override
-            public String applicationId() {
-                return MockProcessorContext.this.applicationId();
-            }
+        return new MockContext();
+    }
 
-            @Override
-            public TaskId taskId() {
-                return MockProcessorContext.this.taskId();
-            }
+    @SuppressWarnings("unchecked")
+    private final class MockContext extends AbstractProcessorContext<Object, 
Object> implements StateStoreContext {
+        public MockContext() {
+            super(
+                new TaskId(0, 0),
+                new StreamsConfig(MockProcessorContext.this.appConfigs()),
+                (StreamsMetricsImpl) MockProcessorContext.this.metrics(),
+                new ThreadCache(new LogContext(), 0, (StreamsMetricsImpl) 
MockProcessorContext.this.metrics()));
+        }
+        @Override
+        public String applicationId() {
+            return MockProcessorContext.this.applicationId();
+        }
 
-            @Override
-            public Optional<RecordMetadata> recordMetadata() {
-                return MockProcessorContext.this.recordMetadata();
-            }
+        @Override
+        public TaskId taskId() {
+            return MockProcessorContext.this.taskId();
+        }
 
-            @Override
-            public Serde<?> keySerde() {
-                return MockProcessorContext.this.keySerde();
-            }
+        @Override
+        public Optional<RecordMetadata> recordMetadata() {
+            return MockProcessorContext.this.recordMetadata();
+        }
 
-            @Override
-            public Serde<?> valueSerde() {
-                return MockProcessorContext.this.valueSerde();
-            }
+        @Override
+        public Serde<?> keySerde() {
+            return MockProcessorContext.this.keySerde();
+        }
 
-            @Override
-            public File stateDir() {
-                return MockProcessorContext.this.stateDir();
-            }
+        @Override
+        public Serde<?> valueSerde() {
+            return MockProcessorContext.this.valueSerde();
+        }
 
-            @Override
-            public StreamsMetrics metrics() {
-                return MockProcessorContext.this.metrics();
-            }
+        @Override
+        public File stateDir() {
+            return MockProcessorContext.this.stateDir();
+        }
 
-            @Override
-            public void register(final StateStore store,
-                                 final StateRestoreCallback 
stateRestoreCallback) {
-                register(store, stateRestoreCallback, () -> { });
-            }
+        @Override
+        public StreamsMetricsImpl metrics() {
+            return (StreamsMetricsImpl) MockProcessorContext.this.metrics();
+        }
 
-            @Override
-            public void register(final StateStore store,
-                                 final StateRestoreCallback 
stateRestoreCallback,
-                                 final CommitCallback checkpoint) {
-                stateStores.put(store.name(), store);
-            }
+        @Override
+        public void register(final StateStore store,
+                             final StateRestoreCallback stateRestoreCallback) {
+            register(store, stateRestoreCallback, () -> { });
+        }
 
-            @Override
-            public Map<String, Object> appConfigs() {
-                return MockProcessorContext.this.appConfigs();
-            }
+        @Override
+        public void register(final StateStore store,
+                             final StateRestoreCallback stateRestoreCallback,
+                             final CommitCallback checkpoint) {
+            stateStores.put(store.name(), store);
+        }
 
-            @Override
-            public Map<String, Object> appConfigsWithPrefix(final String 
prefix) {
-                return MockProcessorContext.this.appConfigsWithPrefix(prefix);
-            }
-        };
+        @Override
+        public Map<String, Object> appConfigs() {
+            return MockProcessorContext.this.appConfigs();
+        }
+
+        @Override
+        public Map<String, Object> appConfigsWithPrefix(final String prefix) {
+            return MockProcessorContext.this.appConfigsWithPrefix(prefix);
+        }
+
+        // only needed for `AbstractProcessorContext` -- not expose to the user
+
+        @SuppressWarnings("rawtypes")
+        @Override
+        public void forward(final Record record, final String childName) { }
+        @SuppressWarnings("rawtypes")
+        @Override
+        public void forward(final Record record) { }
+        @SuppressWarnings("rawtypes")
+        @Override
+        public void forward(final FixedKeyRecord record, final String 
childName) { }
+        @SuppressWarnings("rawtypes")
+        @Override
+        public void forward(final FixedKeyRecord record) { }
+        @Override
+        public Cancellable schedule(final Duration interval, final 
PunctuationType type, final Punctuator callback) {
+            return null;
+        }
+        @Override
+        public Cancellable schedule(final Instant startTime, final Duration 
interval, final PunctuationType type, final Punctuator callback) {
+            return null;
+        }
+        @Override
+        public void commit() { }
+        @Override
+        public long currentStreamTimeMs() {
+            return 0;
+        }
+        @Override
+        public void forward(final Object key, final Object value, final To to) 
{ }
+        @Override
+        public void forward(final Object key, final Object value) { }
+        @Override
+        public StateStore getStateStore(final String name) {
+            return null;
+        }
+        @Override
+        public void transitionToActive(final StreamTask streamTask, final 
RecordCollector recordCollector, final ThreadCache newCache) { }
+        @Override
+        public void transitionToStandby(final ThreadCache newCache) { }
+        @Override
+        public void registerCacheFlushListener(final String namespace, final 
ThreadCache.DirtyEntryFlushListener listener) { }
+        @Override
+        public void logChange(final String storeName, final Bytes key, final 
byte[] value, final long timestamp, final Headers headers, final Position 
position) { }
+        @Override
+        public String changelogFor(final String storeName) {
+            return "changelog";
+        }
+        @Override
+        protected StateManager stateManager() {
+            return null;
+        }
     }
 }
diff --git 
a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java
 
b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java
index f2b7f8ee68c..e863fe8509d 100644
--- 
a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java
+++ 
b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java
@@ -16,25 +16,40 @@
  */
 package org.apache.kafka.streams.test;
 
+import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.processor.Cancellable;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
+import org.apache.kafka.streams.processor.StateRestoreCallback;
+import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.To;
+import org.apache.kafka.streams.processor.api.FixedKeyRecord;
 import org.apache.kafka.streams.processor.api.MockProcessorContext;
 import 
org.apache.kafka.streams.processor.api.MockProcessorContext.CapturedForward;
 import org.apache.kafka.streams.processor.api.Processor;
 import org.apache.kafka.streams.processor.api.ProcessorContext;
 import org.apache.kafka.streams.processor.api.Record;
 import org.apache.kafka.streams.processor.api.RecordMetadata;
+import org.apache.kafka.streams.processor.internals.AbstractProcessorContext;
+import org.apache.kafka.streams.processor.internals.RecordCollector;
+import org.apache.kafka.streams.processor.internals.StateManager;
+import org.apache.kafka.streams.processor.internals.StreamTask;
+import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
+import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.StoreBuilder;
 import org.apache.kafka.streams.state.Stores;
+import org.apache.kafka.streams.state.internals.ThreadCache;
 
 import org.junit.jupiter.api.Test;
 
 import java.io.File;
 import java.time.Duration;
+import java.time.Instant;
 import java.util.List;
 import java.util.Optional;
 import java.util.Properties;
@@ -224,7 +239,66 @@ public class MockProcessorContextAPITest {
 
         final KeyValueStore<String, Long> store = storeBuilder.build();
 
-        store.init(context.getStateStoreContext(), store);
+        store.init(
+            new AbstractProcessorContext<>(new TaskId(0, 0), new 
StreamsConfig(context.appConfigs()), (StreamsMetricsImpl) context.metrics(), 
null) {
+                @SuppressWarnings("rawtypes")
+                @Override
+                public void forward(final Record record, final String 
childName) { }
+                @SuppressWarnings("rawtypes")
+                @Override
+                public void forward(final Record record) { }
+                @SuppressWarnings("rawtypes")
+                @Override
+                public void forward(final FixedKeyRecord record, final String 
childName) { }
+                @SuppressWarnings("rawtypes")
+                @Override
+                public void forward(final FixedKeyRecord record) { }
+                @Override
+                public Cancellable schedule(final Duration interval, final 
PunctuationType type, final Punctuator callback) {
+                    return null;
+                }
+                @Override
+                public Cancellable schedule(final Instant startTime, final 
Duration interval, final PunctuationType type, final Punctuator callback) {
+                    return null;
+                }
+                @Override
+                public void commit() { }
+                @Override
+                public long currentStreamTimeMs() {
+                    return 0;
+                }
+                @Override
+                public void forward(final Object key, final Object value, 
final To to) { }
+                @Override
+                public void forward(final Object key, final Object value) { }
+                @SuppressWarnings("unchecked")
+                @Override
+                public StateStore getStateStore(final String name) {
+                    return null;
+                }
+                @Override
+                public void transitionToActive(final StreamTask streamTask, 
final RecordCollector recordCollector, final ThreadCache newCache) { }
+                @Override
+                public void transitionToStandby(final ThreadCache newCache) { }
+                @Override
+                public void registerCacheFlushListener(final String namespace, 
final ThreadCache.DirtyEntryFlushListener listener) { }
+                @Override
+                public void logChange(final String storeName, final Bytes key, 
final byte[] value, final long timestamp, final Headers headers, final Position 
position) { }
+                @Override
+                protected StateManager stateManager() {
+                    return null;
+                }
+                @Override
+                public String changelogFor(final String storeName) {
+                    return "changelog";
+                }
+                @Override
+                public void register(final StateStore store, final 
StateRestoreCallback stateRestoreCallback) {
+                    context.getStateStoreContext().register(store, 
stateRestoreCallback);
+                }
+            },
+            store
+        );
 
         processor.init(context);
 
diff --git 
a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java
 
b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java
index 29e19b4ce0a..4ff9ffe563b 100644
--- 
a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java
+++ 
b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java
@@ -22,8 +22,6 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.api.MockProcessorContext;
-import org.apache.kafka.streams.processor.internals.InternalProcessorContext;
-import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl;
 import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier;
 import org.apache.kafka.streams.state.KeyValueStore;
 import org.apache.kafka.streams.state.SessionBytesStoreSupplier;
@@ -50,9 +48,6 @@ import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkProperties;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 public class MockProcessorContextStateStoreTest {
 
@@ -157,10 +152,12 @@ public class MockProcessorContextStateStoreTest {
 
     @ParameterizedTest
     @MethodSource(value = "parameters")
-    public void shouldEitherInitOrThrow(final StoreBuilder<StateStore> builder,
-                                        final boolean timestamped,
-                                        final boolean caching,
-                                        final boolean logging) {
+    public void shouldWorkForAllStoreTypeAndSetups(
+        final StoreBuilder<StateStore> builder,
+        final boolean timestamped,
+        final boolean caching,
+        final boolean logging
+    ) {
         final File stateDir = TestUtils.tempDirectory();
         try {
             final MockProcessorContext<Void, Void> context = new 
MockProcessorContext<>(
@@ -172,20 +169,8 @@ public class MockProcessorContextStateStoreTest {
                 stateDir
             );
             final StateStore store = builder.build();
-            if (caching || logging) {
-                assertThrows(
-                    IllegalArgumentException.class,
-                    () -> store.init(context.getStateStoreContext(), store)
-                );
-            } else {
-                final InternalProcessorContext<?, ?> internalProcessorContext 
= mock(InternalProcessorContext.class);
-                
when(internalProcessorContext.taskId()).thenReturn(context.taskId());
-                when(internalProcessorContext.stateDir()).thenReturn(stateDir);
-                
when(internalProcessorContext.metrics()).thenReturn((StreamsMetricsImpl) 
context.metrics());
-                
when(internalProcessorContext.appConfigs()).thenReturn(context.appConfigs());
-                store.init(internalProcessorContext, store);
-                store.close();
-            }
+            store.init(context.getStateStoreContext(), store);
+            store.close();
         } finally {
             try {
                 Utils.delete(stateDir);
@@ -194,4 +179,4 @@ public class MockProcessorContextStateStoreTest {
             }
         }
     }
-}
+}
\ No newline at end of file

Reply via email to