This is an automated email from the ASF dual-hosted git repository.

guozhang pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3b08deaa76 KAFKA-13785: [8/N][emit final] time-ordered session store 
(#12127)
3b08deaa76 is described below

commit 3b08deaa761c2387a41610893dc8302ab1d97338
Author: Guozhang Wang <wangg...@gmail.com>
AuthorDate: Thu May 5 16:09:16 2022 -0700

    KAFKA-13785: [8/N][emit final] time-ordered session store (#12127)
    
    Time ordered session store implementation. I introduced 
AbstractRocksDBTimeOrderedSegmentedBytesStore to make it generic for 
RocksDBTimeOrderedSessionSegmentedBytesStore and 
RocksDBTimeOrderedSegmentedBytesStore.
    
    A few minor follow-up changes:
    
    1. Avoid extra byte array allocation for fixed upper/lower range 
serialization.
    2. Rename some class names to be more consistent.
    
    Authored-by: Hao Li <1127478+lihao...@users.noreply.github.com>
    Reviewers: Guozhang Wang <wangg...@gmail.com.com>, John Roesler 
<vvcep...@apache.org>
---
 ...stractDualSchemaRocksDBSegmentedBytesStore.java |   3 -
 ...ractRocksDBTimeOrderedSegmentedBytesStore.java} | 111 ++----
 .../state/internals/PrefixedSessionKeySchemas.java | 387 +++++++++++++++++++
 .../state/internals/PrefixedWindowKeySchemas.java  |   6 +-
 ...cksDBTimeOrderedSessionSegmentedBytesStore.java | 136 +++++++
 .../internals/RocksDBTimeOrderedSessionStore.java  | 156 ++++++++
 ...ocksDBTimeOrderedWindowSegmentedBytesStore.java | 127 +++++++
 .../internals/RocksDBTimeOrderedWindowStore.java   |   4 +-
 ...IndexedTimeOrderedWindowBytesStoreSupplier.java |   4 +-
 ...ocksDbTimeOrderedSessionBytesStoreSupplier.java |  69 ++++
 .../streams/state/internals/SessionKeySchema.java  |  40 +-
 .../internals/WrappedSessionStoreIterator.java     |  12 +-
 ...ctDualSchemaRocksDBSegmentedBytesStoreTest.java | 411 ++++++++++++++++++++-
 .../state/internals/RocksDBSessionStoreTest.java   |  68 +++-
 .../RocksDBTimeOrderedSegmentedBytesStoreTest.java |  74 ----
 ...DBTimeOrderedWindowSegmentedBytesStoreTest.java | 121 ++++++
 .../state/internals/RocksDBWindowStoreTest.java    |  68 ++--
 ...xedTimeOrderedWindowBytesStoreSupplierTest.java |   8 +-
 .../state/internals/SessionKeySchemaTest.java      | 223 ++++++++---
 .../internals/TimeOrderedWindowStoreTest.java      |   4 +-
 .../state/internals/WindowKeySchemaTest.java       |  10 +-
 21 files changed, 1757 insertions(+), 285 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java
index b1044eb49c..95c1d8d8c8 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStore.java
@@ -50,7 +50,6 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStore<S extends Seg
 
     private final String name;
     protected final AbstractSegments<S> segments;
-    private final String metricScope;
     protected final KeySchema baseKeySchema;
     protected final Optional<KeySchema> indexKeySchema;
 
@@ -65,12 +64,10 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStore<S extends Seg
     private volatile boolean open;
 
     AbstractDualSchemaRocksDBSegmentedBytesStore(final String name,
-                                                 final String metricScope,
                                                  final KeySchema baseKeySchema,
                                                  final Optional<KeySchema> 
indexKeySchema,
                                                  final AbstractSegments<S> 
segments) {
         this.name = name;
-        this.metricScope = metricScope;
         this.baseKeySchema = baseKeySchema;
         this.indexKeySchema = indexKeySchema;
         this.segments = segments;
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBTimeOrderedSegmentedBytesStore.java
similarity index 65%
rename from 
streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStore.java
rename to 
streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBTimeOrderedSegmentedBytesStore.java
index e87af877fb..f7216412f0 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBTimeOrderedSegmentedBytesStore.java
@@ -16,22 +16,12 @@
  */
 package org.apache.kafka.streams.state.internals;
 
-import java.util.Collection;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Optional;
-import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
-import org.apache.kafka.streams.errors.ProcessorStateException;
-import 
org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
 import org.apache.kafka.streams.state.KeyValueIterator;
-import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
-import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
-import org.rocksdb.RocksDBException;
-import org.rocksdb.WriteBatch;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -40,13 +30,15 @@ import org.slf4j.LoggerFactory;
  * lookup for a specific key.
  *
  * Schema for first SegmentedBytesStore (base store) is as below:
- *     Key schema: | timestamp + recordkey |
+ *     Key schema: | timestamp + [timestamp] + recordkey |
  *     Value schema: | value |. Value here is determined by caller.
  *
  * Schema for second SegmentedBytesStore (index store) is as below:
- *     Key schema: | record + timestamp |
+ *     Key schema: | record + timestamp + [timestamp]|
  *     Value schema: ||
  *
+ * Note there could be two timestamps if we store both window end time and 
window start time.
+ *
  * Operations:
  *     Put: 1. Put to index store. 2. Put to base store.
  *     Delete: 1. Delete from base store. 2. Delete from index store.
@@ -59,11 +51,13 @@ import org.slf4j.LoggerFactory;
  *     Index store can be optional if we can construct the timestamp in base 
store instead of looking
  *     them up from index store.
  *
+ * @see RocksDBTimeOrderedSessionSegmentedBytesStore
+ * @see RocksDBTimeOrderedWindowSegmentedBytesStore
  */
-public class RocksDBTimeOrderedSegmentedBytesStore extends 
AbstractDualSchemaRocksDBSegmentedBytesStore<KeyValueSegment> {
+public abstract class AbstractRocksDBTimeOrderedSegmentedBytesStore extends 
AbstractDualSchemaRocksDBSegmentedBytesStore<KeyValueSegment> {
     private static final Logger LOG = 
LoggerFactory.getLogger(AbstractDualSchemaRocksDBSegmentedBytesStore.class);
 
-    private class IndexToBaseStoreIterator implements KeyValueIterator<Bytes, 
byte[]> {
+    abstract class IndexToBaseStoreIterator implements KeyValueIterator<Bytes, 
byte[]> {
         private final KeyValueIterator<Bytes, byte[]> indexIterator;
         private byte[] cachedValue;
 
@@ -95,7 +89,7 @@ public class RocksDBTimeOrderedSegmentedBytesStore extends 
AbstractDualSchemaRoc
                 if (cachedValue == null) {
                     // Key not in base store, inconsistency happened and 
remove from index.
                     indexIterator.next();
-                    
RocksDBTimeOrderedSegmentedBytesStore.this.removeIndex(key);
+                    
AbstractRocksDBTimeOrderedSegmentedBytesStore.this.removeIndex(key);
                 } else {
                     return true;
                 }
@@ -114,84 +108,19 @@ public class RocksDBTimeOrderedSegmentedBytesStore 
extends AbstractDualSchemaRoc
             return KeyValue.pair(getBaseKey(ret.key), value);
         }
 
-        private Bytes getBaseKey(final Bytes indexKey) {
-            final byte[] keyBytes = 
KeyFirstWindowKeySchema.extractStoreKeyBytes(indexKey.get());
-            final long timestamp = 
KeyFirstWindowKeySchema.extractStoreTimestamp(indexKey.get());
-            final int seqnum = 
KeyFirstWindowKeySchema.extractStoreSequence(indexKey.get());
-            return TimeFirstWindowKeySchema.toStoreKeyBinary(keyBytes, 
timestamp, seqnum);
-        }
+        abstract protected Bytes getBaseKey(final Bytes indexKey);
     }
 
-    RocksDBTimeOrderedSegmentedBytesStore(final String name,
-                                          final String metricsScope,
-                                          final long retention,
-                                          final long segmentInterval,
-                                          final boolean withIndex) {
-        super(name, metricsScope, new TimeFirstWindowKeySchema(),
-            Optional.ofNullable(withIndex ? new KeyFirstWindowKeySchema() : 
null),
+    AbstractRocksDBTimeOrderedSegmentedBytesStore(final String name,
+                                                  final String metricsScope,
+                                                  final long retention,
+                                                  final long segmentInterval,
+                                                  final KeySchema 
baseKeySchema,
+                                                  final Optional<KeySchema> 
indexKeySchema) {
+        super(name, baseKeySchema, indexKeySchema,
             new KeyValueSegments(name, metricsScope, retention, 
segmentInterval));
     }
 
-    public void put(final Bytes key, final long timestamp, final int seqnum, 
final byte[] value) {
-        final Bytes baseKey = TimeFirstWindowKeySchema.toStoreKeyBinary(key, 
timestamp, seqnum);
-        put(baseKey, value);
-    }
-
-    byte[] fetch(final Bytes key, final long timestamp, final int seqnum) {
-        return get(TimeFirstWindowKeySchema.toStoreKeyBinary(key, timestamp, 
seqnum));
-    }
-
-    @Override
-    protected KeyValue<Bytes, byte[]> getIndexKeyValue(final Bytes baseKey, 
final byte[] baseValue) {
-        final byte[] key = 
TimeFirstWindowKeySchema.extractStoreKeyBytes(baseKey.get());
-        final long timestamp = 
TimeFirstWindowKeySchema.extractStoreTimestamp(baseKey.get());
-        final int seqnum = 
TimeFirstWindowKeySchema.extractStoreSequence(baseKey.get());
-
-        return KeyValue.pair(KeyFirstWindowKeySchema.toStoreKeyBinary(key, 
timestamp, seqnum), new byte[0]);
-    }
-
-    @Override
-    Map<KeyValueSegment, WriteBatch> getWriteBatches(
-        final Collection<ConsumerRecord<byte[], byte[]>> records) {
-        // advance stream time to the max timestamp in the batch
-        for (final ConsumerRecord<byte[], byte[]> record : records) {
-            final long timestamp = 
WindowKeySchema.extractStoreTimestamp(record.key());
-            observedStreamTime = Math.max(observedStreamTime, timestamp);
-        }
-
-        final Map<KeyValueSegment, WriteBatch> writeBatchMap = new HashMap<>();
-        for (final ConsumerRecord<byte[], byte[]> record : records) {
-            final long timestamp = 
WindowKeySchema.extractStoreTimestamp(record.key());
-            final long segmentId = segments.segmentId(timestamp);
-            final KeyValueSegment segment = 
segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
-            if (segment != null) {
-                
ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
-                    record,
-                    consistencyEnabled,
-                    position
-                );
-                try {
-                    final WriteBatch batch = 
writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
-
-                    // Assuming changelog record is serialized using 
WindowKeySchema
-                    // from ChangeLoggingTimestampedWindowBytesStore. 
Reconstruct key/value to restore
-                    if (hasIndex()) {
-                        final byte[] indexKey = 
KeyFirstWindowKeySchema.fromNonPrefixWindowKey(record.key());
-                        // Take care of tombstone
-                        final byte[] value = record.value() == null ? null : 
new byte[0];
-                        segment.addToBatch(new KeyValue<>(indexKey, value), 
batch);
-                    }
-
-                    final byte[] baseKey = 
TimeFirstWindowKeySchema.fromNonPrefixWindowKey(record.key());
-                    segment.addToBatch(new KeyValue<>(baseKey, 
record.value()), batch);
-                } catch (final RocksDBException e) {
-                    throw new ProcessorStateException("Error restoring batch 
to store " + name(), e);
-                }
-            }
-        }
-        return writeBatchMap;
-    }
-
     @Override
     public KeyValueIterator<Bytes, byte[]> fetch(final Bytes key,
                                                  final long from,
@@ -206,6 +135,8 @@ public class RocksDBTimeOrderedSegmentedBytesStore extends 
AbstractDualSchemaRoc
         return fetch(key, from, to, false);
     }
 
+    abstract protected IndexToBaseStoreIterator 
getIndexToBaseStoreIterator(final SegmentIterator<KeyValueSegment> 
segmentIterator);
+
     KeyValueIterator<Bytes, byte[]> fetch(final Bytes key,
                                           final long from,
                                           final long to,
@@ -217,7 +148,7 @@ public class RocksDBTimeOrderedSegmentedBytesStore extends 
AbstractDualSchemaRoc
             final Bytes binaryFrom = 
indexKeySchema.get().lowerRangeFixedSize(key, from);
             final Bytes binaryTo = 
indexKeySchema.get().upperRangeFixedSize(key, to);
 
-            return new IndexToBaseStoreIterator(new SegmentIterator<>(
+            return getIndexToBaseStoreIterator(new SegmentIterator<>(
                 searchSpace.iterator(),
                 indexKeySchema.get().hasNextCondition(key, key, from, to, 
forward),
                 binaryFrom,
@@ -275,7 +206,7 @@ public class RocksDBTimeOrderedSegmentedBytesStore extends 
AbstractDualSchemaRoc
             final Bytes binaryFrom = indexKeySchema.get().lowerRange(keyFrom, 
from);
             final Bytes binaryTo = indexKeySchema.get().upperRange(keyTo, to);
 
-            return new IndexToBaseStoreIterator(new SegmentIterator<>(
+            return getIndexToBaseStoreIterator(new SegmentIterator<>(
                 searchSpace.iterator(),
                 indexKeySchema.get().hasNextCondition(keyFrom, keyTo, from, 
to, forward),
                 binaryFrom,
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedSessionKeySchemas.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedSessionKeySchemas.java
new file mode 100644
index 0000000000..c98ae83390
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedSessionKeySchemas.java
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.serialization.Deserializer;
+import org.apache.kafka.common.serialization.Serializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.kstream.Window;
+import org.apache.kafka.streams.kstream.Windowed;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import org.apache.kafka.streams.kstream.internals.SessionWindow;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
+
+import static org.apache.kafka.streams.state.StateSerdes.TIMESTAMP_SIZE;
+
+public class PrefixedSessionKeySchemas {
+
+    private static final int PREFIX_SIZE = 1;
+    private static final byte TIME_FIRST_PREFIX = 0;
+    private static final byte KEY_FIRST_PREFIX = 1;
+
+    private static byte extractPrefix(final byte[] binaryBytes) {
+        return binaryBytes[0];
+    }
+
+    public static class TimeFirstSessionKeySchema implements KeySchema {
+
+        @Override
+        public Bytes upperRange(final Bytes key, final long to) {
+            if (key == null) {
+                // Put next prefix instead of null so that we can start from 
right prefix
+                // when scanning backwards
+                final byte nextPrefix = TIME_FIRST_PREFIX + 1;
+                return 
Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE).put(nextPrefix).array());
+            }
+            return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE + 2 * 
TIMESTAMP_SIZE + key.get().length)
+                .put(TIME_FIRST_PREFIX)
+                // the end timestamp can be as large as possible as long as 
it's larger than start time
+                .putLong(Long.MAX_VALUE)
+                // this is the start timestamp
+                .putLong(to)
+                .put(key.get())
+                .array());
+        }
+
+        @Override
+        public Bytes lowerRange(final Bytes key, final long from) {
+            if (key == null) {
+                return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE + 
TIMESTAMP_SIZE)
+                    .put(TIME_FIRST_PREFIX)
+                    .putLong(from)
+                    .array());
+            }
+
+            return Bytes.wrap(ByteBuffer.allocate(PREFIX_SIZE + 2 * 
TIMESTAMP_SIZE + key.get().length)
+                .put(TIME_FIRST_PREFIX)
+                .putLong(from)
+                .putLong(0L)
+                .put(key.get())
+                .array());
+        }
+
+        /**
+         *
+         * @param key the key in the range
+         * @param to the latest start time
+         * @return
+         */
+        @Override
+        public Bytes upperRangeFixedSize(final Bytes key, final long to) {
+            return toBinary(key, to, Long.MAX_VALUE);
+        }
+
+        /**
+         *
+         * @param key the key in the range
+         * @param from the earliest end timestamp in the range
+         * @return
+         */
+        @Override
+        public Bytes lowerRangeFixedSize(final Bytes key, final long from) {
+            return toBinary(key, 0, Math.max(0, from));
+        }
+
+        @Override
+        public long segmentTimestamp(final Bytes key) {
+            return extractEndTimestamp(key.get());
+        }
+
+        @Override
+        public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom,
+            final Bytes binaryKeyTo, final long from, final long to, final 
boolean forward) {
+            return iterator -> {
+                while (iterator.hasNext()) {
+                    final Bytes bytes = iterator.peekNextKey();
+                    final byte prefix = extractPrefix(bytes.get());
+
+                    if (prefix != TIME_FIRST_PREFIX) {
+                        return false;
+                    }
+
+                    final Windowed<Bytes> windowedKey = from(bytes);
+                    final long endTime = windowedKey.window().end();
+                    final long startTime = windowedKey.window().start();
+
+                    // We can return false directly here since keys are sorted 
by end time and if
+                    // we get time smaller than `from`, there won't be time 
within range.
+                    if (!forward && endTime < from) {
+                        return false;
+                    }
+
+                    if ((binaryKeyFrom == null || 
windowedKey.key().compareTo(binaryKeyFrom) >= 0)
+                        && (binaryKeyTo == null || 
windowedKey.key().compareTo(binaryKeyTo) <= 0)
+                        && endTime >= from && startTime <= to) {
+                        return true;
+                    }
+                    iterator.next();
+                }
+                return false;
+            };
+        }
+
+        @Override
+        public <S extends Segment> List<S> segmentsToSearch(final Segments<S> 
segments,
+                                                            final long from,
+                                                            final long to,
+                                                            final boolean 
forward) {
+            return segments.segments(from, Long.MAX_VALUE, forward);
+        }
+
+        static long extractStartTimestamp(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getLong(PREFIX_SIZE + 
TIMESTAMP_SIZE);
+        }
+
+        static long extractEndTimestamp(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getLong(PREFIX_SIZE);
+        }
+
+        private static <K> K extractKey(final byte[] binaryKey,
+                                        final Deserializer<K> deserializer,
+                                        final String topic) {
+            return deserializer.deserialize(topic, extractKeyBytes(binaryKey));
+        }
+
+        static byte[] extractKeyBytes(final byte[] binaryKey) {
+            final byte[] bytes = new byte[binaryKey.length - 2 * 
TIMESTAMP_SIZE - PREFIX_SIZE];
+            System.arraycopy(binaryKey, PREFIX_SIZE + 2 * TIMESTAMP_SIZE, 
bytes, 0, bytes.length);
+            return bytes;
+        }
+
+        static Window extractWindow(final byte[] binaryKey) {
+            final ByteBuffer buffer = ByteBuffer.wrap(binaryKey);
+            final long start = buffer.getLong(PREFIX_SIZE + TIMESTAMP_SIZE);
+            final long end = buffer.getLong(PREFIX_SIZE);
+            return new SessionWindow(start, end);
+        }
+
+        public static Windowed<Bytes> from(final Bytes bytesKey) {
+            final byte[] binaryKey = bytesKey.get();
+            final Window window = extractWindow(binaryKey);
+            return new Windowed<>(Bytes.wrap(extractKeyBytes(binaryKey)), 
window);
+        }
+
+        public static <K> Windowed<K> from(final byte[] binaryKey,
+                                           final Deserializer<K> 
keyDeserializer,
+                                           final String topic) {
+            final K key = extractKey(binaryKey, keyDeserializer, topic);
+            final Window window = extractWindow(binaryKey);
+            return new Windowed<>(key, window);
+        }
+
+        public static <K> byte[] toBinary(final Windowed<K> sessionKey,
+                                          final Serializer<K> serializer,
+                                          final String topic) {
+            final byte[] bytes = serializer.serialize(topic, sessionKey.key());
+            return toBinary(Bytes.wrap(bytes), sessionKey.window().start(), 
sessionKey.window().end()).get();
+        }
+
+        public static Bytes toBinary(final Windowed<Bytes> sessionKey) {
+            return toBinary(sessionKey.key(), sessionKey.window().start(), 
sessionKey.window().end());
+        }
+
+        // for time prefixed schema, like the session key schema we need to 
put time stamps first, then the key
+        // and hence we need to override the write binary function with the 
write reordering
+        public static void writeBinary(final ByteBuffer buf,
+                                       final Bytes key,
+                                       final long startTime,
+                                       final long endTime) {
+            buf.putLong(endTime);
+            buf.putLong(startTime);
+            buf.put(key.get());
+        }
+
+        public static Bytes toBinary(final Bytes key,
+                                     final long startTime,
+                                     final long endTime) {
+            final ByteBuffer buf = ByteBuffer.allocate(PREFIX_SIZE + 
SessionKeySchema.keyByteLength(key));
+            buf.put(TIME_FIRST_PREFIX);
+            writeBinary(buf, key, startTime, endTime);
+            return Bytes.wrap(buf.array());
+        }
+
+        public static byte[] extractWindowBytesFromNonPrefixSessionKey(final 
byte[] binaryKey) {
+            final ByteBuffer buffer = ByteBuffer.allocate(PREFIX_SIZE + 
binaryKey.length).put(TIME_FIRST_PREFIX);
+            // Put timestamp
+            buffer.put(binaryKey, binaryKey.length - 2 * TIMESTAMP_SIZE, 2 * 
TIMESTAMP_SIZE);
+            buffer.put(binaryKey, 0, binaryKey.length - 2 * TIMESTAMP_SIZE);
+
+            return buffer.array();
+        }
+    }
+
+    public static class KeyFirstSessionKeySchema implements KeySchema {
+
+        @Override
+        public Bytes upperRange(final Bytes key, final long to) {
+            final Bytes noPrefixBytes = new SessionKeySchema().upperRange(key, 
to);
+            return wrapPrefix(noPrefixBytes, true);
+        }
+
+        @Override
+        public Bytes lowerRange(final Bytes key, final long from) {
+            final Bytes noPrefixBytes = new SessionKeySchema().lowerRange(key, 
from);
+            // Wrap at least prefix even key is null
+            return wrapPrefix(noPrefixBytes, false);
+        }
+
+        @Override
+        public Bytes upperRangeFixedSize(final Bytes key, final long to) {
+            final ByteBuffer buffer = ByteBuffer.allocate(PREFIX_SIZE + 
SessionKeySchema.keyByteLength(key));
+            buffer.put(KEY_FIRST_PREFIX);
+            SessionKeySchema.writeBinary(buffer, 
SessionKeySchema.upperRangeFixedWindow(key, to));
+            return Bytes.wrap(buffer.array());
+        }
+
+        @Override
+        public Bytes lowerRangeFixedSize(final Bytes key, final long from) {
+            final ByteBuffer buffer = ByteBuffer.allocate(PREFIX_SIZE + 
SessionKeySchema.keyByteLength(key));
+            buffer.put(KEY_FIRST_PREFIX);
+            SessionKeySchema.writeBinary(buffer, 
SessionKeySchema.lowerRangeFixedWindow(key, from));
+            return Bytes.wrap(buffer.array());
+        }
+
+        @Override
+        public long segmentTimestamp(final Bytes key) {
+            return extractEndTimestamp(key.get());
+        }
+
+        @Override
+        public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom,
+                                                 final Bytes binaryKeyTo,
+                                                 final long from,
+                                                 final long to,
+                                                 final boolean forward) {
+            return iterator -> {
+                while (iterator.hasNext()) {
+                    final Bytes bytes = iterator.peekNextKey();
+                    final byte prefix = extractPrefix(bytes.get());
+
+                    if (prefix != KEY_FIRST_PREFIX) {
+                        return false;
+                    }
+
+                    final Windowed<Bytes> windowedKey = from(bytes);
+                    final long endTime = windowedKey.window().end();
+                    final long startTime = windowedKey.window().start();
+
+                    if ((binaryKeyFrom == null || 
windowedKey.key().compareTo(binaryKeyFrom) >= 0)
+                        && (binaryKeyTo == null || 
windowedKey.key().compareTo(binaryKeyTo) <= 0)
+                        && endTime >= from
+                        && startTime <= to) {
+                        return true;
+                    }
+                    iterator.next();
+                }
+                return false;
+            };
+        }
+
+        @Override
+        public <S extends Segment> List<S> segmentsToSearch(final Segments<S> 
segments,
+                                                            final long from,
+                                                            final long to,
+                                                            final boolean 
forward) {
+            return segments.segments(from, Long.MAX_VALUE, forward);
+        }
+
+        static Window extractWindow(final byte[] binaryKey) {
+            final ByteBuffer buffer = ByteBuffer.wrap(binaryKey);
+            final long start = buffer.getLong(binaryKey.length - 
TIMESTAMP_SIZE);
+            final long end = buffer.getLong(binaryKey.length - 2 * 
TIMESTAMP_SIZE);
+            return new SessionWindow(start, end);
+        }
+
+        static byte[] extractKeyBytes(final byte[] binaryKey) {
+            final byte[] bytes = new byte[binaryKey.length - 2 * 
TIMESTAMP_SIZE - PREFIX_SIZE];
+            System.arraycopy(binaryKey, PREFIX_SIZE, bytes, 0, bytes.length);
+            return bytes;
+        }
+
+        public static Windowed<Bytes> from(final Bytes bytesKey) {
+            final byte[] binaryKey = bytesKey.get();
+            final Window window = extractWindow(binaryKey);
+            return new Windowed<>(Bytes.wrap(extractKeyBytes(binaryKey)), 
window);
+        }
+
+        private static <K> K extractKey(final byte[] binaryKey,
+                                        final Deserializer<K> deserializer,
+                                        final String topic) {
+            return deserializer.deserialize(topic, extractKeyBytes(binaryKey));
+        }
+
+        public static <K> Windowed<K> from(final byte[] binaryKey,
+                                           final Deserializer<K> 
keyDeserializer,
+                                           final String topic) {
+            final K key = extractKey(binaryKey, keyDeserializer, topic);
+            final Window window = extractWindow(binaryKey);
+            return new Windowed<>(key, window);
+        }
+
+        static long extractStartTimestamp(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getLong(binaryKey.length - 
TIMESTAMP_SIZE);
+        }
+
+        static long extractEndTimestamp(final byte[] binaryKey) {
+            return ByteBuffer.wrap(binaryKey).getLong(binaryKey.length - 2 * 
TIMESTAMP_SIZE);
+        }
+
+        public static Bytes toBinary(final Windowed<Bytes> sessionKey) {
+            return toBinary(sessionKey.key(), sessionKey.window().start(), 
sessionKey.window().end());
+        }
+
+        public static <K> byte[] toBinary(final Windowed<K> sessionKey,
+                                          final Serializer<K> serializer,
+                                          final String topic) {
+            final byte[] bytes = serializer.serialize(topic, sessionKey.key());
+            return toBinary(Bytes.wrap(bytes), sessionKey.window().start(), 
sessionKey.window().end()).get();
+        }
+
+        public static Bytes toBinary(final Bytes key,
+                                     final long startTime,
+                                     final long endTime) {
+            final ByteBuffer buf = ByteBuffer.allocate(PREFIX_SIZE + 
SessionKeySchema.keyByteLength(key));
+            buf.put(KEY_FIRST_PREFIX);
+            SessionKeySchema.writeBinary(buf, key, startTime, endTime);
+            return Bytes.wrap(buf.array());
+        }
+
+        private static Bytes wrapPrefix(final Bytes noPrefixKey, final boolean 
upperRange) {
+            // Need to scan from prefix even key is null
+            if (noPrefixKey == null) {
+                final byte prefix = upperRange ? KEY_FIRST_PREFIX + 1 : 
KEY_FIRST_PREFIX;
+                final byte[] ret = ByteBuffer.allocate(PREFIX_SIZE)
+                    .put(prefix)
+                    .array();
+                return Bytes.wrap(ret);
+            }
+            final byte[] ret = ByteBuffer.allocate(PREFIX_SIZE + 
noPrefixKey.get().length)
+                .put(KEY_FIRST_PREFIX)
+                .put(noPrefixKey.get())
+                .array();
+            return Bytes.wrap(ret);
+        }
+
+        public static byte[] prefixNonPrefixSessionKey(final byte[] binaryKey) 
{
+            assert binaryKey != null;
+
+            return wrapPrefix(Bytes.wrap(binaryKey), false).get();
+        }
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java
index 6304e4bd1d..47cf4b49b5 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/PrefixedWindowKeySchemas.java
@@ -45,7 +45,7 @@ public class PrefixedWindowKeySchemas {
         return binaryBytes.length > 0 && binaryBytes[0] == TIME_FIRST_PREFIX;
     }
 
-    public static class TimeFirstWindowKeySchema implements 
RocksDBSegmentedBytesStore.KeySchema {
+    public static class TimeFirstWindowKeySchema implements KeySchema {
 
         @Override
         public Bytes upperRange(final Bytes key, final long to) {
@@ -238,8 +238,6 @@ public class PrefixedWindowKeySchemas {
 
     public static class KeyFirstWindowKeySchema implements KeySchema {
 
-
-
         @Override
         public Bytes upperRange(final Bytes key, final long to) {
             final Bytes noPrefixBytes = new WindowKeySchema().upperRange(key, 
to);
@@ -267,7 +265,7 @@ public class PrefixedWindowKeySchemas {
 
         @Override
         public long segmentTimestamp(final Bytes key) {
-            return KeyFirstWindowKeySchema.extractStoreTimestamp(key.get());
+            return extractStoreTimestamp(key.get());
         }
 
         @Override
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSessionSegmentedBytesStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSessionSegmentedBytesStore.java
new file mode 100644
index 0000000000..4265150eb9
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSessionSegmentedBytesStore.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import org.apache.kafka.streams.kstream.Window;
+import org.apache.kafka.streams.kstream.Windowed;
+import 
org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.KeyFirstSessionKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.TimeFirstSessionKeySchema;
+import org.rocksdb.RocksDBException;
+import org.rocksdb.WriteBatch;
+
+/**
+ * A RocksDB backed time-ordered segmented bytes store for session key schema.
+ */
+public class RocksDBTimeOrderedSessionSegmentedBytesStore extends 
AbstractRocksDBTimeOrderedSegmentedBytesStore {
+
+    private class SessionKeySchemaIndexToBaseStoreIterator  extends 
IndexToBaseStoreIterator {
+        SessionKeySchemaIndexToBaseStoreIterator(final KeyValueIterator<Bytes, 
byte[]> indexIterator) {
+            super(indexIterator);
+        }
+
+        @Override
+        protected Bytes getBaseKey(final Bytes indexKey) {
+            final Window window = 
KeyFirstSessionKeySchema.extractWindow(indexKey.get());
+            final byte[] key = 
KeyFirstSessionKeySchema.extractKeyBytes(indexKey.get());
+
+            return TimeFirstSessionKeySchema.toBinary(Bytes.wrap(key), 
window.start(), window.end());
+        }
+    }
+
+    RocksDBTimeOrderedSessionSegmentedBytesStore(final String name,
+                                                 final String metricsScope,
+                                                 final long retention,
+                                                 final long segmentInterval,
+                                                 final boolean withIndex) {
+        super(name, metricsScope, retention, segmentInterval, new 
TimeFirstSessionKeySchema(),
+            Optional.ofNullable(withIndex ? new KeyFirstSessionKeySchema() : 
null));
+    }
+
+    public byte[] fetchSession(final Bytes key,
+                               final long earliestSessionEndTime,
+                               final long latestSessionStartTime) {
+        return get(TimeFirstSessionKeySchema.toBinary(
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        ));
+    }
+
+    public void remove(final Windowed<Bytes> key) {
+        remove(TimeFirstSessionKeySchema.toBinary(key));
+    }
+
+    public void put(final Windowed<Bytes> sessionKey, final byte[] aggregate) {
+        put(TimeFirstSessionKeySchema.toBinary(sessionKey), aggregate);
+    }
+
+    @Override
+    protected KeyValue<Bytes, byte[]> getIndexKeyValue(final Bytes baseKey, 
final byte[] baseValue) {
+        final Window window = 
TimeFirstSessionKeySchema.extractWindow(baseKey.get());
+        final byte[] key = 
TimeFirstSessionKeySchema.extractKeyBytes(baseKey.get());
+        return 
KeyValue.pair(KeyFirstSessionKeySchema.toBinary(Bytes.wrap(key), 
window.start(), window.end()), new byte[0]);
+    }
+
+    @Override
+    Map<KeyValueSegment, WriteBatch> getWriteBatches(
+        final Collection<ConsumerRecord<byte[], byte[]>> records) {
+        // advance stream time to the max timestamp in the batch
+        for (final ConsumerRecord<byte[], byte[]> record : records) {
+            final long timestamp = 
SessionKeySchema.extractEndTimestamp(record.key());
+            observedStreamTime = Math.max(observedStreamTime, timestamp);
+        }
+
+        final Map<KeyValueSegment, WriteBatch> writeBatchMap = new HashMap<>();
+        for (final ConsumerRecord<byte[], byte[]> record : records) {
+            final long timestamp = 
SessionKeySchema.extractEndTimestamp(record.key());
+            final long segmentId = segments.segmentId(timestamp);
+            final KeyValueSegment segment = 
segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+            if (segment != null) {
+                
ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                    record,
+                    consistencyEnabled,
+                    position
+                );
+                try {
+                    final WriteBatch batch = 
writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
+
+                    // Assuming changelog record is serialized using 
SessionKeySchema
+                    // from ChangeLoggingSessionBytesStore. Reconstruct 
key/value to restore
+                    if (hasIndex()) {
+                        final byte[] indexKey = 
KeyFirstSessionKeySchema.prefixNonPrefixSessionKey(record.key());
+                        // Take care of tombstone
+                        final byte[] value = record.value() == null ? null : 
new byte[0];
+                        segment.addToBatch(new KeyValue<>(indexKey, value), 
batch);
+                    }
+
+                    final byte[] baseKey = 
TimeFirstSessionKeySchema.extractWindowBytesFromNonPrefixSessionKey(record.key());
+                    segment.addToBatch(new KeyValue<>(baseKey, 
record.value()), batch);
+                } catch (final RocksDBException e) {
+                    throw new ProcessorStateException("Error restoring batch 
to store " + name(), e);
+                }
+            }
+        }
+        return writeBatchMap;
+    }
+
+    @Override
+    protected IndexToBaseStoreIterator getIndexToBaseStoreIterator(
+        final SegmentIterator<KeyValueSegment> segmentIterator) {
+        return new SessionKeySchemaIndexToBaseStoreIterator(segmentIterator);
+    }
+}
\ No newline at end of file
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSessionStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSessionStore.java
new file mode 100644
index 0000000000..5b72163757
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSessionStore.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Objects;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.query.PositionBound;
+import org.apache.kafka.streams.query.Query;
+import org.apache.kafka.streams.query.QueryConfig;
+import org.apache.kafka.streams.query.QueryResult;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import org.apache.kafka.streams.state.SessionStore;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.TimeFirstSessionKeySchema;
+
+public class RocksDBTimeOrderedSessionStore
+    extends WrappedStateStore<RocksDBTimeOrderedSessionSegmentedBytesStore, 
Object, Object>
+    implements SessionStore<Bytes, byte[]> {
+
+    private StateStoreContext stateStoreContext;
+
+    RocksDBTimeOrderedSessionStore(final 
RocksDBTimeOrderedSessionSegmentedBytesStore store) {
+        super(store);
+        Objects.requireNonNull(store, "store is null");
+    }
+
+    @Override
+    public void init(final StateStoreContext context, final StateStore root) {
+        wrapped().init(context, root);
+        this.stateStoreContext = context;
+    }
+
+    @Override
+    public <R> QueryResult<R> query(final Query<R> query,
+                                    final PositionBound positionBound,
+                                    final QueryConfig config) {
+
+        return StoreQueryUtils.handleBasicQueries(
+            query,
+            positionBound,
+            config,
+            this,
+            getPosition(),
+            stateStoreContext
+        );
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> findSessions(final Bytes 
key,
+                                                                  final long 
earliestSessionEndTime,
+                                                                  final long 
latestSessionStartTime) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().fetch(
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        return new WrappedSessionStoreIterator(bytesIterator, 
TimeFirstSessionKeySchema::from);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> 
backwardFindSessions(final Bytes key,
+                                                                          
final long earliestSessionEndTime,
+                                                                          
final long latestSessionStartTime) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = 
wrapped().backwardFetch(
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        return new WrappedSessionStoreIterator(bytesIterator, 
TimeFirstSessionKeySchema::from);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> findSessions(final Bytes 
keyFrom,
+                                                                  final Bytes 
keyTo,
+                                                                  final long 
earliestSessionEndTime,
+                                                                  final long 
latestSessionStartTime) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = wrapped().fetch(
+            keyFrom,
+            keyTo,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        return new WrappedSessionStoreIterator(bytesIterator, 
TimeFirstSessionKeySchema::from);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> 
backwardFindSessions(final Bytes keyFrom,
+                                                                          
final Bytes keyTo,
+                                                                          
final long earliestSessionEndTime,
+                                                                          
final long latestSessionStartTime) {
+        final KeyValueIterator<Bytes, byte[]> bytesIterator = 
wrapped().backwardFetch(
+            keyFrom,
+            keyTo,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+        return new WrappedSessionStoreIterator(bytesIterator, 
TimeFirstSessionKeySchema::from);
+    }
+
+    @Override
+    public byte[] fetchSession(final Bytes key,
+                               final long earliestSessionEndTime,
+                               final long latestSessionStartTime) {
+        return wrapped().fetchSession(
+            key,
+            earliestSessionEndTime,
+            latestSessionStartTime
+        );
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes key) {
+        return findSessions(key, 0, Long.MAX_VALUE);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes 
key) {
+        return backwardFindSessions(key, 0, Long.MAX_VALUE);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> fetch(final Bytes 
keyFrom, final Bytes keyTo) {
+        return findSessions(keyFrom, keyTo, 0, Long.MAX_VALUE);
+    }
+
+    @Override
+    public KeyValueIterator<Windowed<Bytes>, byte[]> backwardFetch(final Bytes 
keyFrom, final Bytes keyTo) {
+        return backwardFindSessions(keyFrom, keyTo, 0, Long.MAX_VALUE);
+    }
+
+    @Override
+    public void remove(final Windowed<Bytes> sessionKey) {
+        wrapped().remove(sessionKey);
+    }
+
+    @Override
+    public void put(final Windowed<Bytes> sessionKey, final byte[] aggregate) {
+        wrapped().put(sessionKey, aggregate);
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowSegmentedBytesStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowSegmentedBytesStore.java
new file mode 100644
index 0000000000..b44588da2b
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowSegmentedBytesStore.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.errors.ProcessorStateException;
+import 
org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
+import org.apache.kafka.streams.state.KeyValueIterator;
+import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+import org.rocksdb.RocksDBException;
+import org.rocksdb.WriteBatch;
+
+/**
+ * A RocksDB backed time-ordered segmented bytes store for window key schema.
+ */
+public class RocksDBTimeOrderedWindowSegmentedBytesStore extends 
AbstractRocksDBTimeOrderedSegmentedBytesStore {
+
+    private class WindowKeySchemaIndexToBaseStoreIterator  extends 
IndexToBaseStoreIterator {
+        WindowKeySchemaIndexToBaseStoreIterator(final KeyValueIterator<Bytes, 
byte[]> indexIterator) {
+            super(indexIterator);
+        }
+
+        @Override
+        protected Bytes getBaseKey(final Bytes indexKey) {
+            final byte[] keyBytes = 
KeyFirstWindowKeySchema.extractStoreKeyBytes(indexKey.get());
+            final long timestamp = 
KeyFirstWindowKeySchema.extractStoreTimestamp(indexKey.get());
+            final int seqnum = 
KeyFirstWindowKeySchema.extractStoreSequence(indexKey.get());
+            return TimeFirstWindowKeySchema.toStoreKeyBinary(keyBytes, 
timestamp, seqnum);
+        }
+    }
+
+    RocksDBTimeOrderedWindowSegmentedBytesStore(final String name,
+                                                final String metricsScope,
+                                                final long retention,
+                                                final long segmentInterval,
+                                                final boolean withIndex) {
+        super(name, metricsScope, retention, segmentInterval, new 
TimeFirstWindowKeySchema(),
+            Optional.ofNullable(withIndex ? new KeyFirstWindowKeySchema() : 
null));
+    }
+
+    public void put(final Bytes key, final long timestamp, final int seqnum, 
final byte[] value) {
+        final Bytes baseKey = TimeFirstWindowKeySchema.toStoreKeyBinary(key, 
timestamp, seqnum);
+        put(baseKey, value);
+    }
+
+    byte[] fetch(final Bytes key, final long timestamp, final int seqnum) {
+        return get(TimeFirstWindowKeySchema.toStoreKeyBinary(key, timestamp, 
seqnum));
+    }
+
+    @Override
+    protected KeyValue<Bytes, byte[]> getIndexKeyValue(final Bytes baseKey, 
final byte[] baseValue) {
+        final byte[] key = 
TimeFirstWindowKeySchema.extractStoreKeyBytes(baseKey.get());
+        final long timestamp = 
TimeFirstWindowKeySchema.extractStoreTimestamp(baseKey.get());
+        final int seqnum = 
TimeFirstWindowKeySchema.extractStoreSequence(baseKey.get());
+
+        return KeyValue.pair(KeyFirstWindowKeySchema.toStoreKeyBinary(key, 
timestamp, seqnum), new byte[0]);
+    }
+
+    @Override
+    Map<KeyValueSegment, WriteBatch> getWriteBatches(
+        final Collection<ConsumerRecord<byte[], byte[]>> records) {
+        // advance stream time to the max timestamp in the batch
+        for (final ConsumerRecord<byte[], byte[]> record : records) {
+            final long timestamp = 
WindowKeySchema.extractStoreTimestamp(record.key());
+            observedStreamTime = Math.max(observedStreamTime, timestamp);
+        }
+
+        final Map<KeyValueSegment, WriteBatch> writeBatchMap = new HashMap<>();
+        for (final ConsumerRecord<byte[], byte[]> record : records) {
+            final long timestamp = 
WindowKeySchema.extractStoreTimestamp(record.key());
+            final long segmentId = segments.segmentId(timestamp);
+            final KeyValueSegment segment = 
segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime);
+            if (segment != null) {
+                
ChangelogRecordDeserializationHelper.applyChecksAndUpdatePosition(
+                    record,
+                    consistencyEnabled,
+                    position
+                );
+                try {
+                    final WriteBatch batch = 
writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch());
+
+                    // Assuming changelog record is serialized using 
WindowKeySchema
+                    // from ChangeLoggingTimestampedWindowBytesStore. 
Reconstruct key/value to restore
+                    if (hasIndex()) {
+                        final byte[] indexKey = 
KeyFirstWindowKeySchema.fromNonPrefixWindowKey(record.key());
+                        // Take care of tombstone
+                        final byte[] value = record.value() == null ? null : 
new byte[0];
+                        segment.addToBatch(new KeyValue<>(indexKey, value), 
batch);
+                    }
+
+                    final byte[] baseKey = 
TimeFirstWindowKeySchema.fromNonPrefixWindowKey(record.key());
+                    segment.addToBatch(new KeyValue<>(baseKey, 
record.value()), batch);
+                } catch (final RocksDBException e) {
+                    throw new ProcessorStateException("Error restoring batch 
to store " + name(), e);
+                }
+            }
+        }
+        return writeBatchMap;
+    }
+
+    @Override
+    protected IndexToBaseStoreIterator getIndexToBaseStoreIterator(
+        final SegmentIterator<KeyValueSegment> segmentIterator) {
+        return new WindowKeySchemaIndexToBaseStoreIterator(segmentIterator);
+    }
+}
\ No newline at end of file
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java
index 4f2587d1e0..598b3d077e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowStore.java
@@ -33,7 +33,7 @@ import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFir
 
 
 public class RocksDBTimeOrderedWindowStore
-    extends WrappedStateStore<RocksDBTimeOrderedSegmentedBytesStore, Object, 
Object>
+    extends WrappedStateStore<RocksDBTimeOrderedWindowSegmentedBytesStore, 
Object, Object>
     implements WindowStore<Bytes, byte[]>, TimestampedBytesStore {
 
     private final boolean retainDuplicates;
@@ -43,7 +43,7 @@ public class RocksDBTimeOrderedWindowStore
     private int seqnum = 0;
 
     RocksDBTimeOrderedWindowStore(
-        final RocksDBTimeOrderedSegmentedBytesStore store,
+        final RocksDBTimeOrderedWindowSegmentedBytesStore store,
         final boolean retainDuplicates,
         final long windowSize
     ) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java
index af5417fccf..ac0b82f99f 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.java
@@ -106,7 +106,7 @@ public class 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier implements Window
         switch (windowStoreType) {
             case DEFAULT_WINDOW_STORE:
                 return new RocksDBTimeOrderedWindowStore(
-                    new RocksDBTimeOrderedSegmentedBytesStore(
+                    new RocksDBTimeOrderedWindowSegmentedBytesStore(
                         name,
                         metricsScope(),
                         retentionPeriod,
@@ -116,7 +116,7 @@ public class 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier implements Window
                     windowSize);
             case INDEXED_WINDOW_STORE:
                 return new RocksDBTimeOrderedWindowStore(
-                    new RocksDBTimeOrderedSegmentedBytesStore(
+                    new RocksDBTimeOrderedWindowSegmentedBytesStore(
                         name,
                         metricsScope(),
                         retentionPeriod,
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbTimeOrderedSessionBytesStoreSupplier.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbTimeOrderedSessionBytesStoreSupplier.java
new file mode 100644
index 0000000000..60cd710e6a
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbTimeOrderedSessionBytesStoreSupplier.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.streams.state.SessionBytesStoreSupplier;
+import org.apache.kafka.streams.state.SessionStore;
+
+public class RocksDbTimeOrderedSessionBytesStoreSupplier implements 
SessionBytesStoreSupplier {
+    private final String name;
+    private final long retentionPeriod;
+    private final boolean withIndex;
+
+    public RocksDbTimeOrderedSessionBytesStoreSupplier(final String name,
+                                                       final long 
retentionPeriod,
+                                                       final boolean 
withIndex) {
+        this.name = name;
+        this.retentionPeriod = retentionPeriod;
+        this.withIndex = withIndex;
+    }
+
+    @Override
+    public String name() {
+        return name;
+    }
+
+    @Override
+    public SessionStore<Bytes, byte[]> get() {
+        return new RocksDBTimeOrderedSessionStore(
+            new RocksDBTimeOrderedSessionSegmentedBytesStore(
+                name,
+                metricsScope(),
+                retentionPeriod,
+                segmentIntervalMs(),
+                withIndex
+            )
+        );
+    }
+
+    @Override
+    public String metricsScope() {
+        return "rocksdb-session";
+    }
+
+    @Override
+    public long segmentIntervalMs() {
+        // Selected somewhat arbitrarily. Profiling may reveal a different 
value is preferable.
+        return Math.max(retentionPeriod / 2, 60_000L);
+    }
+
+    @Override
+    public long retentionPeriod() {
+        return retentionPeriod;
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
index d4196a9ede..505bbddc80 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java
@@ -34,18 +34,30 @@ public class SessionKeySchema implements 
SegmentedBytesStore.KeySchema {
     private static final int SUFFIX_SIZE = 2 * TIMESTAMP_SIZE;
     private static final byte[] MIN_SUFFIX = new byte[SUFFIX_SIZE];
 
+    public static int keyByteLength(final Bytes key) {
+        return key.get().length + 2 * TIMESTAMP_SIZE;
+    }
+
     @Override
     public Bytes upperRangeFixedSize(final Bytes key, final long to) {
-        final Windowed<Bytes> sessionKey = new Windowed<>(key, new 
SessionWindow(to, Long.MAX_VALUE));
+        final Windowed<Bytes> sessionKey = upperRangeFixedWindow(key, to);
         return SessionKeySchema.toBinary(sessionKey);
     }
 
+    public static <K> Windowed<K> upperRangeFixedWindow(final K key, final 
long to) {
+        return new Windowed<K>(key, new SessionWindow(to, Long.MAX_VALUE));
+    }
+
     @Override
     public Bytes lowerRangeFixedSize(final Bytes key, final long from) {
-        final Windowed<Bytes> sessionKey = new Windowed<>(key, new 
SessionWindow(0, Math.max(0, from)));
+        final Windowed<Bytes> sessionKey = lowerRangeFixedWindow(key, from);
         return SessionKeySchema.toBinary(sessionKey);
     }
 
+    public static <K> Windowed<K> lowerRangeFixedWindow(final K key, final 
long from) {
+        return new Windowed<K>(key, new SessionWindow(0, Math.max(0, from)));
+    }
+
     @Override
     public Bytes upperRange(final Bytes key, final long to) {
         if (key == null) {
@@ -161,11 +173,27 @@ public class SessionKeySchema implements 
SegmentedBytesStore.KeySchema {
     public static Bytes toBinary(final Bytes key,
                                  final long startTime,
                                  final long endTime) {
-        final byte[] bytes = key.get();
-        final ByteBuffer buf = ByteBuffer.allocate(bytes.length + 2 * 
TIMESTAMP_SIZE);
-        buf.put(bytes);
+        final ByteBuffer buf = ByteBuffer.allocate(keyByteLength(key));
+        writeBinary(buf, key, startTime, endTime);
+        return Bytes.wrap(buf.array());
+    }
+
+    public static void writeBinary(final ByteBuffer buf, final Windowed<Bytes> 
sessionKey) {
+        writeBinary(buf, sessionKey.key(), sessionKey.window().start(), 
sessionKey.window().end());
+    }
+
+    public static void writeBinary(final ByteBuffer buf,
+                                   final Bytes key,
+                                   final long startTime,
+                                   final long endTime) {
+        // we search for the session window that can overlap with the [ESET, 
LSST] range
+        // since the session window length can vary, we define the search 
boundary as:
+        // lower: [0, ESET]
+        // upper: [LSST, INF]
+        // and by puting the end time first and then the start time, the 
serialized search boundary
+        // is: [(ESET-0), (INF-LSST)]
+        buf.put(key.get());
         buf.putLong(endTime);
         buf.putLong(startTime);
-        return Bytes.wrap(buf.array());
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java
index ce26029af4..3a39a95966 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import java.util.function.Function;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.kstream.Windowed;
@@ -24,9 +25,16 @@ import org.apache.kafka.streams.state.KeyValueIterator;
 class WrappedSessionStoreIterator implements KeyValueIterator<Windowed<Bytes>, 
byte[]> {
 
     private final KeyValueIterator<Bytes, byte[]> bytesIterator;
+    private final Function<Bytes, Windowed<Bytes>> windowConstructor;
 
     WrappedSessionStoreIterator(final KeyValueIterator<Bytes, byte[]> 
bytesIterator) {
+        this(bytesIterator, SessionKeySchema::from);
+    }
+
+    WrappedSessionStoreIterator(final KeyValueIterator<Bytes, byte[]> 
bytesIterator,
+                                final Function<Bytes, Windowed<Bytes>> 
windowConstructor) {
         this.bytesIterator = bytesIterator;
+        this.windowConstructor = windowConstructor;
     }
 
     @Override
@@ -36,7 +44,7 @@ class WrappedSessionStoreIterator implements 
KeyValueIterator<Windowed<Bytes>, b
 
     @Override
     public Windowed<Bytes> peekNextKey() {
-        return SessionKeySchema.from(bytesIterator.peekNextKey());
+        return windowConstructor.apply(bytesIterator.peekNextKey());
     }
 
     @Override
@@ -47,6 +55,6 @@ class WrappedSessionStoreIterator implements 
KeyValueIterator<Windowed<Bytes>, b
     @Override
     public KeyValue<Windowed<Bytes>, byte[]> next() {
         final KeyValue<Bytes, byte[]> next = bytesIterator.next();
-        return KeyValue.pair(SessionKeySchema.from(next.key), next.value);
+        return KeyValue.pair(windowConstructor.apply(next.key), next.value);
     }
 }
\ No newline at end of file
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java
index e8d578d017..0641392b2a 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractDualSchemaRocksDBSegmentedBytesStoreTest.java
@@ -37,6 +37,7 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.StreamsConfig.InternalConfig;
 import org.apache.kafka.streams.kstream.Window;
 import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.kstream.internals.SessionWindow;
 import org.apache.kafka.streams.kstream.internals.TimeWindow;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import 
org.apache.kafka.streams.processor.internals.ChangelogRecordDeserializationHelper;
@@ -48,6 +49,8 @@ import 
org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender;
 import org.apache.kafka.streams.query.Position;
 import org.apache.kafka.streams.state.KeyValueIterator;
 import org.apache.kafka.streams.state.StateSerdes;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.KeyFirstSessionKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.TimeFirstSessionKeySchema;
 import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
 import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
 import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
@@ -98,7 +101,9 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
     private AbstractDualSchemaRocksDBSegmentedBytesStore<S> bytesStore;
     private File stateDir;
     private final Window[] windows = new Window[4];
-    private Window nextSegmentWindow;
+    private Window nextSegmentWindow, startEdgeWindow, endEdgeWindow;
+    private final long startEdgeTime = Long.MAX_VALUE - 700L;
+    private final long endEdgeTime = Long.MAX_VALUE - 600L;
 
     final long retention = 1000;
     final long segmentInterval = 60_000L;
@@ -106,6 +111,20 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
 
     @Before
     public void before() {
+        if (getBaseSchema() instanceof TimeFirstSessionKeySchema) {
+            windows[0] = new SessionWindow(10L, 10L);
+            windows[1] = new SessionWindow(500L, 1000L);
+            windows[2] = new SessionWindow(1_000L, 1_500L);
+            windows[3] = new SessionWindow(30_000L, 60_000L);
+            // All four of the previous windows will go into segment 1.
+            // The nextSegmentWindow is computed be a high enough time that 
when it gets written
+            // to the segment store, it will advance stream time past the 
first segment's retention time and
+            // expire it.
+            nextSegmentWindow = new SessionWindow(segmentInterval + retention, 
segmentInterval + retention);
+
+            startEdgeWindow = new SessionWindow(0L, startEdgeTime);
+            endEdgeWindow = new SessionWindow(endEdgeTime, Long.MAX_VALUE);
+        }
         if (getBaseSchema() instanceof TimeFirstWindowKeySchema) {
             windows[0] = timeWindowForSize(10L, windowSizeForTimeWindow);
             windows[1] = timeWindowForSize(500L, windowSizeForTimeWindow);
@@ -116,6 +135,9 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
             // to the segment store, it will advance stream time past the 
first segment's retention time and
             // expire it.
             nextSegmentWindow = timeWindowForSize(segmentInterval + retention, 
windowSizeForTimeWindow);
+
+            startEdgeWindow = timeWindowForSize(startEdgeTime, 
windowSizeForTimeWindow);
+            endEdgeWindow = timeWindowForSize(endEdgeTime, 
windowSizeForTimeWindow);
         }
 
         bytesStore = getBytesStore();
@@ -285,8 +307,370 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
         }
     }
 
+    @Test
+    public void shouldPutAndFetchEdgeSingleKey() {
+        final String keyA = "a";
+        final String keyB = "b";
+
+        final Bytes serializedKeyAStart = serializeKey(new Windowed<>(keyA, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyAEnd = serializeKey(new Windowed<>(keyA, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBStart = serializeKey(new Windowed<>(keyB, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBEnd = serializeKey(new Windowed<>(keyB, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+
+        bytesStore.put(serializedKeyAStart, serializeValue(10));
+        bytesStore.put(serializedKeyAEnd, serializeValue(50));
+        bytesStore.put(serializedKeyBStart, serializeValue(100));
+        bytesStore.put(serializedKeyBEnd, serializeValue(150));
+
+        // Can fetch start/end edge for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch start/end edge for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyB.getBytes()), startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch from 0 to max for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch from 0 to max for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyB.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldPutAndFetchEdgeKeyRange() {
+        final String keyA = "a";
+        final String keyB = "b";
+
+        final Bytes serializedKeyAStart = serializeKey(new Windowed<>(keyA, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyAEnd = serializeKey(new Windowed<>(keyA, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBStart = serializeKey(new Windowed<>(keyB, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBEnd = serializeKey(new Windowed<>(keyB, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+
+        bytesStore.put(serializedKeyAStart, serializeValue(10));
+        bytesStore.put(serializedKeyAEnd, serializeValue(50));
+        bytesStore.put(serializedKeyBStart, serializeValue(100));
+        bytesStore.put(serializedKeyBEnd, serializeValue(150));
+        // Can fetch from start/end for key range
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 
startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch from 0 to max for key range
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0L, 
Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+            assertEquals(expected, toList(values));
+        }
+
+        // KeyB should be ignored and KeyA should be included even in storage
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, Bytes.wrap(keyA.getBytes()), startEdgeTime, endEdgeTime - 
1L)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            Bytes.wrap(keyB.getBytes()), null, startEdgeTime + 1, 
endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, null, 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = bytesStore.fetch(
+            null, null, startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldPutAndBackwardFetchEdgeSingleKey() {
+        final String keyA = "a";
+        final String keyB = "b";
+
+        final Bytes serializedKeyAStart = serializeKey(new Windowed<>(keyA, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyAEnd = serializeKey(new Windowed<>(keyA, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBStart = serializeKey(new Windowed<>(keyB, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBEnd = serializeKey(new Windowed<>(keyB, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+
+        bytesStore.put(serializedKeyAStart, serializeValue(10));
+        bytesStore.put(serializedKeyAEnd, serializeValue(50));
+        bytesStore.put(serializedKeyBStart, serializeValue(100));
+        bytesStore.put(serializedKeyBEnd, serializeValue(150));
+
+        // Can fetch start/end edge for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch start/end edge for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyB.getBytes()), startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch from 0 to max for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch from 0 to max for single key
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyB.getBytes()), 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+    }
+
+    @Test
+    public void shouldPutAndBackwardFetchEdgeKeyRange() {
+        final String keyA = "a";
+        final String keyB = "b";
+
+        final Bytes serializedKeyAStart = serializeKey(new Windowed<>(keyA, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyAEnd = serializeKey(new Windowed<>(keyA, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBStart = serializeKey(new Windowed<>(keyB, 
startEdgeWindow), false,
+            Integer.MAX_VALUE);
+        final Bytes serializedKeyBEnd = serializeKey(new Windowed<>(keyB, 
endEdgeWindow), false,
+            Integer.MAX_VALUE);
+
+        bytesStore.put(serializedKeyAStart, serializeValue(10));
+        bytesStore.put(serializedKeyAEnd, serializeValue(50));
+        bytesStore.put(serializedKeyBStart, serializeValue(100));
+        bytesStore.put(serializedKeyBEnd, serializeValue(150));
+
+        // Can fetch from start/end for key range
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 
startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+            assertEquals(expected, toList(values));
+        }
+
+        // Can fetch from 0 to max for key range
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0L, 
Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+            assertEquals(expected, toList(values));
+        }
+
+        // KeyB should be ignored and KeyA should be included even in storage
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            null, Bytes.wrap(keyA.getBytes()), startEdgeTime, endEdgeTime - 
1L)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            Bytes.wrap(keyB.getBytes()), null, startEdgeTime + 1, 
endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L)
+            );
+
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            null, null, 0, Long.MAX_VALUE)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+            assertEquals(expected, toList(values));
+        }
+
+        try (final KeyValueIterator<Bytes, byte[]> values = 
bytesStore.backwardFetch(
+            null, null, startEdgeTime, endEdgeTime)) {
+
+            final List<KeyValue<Windowed<String>, Long>> expected = 
getIndexSchema() == null ? asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            ) : asList(
+                KeyValue.pair(new Windowed<>(keyB, endEdgeWindow), 150L),
+                KeyValue.pair(new Windowed<>(keyB, startEdgeWindow), 100L),
+                KeyValue.pair(new Windowed<>(keyA, endEdgeWindow), 50L),
+                KeyValue.pair(new Windowed<>(keyA, startEdgeWindow), 10L)
+            );
+            assertEquals(expected, toList(values));
+        }
+    }
+
     @Test
     public void shouldPutAndFetchWithPrefixKey() {
+        // Only for TimeFirstWindowKeySchema schema
+        if (!(getBaseSchema() instanceof TimeFirstWindowKeySchema)) {
+            return;
+        }
         final String keyA = "a";
         final String keyB = "aa";
         final String keyC = "aaa";
@@ -365,6 +749,11 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
 
     @Test
     public void shouldPutAndBackwardFetchWithPrefix() {
+        // Only for TimeFirstWindowKeySchema schema
+        if (!(getBaseSchema() instanceof TimeFirstWindowKeySchema)) {
+            return;
+        }
+
         final String keyA = "a";
         final String keyB = "aa";
         final String keyC = "aaa";
@@ -1081,10 +1470,16 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
 
     private Bytes serializeKey(final Windowed<String> key, final boolean 
changeLog, final int seq) {
         final StateSerdes<String, Long> stateSerdes = 
StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
-        if (changeLog) {
-            return WindowKeySchema.toStoreKeyBinary(key, seq, stateSerdes);
-        } else if (getBaseSchema() instanceof TimeFirstWindowKeySchema) {
+        if (getBaseSchema() instanceof TimeFirstWindowKeySchema) {
+            if (changeLog) {
+                return WindowKeySchema.toStoreKeyBinary(key, seq, stateSerdes);
+            }
             return TimeFirstWindowKeySchema.toStoreKeyBinary(key, seq, 
stateSerdes);
+        } else if (getBaseSchema() instanceof TimeFirstSessionKeySchema) {
+            if (changeLog) {
+                return Bytes.wrap(SessionKeySchema.toBinary(key, 
stateSerdes.keySerializer(), "dummy"));
+            }
+            return Bytes.wrap(TimeFirstSessionKeySchema.toBinary(key, 
stateSerdes.keySerializer(), "dummy"));
         } else {
             throw new IllegalStateException("Unrecognized serde schema");
         }
@@ -1094,6 +1489,8 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
         final StateSerdes<String, Long> stateSerdes = 
StateSerdes.withBuiltinTypes("dummy", String.class, Long.class);
         if (getIndexSchema() instanceof KeyFirstWindowKeySchema) {
             return KeyFirstWindowKeySchema.toStoreKeyBinary(key, 0, 
stateSerdes);
+        } else if (getIndexSchema() instanceof KeyFirstSessionKeySchema) {
+            return Bytes.wrap(KeyFirstSessionKeySchema.toBinary(key, 
stateSerdes.keySerializer(), "dummy"));
         } else {
             throw new IllegalStateException("Unrecognized serde schema");
         }
@@ -1119,6 +1516,12 @@ public abstract class 
AbstractDualSchemaRocksDBSegmentedBytesStoreTest<S extends
                     stateSerdes.valueDeserializer().deserialize("dummy", 
next.value)
                 );
                 results.add(deserialized);
+            } else if (getBaseSchema() instanceof TimeFirstSessionKeySchema) {
+                final KeyValue<Windowed<String>, Long> deserialized = 
KeyValue.pair(
+                    TimeFirstSessionKeySchema.from(next.key.get(), 
stateSerdes.keyDeserializer(), "dummy"),
+                    stateSerdes.valueDeserializer().deserialize("dummy", 
next.value)
+                );
+                results.add(deserialized);
             } else {
                 throw new IllegalStateException("Unrecognized serde schema");
             }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
index deabea8596..b3a749a8a3 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.state.internals;
 
+import java.util.Collection;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.streams.kstream.Windowed;
@@ -27,29 +28,78 @@ import org.apache.kafka.streams.state.SessionStore;
 import org.apache.kafka.streams.state.Stores;
 import org.junit.Test;
 
-import java.util.Arrays;
 import java.util.HashSet;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
 import static java.time.Duration.ofMillis;
+import static java.util.Arrays.asList;
 import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.test.StreamsTestUtils.valuesToSet;
 import static org.junit.Assert.assertEquals;
 
+@RunWith(Parameterized.class)
 public class RocksDBSessionStoreTest extends AbstractSessionBytesStoreTest {
 
     private static final String STORE_NAME = "rocksDB session store";
 
+    enum StoreType {
+        RocksDBSessionStore,
+        RocksDBTimeOrderedSessionStoreWithIndex,
+        RocksDBTimeOrderedSessionStoreWithoutIndex
+    }
+    @Parameter
+    public StoreType storeType;
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Collection<Object[]> getKeySchema() {
+        return asList(new Object[][] {
+            {StoreType.RocksDBSessionStore},
+            {StoreType.RocksDBTimeOrderedSessionStoreWithIndex},
+            {StoreType.RocksDBTimeOrderedSessionStoreWithoutIndex}
+        });
+    }
+
     @Override
     <K, V> SessionStore<K, V> buildSessionStore(final long retentionPeriod,
                                                  final Serde<K> keySerde,
                                                  final Serde<V> valueSerde) {
-        return Stores.sessionStoreBuilder(
-            Stores.persistentSessionStore(
-                STORE_NAME,
-                ofMillis(retentionPeriod)),
-            keySerde,
-            valueSerde).build();
+        switch (storeType) {
+            case RocksDBSessionStore: {
+                return Stores.sessionStoreBuilder(
+                    Stores.persistentSessionStore(
+                        STORE_NAME,
+                        ofMillis(retentionPeriod)),
+                    keySerde,
+                    valueSerde).build();
+            }
+            case RocksDBTimeOrderedSessionStoreWithIndex: {
+                return Stores.sessionStoreBuilder(
+                    new RocksDbTimeOrderedSessionBytesStoreSupplier(
+                        STORE_NAME,
+                        retentionPeriod,
+                        true
+                    ),
+                    keySerde,
+                    valueSerde
+                ).build();
+            }
+            case RocksDBTimeOrderedSessionStoreWithoutIndex: {
+                return Stores.sessionStoreBuilder(
+                    new RocksDbTimeOrderedSessionBytesStoreSupplier(
+                        STORE_NAME,
+                        retentionPeriod,
+                       false
+                    ),
+                    keySerde,
+                    valueSerde
+                ).build();
+            }
+            default:
+                throw new IllegalStateException("Unknown StoreType: " + 
storeType);
+        }
     }
 
     @Test
@@ -64,7 +114,7 @@ public class RocksDBSessionStoreTest extends 
AbstractSessionBytesStoreTest {
         try (final KeyValueIterator<Windowed<String>, Long> iterator =
             sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE)
         ) {
-            assertEquals(valuesToSet(iterator), new 
HashSet<>(Arrays.asList(2L, 3L, 4L)));
+            assertEquals(valuesToSet(iterator), new HashSet<>(asList(2L, 3L, 
4L)));
         }
     }
 
@@ -72,7 +122,7 @@ public class RocksDBSessionStoreTest extends 
AbstractSessionBytesStoreTest {
     public void shouldMatchPositionAfterPut() {
         final MeteredSessionStore<String, Long> meteredSessionStore = 
(MeteredSessionStore<String, Long>) sessionStore;
         final ChangeLoggingSessionBytesStore changeLoggingSessionBytesStore = 
(ChangeLoggingSessionBytesStore) meteredSessionStore.wrapped();
-        final RocksDBSessionStore rocksDBSessionStore = (RocksDBSessionStore) 
changeLoggingSessionBytesStore.wrapped();
+        final WrappedStateStore rocksDBSessionStore = (WrappedStateStore) 
changeLoggingSessionBytesStore.wrapped();
 
         context.setRecordContext(new ProcessorRecordContext(0, 1, 0, "", new 
RecordHeaders()));
         sessionStore.put(new Windowed<String>("a", new SessionWindow(0, 0)), 
1L);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStoreTest.java
deleted file mode 100644
index 0d5b016a9e..0000000000
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedSegmentedBytesStoreTest.java
+++ /dev/null
@@ -1,74 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.kafka.streams.state.internals;
-
-import static java.util.Arrays.asList;
-
-import java.util.Collection;
-import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
-import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
-import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
-import org.junit.runners.Parameterized.Parameter;
-
-@RunWith(Parameterized.class)
-public class RocksDBTimeOrderedSegmentedBytesStoreTest
-    extends AbstractDualSchemaRocksDBSegmentedBytesStoreTest<KeyValueSegment> {
-
-    private final static String METRICS_SCOPE = "metrics-scope";
-
-    @Parameter
-    public String name;
-
-    @Parameter(1)
-    public boolean hasIndex;
-
-    @Parameterized.Parameters(name = "{0}")
-    public static Collection<Object[]> getKeySchema() {
-        return asList(new Object[][] {
-            {"WindowSchemaWithIndex", true},
-            {"WindowSchemaWithoutIndex", false}
-        });
-    }
-
-    AbstractDualSchemaRocksDBSegmentedBytesStore<KeyValueSegment> 
getBytesStore() {
-        return new RocksDBTimeOrderedSegmentedBytesStore(
-            storeName,
-            METRICS_SCOPE,
-            retention,
-            segmentInterval,
-            hasIndex
-        );
-    }
-
-    @Override
-    KeyValueSegments newSegments() {
-        return new KeyValueSegments(storeName, METRICS_SCOPE, retention, 
segmentInterval);
-    }
-
-    @Override
-    KeySchema getBaseSchema() {
-        return new TimeFirstWindowKeySchema();
-    }
-
-    @Override
-    KeySchema getIndexSchema() {
-        return hasIndex ? new KeyFirstWindowKeySchema() : null;
-    }
-
-}
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowSegmentedBytesStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowSegmentedBytesStoreTest.java
new file mode 100644
index 0000000000..db02f5b6ff
--- /dev/null
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimeOrderedWindowSegmentedBytesStoreTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import static java.util.Arrays.asList;
+
+import java.util.Collection;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.KeyFirstSessionKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.TimeFirstSessionKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.KeyFirstWindowKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedWindowKeySchemas.TimeFirstWindowKeySchema;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(Parameterized.class)
+public class RocksDBTimeOrderedWindowSegmentedBytesStoreTest
+    extends AbstractDualSchemaRocksDBSegmentedBytesStoreTest<KeyValueSegment> {
+
+    private final static String METRICS_SCOPE = "metrics-scope";
+
+    private enum SchemaType {
+        WindowSchemaWithIndex,
+        WindowSchemaWithoutIndex,
+        SessionSchemaWithIndex,
+        SessionSchemaWithoutIndex
+    }
+
+    private boolean hasIndex;
+    private SchemaType schemaType;
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Collection<Object[]> getKeySchema() {
+        return asList(new Object[][] {
+            {SchemaType.WindowSchemaWithIndex, true},
+            {SchemaType.WindowSchemaWithoutIndex, false},
+            {SchemaType.SessionSchemaWithIndex, true},
+            {SchemaType.SessionSchemaWithoutIndex, false}
+        });
+    }
+
+    public RocksDBTimeOrderedWindowSegmentedBytesStoreTest(final SchemaType 
schemaType, final boolean hasIndex) {
+        this.schemaType = schemaType;
+        this.hasIndex = hasIndex;
+    }
+
+
+    AbstractDualSchemaRocksDBSegmentedBytesStore<KeyValueSegment> 
getBytesStore() {
+        switch (schemaType) {
+            case WindowSchemaWithIndex:
+            case WindowSchemaWithoutIndex:
+                return new RocksDBTimeOrderedWindowSegmentedBytesStore(
+                    storeName,
+                    METRICS_SCOPE,
+                    retention,
+                    segmentInterval,
+                    hasIndex
+                );
+            case SessionSchemaWithIndex:
+            case SessionSchemaWithoutIndex:
+                return new RocksDBTimeOrderedSessionSegmentedBytesStore(
+                    storeName,
+                    METRICS_SCOPE,
+                    retention,
+                    segmentInterval,
+                    hasIndex
+                );
+            default:
+                throw new IllegalStateException("Unknown SchemaType: " + 
schemaType);
+        }
+    }
+
+    @Override
+    KeyValueSegments newSegments() {
+        return new KeyValueSegments(storeName, METRICS_SCOPE, retention, 
segmentInterval);
+    }
+
+    @Override
+    KeySchema getBaseSchema() {
+        switch (schemaType) {
+            case WindowSchemaWithIndex:
+            case WindowSchemaWithoutIndex:
+                return new TimeFirstWindowKeySchema();
+            case SessionSchemaWithIndex:
+            case SessionSchemaWithoutIndex:
+                return new TimeFirstSessionKeySchema();
+            default:
+                throw new IllegalStateException("Unknown SchemaType: " + 
schemaType);
+        }
+    }
+
+    @Override
+    KeySchema getIndexSchema() {
+        if (!hasIndex) {
+            return null;
+        }
+        switch (schemaType) {
+            case WindowSchemaWithIndex:
+                return new KeyFirstWindowKeySchema();
+            case SessionSchemaWithIndex:
+                return new KeyFirstSessionKeySchema();
+            default:
+                throw new IllegalStateException("Unknown SchemaType: " + 
schemaType);
+        }
+    }
+
+}
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
index 5abfd0667d..c0c7e963e6 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java
@@ -68,17 +68,14 @@ public class RocksDBWindowStoreTest extends 
AbstractWindowBytesStoreTest {
     }
 
     @Parameter
-    public String name;
-
-    @Parameter(1)
     public StoreType storeType;
 
     @Parameterized.Parameters(name = "{0}")
     public static Collection<Object[]> getKeySchema() {
         return asList(new Object[][] {
-            {"RocksDBWindowStore", StoreType.RocksDBWindowStore},
-            {"RocksDBTimeOrderedWindowStoreWithIndex", 
StoreType.RocksDBTimeOrderedWindowStoreWithIndex},
-            {"RocksDBTimeOrderedWindowStoreWithoutIndex", 
StoreType.RocksDBTimeOrderedWindowStoreWithoutIndex}
+            {StoreType.RocksDBWindowStore},
+            {StoreType.RocksDBTimeOrderedWindowStoreWithIndex},
+            {StoreType.RocksDBTimeOrderedWindowStoreWithoutIndex}
         });
     }
 
@@ -88,32 +85,41 @@ public class RocksDBWindowStoreTest extends 
AbstractWindowBytesStoreTest {
                                               final boolean retainDuplicates,
                                               final Serde<K> keySerde,
                                               final Serde<V> valueSerde) {
-        if (storeType == StoreType.RocksDBWindowStore) {
-            return Stores.windowStoreBuilder(
-                    Stores.persistentWindowStore(
-                        STORE_NAME,
-                        ofMillis(retentionPeriod),
-                        ofMillis(windowSize),
-                        retainDuplicates),
+
+        switch (storeType) {
+            case RocksDBWindowStore: {
+                return Stores.windowStoreBuilder(
+                        Stores.persistentWindowStore(
+                            STORE_NAME,
+                            ofMillis(retentionPeriod),
+                            ofMillis(windowSize),
+                            retainDuplicates),
+                        keySerde,
+                        valueSerde)
+                    .build();
+            }
+            case RocksDBTimeOrderedWindowStoreWithIndex: {
+                final long defaultSegmentInterval = Math.max(retentionPeriod / 
2, 60_000L);
+                return Stores.windowStoreBuilder(
+                    new 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(STORE_NAME,
+                        retentionPeriod, defaultSegmentInterval, windowSize, 
retainDuplicates,
+                        true),
                     keySerde,
-                    valueSerde)
-                .build();
-        } else if (storeType == 
StoreType.RocksDBTimeOrderedWindowStoreWithIndex) {
-            final long defaultSegmentInterval = Math.max(retentionPeriod / 2, 
60_000L);
-            return Stores.windowStoreBuilder(
-                new 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(STORE_NAME,
-                    retentionPeriod, defaultSegmentInterval, windowSize, 
retainDuplicates, true),
-                keySerde,
-                valueSerde
-            ).build();
-        } else {
-            final long defaultSegmentInterval = Math.max(retentionPeriod / 2, 
60_000L);
-            return Stores.windowStoreBuilder(
-                new 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(STORE_NAME,
-                    retentionPeriod, defaultSegmentInterval, windowSize, 
retainDuplicates, false),
-                keySerde,
-                valueSerde
-            ).build();
+                    valueSerde
+                ).build();
+            }
+            case RocksDBTimeOrderedWindowStoreWithoutIndex: {
+                final long defaultSegmentInterval = Math.max(retentionPeriod / 
2, 60_000L);
+                return Stores.windowStoreBuilder(
+                    new 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier(STORE_NAME,
+                        retentionPeriod, defaultSegmentInterval, windowSize, 
retainDuplicates,
+                        false),
+                    keySerde,
+                    valueSerde
+                ).build();
+            }
+            default:
+                throw new IllegalStateException("Unknown StoreType: " + 
storeType);
         }
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplierTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplierTest.java
index ed1bbb8fd4..fad4cc5a47 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplierTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDbIndexedTimeOrderedWindowBytesStoreSupplierTest.java
@@ -60,8 +60,8 @@ public class 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplierTest {
         final WindowStore store = 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.create("store", ofMillis(1L), 
ofMillis(1L), false, true).get();
         final StateStore wrapped = ((WrappedStateStore) store).wrapped();
         assertThat(store, instanceOf(RocksDBTimeOrderedWindowStore.class));
-        assertThat(wrapped, 
instanceOf(RocksDBTimeOrderedSegmentedBytesStore.class));
-        assertTrue(((RocksDBTimeOrderedSegmentedBytesStore) 
wrapped).hasIndex());
+        assertThat(wrapped, 
instanceOf(RocksDBTimeOrderedWindowSegmentedBytesStore.class));
+        assertTrue(((RocksDBTimeOrderedWindowSegmentedBytesStore) 
wrapped).hasIndex());
     }
 
     @Test
@@ -69,7 +69,7 @@ public class 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplierTest {
         final WindowStore store = 
RocksDbIndexedTimeOrderedWindowBytesStoreSupplier.create("store", ofMillis(1L), 
ofMillis(1L), false, false).get();
         final StateStore wrapped = ((WrappedStateStore) store).wrapped();
         assertThat(store, instanceOf(RocksDBTimeOrderedWindowStore.class));
-        assertThat(wrapped, 
instanceOf(RocksDBTimeOrderedSegmentedBytesStore.class));
-        assertFalse(((RocksDBTimeOrderedSegmentedBytesStore) 
wrapped).hasIndex());
+        assertThat(wrapped, 
instanceOf(RocksDBTimeOrderedWindowSegmentedBytesStore.class));
+        assertFalse(((RocksDBTimeOrderedWindowSegmentedBytesStore) 
wrapped).hasIndex());
     }
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
index 0482f01ba5..8b5391a7cb 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java
@@ -17,30 +17,101 @@
 
 package org.apache.kafka.streams.state.internals;
 
+import java.util.Collection;
+import java.util.Map;
+import java.util.function.Function;
+import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.kstream.Window;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.kstream.WindowedSerdes;
 import org.apache.kafka.streams.kstream.internals.SessionWindow;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.KeyFirstSessionKeySchema;
+import 
org.apache.kafka.streams.state.internals.PrefixedSessionKeySchemas.TimeFirstSessionKeySchema;
+import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema;
 import org.apache.kafka.test.KeyValueIteratorStub;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
+import static java.util.Arrays.asList;
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.IsEqual.equalTo;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 
+@RunWith(Parameterized.class)
 public class SessionKeySchemaTest {
+    private static final Map<SchemaType, KeySchema> SCHEMA_TYPE_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, new SessionKeySchema()),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, new 
KeyFirstSessionKeySchema()),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, new 
TimeFirstSessionKeySchema())
+    );
+
+    private static final Map<SchemaType, Function<Windowed<Bytes>, Bytes>> 
WINDOW_TO_STORE_BINARY_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, SessionKeySchema::toBinary),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::toBinary),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::toBinary)
+    );
+
+    private static final Map<SchemaType, Function<byte[], Long>> 
EXTRACT_END_TS_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, 
SessionKeySchema::extractEndTimestamp),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::extractEndTimestamp),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::extractEndTimestamp)
+    );
+
+    private static final Map<SchemaType, Function<byte[], Long>> 
EXTRACT_START_TS_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, 
SessionKeySchema::extractStartTimestamp),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::extractStartTimestamp),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::extractStartTimestamp)
+    );
+
+    @FunctionalInterface
+    interface TriFunction<A, B, C, R> {
+        R apply(A a, B b, C c);
+    }
+
+    private static final Map<SchemaType, TriFunction<Windowed<String>, 
Serializer<String>, String, byte[]>> SERDE_TO_STORE_BINARY_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, SessionKeySchema::toBinary),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::toBinary),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::toBinary)
+    );
+
+    private static final Map<SchemaType, TriFunction<byte[], 
Deserializer<String>, String, Windowed<String>>> SERDE_FROM_BYTES_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, SessionKeySchema::from),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::from),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::from)
+    );
+
+    private static final Map<SchemaType, Function<Bytes, Windowed<Bytes>>> 
FROM_BYTES_MAP = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, SessionKeySchema::from),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::from),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::from)
+    );
+
+    private static final Map<SchemaType, Function<byte[], Window>> 
EXTRACT_WINDOW = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, SessionKeySchema::extractWindow),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::extractWindow),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::extractWindow)
+    );
+
+    private static final Map<SchemaType, Function<byte[], byte[]>> 
EXTRACT_KEY_BYTES = mkMap(
+        mkEntry(SchemaType.SessionKeySchema, 
SessionKeySchema::extractKeyBytes),
+        mkEntry(SchemaType.PrefixedKeyFirstSchema, 
KeyFirstSessionKeySchema::extractKeyBytes),
+        mkEntry(SchemaType.PrefixedTimeFirstSchema, 
TimeFirstSessionKeySchema::extractKeyBytes)
+    );
 
     private final String key = "key";
     private final String topic = "topic";
@@ -52,8 +123,45 @@ public class SessionKeySchemaTest {
     private final Windowed<String> windowedKey = new Windowed<>(key, window);
     private final Serde<Windowed<String>> keySerde = new 
WindowedSerdes.SessionWindowedSerde<>(serde);
 
-    private final SessionKeySchema sessionKeySchema = new SessionKeySchema();
+    private final KeySchema keySchema;
     private DelegatingPeekingKeyValueIterator<Bytes, Integer> iterator;
+    private final SchemaType schemaType;
+    private final Function<Windowed<Bytes>, Bytes> toBinary;
+    private final TriFunction<Windowed<String>, Serializer<String>, String, 
byte[]> serdeToBinary;
+    private final TriFunction<byte[], Deserializer<String>, String, 
Windowed<String>> serdeFromBytes;
+    private final Function<Bytes, Windowed<Bytes>> fromBytes;
+    private final Function<byte[], Long> extractStartTS;
+    private final Function<byte[], Long> extractEndTS;
+    private final Function<byte[], byte[]> extractKeyBytes;
+    private final Function<byte[], Window> extractWindow;
+
+    private enum SchemaType {
+        SessionKeySchema,
+        PrefixedTimeFirstSchema,
+        PrefixedKeyFirstSchema
+    }
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Collection<Object[]> data() {
+        return asList(new Object[][] {
+            {SchemaType.SessionKeySchema},
+            {SchemaType.PrefixedTimeFirstSchema},
+            {SchemaType.PrefixedKeyFirstSchema}
+        });
+    }
+
+    public SessionKeySchemaTest(final SchemaType type) {
+        schemaType = type;
+        keySchema = SCHEMA_TYPE_MAP.get(type);
+        toBinary = WINDOW_TO_STORE_BINARY_MAP.get(schemaType);
+        serdeToBinary = SERDE_TO_STORE_BINARY_MAP.get(schemaType);
+        serdeFromBytes = SERDE_FROM_BYTES_MAP.get(schemaType);
+        fromBytes = FROM_BYTES_MAP.get(schemaType);
+        extractStartTS = EXTRACT_START_TS_MAP.get(schemaType);
+        extractEndTS = EXTRACT_END_TS_MAP.get(schemaType);
+        extractKeyBytes = EXTRACT_KEY_BYTES.get(schemaType);
+        extractWindow = EXTRACT_WINDOW.get(schemaType);
+    }
 
     @After
     public void after() {
@@ -64,44 +172,44 @@ public class SessionKeySchemaTest {
 
     @Before
     public void before() {
-        final List<KeyValue<Bytes, Integer>> keys = 
Arrays.asList(KeyValue.pair(SessionKeySchema.toBinary(new 
Windowed<>(Bytes.wrap(new byte[]{0, 0}), new SessionWindow(0, 0))), 1),
-                                                                  
KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new 
byte[]{0}), new SessionWindow(0, 0))), 2),
-                                                                  
KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 
0, 0}), new SessionWindow(0, 0))), 3),
-                                                                  
KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new 
byte[]{0}), new SessionWindow(10, 20))), 4),
-                                                                  
KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 
0}), new SessionWindow(10, 20))), 5),
-                                                                  
KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 
0, 0}), new SessionWindow(10, 20))), 6));
+        final List<KeyValue<Bytes, Integer>> keys = 
asList(KeyValue.pair(toBinary.apply(new Windowed<>(Bytes.wrap(new byte[]{0, 
0}), new SessionWindow(0, 0))), 1),
+                                                                  
KeyValue.pair(toBinary.apply(new Windowed<>(Bytes.wrap(new byte[]{0}), new 
SessionWindow(0, 0))), 2),
+                                                                  
KeyValue.pair(toBinary.apply(new Windowed<>(Bytes.wrap(new byte[]{0, 0, 0}), 
new SessionWindow(0, 0))), 3),
+                                                                  
KeyValue.pair(toBinary.apply(new Windowed<>(Bytes.wrap(new byte[]{0}), new 
SessionWindow(10, 20))), 4),
+                                                                  
KeyValue.pair(toBinary.apply(new Windowed<>(Bytes.wrap(new byte[]{0, 0}), new 
SessionWindow(10, 20))), 5),
+                                                                  
KeyValue.pair(toBinary.apply(new Windowed<>(Bytes.wrap(new byte[]{0, 0, 0}), 
new SessionWindow(10, 20))), 6));
         iterator = new DelegatingPeekingKeyValueIterator<>("foo", new 
KeyValueIteratorStub<>(keys.iterator()));
     }
 
     @Test
     public void shouldFetchExactKeysSkippingLongerKeys() {
         final Bytes key = Bytes.wrap(new byte[]{0});
-        final List<Integer> result = 
getValues(sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE, true));
-        assertThat(result, equalTo(Arrays.asList(2, 4)));
+        final List<Integer> result = getValues(keySchema.hasNextCondition(key, 
key, 0, Long.MAX_VALUE, true));
+        assertThat(result, equalTo(asList(2, 4)));
     }
 
     @Test
     public void shouldFetchExactKeySkippingShorterKeys() {
         final Bytes key = Bytes.wrap(new byte[]{0, 0});
-        final HasNextCondition hasNextCondition = 
sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE, true);
+        final HasNextCondition hasNextCondition = 
keySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE, true);
         final List<Integer> results = getValues(hasNextCondition);
-        assertThat(results, equalTo(Arrays.asList(1, 5)));
+        assertThat(results, equalTo(asList(1, 5)));
     }
 
     @Test
     public void shouldFetchAllKeysUsingNullKeys() {
-        final HasNextCondition hasNextCondition = 
sessionKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, true);
+        final HasNextCondition hasNextCondition = 
keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE, true);
         final List<Integer> results = getValues(hasNextCondition);
-        assertThat(results, equalTo(Arrays.asList(1, 2, 3, 4, 5, 6)));
+        assertThat(results, equalTo(asList(1, 2, 3, 4, 5, 6)));
     }
     
     @Test
     public void testUpperBoundWithLargeTimestamps() {
-        final Bytes upper = sessionKeySchema.upperRange(Bytes.wrap(new 
byte[]{0xA, 0xB, 0xC}), Long.MAX_VALUE);
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[]{0xA, 
0xB, 0xC}), Long.MAX_VALUE);
 
         assertThat(
             "shorter key with max timestamp should be in range",
-            upper.compareTo(SessionKeySchema.toBinary(
+            upper.compareTo(toBinary.apply(
                 new Windowed<>(
                     Bytes.wrap(new byte[]{0xA}),
                     new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))
@@ -110,7 +218,7 @@ public class SessionKeySchemaTest {
 
         assertThat(
             "shorter key with max timestamp should be in range",
-            upper.compareTo(SessionKeySchema.toBinary(
+            upper.compareTo(toBinary.apply(
                 new Windowed<>(
                     Bytes.wrap(new byte[]{0xA, 0xB}),
                     new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))
@@ -118,18 +226,26 @@ public class SessionKeySchemaTest {
             )) >= 0
         );
 
-        assertThat(upper, equalTo(SessionKeySchema.toBinary(
-            new Windowed<>(Bytes.wrap(new byte[]{0xA}), new 
SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))))
-        );
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(upper, equalTo(toBinary.apply(
+                new Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}),
+                    new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))))
+            );
+        } else {
+            assertThat(upper, equalTo(toBinary.apply(
+                new Windowed<>(Bytes.wrap(new byte[]{0xA}),
+                    new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))))
+            );
+        }
     }
 
     @Test
     public void testUpperBoundWithKeyBytesLargerThanFirstTimestampByte() {
-        final Bytes upper = sessionKeySchema.upperRange(Bytes.wrap(new 
byte[]{0xA, (byte) 0x8F, (byte) 0x9F}), Long.MAX_VALUE);
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[]{0xA, 
(byte) 0x8F, (byte) 0x9F}), Long.MAX_VALUE);
 
         assertThat(
             "shorter key with max timestamp should be in range",
-            upper.compareTo(SessionKeySchema.toBinary(
+            upper.compareTo(toBinary.apply(
                 new Windowed<>(
                     Bytes.wrap(new byte[]{0xA, (byte) 0x8F}),
                     new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))
@@ -137,40 +253,53 @@ public class SessionKeySchemaTest {
             ) >= 0
         );
 
-        assertThat(upper, equalTo(SessionKeySchema.toBinary(
+        assertThat(upper, equalTo(toBinary.apply(
             new Windowed<>(Bytes.wrap(new byte[]{0xA, (byte) 0x8F, (byte) 
0x9F}), new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))))
         );
     }
 
     @Test
     public void testUpperBoundWithZeroTimestamp() {
-        final Bytes upper = sessionKeySchema.upperRange(Bytes.wrap(new 
byte[]{0xA, 0xB, 0xC}), 0);
-
-        assertThat(upper, equalTo(SessionKeySchema.toBinary(
-            new Windowed<>(Bytes.wrap(new byte[]{0xA}), new SessionWindow(0, 
Long.MAX_VALUE))))
-        );
+        final Bytes upper = keySchema.upperRange(Bytes.wrap(new byte[]{0xA, 
0xB, 0xC}), 0);
+        final Function<Windowed<Bytes>, Bytes> toBinary = 
WINDOW_TO_STORE_BINARY_MAP.get(schemaType);
+
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(upper, equalTo(toBinary.apply(
+                new Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new 
SessionWindow(0, Long.MAX_VALUE))))
+            );
+        } else {
+            assertThat(upper, equalTo(toBinary.apply(
+                new Windowed<>(Bytes.wrap(new byte[]{0xA}), new 
SessionWindow(0, Long.MAX_VALUE))))
+            );
+        }
     }
 
     @Test
     public void testLowerBoundWithZeroTimestamp() {
-        final Bytes lower = sessionKeySchema.lowerRange(Bytes.wrap(new 
byte[]{0xA, 0xB, 0xC}), 0);
-        assertThat(lower, equalTo(SessionKeySchema.toBinary(new 
Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new SessionWindow(0, 0)))));
+        final Bytes lower = keySchema.lowerRange(Bytes.wrap(new byte[]{0xA, 
0xB, 0xC}), 0);
+        assertThat(lower, equalTo(toBinary.apply(new Windowed<>(Bytes.wrap(new 
byte[]{0xA, 0xB, 0xC}), new SessionWindow(0, 0)))));
     }
 
     @Test
     public void testLowerBoundMatchesTrailingZeros() {
-        final Bytes lower = sessionKeySchema.lowerRange(Bytes.wrap(new 
byte[]{0xA, 0xB, 0xC}), Long.MAX_VALUE);
+        final Bytes lower = keySchema.lowerRange(Bytes.wrap(new byte[]{0xA, 
0xB, 0xC}), Long.MAX_VALUE);
 
         assertThat(
             "appending zeros to key should still be in range",
-            lower.compareTo(SessionKeySchema.toBinary(
+            lower.compareTo(toBinary.apply(
                 new Windowed<>(
                     Bytes.wrap(new byte[]{0xA, 0xB, 0xC, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0}),
                     new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE))
             )) < 0
         );
 
-        assertThat(lower, equalTo(SessionKeySchema.toBinary(new 
Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new SessionWindow(0, 0)))));
+        if (schemaType == SchemaType.PrefixedTimeFirstSchema) {
+            assertThat(lower, equalTo(toBinary.apply(
+                new Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new 
SessionWindow(0, Long.MAX_VALUE)))));
+        } else {
+            assertThat(lower, equalTo(toBinary.apply(
+                new Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new 
SessionWindow(0, 0)))));
+        }
     }
 
     @Test
