Repository: flink
Updated Branches:
  refs/heads/master 0651876ae -> eeac022f0


[FLINK-8679][State Backends] Ensure that RocksDBKeyedBackend.getKeys() filters 
keys by namespace

This closes #5518.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/eeac022f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/eeac022f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/eeac022f

Branch: refs/heads/master
Commit: eeac022f0538e0979e6ad4eb06a2d1031cbd0146
Parents: 0651876
Author: sihuazhou <summerle...@163.com>
Authored: Sat Feb 17 01:26:12 2018 +0800
Committer: Stefan Richter <s.rich...@data-artisans.com>
Committed: Fri Feb 23 10:57:36 2018 +0100

----------------------------------------------------------------------
 .../state/heap/CopyOnWriteStateTable.java       |   1 +
 .../runtime/state/StateBackendTestBase.java     |  46 ++++--
 .../streaming/state/AbstractRocksDBState.java   |  64 +-------
 .../state/RocksDBKeySerializationUtils.java     | 141 +++++++++++++++++
 .../state/RocksDBKeyedStateBackend.java         | 104 +++++++++++--
 .../state/RocksDBKeySerializationUtilsTest.java | 100 ++++++++++++
 .../state/RocksDBRocksIteratorWrapperTest.java  | 152 +++++++++++++++++++
 7 files changed, 523 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
