dianfu commented on a change in pull request #13739:
URL: https://github.com/apache/flink/pull/13739#discussion_r509978593



##########
File path: 
flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
##########
@@ -863,6 +965,84 @@ private static StateRequestHandler getStateRequestHandler(
                        }
                }
 
+               private BeamFnApi.StateGetResponse.Builder 
handleMapIterateRequest(
+                               MapState<ByteArrayWrapper, byte[]> mapState,
+                               IterateType iterateType,
+                               ByteArrayWrapper iteratorToken) throws 
Exception {
+                       final Iterator iterator;
+                       if (iteratorToken == null) {
+                               switch (iterateType) {
+                                       case ITEMS:
+                                       case VALUES:
+                                               iterator = mapState.iterator();
+                                               break;
+                                       case KEYS:
+                                               iterator = 
mapState.keys().iterator();
+                                               break;
+                                       default:
+                                               throw new 
RuntimeException("Unsupported iterate type: " + iterateType);
+                               }
+                       } else {
+                               iterator = 
mapStateIteratorCache.get(iteratorToken);
+                               if (iterator == null) {
+                                       throw new RuntimeException("The cached 
iterator not exist!");

Review comment:
       ```suggestion
                                        throw new RuntimeException("The cached 
iterator does not exist!");
   ```

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -301,6 +380,124 @@ def _convert_to_cache_key(state_key):
         return state_key.SerializeToString()
 
 
+class RemovableIterator(collections.Iterator):
+
+    def __init__(self, internal_map_state, iterate_type):
+        self._internal_map_state = internal_map_state
+        self._mod_count = internal_map_state._mod_count
+        self._underlying = iter(self._internal_map_state._write_cache.items())
+        self._underlying_is_write_cache = True
+        self._underlying_is_read_cache = False
+        self._iterator_token = None
+        self._cached_map_state = None
+        self._last_key = None
+        self._iterate_type = iterate_type
+        self._removed_keys = set()
+        if self._iterate_type == IterateType.KEYS:
+            self._get_from_cache = self._get_key_from_cache
+            self._get_from_data = self._get_key_from_data
+        elif self._iterate_type == IterateType.VALUES:
+            self._get_from_cache = self._get_value_from_cache
+            self._get_from_data = self._get_value_from_data
+        else:
+            self._get_from_cache = self._get_item_from_cache
+            self._get_from_data = self._get_item_from_data
+
+    def __next__(self):
+        self._check_modification()
+        if self._underlying_is_write_cache:
+            # Iterate the data in write cache firstly
+            try:
+                key, existed_and_value = next(self._underlying)
+                while not existed_and_value[0]:
+                    key, existed_and_value = next(self._underlying)
+                return self._get_from_cache(key, existed_and_value)
+            except StopIteration:
+                self._underlying = self._next_batch()
+                self._underlying_is_write_cache = False
+        if self._underlying_is_read_cache:
+            # If the read cache contains all data (except the data in write 
cache) of the map state
+            self._last_key, existed_and_value = next(self._underlying)
+            while not existed_and_value[0] or \
+                    self._last_key in self._internal_map_state._write_cache:
+                key, existed_and_value = next(self._underlying)
+            return self._get_from_cache(self._last_key, existed_and_value)
+        else:
+            try:
+                # The data is from Java side
+                next_value = self._get_from_data(next(self._underlying))
+                while self._last_key in self._internal_map_state._write_cache:
+                    next_value = self._get_from_data(next(self._underlying))
+                return next_value
+            except StopIteration:
+                if self._iterator_token is None:
+                    raise
+                else:
+                    self._underlying = self._next_batch()
+                    return self.__next__()
+
+    def remove(self):
+        """
+        Remove the the last element returned by this iterator.
+        """
+        if self._last_key is None:
+            raise Exception("You need to call the '__next__' method before 
calling "
+                            "this method.")
+        self._check_modification()
+        # Bypass the 'remove' method of the map state to avoid triggering the 
commit of the write
+        # cache.
+        if self._internal_map_state._cleared:
+            del self._internal_map_state._write_cache[self._last_key]
+            self._mod_count += 1

Review comment:
       Could refactor a bit to remove the duplicate code

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -301,6 +380,124 @@ def _convert_to_cache_key(state_key):
         return state_key.SerializeToString()
 
 
+class RemovableIterator(collections.Iterator):
+
+    def __init__(self, internal_map_state, iterate_type):
+        self._internal_map_state = internal_map_state
+        self._mod_count = internal_map_state._mod_count
+        self._underlying = iter(self._internal_map_state._write_cache.items())
+        self._underlying_is_write_cache = True
+        self._underlying_is_read_cache = False
+        self._iterator_token = None
+        self._cached_map_state = None
+        self._last_key = None
+        self._iterate_type = iterate_type
+        self._removed_keys = set()
+        if self._iterate_type == IterateType.KEYS:
+            self._get_from_cache = self._get_key_from_cache
+            self._get_from_data = self._get_key_from_data
+        elif self._iterate_type == IterateType.VALUES:
+            self._get_from_cache = self._get_value_from_cache
+            self._get_from_data = self._get_value_from_data
+        else:
+            self._get_from_cache = self._get_item_from_cache
+            self._get_from_data = self._get_item_from_data
+
+    def __next__(self):
+        self._check_modification()
+        if self._underlying_is_write_cache:
+            # Iterate the data in write cache firstly
+            try:
+                key, existed_and_value = next(self._underlying)

Review comment:
       ```suggestion
                   self._last_key, (exists, value) = next(self._underlying)
   ```

##########
File path: flink-python/src/main/java/org/apache/flink/python/PythonOptions.java
##########
@@ -164,11 +164,22 @@
         */
        @Experimental
        public static final ConfigOption<Integer> MAP_STATE_WRITE_CACHE_SIZE = 
ConfigOptions
-               .key("python.map-state.write.cache.size")
+               .key("python.map-state.write-cache-size")
                .defaultValue(1000)
                .withDescription("The maximum number of cached write requests 
for a single Python " +
                        "MapState. The write requests will be flushed to the 
state backend (managed in " +
                        "the Java operator) when the number of cached write 
requests exceed this limit. " +
                        "Note that this is an experimental flag and might not 
be available in future " +
                        "releases.");
+
+       /**
+        * The maximum number of write requests cached in a Python MapState.
+        */
+       @Experimental
+       public static final ConfigOption<Integer> MAP_STATE_ITERATE_CACHE_SIZE 
= ConfigOptions
+               .key("python.map-state.iterate-cache-size")
+               .defaultValue(1000)
+               .withDescription("The maximum number of entries read from Java 
side when iterating a " +

Review comment:
       What about improving the doc a bit to make it more clear what this 
config option is used for?

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -395,6 +619,17 @@ def commit(self):
         self._write_cache.clear()
         self._cleared = False
 
+    def _iterate_next_batch(self, iterate_type, iterator_token):

Review comment:
       ```suggestion
       def next_batch(self, iterate_type, iterator_token):
   ```

##########
File path: docs/_includes/generated/python_configuration.html
##########
@@ -57,13 +57,19 @@
             <td>If set, the Python worker will configure itself to use the 
managed memory budget of the task slot. Otherwise, it will use the Off-Heap 
Memory of the task slot. In this case, users should set the Task Off-Heap 
Memory using the configuration key taskmanager.memory.task.off-heap.size.</td>
         </tr>
         <tr>
-            <td><h5>python.map-state.read.cache.size</h5></td>
+            <td><h5>python.map-state.iterate-cache-size</h5></td>

Review comment:
       python.map-state.iterate-req-batch-size




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to