@@ -197,47 +326,47 @@ public class SessionKeySchemaTest {
 
     @Test
     public void shouldConvertToBinaryAndBack() {
-        final byte[] serialized = SessionKeySchema.toBinary(windowedKey, 
serde.serializer(), "dummy");
-        final Windowed<String> result = SessionKeySchema.from(serialized, 
Serdes.String().deserializer(), "dummy");
+        final byte[] serialized = serdeToBinary.apply(windowedKey, 
serde.serializer(), "dummy");
+        final Windowed<String> result = serdeFromBytes.apply(serialized, 
Serdes.String().deserializer(), "dummy");
         assertEquals(windowedKey, result);
     }
 
     @Test
     public void shouldExtractEndTimeFromBinary() {
-        final byte[] serialized = SessionKeySchema.toBinary(windowedKey, 
serde.serializer(), "dummy");
-        assertEquals(endTime, 
SessionKeySchema.extractEndTimestamp(serialized));
+        final byte[] serialized = serdeToBinary.apply(windowedKey, 
serde.serializer(), "dummy");
+        assertEquals(endTime, (long) extractEndTS.apply(serialized));
     }
 
     @Test
     public void shouldExtractStartTimeFromBinary() {
-        final byte[] serialized = SessionKeySchema.toBinary(windowedKey, 
serde.serializer(), "dummy");
-        assertEquals(startTime, 
SessionKeySchema.extractStartTimestamp(serialized));
+        final byte[] serialized = serdeToBinary.apply(windowedKey, 
serde.serializer(), "dummy");
+        assertEquals(startTime, (long) extractStartTS.apply(serialized));
     }
 
     @Test
     public void shouldExtractWindowFromBindary() {
-        final byte[] serialized = SessionKeySchema.toBinary(windowedKey, 
serde.serializer(), "dummy");
-        assertEquals(window, SessionKeySchema.extractWindow(serialized));
+        final byte[] serialized = serdeToBinary.apply(windowedKey, 
serde.serializer(), "dummy");
+        assertEquals(window, extractWindow.apply(serialized));
     }
 
     @Test
     public void shouldExtractKeyBytesFromBinary() {
-        final byte[] serialized = SessionKeySchema.toBinary(windowedKey, 
serde.serializer(), "dummy");
-        assertArrayEquals(key.getBytes(), 
SessionKeySchema.extractKeyBytes(serialized));
+        final byte[] serialized = serdeToBinary.apply(windowedKey, 
serde.serializer(), "dummy");
+        assertArrayEquals(key.getBytes(), extractKeyBytes.apply(serialized));
     }
 
     @Test
     public void shouldExtractKeyFromBinary() {
-        final byte[] serialized = SessionKeySchema.toBinary(windowedKey, 
serde.serializer(), "dummy");
-        assertEquals(windowedKey, SessionKeySchema.from(serialized, 
serde.deserializer(), "dummy"));
+        final byte[] serialized = serdeToBinary.apply(windowedKey, 
serde.serializer(), "dummy");
+        assertEquals(windowedKey, serdeFromBytes.apply(serialized, 
serde.deserializer(), "dummy"));
     }
 
     @Test
     public void shouldExtractBytesKeyFromBinary() {
         final Bytes bytesKey = Bytes.wrap(key.getBytes());
         final Windowed<Bytes> windowedBytesKey = new Windowed<>(bytesKey, 
window);
-        final Bytes serialized = SessionKeySchema.toBinary(windowedBytesKey);
-        assertEquals(windowedBytesKey, SessionKeySchema.from(serialized));
+        final Bytes serialized = toBinary.apply(windowedBytesKey);
+        assertEquals(windowedBytesKey, fromBytes.apply(serialized));
     }
 
     private List<Integer> getValues(final HasNextCondition hasNextCondition) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java
index b9ef24d3c6..bf597fb789 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedWindowStoreTest.java
@@ -102,7 +102,7 @@ public class TimeOrderedWindowStoreTest {
     private static final String CACHE_NAMESPACE = "0_0-store-name";
 
     private InternalMockProcessorContext context;
-    private RocksDBTimeOrderedSegmentedBytesStore bytesStore;
+    private RocksDBTimeOrderedWindowSegmentedBytesStore bytesStore;
     private WindowStore<Bytes, byte[]> underlyingStore;
     private TimeOrderedCachingWindowStore cachingStore;
     private CacheFlushListenerStub<Windowed<String>, String> cacheListener;
@@ -123,7 +123,7 @@ public class TimeOrderedWindowStoreTest {
     @Before
     public void setUp() {
         baseKeySchema = new TimeFirstWindowKeySchema();
-        bytesStore = new RocksDBTimeOrderedSegmentedBytesStore("test", 
"metrics-scope", 100, SEGMENT_INTERVAL, hasIndex);
+        bytesStore = new RocksDBTimeOrderedWindowSegmentedBytesStore("test", 
"metrics-scope", 100, SEGMENT_INTERVAL, hasIndex);
         underlyingStore = new RocksDBTimeOrderedWindowStore(bytesStore, false, 
WINDOW_SIZE);
         final TimeWindowedDeserializer<String> keyDeserializer = new 
TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE);
         keyDeserializer.setIsChangelogTopic(true);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
index e9360534a8..4729a73f14 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java
@@ -130,7 +130,7 @@ public class WindowKeySchemaTest {
     final private KeySchema keySchema;
     final private Serde<Windowed<String>> keySerde = new 
WindowedSerdes.TimeWindowedSerde<>(serde, Long.MAX_VALUE);
     final private StateSerdes<String, byte[]> stateSerdes = new 
StateSerdes<>("dummy", serde, Serdes.ByteArray());
-    final private SchemaType schemaType;
+    final public SchemaType schemaType;
 
     private enum SchemaType {
         WindowKeySchema,
@@ -141,13 +141,13 @@ public class WindowKeySchemaTest {
     @Parameterized.Parameters(name = "{0}")
     public static Collection<Object[]> data() {
         return asList(new Object[][] {
-            {"WindowKeySchema", SchemaType.WindowKeySchema},
-            {"PrefixedTimeFirstSchema", SchemaType.PrefixedTimeFirstSchema},
-            {"PrefixedKeyFirstSchema", SchemaType.PrefixedKeyFirstSchema}
+            {SchemaType.WindowKeySchema},
+            {SchemaType.PrefixedTimeFirstSchema},
+            {SchemaType.PrefixedKeyFirstSchema}
         });
     }
 
-    public WindowKeySchemaTest(final String name, final SchemaType type) {
+    public WindowKeySchemaTest(final SchemaType type) {
         schemaType = type;
         keySchema = SCHEMA_TYPE_MAP.get(type);
     }

Reply via email to