index c5f2937..4ecb0ed 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
@@ -293,6 +293,7 @@ public class CopyOnWriteStateTable<K, N, S> extends 
StateTable<K, N, S> implemen
        public Stream<K> getKeys(N namespace) {
                Iterable<StateEntry<K, N, S>> iterable = () -> iterator();
                return StreamSupport.stream(iterable.spliterator(), false)
+                       .filter(entry -> entry.getNamespace().equals(namespace))
                        .map(entry -> entry.getKey());
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 8acefa4..7838450 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -211,24 +211,52 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
        @Test
        public void testGetKeys() throws Exception {
-               final int elementsToTest = 1000;
+               final int namespace1ElementsNum = 1000;
+               final int namespace2ElementsNum = 1000;
                String fieldName = "get-keys-test";
                AbstractKeyedStateBackend<Integer> backend = 
createKeyedBackend(IntSerializer.INSTANCE);
                try {
-                       ValueState<Integer> keyedState = 
backend.getOrCreateKeyedState(
-                               VoidNamespaceSerializer.INSTANCE,
-                               new ValueStateDescriptor<>(fieldName, 
IntSerializer.INSTANCE));
-                       ((InternalValueState<VoidNamespace, Integer>) 
keyedState).setCurrentNamespace(VoidNamespace.INSTANCE);
+                       final String ns1 = "ns1";
+                       ValueState<Integer> keyedState1 = 
backend.getPartitionedState(
+                               ns1,
+                               StringSerializer.INSTANCE,
+                               new ValueStateDescriptor<>(fieldName, 
IntSerializer.INSTANCE)
+                       );
+
+                       for (int key = 0; key < namespace1ElementsNum; key++) {
+                               backend.setCurrentKey(key);
+                               keyedState1.update(key * 2);
+                       }
+
+                       final String ns2 = "ns2";
+                       ValueState<Integer> keyedState2 = 
backend.getPartitionedState(
+                               ns2,
+                               StringSerializer.INSTANCE,
+                               new ValueStateDescriptor<>(fieldName, 
IntSerializer.INSTANCE)
+                       );
 
-                       for (int key = 0; key < elementsToTest; key++) {
+                       for (int key = namespace1ElementsNum; key < 
namespace1ElementsNum + namespace2ElementsNum; key++) {
                                backend.setCurrentKey(key);
-                               keyedState.update(key * 2);
+                               keyedState2.update(key * 2);
+                       }
+
+                       // valid for namespace1
+                       try (Stream<Integer> keysStream = 
backend.getKeys(fieldName, ns1).sorted()) {
+                               PrimitiveIterator.OfInt actualIterator = 
keysStream.mapToInt(value -> value.intValue()).iterator();
+
+                               for (int expectedKey = 0; expectedKey < 
namespace1ElementsNum; expectedKey++) {
+                                       assertTrue(actualIterator.hasNext());
+                                       assertEquals(expectedKey, 
actualIterator.nextInt());
+                               }
+
+                               assertFalse(actualIterator.hasNext());
                        }
 
-                       try (Stream<Integer> keysStream = 
backend.getKeys(fieldName, VoidNamespace.INSTANCE).sorted()) {
+                       // valid for namespace2
+                       try (Stream<Integer> keysStream = 
backend.getKeys(fieldName, ns2).sorted()) {
                                PrimitiveIterator.OfInt actualIterator = 
keysStream.mapToInt(value -> value.intValue()).iterator();
 
-                               for (int expectedKey = 0; expectedKey < 
elementsToTest; expectedKey++) {
+                               for (int expectedKey = namespace1ElementsNum; 
expectedKey < namespace1ElementsNum + namespace2ElementsNum; expectedKey++) {
                                        assertTrue(actualIterator.hasNext());
                                        assertEquals(expectedKey, 
actualIterator.nextInt());
                                }

http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
----------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
index 64b6d48..6db0e86 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
@@ -95,8 +95,7 @@ public abstract class AbstractRocksDBState<K, N, S extends 
State, SD extends Sta
 
                this.keySerializationStream = new 
ByteArrayOutputStreamWithPos(128);
                this.keySerializationDataOutputView = new 
DataOutputViewStreamWrapper(keySerializationStream);
-               this.ambiguousKeyPossible = 
(backend.getKeySerializer().getLength() < 0)
-                               && (namespaceSerializer.getLength() < 0);
+               this.ambiguousKeyPossible = 
RocksDBKeySerializationUtils.isAmbiguousKeyPossible(backend.getKeySerializer(), 
namespaceSerializer);
        }
 
        // 
------------------------------------------------------------------------
@@ -158,63 +157,8 @@ public abstract class AbstractRocksDBState<K, N, S extends 
State, SD extends Sta
                Preconditions.checkNotNull(key, "No key set. This method should 
not be called outside of a keyed context.");
 
                keySerializationStream.reset();
-               writeKeyGroup(keyGroup, keySerializationDataOutputView);
-               writeKey(key, keySerializationStream, 
keySerializationDataOutputView);
-               writeNameSpace(namespace, keySerializationStream, 
keySerializationDataOutputView);
-       }
-
-       private void writeKeyGroup(
-                       int keyGroup,
-                       DataOutputView keySerializationDateDataOutputView) 
throws IOException {
-               for (int i = backend.getKeyGroupPrefixBytes(); --i >= 0;) {
-                       keySerializationDateDataOutputView.writeByte(keyGroup 
>>> (i << 3));
-               }
-       }
-
-       private void writeKey(
-                       K key,
-                       ByteArrayOutputStreamWithPos keySerializationStream,
-                       DataOutputView keySerializationDataOutputView) throws 
IOException {
-               //write key
-               int beforeWrite = keySerializationStream.getPosition();
-               backend.getKeySerializer().serialize(key, 
keySerializationDataOutputView);
-
-               if (ambiguousKeyPossible) {
-                       //write size of key
-                       writeLengthFrom(beforeWrite, keySerializationStream,
-                               keySerializationDataOutputView);
-               }
-       }
-
-       private void writeNameSpace(
-                       N namespace,
-                       ByteArrayOutputStreamWithPos keySerializationStream,
-                       DataOutputView keySerializationDataOutputView) throws 
IOException {
-               int beforeWrite = keySerializationStream.getPosition();
-               namespaceSerializer.serialize(namespace, 
keySerializationDataOutputView);
-
-               if (ambiguousKeyPossible) {
-                       //write length of namespace
-                       writeLengthFrom(beforeWrite, keySerializationStream,
-                               keySerializationDataOutputView);
-               }
-       }
-
-       private static void writeLengthFrom(
-                       int fromPosition,
-                       ByteArrayOutputStreamWithPos keySerializationStream,
-                       DataOutputView keySerializationDateDataOutputView) 
throws IOException {
-               int length = keySerializationStream.getPosition() - 
fromPosition;
-               writeVariableIntBytes(length, 
keySerializationDateDataOutputView);
-       }
-
-       private static void writeVariableIntBytes(
-                       int value,
-                       DataOutputView keySerializationDateDataOutputView)
-                       throws IOException {
-               do {
-                       keySerializationDateDataOutputView.writeByte(value);
-                       value >>>= 8;
-               } while (value != 0);
+               RocksDBKeySerializationUtils.writeKeyGroup(keyGroup, 
backend.getKeyGroupPrefixBytes(), keySerializationDataOutputView);
+               RocksDBKeySerializationUtils.writeKey(key, 
backend.getKeySerializer(), keySerializationStream, 
keySerializationDataOutputView, ambiguousKeyPossible);
+               RocksDBKeySerializationUtils.writeNameSpace(namespace, 
namespaceSerializer, keySerializationStream, keySerializationDataOutputView, 
ambiguousKeyPossible);
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtils.java
----------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtils.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtils.java
new file mode 100644
index 0000000..1987c11
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtils.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+
+/**
+ * Utils for RocksDB state serialization and deserialization.
+ */
+public class RocksDBKeySerializationUtils {
+
+       public static int readKeyGroup(int keyGroupPrefixBytes, DataInputView 
inputView) throws IOException {
+               int keyGroup = 0;
+               for (int i = 0; i < keyGroupPrefixBytes; ++i) {
+                       keyGroup <<= 8;
+                       keyGroup |= (inputView.readByte() & 0xFF);
+               }
+               return keyGroup;
+       }
+
+       public static <K> K readKey(
+               TypeSerializer<K> keySerializer,
+               ByteArrayInputStreamWithPos inputStream,
+               DataInputView inputView,
+               boolean ambiguousKeyPossible) throws IOException {
+               int beforeRead = inputStream.getPosition();
+               K key = keySerializer.deserialize(inputView);
+               if (ambiguousKeyPossible) {
+                       int length = inputStream.getPosition() - beforeRead;
+                       readVariableIntBytes(inputView, length);
+               }
+               return key;
+       }
+
+       public static <N> N readNamespace(
+               TypeSerializer<N> namespaceSerializer,
+               ByteArrayInputStreamWithPos inputStream,
+               DataInputView inputView,
+               boolean ambiguousKeyPossible) throws IOException {
+               int beforeRead = inputStream.getPosition();
+               N namespace = namespaceSerializer.deserialize(inputView);
+               if (ambiguousKeyPossible) {
+                       int length = inputStream.getPosition() - beforeRead;
+                       readVariableIntBytes(inputView, length);
+               }
+               return namespace;
+       }
+
+       public static <N> void writeNameSpace(
+               N namespace,
+               TypeSerializer<N> namespaceSerializer,
+               ByteArrayOutputStreamWithPos keySerializationStream,
+               DataOutputView keySerializationDataOutputView,
+               boolean ambiguousKeyPossible) throws IOException {
+
+               int beforeWrite = keySerializationStream.getPosition();
+               namespaceSerializer.serialize(namespace, 
keySerializationDataOutputView);
+
+               if (ambiguousKeyPossible) {
+                       //write length of namespace
+                       writeLengthFrom(beforeWrite, keySerializationStream,
+                               keySerializationDataOutputView);
+               }
+       }
+
+       public static boolean isAmbiguousKeyPossible(TypeSerializer 
keySerializer, TypeSerializer namespaceSerializer) {
+               return (keySerializer.getLength() < 0) && 
(namespaceSerializer.getLength() < 0);
+       }
+
+       public static void writeKeyGroup(
+               int keyGroup,
+               int keyGroupPrefixBytes,
+               DataOutputView keySerializationDateDataOutputView) throws 
IOException {
+               for (int i = keyGroupPrefixBytes; --i >= 0; ) {
+                       keySerializationDateDataOutputView.writeByte(keyGroup 
>>> (i << 3));
+               }
+       }
+
+       public static <K> void writeKey(
+               K key,
+               TypeSerializer<K> keySerializer,
+               ByteArrayOutputStreamWithPos keySerializationStream,
+               DataOutputView keySerializationDataOutputView,
+               boolean ambiguousKeyPossible) throws IOException {
+               //write key
+               int beforeWrite = keySerializationStream.getPosition();
+               keySerializer.serialize(key, keySerializationDataOutputView);
+
+               if (ambiguousKeyPossible) {
+                       //write size of key
+                       writeLengthFrom(beforeWrite, keySerializationStream,
+                               keySerializationDataOutputView);
+               }
+       }
+
+       private static void readVariableIntBytes(DataInputView inputView, int 
value) throws IOException {
+               do {
+                       inputView.readByte();
+                       value >>>= 8;
+               } while (value != 0);
+       }
+
+       private static void writeLengthFrom(
+               int fromPosition,
+               ByteArrayOutputStreamWithPos keySerializationStream,
+               DataOutputView keySerializationDateDataOutputView) throws 
IOException {
+               int length = keySerializationStream.getPosition() - 
fromPosition;
+               writeVariableIntBytes(length, 
keySerializationDateDataOutputView);
+       }
+
+       private static void writeVariableIntBytes(
+               int value,
+               DataOutputView keySerializationDateDataOutputView)
+               throws IOException {
+               do {
+                       keySerializationDateDataOutputView.writeByte(value);
+                       value >>>= 8;
+               } while (value != 0);
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 5507339..357a1dc 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -40,6 +40,7 @@ import org.apache.flink.core.fs.FileStatus;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
@@ -94,6 +95,8 @@ import org.rocksdb.Snapshot;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nonnull;
+
 import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
@@ -113,6 +116,8 @@ import java.util.Objects;
 import java.util.PriorityQueue;
 import java.util.Set;
 import java.util.SortedMap;
+import java.util.Spliterator;
+import java.util.Spliterators;
 import java.util.TreeMap;
 import java.util.UUID;
 import java.util.concurrent.Callable;
@@ -258,17 +263,41 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
        @Override
        public <N> Stream<K> getKeys(String state, N namespace) {
-               Tuple2<ColumnFamilyHandle, ?> columnInfo = 
kvStateInformation.get(state);
+               Tuple2<ColumnFamilyHandle, 
RegisteredKeyedBackendStateMetaInfo<?, ?>> columnInfo = 
kvStateInformation.get(state);
                if (columnInfo == null) {
                        return Stream.empty();
                }
 
+               final TypeSerializer<N> namespaceSerializer = 
(TypeSerializer<N>) columnInfo.f1.getNamespaceSerializer();
+               final ByteArrayOutputStreamWithPos namespaceOutputStream = new 
ByteArrayOutputStreamWithPos(8);
+               boolean ambiguousKeyPossible = 
RocksDBKeySerializationUtils.isAmbiguousKeyPossible(keySerializer, 
namespaceSerializer);
+               final byte[] nameSpaceBytes;
+               try {
+                       RocksDBKeySerializationUtils.writeNameSpace(
+                               namespace,
+                               namespaceSerializer,
+                               namespaceOutputStream,
+                               new 
DataOutputViewStreamWrapper(namespaceOutputStream),
+                               ambiguousKeyPossible);
+                       nameSpaceBytes = namespaceOutputStream.toByteArray();
+               } catch (IOException ex) {
+                       throw new FlinkRuntimeException("Failed to get keys 
from RocksDB state backend.", ex);
+               }
+
                RocksIterator iterator = db.newIterator(columnInfo.f0);
                iterator.seekToFirst();
 
-               Iterable<K> iterable = () -> new 
RocksIteratorWrapper<>(iterator, state, keySerializer, keyGroupPrefixBytes);
-               Stream<K> targetStream = 
StreamSupport.stream(iterable.spliterator(), false);
-               return targetStream.onClose(iterator::close);
+               final RocksIteratorWrapper<K> iteratorWrapper = new 
RocksIteratorWrapper<>(iterator, state, keySerializer, keyGroupPrefixBytes,
+                       ambiguousKeyPossible, nameSpaceBytes);
+
+               Stream<K> targetStream = 
StreamSupport.stream(Spliterators.spliteratorUnknownSize(iteratorWrapper, 
Spliterator.ORDERED), false);
+               return targetStream.onClose(iteratorWrapper::close);
+       }
+
+       @VisibleForTesting
+       ColumnFamilyHandle getColumnFamilyHandle(String state) {
+               Tuple2<ColumnFamilyHandle, ?> columnInfo = 
kvStateInformation.get(state);
+               return columnInfo != null ? columnInfo.f0 : null;
        }
 
        /**
@@ -1991,26 +2020,56 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                return count;
        }
 
-       private static class RocksIteratorWrapper<K> implements Iterator<K> {
+       /**
+        * This class is not thread safe.
+        */
+       static class RocksIteratorWrapper<K> implements Iterator<K>, 
AutoCloseable {
                private final RocksIterator iterator;
                private final String state;
                private final TypeSerializer<K> keySerializer;
                private final int keyGroupPrefixBytes;
+               private final byte[] namespaceBytes;
+               private final boolean ambiguousKeyPossible;
+               private K nextKey;
 
                public RocksIteratorWrapper(
                                RocksIterator iterator,
                                String state,
                                TypeSerializer<K> keySerializer,
-                               int keyGroupPrefixBytes) {
+                               int keyGroupPrefixBytes,
+                               boolean ambiguousKeyPossible,
+                               byte[] namespaceBytes) {
                        this.iterator = Preconditions.checkNotNull(iterator);
                        this.state = Preconditions.checkNotNull(state);
                        this.keySerializer = 
Preconditions.checkNotNull(keySerializer);
                        this.keyGroupPrefixBytes = 
Preconditions.checkNotNull(keyGroupPrefixBytes);
+                       this.namespaceBytes = 
Preconditions.checkNotNull(namespaceBytes);
+                       this.nextKey = null;
+                       this.ambiguousKeyPossible = ambiguousKeyPossible;
                }
 
                @Override
                public boolean hasNext() {
-                       return iterator.isValid();
+                       while (nextKey == null && iterator.isValid()) {
+                               try {
+                                       byte[] key = iterator.key();
+                                       if (isMatchingNameSpace(key)) {
+                                               ByteArrayInputStreamWithPos 
inputStream =
+                                                       new 
ByteArrayInputStreamWithPos(key, keyGroupPrefixBytes, key.length - 
keyGroupPrefixBytes);
+                                               DataInputViewStreamWrapper 
dataInput = new DataInputViewStreamWrapper(inputStream);
+                                               K value = 
RocksDBKeySerializationUtils.readKey(
+                                                       keySerializer,
+                                                       inputStream,
+                                                       dataInput,
+                                                       ambiguousKeyPossible);
+                                               nextKey = value;
+                                       }
+                                       iterator.next();
+                               } catch (IOException e) {
+                                       throw new FlinkRuntimeException("Failed 
to access state [" + state + "]", e);
+                               }
+                       }
+                       return nextKey != null;
                }
 
                @Override
@@ -2018,16 +2077,29 @@ public class RocksDBKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                        if (!hasNext()) {
                                throw new NoSuchElementException("Failed to 
access state [" + state + "]");
                        }
-                       try {
-                               byte[] key = iterator.key();
-                                       DataInputViewStreamWrapper dataInput = 
new DataInputViewStreamWrapper(
-                                       new ByteArrayInputStreamWithPos(key, 
keyGroupPrefixBytes, key.length - keyGroupPrefixBytes));
-                               K value = keySerializer.deserialize(dataInput);
-                               iterator.next();
-                               return value;
-                       } catch (IOException e) {
-                               throw new FlinkRuntimeException("Failed to 
access state [" + state + "]", e);
+
+                       K tmpKey = nextKey;
+                       nextKey = null;
+                       return tmpKey;
+               }
+
+               private boolean isMatchingNameSpace(@Nonnull byte[] key) {
+                       final int namespaceBytesLength = namespaceBytes.length;
+                       final int basicLength = namespaceBytesLength + 
keyGroupPrefixBytes;
+                       if (key.length >= basicLength) {
+                               for (int i = 1; i <= namespaceBytesLength; ++i) 
{
+                                       if (key[key.length - i] != 
namespaceBytes[namespaceBytesLength - i]) {
+                                               return false;
+                                       }
+                               }
+                               return true;
                        }
+                       return false;
+               }
+
+               @Override
+               public void close() {
+                       iterator.close();
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtilsTest.java
----------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtilsTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtilsTest.java
new file mode 100644
index 0000000..b1737ed
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBKeySerializationUtilsTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for guarding {@link RocksDBKeySerializationUtils}.
+ */
+public class RocksDBKeySerializationUtilsTest {
+
+       @Test
+       public void testIsAmbiguousKeyPossible() {
+               
Assert.assertFalse(RocksDBKeySerializationUtils.isAmbiguousKeyPossible(
+                       IntSerializer.INSTANCE, StringSerializer.INSTANCE));
+
+               
Assert.assertTrue(RocksDBKeySerializationUtils.isAmbiguousKeyPossible(
+                       StringSerializer.INSTANCE, StringSerializer.INSTANCE));
+       }
+
+       @Test
+       public void testKeyGroupSerializationAndDeserialization() throws 
Exception {
+               ByteArrayOutputStreamWithPos outputStream = new 
ByteArrayOutputStreamWithPos(8);
+               DataOutputView outputView = new 
DataOutputViewStreamWrapper(outputStream);
+
+               for (int keyGroupPrefixBytes = 1; keyGroupPrefixBytes <= 2; 
++keyGroupPrefixBytes) {
+                       for (int orgKeyGroup = 0; orgKeyGroup < 128; 
++orgKeyGroup) {
+                               outputStream.reset();
+                               
RocksDBKeySerializationUtils.writeKeyGroup(orgKeyGroup, keyGroupPrefixBytes, 
outputView);
+                               int deserializedKeyGroup = 
RocksDBKeySerializationUtils.readKeyGroup(
+                                       keyGroupPrefixBytes,
+                                       new DataInputViewStreamWrapper(new 
ByteArrayInputStreamWithPos(outputStream.toByteArray())));
+                               Assert.assertEquals(orgKeyGroup, 
deserializedKeyGroup);
+                       }
+               }
+       }
+
+       @Test
+       public void testKeySerializationAndDeserialization() throws Exception {
+               ByteArrayOutputStreamWithPos outputStream = new 
ByteArrayOutputStreamWithPos(8);
+               DataOutputView outputView = new 
DataOutputViewStreamWrapper(outputStream);
+
+               // test for key
+               for (int orgKey = 0; orgKey < 100; ++orgKey) {
+                       outputStream.reset();
+                       RocksDBKeySerializationUtils.writeKey(orgKey, 
IntSerializer.INSTANCE, outputStream, outputView, false);
+                       ByteArrayInputStreamWithPos inputStream = new 
ByteArrayInputStreamWithPos(outputStream.toByteArray());
+                       int deserializedKey = 
RocksDBKeySerializationUtils.readKey(IntSerializer.INSTANCE, inputStream, new 
DataInputViewStreamWrapper(inputStream), false);
+                       Assert.assertEquals(orgKey, deserializedKey);
+
+                       RocksDBKeySerializationUtils.writeKey(orgKey, 
IntSerializer.INSTANCE, outputStream, outputView, true);
+                       inputStream = new 
ByteArrayInputStreamWithPos(outputStream.toByteArray());
+                       deserializedKey = 
RocksDBKeySerializationUtils.readKey(IntSerializer.INSTANCE, inputStream, new 
DataInputViewStreamWrapper(inputStream), true);
+                       Assert.assertEquals(orgKey, deserializedKey);
+               }
+       }
+
+       @Test
+       public void testNamespaceSerializationAndDeserialization() throws 
Exception {
+               ByteArrayOutputStreamWithPos outputStream = new 
ByteArrayOutputStreamWithPos(8);
+               DataOutputView outputView = new 
DataOutputViewStreamWrapper(outputStream);
+
+               for (int orgNamespace = 0; orgNamespace < 100; ++orgNamespace) {
+                       outputStream.reset();
+                       
RocksDBKeySerializationUtils.writeNameSpace(orgNamespace, 
IntSerializer.INSTANCE, outputStream, outputView, false);
+                       ByteArrayInputStreamWithPos inputStream = new 
ByteArrayInputStreamWithPos(outputStream.toByteArray());
+                       int deserializedNamepsace = 
RocksDBKeySerializationUtils.readNamespace(IntSerializer.INSTANCE, inputStream, 
new DataInputViewStreamWrapper(inputStream), false);
+                       Assert.assertEquals(orgNamespace, 
deserializedNamepsace);
+
+                       
RocksDBKeySerializationUtils.writeNameSpace(orgNamespace, 
IntSerializer.INSTANCE, outputStream, outputView, true);
+                       inputStream = new 
ByteArrayInputStreamWithPos(outputStream.toByteArray());
+                       deserializedNamepsace = 
RocksDBKeySerializationUtils.readNamespace(IntSerializer.INSTANCE, inputStream, 
new DataInputViewStreamWrapper(inputStream), true);
+                       Assert.assertEquals(orgNamespace, 
deserializedNamepsace);
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/eeac022f/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBRocksIteratorWrapperTest.java
----------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBRocksIteratorWrapperTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBRocksIteratorWrapperTest.java
new file mode 100644
index 0000000..98f937a
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBRocksIteratorWrapperTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.filesystem.FsStateBackend;
+
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.rocksdb.ColumnFamilyHandle;
+import org.rocksdb.RocksIterator;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.function.Function;
+
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for the RocksIteratorWrapper.
+ */
+public class RocksDBRocksIteratorWrapperTest {
+
+       @Rule
+       public final TemporaryFolder tmp = new TemporaryFolder();
+
+       @Test
+       public void testIterator() throws Exception{
+
+               // test for keyGroupPrefixBytes == 1 && ambiguousKeyPossible == 
false
+               testIteratorHelper(IntSerializer.INSTANCE, 
StringSerializer.INSTANCE, 128, i -> i);
+
+               // test for keyGroupPrefixBytes == 1 && ambiguousKeyPossible == 
true
+               testIteratorHelper(StringSerializer.INSTANCE, 
StringSerializer.INSTANCE, 128, i -> String.valueOf(i));
+
+               // test for keyGroupPrefixBytes == 2 && ambiguousKeyPossible == 
false
+               testIteratorHelper(IntSerializer.INSTANCE, 
StringSerializer.INSTANCE, 256, i -> i);
+
+               // test for keyGroupPrefixBytes == 2 && ambiguousKeyPossible == 
true
+               testIteratorHelper(StringSerializer.INSTANCE, 
StringSerializer.INSTANCE, 256, i -> String.valueOf(i));
+       }
+
+       <K> void testIteratorHelper(
+               TypeSerializer<K> keySerializer,
+               TypeSerializer namespaceSerializer,
+               int maxKeyGroupNumber,
+               Function<Integer, K> getKeyFunc) throws Exception {
+
+               String testStateName = "aha";
+               String namespace = "ns";
+
+               String dbPath = tmp.newFolder().getAbsolutePath();
+               String checkpointPath = tmp.newFolder().toURI().toString();
+               RocksDBStateBackend backend = new RocksDBStateBackend(new 
FsStateBackend(checkpointPath), true);
+               backend.setDbStoragePath(dbPath);
+
+               Environment env = new DummyEnvironment("TestTask", 1, 0);
+               RocksDBKeyedStateBackend<K> keyedStateBackend = 
(RocksDBKeyedStateBackend<K>) backend.createKeyedStateBackend(
+                       env,
+                       new JobID(),
+                       "Test",
+                       keySerializer,
+                       maxKeyGroupNumber,
+                       new KeyGroupRange(0, maxKeyGroupNumber - 1),
+                       mock(TaskKvStateRegistry.class));
+
+               try {
+                       keyedStateBackend.restore(null);
+                       ValueState<String> testState = 
keyedStateBackend.getPartitionedState(
+                               namespace,
+                               namespaceSerializer,
+                               new ValueStateDescriptor<String>(testStateName, 
String.class));
+
+                       // insert record
+                       for (int i = 0; i < 1000; ++i) {
+                               
keyedStateBackend.setCurrentKey(getKeyFunc.apply(i));
+                               testState.update(String.valueOf(i));
+                       }
+
+                       ByteArrayOutputStreamWithPos outputStream = new 
ByteArrayOutputStreamWithPos(8);
+                       boolean ambiguousKeyPossible = 
RocksDBKeySerializationUtils.isAmbiguousKeyPossible(keySerializer, 
namespaceSerializer);
+                       RocksDBKeySerializationUtils.writeNameSpace(
+                               namespace,
+                               namespaceSerializer,
+                               outputStream,
+                               new DataOutputViewStreamWrapper(outputStream),
+                               ambiguousKeyPossible);
+
+                       byte[] nameSpaceBytes = outputStream.toByteArray();
+
+                       try (
+                               ColumnFamilyHandle handle = 
keyedStateBackend.getColumnFamilyHandle(testStateName);
+                               RocksIterator iterator = 
keyedStateBackend.db.newIterator(handle);
+                               
RocksDBKeyedStateBackend.RocksIteratorWrapper<K> iteratorWrapper = new 
RocksDBKeyedStateBackend.RocksIteratorWrapper(
+                               iterator,
+                               testStateName,
+                               keySerializer,
+                               keyedStateBackend.getKeyGroupPrefixBytes(),
+                               ambiguousKeyPossible,
+                               nameSpaceBytes)) {
+
+                               iterator.seekToFirst();
+
+                               // valid record
+                               List<Integer> fetchedKeys = new 
ArrayList<>(1000);
+                               while (iteratorWrapper.hasNext()) {
+                                       
fetchedKeys.add(Integer.parseInt(iteratorWrapper.next().toString()));
+                               }
+
+                               fetchedKeys.sort(Comparator.comparingInt(a -> 
a));
+                               Assert.assertEquals(1000, fetchedKeys.size());
+
+                               for (int i = 0; i < 1000; ++i) {
+                                       Assert.assertEquals(i, 
fetchedKeys.get(i).intValue());
+                               }
+                       }
+               } finally {
+                       if (keyedStateBackend != null) {
+                               keyedStateBackend.dispose();
+                       }
+               }
+       }
+}

Reply via email to