mjsax commented on code in PR #21734:
URL: https://github.com/apache/kafka/pull/21734#discussion_r2934548673


##########
streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java:
##########
@@ -138,4 +237,72 @@ private <R> QueryResult<R> runRangeQuery(final Query<R> 
query,
             return (QueryResult<R>) rawResult;
         }
     }
+
+    private class MeteredSessionStoreWithHeadersIterator
+        implements KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>, 
MeteredIterator {
+
+        private final KeyValueIterator<Windowed<Bytes>, byte[]> iter;
+        private final long startNs;
+        private final long startTimestampMs;
+        private KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> cachedNext;
+
+        private MeteredSessionStoreWithHeadersIterator(final 
KeyValueIterator<Windowed<Bytes>, byte[]> iter) {
+            this.iter = iter;
+            this.startNs = time.nanoseconds();
+            this.startTimestampMs = time.milliseconds();
+            numOpenIterators.increment();
+            openIterators.add(this);
+        }
+
+        @Override
+        public long startTimestamp() {
+            return startTimestampMs;
+        }
+
+        @Override
+        public boolean hasNext() {
+            return cachedNext != null || iter.hasNext();
+        }
+
+        @Override
+        public KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> next() {
+            if (cachedNext != null) {
+                final KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> 
result = cachedNext;
+                cachedNext = null;
+                return result;
+            }
+
+            final KeyValue<Windowed<Bytes>, byte[]> next = iter.next();
+            if (next == null) {

Review Comment:
   I think we should remove this check. Cf my comments on 
https://github.com/apache/kafka/pull/21736



##########
streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java:
##########
@@ -806,4 +808,264 @@ public void shouldDelegateUnknownQueryToWrappedStore() {
 
         assertTrue(result.isFailure());
     }
+
+    // --- Tests verifying headers from value are used to deserialize keys ---
+
+    private static final Headers HEADERS = new RecordHeaders().add("key1", 
"value1".getBytes());
+    private static final AggregationWithHeaders<String> AGG_WITH_HEADERS = 
AggregationWithHeaders.make(VALUE, HEADERS);
+    private static final byte[] SERIALIZED_VALUE = new 
AggregationWithHeadersSerializer<>(Serdes.String().serializer())
+        .serialize(CHANGELOG_TOPIC, AGG_WITH_HEADERS);
+
+    @SuppressWarnings("unchecked")
+    private MeteredSessionStoreWithHeaders<String, String> 
createStoreWithMockSerdes(
+        final Serde<String> keySerde,
+        final Serde<AggregationWithHeaders<String>> valueSerde
+    ) {
+        final Deserializer<String> keyDeserializer = mock(Deserializer.class);
+        final Serializer<String> keySerializer = mock(Serializer.class);
+        final Deserializer<AggregationWithHeaders<String>> valueDeserializer = 
mock(Deserializer.class);
+
+        lenient().when(keySerde.deserializer()).thenReturn(keyDeserializer);
+        lenient().when(keySerde.serializer()).thenReturn(keySerializer);
+        
lenient().when(valueSerde.deserializer()).thenReturn(valueDeserializer);
+
+        lenient().when(keySerializer.serialize(any(), 
any(RecordHeaders.class), any())).thenReturn(KEY.getBytes());
+
+        lenient().when(valueDeserializer.deserialize(any(), 
any(RecordHeaders.class), eq(SERIALIZED_VALUE)))
+            .thenReturn(AGG_WITH_HEADERS);
+
+        lenient().when(keyDeserializer.deserialize(any(), eq(HEADERS), 
eq(KEY.getBytes())))
+            .thenReturn(KEY);
+
+        final MeteredSessionStoreWithHeaders<String, String> mockStore = new 
MeteredSessionStoreWithHeaders<>(
+            innerStore,
+            STORE_TYPE,
+            keySerde,
+            valueSerde,
+            new MockTime()
+        );
+        mockStore.init(context, mockStore);
+        return mockStore;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldUseHeadersFromValueToDeserializeKeyInFetch() {
+        setUp();
+        final Serde<String> keySerde = mock(Serde.class);
+        final Serde<AggregationWithHeaders<String>> valueSerde = 
mock(Serde.class);
+        final MeteredSessionStoreWithHeaders<String, String> store = 
createStoreWithMockSerdes(keySerde, valueSerde);
+
+        when(innerStore.fetch(any(Bytes.class)))
+            .thenReturn(new KeyValueIteratorStub<>(
+                List.of(KeyValue.pair(WINDOWED_KEY_BYTES, 
SERIALIZED_VALUE)).iterator()));
+
+        final KeyValueIterator<Windowed<String>, 
AggregationWithHeaders<String>> iterator = store.fetch(KEY);
+
+        assertTrue(iterator.hasNext());
+        final KeyValue<Windowed<String>, AggregationWithHeaders<String>> 
result = iterator.next();

Review Comment:
   Should we insert a assertEquals(KEY, iterator.peekNextKey().key); before 
this call? (Same for other tests below)
   
   cf https://github.com/apache/kafka/pull/21736



##########
streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java:
##########
@@ -806,4 +808,264 @@ public void shouldDelegateUnknownQueryToWrappedStore() {
 
         assertTrue(result.isFailure());
     }
+
+    // --- Tests verifying headers from value are used to deserialize keys ---
+
+    private static final Headers HEADERS = new RecordHeaders().add("key1", 
"value1".getBytes());
+    private static final AggregationWithHeaders<String> AGG_WITH_HEADERS = 
AggregationWithHeaders.make(VALUE, HEADERS);
+    private static final byte[] SERIALIZED_VALUE = new 
AggregationWithHeadersSerializer<>(Serdes.String().serializer())
+        .serialize(CHANGELOG_TOPIC, AGG_WITH_HEADERS);
+
+    @SuppressWarnings("unchecked")
+    private MeteredSessionStoreWithHeaders<String, String> 
createStoreWithMockSerdes(
+        final Serde<String> keySerde,
+        final Serde<AggregationWithHeaders<String>> valueSerde
+    ) {
+        final Deserializer<String> keyDeserializer = mock(Deserializer.class);
+        final Serializer<String> keySerializer = mock(Serializer.class);
+        final Deserializer<AggregationWithHeaders<String>> valueDeserializer = 
mock(Deserializer.class);
+
+        lenient().when(keySerde.deserializer()).thenReturn(keyDeserializer);
+        lenient().when(keySerde.serializer()).thenReturn(keySerializer);
+        
lenient().when(valueSerde.deserializer()).thenReturn(valueDeserializer);
+
+        lenient().when(keySerializer.serialize(any(), 
any(RecordHeaders.class), any())).thenReturn(KEY.getBytes());
+
+        lenient().when(valueDeserializer.deserialize(any(), 
any(RecordHeaders.class), eq(SERIALIZED_VALUE)))
+            .thenReturn(AGG_WITH_HEADERS);
+
+        lenient().when(keyDeserializer.deserialize(any(), eq(HEADERS), 
eq(KEY.getBytes())))
+            .thenReturn(KEY);
+
+        final MeteredSessionStoreWithHeaders<String, String> mockStore = new 
MeteredSessionStoreWithHeaders<>(
+            innerStore,
+            STORE_TYPE,
+            keySerde,
+            valueSerde,
+            new MockTime()
+        );
+        mockStore.init(context, mockStore);
+        return mockStore;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void shouldUseHeadersFromValueToDeserializeKeyInFetch() {
+        setUp();
+        final Serde<String> keySerde = mock(Serde.class);
+        final Serde<AggregationWithHeaders<String>> valueSerde = 
mock(Serde.class);

Review Comment:
   It seem we don't use `valueSerde` below, but only pass it into 
`createStoreWithMockSerdes(...)` -- can't we move it inside 
`createStoreWithMockSerdes(...)` directly, and removing the parameter?



##########
streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java:
##########
@@ -138,4 +237,72 @@ private <R> QueryResult<R> runRangeQuery(final Query<R> 
query,
             return (QueryResult<R>) rawResult;
         }
     }
+
+    private class MeteredSessionStoreWithHeadersIterator
+        implements KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>, 
MeteredIterator {
+
+        private final KeyValueIterator<Windowed<Bytes>, byte[]> iter;
+        private final long startNs;
+        private final long startTimestampMs;
+        private KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> cachedNext;
+
+        private MeteredSessionStoreWithHeadersIterator(final 
KeyValueIterator<Windowed<Bytes>, byte[]> iter) {
+            this.iter = iter;
+            this.startNs = time.nanoseconds();
+            this.startTimestampMs = time.milliseconds();
+            numOpenIterators.increment();
+            openIterators.add(this);
+        }
+
+        @Override
+        public long startTimestamp() {
+            return startTimestampMs;
+        }
+
+        @Override
+        public boolean hasNext() {
+            return cachedNext != null || iter.hasNext();
+        }
+
+        @Override
+        public KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> next() {
+            if (cachedNext != null) {
+                final KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> 
result = cachedNext;
+                cachedNext = null;
+                return result;
+            }
+
+            final KeyValue<Windowed<Bytes>, byte[]> next = iter.next();
+            if (next == null) {
+                return null;
+            }
+
+            final AggregationWithHeaders<AGG> value = 
serdes.valueFrom(next.value, new RecordHeaders());
+            final Headers headers = value != null ? value.headers() : new 
RecordHeaders();
+            final K key = serdes.keyFrom(next.key.key().get(), headers);
+            final Windowed<K> windowedKey = new Windowed<>(key, 
next.key.window());
+            return KeyValue.pair(windowedKey, value);
+        }
+
+        @Override
+        public void close() {
+            try {
+                iter.close();
+            } finally {
+                final long duration = time.nanoseconds() - startNs;
+                fetchSensor.record(duration);
+                iteratorDurationSensor.record(duration);
+                numOpenIterators.decrement();
+                openIterators.remove(this);
+            }
+        }
+
+        @Override
+        public Windowed<K> peekNextKey() {
+            if (cachedNext == null) {
+                cachedNext = next();
+            }
+            return cachedNext == null ? null : cachedNext.key;

Review Comment:
   Similar -- I would remove the `null` check (cf 
https://github.com/apache/kafka/pull/21736)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to