This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 214edd0  [FLINK-21192][python] Support setting namespace in 
RemoteKeyedStateBackend
214edd0 is described below

commit 214edd0cc79301f6fda82dae69098a12f1e1a892
Author: huangxingbo <[email protected]>
AuthorDate: Fri Jan 29 12:58:20 2021 +0800

    [FLINK-21192][python] Support setting namespace in RemoteKeyedStateBackend
    
    This closes #14800
---
 .../pyflink/fn_execution/beam/beam_operations.py   |   2 +
 flink-python/pyflink/fn_execution/state_impl.py    | 211 +++++++++++++--------
 .../python/beam/BeamPythonFunctionRunner.java      |  49 +++--
 3 files changed, 173 insertions(+), 89 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations.py 
b/flink-python/pyflink/fn_execution/beam/beam_operations.py
index b4276ce..63ff70c 100644
--- a/flink-python/pyflink/fn_execution/beam/beam_operations.py
+++ b/flink-python/pyflink/fn_execution/beam/beam_operations.py
@@ -137,6 +137,7 @@ def _create_user_defined_function_operation(factory, 
transform_proto, consumers,
         keyed_state_backend = RemoteKeyedStateBackend(
             factory.state_handler,
             key_row_coder,
+            None,
             spec.serialized_fn.state_cache_size,
             spec.serialized_fn.map_state_read_cache_size,
             spec.serialized_fn.map_state_write_cache_size)
@@ -154,6 +155,7 @@ def _create_user_defined_function_operation(factory, 
transform_proto, consumers,
         keyed_state_backend = RemoteKeyedStateBackend(
             factory.state_handler,
             key_row_coder,
+            None,
             1000,
             1000,
             1000)
diff --git a/flink-python/pyflink/fn_execution/state_impl.py 
b/flink-python/pyflink/fn_execution/state_impl.py
index fa08485..e8c9714 100644
--- a/flink-python/pyflink/fn_execution/state_impl.py
+++ b/flink-python/pyflink/fn_execution/state_impl.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
################################################################################
 import collections
-from abc import ABC
+from abc import ABC, abstractmethod
 from enum import Enum
 from functools import partial
 
@@ -24,7 +24,7 @@ from apache_beam.coders import coder_impl
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.runners.worker.bundle_processor import 
SynchronousBagRuntimeState
 from apache_beam.transforms import userstate
-from typing import List, Tuple, Any, Iterable, Union
+from typing import List, Tuple, Any, Iterable
 
 from pyflink.datastream import ReduceFunction
 from pyflink.datastream.functions import AggregateFunction
@@ -88,43 +88,71 @@ class SynchronousKvRuntimeState(InternalKvState, ABC):
     Base Class for partitioned State implementation.
     """
 
-    def __init__(self,
-                 internal_state: Union[SynchronousBagRuntimeState,
-                                       'InternalSynchronousMapRuntimeState']):
-        self._internal_state = internal_state
+    def __init__(self, name: str, remote_state_backend: 
'RemoteKeyedStateBackend'):
+        self.name = name
+        self._remote_state_backend = remote_state_backend
+        self._internal_state = None
+        self.namespace = None
 
     def set_current_namespace(self, namespace: N) -> None:
-        raise Exception("This method will be implemented in FLINK-21192")
+        if namespace == self.namespace:
+            return
+        if self.namespace is not None:
+            self._remote_state_backend.cache_internal_state(
+                self._remote_state_backend._encoded_current_key, self)
+        self.namespace = namespace
+        self._internal_state = None
 
+    @abstractmethod
+    def get_internal_state(self):
+        pass
 
-class SynchronousValueRuntimeState(SynchronousKvRuntimeState, 
InternalValueState):
+
+class SynchronousBagKvRuntimeState(SynchronousKvRuntimeState, ABC):
+    """
+    Base Class for State implementation backed by a 
:class:`SynchronousBagRuntimeState`.
+    """
+    def __init__(self, name: str, value_coder, remote_state_backend: 
'RemoteKeyedStateBackend'):
+        super(SynchronousBagKvRuntimeState, self).__init__(name, 
remote_state_backend)
+        self._value_coder = value_coder
+
+    def get_internal_state(self):
+        if self._internal_state is None:
+            self._internal_state = 
self._remote_state_backend._get_internal_bag_state(
+                self.name, self.namespace, self._value_coder)
+        return self._internal_state
+
+
+class SynchronousValueRuntimeState(SynchronousBagKvRuntimeState, 
InternalValueState):
     """
     The runtime ValueState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
     """
 
-    def __init__(self, internal_state: SynchronousBagRuntimeState):
-        super(SynchronousValueRuntimeState, self).__init__(internal_state)
+    def __init__(self, name: str, value_coder, remote_state_backend: 
'RemoteKeyedStateBackend'):
+        super(SynchronousValueRuntimeState, self).__init__(name, value_coder, 
remote_state_backend)
 
     def value(self):
-        for i in self._internal_state.read():
+        for i in self.get_internal_state().read():
             return i
         return None
 
     def update(self, value) -> None:
+        self.get_internal_state()
         self._internal_state.clear()
         self._internal_state.add(value)
 
     def clear(self) -> None:
-        self._internal_state.clear()
+        self.get_internal_state().clear()
 
 
-class SynchronousMergingRuntimeState(SynchronousKvRuntimeState, 
InternalMergingState, ABC):
+class SynchronousMergingRuntimeState(SynchronousBagKvRuntimeState, 
InternalMergingState, ABC):
     """
     Base Class for MergingState implementation.
     """
 
-    def __init__(self, internal_state: SynchronousBagRuntimeState):
-        super(SynchronousMergingRuntimeState, self).__init__(internal_state)
+    def __init__(self, name: str, value_coder, remote_state_backend: 
'RemoteKeyedStateBackend'):
+        super(SynchronousMergingRuntimeState, self).__init__(
+            name, value_coder, remote_state_backend)
 
     def merge_namespaces(self, target: N, sources: Iterable[N]) -> None:
         raise Exception("This method will be implemented in FLINK-21631")
@@ -135,24 +163,24 @@ class 
SynchronousListRuntimeState(SynchronousMergingRuntimeState, InternalListSt
     The runtime ListState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
     """
 
-    def __init__(self, internal_state: SynchronousBagRuntimeState):
-        super(SynchronousListRuntimeState, self).__init__(internal_state)
+    def __init__(self, name: str, value_coder, remote_state_backend: 
'RemoteKeyedStateBackend'):
+        super(SynchronousListRuntimeState, self).__init__(name, value_coder, 
remote_state_backend)
 
     def add(self, v):
-        self._internal_state.add(v)
+        self.get_internal_state().add(v)
 
     def get(self):
-        return self._internal_state.read()
+        return self.get_internal_state().read()
 
     def add_all(self, values):
-        self._internal_state._added_elements.extend(values)
+        self.get_internal_state()._added_elements.extend(values)
 
     def update(self, values):
         self.clear()
         self.add_all(values)
 
     def clear(self):
-        self._internal_state.clear()
+        self.get_internal_state().clear()
 
 
 class SynchronousReducingRuntimeState(SynchronousMergingRuntimeState, 
InternalReducingState):
@@ -160,8 +188,13 @@ class 
SynchronousReducingRuntimeState(SynchronousMergingRuntimeState, InternalRe
     The runtime ReducingState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
     """
 
-    def __init__(self, internal_state: SynchronousBagRuntimeState, 
reduce_function: ReduceFunction):
-        super(SynchronousReducingRuntimeState, self).__init__(internal_state)
+    def __init__(self,
+                 name: str,
+                 value_coder,
+                 remote_state_backend: 'RemoteKeyedStateBackend',
+                 reduce_function: ReduceFunction):
+        super(SynchronousReducingRuntimeState, self).__init__(
+            name, value_coder, remote_state_backend)
         self._reduce_function = reduce_function
 
     def add(self, v):
@@ -173,12 +206,12 @@ class 
SynchronousReducingRuntimeState(SynchronousMergingRuntimeState, InternalRe
             
self._internal_state.add(self._reduce_function.reduce(current_value, v))
 
     def get(self):
-        for i in self._internal_state.read():
+        for i in self.get_internal_state().read():
             return i
         return None
 
     def clear(self):
-        self._internal_state.clear()
+        self.get_internal_state().clear()
 
 
 class SynchronousAggregatingRuntimeState(SynchronousMergingRuntimeState, 
InternalAggregatingState):
@@ -186,8 +219,13 @@ class 
SynchronousAggregatingRuntimeState(SynchronousMergingRuntimeState, Interna
     The runtime AggregatingState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
     """
 
-    def __init__(self, internal_state: SynchronousBagRuntimeState, 
agg_function: AggregateFunction):
-        super(SynchronousAggregatingRuntimeState, 
self).__init__(internal_state)
+    def __init__(self,
+                 name: str,
+                 value_coder,
+                 remote_state_backend: 'RemoteKeyedStateBackend',
+                 agg_function: AggregateFunction):
+        super(SynchronousAggregatingRuntimeState, self).__init__(
+            name, value_coder, remote_state_backend)
         self._agg_function = agg_function
 
     def add(self, v):
@@ -209,12 +247,12 @@ class 
SynchronousAggregatingRuntimeState(SynchronousMergingRuntimeState, Interna
             return self._agg_function.get_result(accumulator)
 
     def _get_accumulator(self):
-        for i in self._internal_state.read():
+        for i in self.get_internal_state().read():
             return i
         return None
 
     def clear(self):
-        self._internal_state.clear()
+        self.get_internal_state().clear()
 
 
 class CachedMapState(LRUCache):
@@ -640,9 +678,9 @@ class InternalSynchronousMapRuntimeState(object):
         self._map_state_handler = map_state_handler
         self._state_key = state_key
         self._map_key_coder = map_key_coder
-        self._map_key_coder_impl = map_key_coder._create_impl()
+        self._map_key_coder_impl = map_key_coder.get_impl()
         self._map_value_coder = map_value_coder
-        self._map_value_coder_impl = map_value_coder._create_impl()
+        self._map_value_coder_impl = map_value_coder.get_impl()
         self._write_cache = dict()
         self._max_write_cache_entries = max_write_cache_entries
         self._is_empty = None
@@ -777,38 +815,50 @@ class InternalSynchronousMapRuntimeState(object):
 
 class SynchronousMapRuntimeState(SynchronousKvRuntimeState, InternalMapState):
 
-    def __init__(self, internal_state: InternalSynchronousMapRuntimeState):
-        super(SynchronousMapRuntimeState, self).__init__(internal_state)
+    def __init__(self,
+                 name: str,
+                 map_key_coder,
+                 map_value_coder,
+                 remote_state_backend: 'RemoteKeyedStateBackend'):
+        super(SynchronousMapRuntimeState, self).__init__(name, 
remote_state_backend)
+        self._map_key_coder = map_key_coder
+        self._map_value_coder = map_value_coder
+
+    def get_internal_state(self):
+        if self._internal_state is None:
+            self._internal_state = 
self._remote_state_backend._get_internal_map_state(
+                self.name, self.namespace, self._map_key_coder, 
self._map_value_coder)
+        return self._internal_state
 
     def get(self, key):
-        return self._internal_state.get(key)
+        return self.get_internal_state().get(key)
 
     def put(self, key, value):
-        self._internal_state.put(key, value)
+        self.get_internal_state().put(key, value)
 
     def put_all(self, dict_value):
-        self._internal_state.put_all(dict_value)
+        self.get_internal_state().put_all(dict_value)
 
     def remove(self, key):
-        self._internal_state.remove(key)
+        self.get_internal_state().remove(key)
 
     def contains(self, key):
-        return self._internal_state.contains(key)
+        return self.get_internal_state().contains(key)
 
     def items(self):
-        return self._internal_state.items()
+        return self.get_internal_state().items()
 
     def keys(self):
-        return self._internal_state.keys()
+        return self.get_internal_state().keys()
 
     def values(self):
-        return self._internal_state.values()
+        return self.get_internal_state().values()
 
     def is_empty(self):
-        return self._internal_state.is_empty()
+        return self.get_internal_state().is_empty()
 
     def clear(self):
-        self._internal_state.clear()
+        self.get_internal_state().clear()
 
 
 class RemoteKeyedStateBackend(object):
@@ -819,6 +869,7 @@ class RemoteKeyedStateBackend(object):
     def __init__(self,
                  state_handler,
                  key_coder,
+                 namespace_coder,
                  state_cache_size,
                  map_state_read_cache_size,
                  map_state_write_cache_size):
@@ -827,6 +878,10 @@ class RemoteKeyedStateBackend(object):
             state_handler, map_state_read_cache_size)
         from pyflink.fn_execution.coders import FlattenRowCoder
         self._key_coder_impl = 
FlattenRowCoder(key_coder._field_coders).get_impl()
+        if namespace_coder:
+            self._namespace_coder_impl = namespace_coder.get_impl()
+        else:
+            self._namespace_coder_impl = None
         self._state_cache_size = state_cache_size
         self._map_state_write_cache_size = map_state_write_cache_size
         self._all_states = {}
@@ -853,8 +908,7 @@ class RemoteKeyedStateBackend(object):
         if name in self._all_states:
             self.validate_map_state(name, map_key_coder, map_value_coder)
             return self._all_states[name]
-        internal_map_state = self._get_internal_map_state(name, map_key_coder, 
map_value_coder)
-        map_state = SynchronousMapRuntimeState(internal_map_state)
+        map_state = SynchronousMapRuntimeState(name, map_key_coder, 
map_value_coder, self)
         self._all_states[name] = map_state
         return map_state
 
@@ -874,7 +928,7 @@ class RemoteKeyedStateBackend(object):
             if not isinstance(state, expected_type):
                 raise Exception("The state name '%s' is already in use and not 
a %s."
                                 % (name, expected_type))
-            if state._internal_state._value_coder != coder:
+            if state._value_coder != coder:
                 raise Exception("State name corrupted: %s" % name)
 
     def validate_map_state(self, name, map_key_coder, map_value_coder):
@@ -883,21 +937,22 @@ class RemoteKeyedStateBackend(object):
             if not isinstance(state, SynchronousMapRuntimeState):
                 raise Exception("The state name '%s' is already in use and not 
a map state."
                                 % name)
-            if state._internal_state._map_key_coder != map_key_coder or \
-                    state._internal_state._map_value_coder != map_value_coder:
+            if state._map_key_coder != map_key_coder or \
+                    state._map_value_coder != map_value_coder:
                 raise Exception("State name corrupted: %s" % name)
 
     def _wrap_internal_bag_state(self, name, element_coder, wrapper_type, 
wrap_method):
         if name in self._all_states:
             self.validate_state(name, element_coder, wrapper_type)
             return self._all_states[name]
-        internal_state = self._get_internal_bag_state(name, element_coder)
-        wrapped_state = wrap_method(internal_state)
+        wrapped_state = wrap_method(name, element_coder, self)
         self._all_states[name] = wrapped_state
         return wrapped_state
 
-    def _get_internal_bag_state(self, name, element_coder):
-        cached_state = self._internal_state_cache.get((name, 
self._encoded_current_key))
+    def _get_internal_bag_state(self, name, namespace, element_coder):
+        encoded_namespace = self._encode_namespace(namespace)
+        cached_state = self._internal_state_cache.get(
+            (name, self._encoded_current_key, encoded_namespace))
         if cached_state is not None:
             return cached_state
         # The created internal state would not be put into the internal state 
cache
@@ -905,17 +960,20 @@ class RemoteKeyedStateBackend(object):
         # The reason is that the state cache size may be smaller that the 
count of activated
         # state (i.e. the state with current key).
         state_spec = userstate.BagStateSpec(name, element_coder)
-        internal_state = self._create_bag_state(state_spec)
+        internal_state = self._create_bag_state(state_spec, encoded_namespace)
         return internal_state
 
-    def _get_internal_map_state(self, name, map_key_coder, map_value_coder):
-        cached_state = self._internal_state_cache.get((name, 
self._encoded_current_key))
+    def _get_internal_map_state(self, name, namespace, map_key_coder, 
map_value_coder):
+        encoded_namespace = self._encode_namespace(namespace)
+        cached_state = self._internal_state_cache.get(
+            (name, self._encoded_current_key, encoded_namespace))
         if cached_state is not None:
             return cached_state
-        internal_map_state = self._create_internal_map_state(name, 
map_key_coder, map_value_coder)
+        internal_map_state = self._create_internal_map_state(
+            name, encoded_namespace, map_key_coder, map_value_coder)
         return internal_map_state
 
-    def _create_bag_state(self, state_spec: userstate.StateSpec) \
+    def _create_bag_state(self, state_spec: userstate.StateSpec, 
encoded_namespace) \
             -> userstate.AccumulatingRuntimeState:
         if isinstance(state_spec, userstate.BagStateSpec):
             bag_state = SynchronousBagRuntimeState(
@@ -923,6 +981,7 @@ class RemoteKeyedStateBackend(object):
                 state_key=beam_fn_api_pb2.StateKey(
                     bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                         transform_id="",
+                        window=encoded_namespace,
                         user_state_id=state_spec.name,
                         key=self._encoded_current_key)),
                 value_coder=state_spec.coder)
@@ -930,12 +989,13 @@ class RemoteKeyedStateBackend(object):
         else:
             raise NotImplementedError(state_spec)
 
-    def _create_internal_map_state(self, name, map_key_coder, map_value_coder):
+    def _create_internal_map_state(self, name, encoded_namespace, 
map_key_coder, map_value_coder):
         # Currently the `beam_fn_api.proto` does not support MapState, so we 
use the
         # the `MultimapSideInput` message to mark the state as a MapState for 
now.
         state_key = beam_fn_api_pb2.StateKey(
             multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                 transform_id="",
+                window=encoded_namespace,
                 side_input_id=name,
                 key=self._encoded_current_key))
         return InternalSynchronousMapRuntimeState(
@@ -945,6 +1005,19 @@ class RemoteKeyedStateBackend(object):
             map_value_coder,
             self._map_state_write_cache_size)
 
+    def _encode_namespace(self, namespace):
+        if namespace is not None:
+            encoded_namespace = 
self._namespace_coder_impl.encode_nested(namespace)
+        else:
+            encoded_namespace = b''
+        return encoded_namespace
+
+    def cache_internal_state(self, encoded_key, internal_kv_state: 
SynchronousKvRuntimeState):
+        encoded_old_namespace = 
self._encode_namespace(internal_kv_state.namespace)
+        self._internal_state_cache.put(
+            (internal_kv_state.name, encoded_key, encoded_old_namespace),
+            internal_kv_state.get_internal_state())
+
     def set_current_key(self, key):
         if key == self._current_key:
             return
@@ -954,22 +1027,9 @@ class RemoteKeyedStateBackend(object):
         for state_name, state_obj in self._all_states.items():
             if self._state_cache_size > 0:
                 # cache old internal state
-                self._internal_state_cache.put(
-                    (state_name, encoded_old_key), state_obj._internal_state)
-            if isinstance(state_obj,
-                          (SynchronousValueRuntimeState,
-                           SynchronousListRuntimeState,
-                           SynchronousReducingRuntimeState,
-                           SynchronousAggregatingRuntimeState)):
-                state_obj._internal_state = self._get_internal_bag_state(
-                    state_name, state_obj._internal_state._value_coder)
-            elif isinstance(state_obj, SynchronousMapRuntimeState):
-                state_obj._internal_state = self._get_internal_map_state(
-                    state_name,
-                    state_obj._internal_state._map_key_coder,
-                    state_obj._internal_state._map_value_coder)
-            else:
-                raise Exception("Unknown internal state '%s': %s" % 
(state_name, state_obj))
+                self.cache_internal_state(encoded_old_key, state_obj)
+            state_obj.namespace = None
+            state_obj._internal_state = None
 
     def get_current_key(self):
         return self._current_key
@@ -988,7 +1048,8 @@ class RemoteKeyedStateBackend(object):
 
     @staticmethod
     def commit_internal_state(internal_state):
-        internal_state.commit()
+        if internal_state is not None:
+            internal_state.commit()
         # reset the status of the internal state to reuse the object cross 
bundle
         if isinstance(internal_state, SynchronousBagRuntimeState):
             internal_state._cleared = False
diff --git 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
index def9057..c629972 100644
--- 
a/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
+++ 
b/flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
@@ -206,7 +206,7 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
         this.jobOptions = Preconditions.checkNotNull(jobOptions);
         this.flinkMetricContainer = flinkMetricContainer;
         this.stateRequestHandler =
-                getStateRequestHandler(keyedStateBackend, keySerializer, 
jobOptions);
+                getStateRequestHandler(keyedStateBackend, keySerializer, null, 
jobOptions);
         this.memoryManager = memoryManager;
         this.managedMemoryFraction = managedMemoryFraction;
         this.resultTuple = new Tuple2<>();
@@ -545,12 +545,14 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
     private static StateRequestHandler getStateRequestHandler(
             KeyedStateBackend keyedStateBackend,
             TypeSerializer keySerializer,
+            TypeSerializer namespaceSerializer,
             Map<String, String> jobOptions) {
         if (keyedStateBackend == null) {
             return StateRequestHandler.unsupported();
         } else {
             assert keySerializer != null;
-            return new SimpleStateRequestHandler(keyedStateBackend, 
keySerializer, jobOptions);
+            return new SimpleStateRequestHandler(
+                    keyedStateBackend, keySerializer, namespaceSerializer, 
jobOptions);
         }
     }
 
@@ -626,6 +628,7 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                         .setData(ByteString.copyFrom(new byte[] 
{NOT_EMPTY_FLAG}));
 
         private final TypeSerializer keySerializer;
+        private final TypeSerializer namespaceSerializer;
         private final TypeSerializer<byte[]> valueSerializer;
         private final KeyedStateBackend keyedStateBackend;
 
@@ -657,6 +660,7 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
         SimpleStateRequestHandler(
                 KeyedStateBackend keyedStateBackend,
                 TypeSerializer keySerializer,
+                TypeSerializer namespaceSerializer,
                 Map<String, String> config) {
             this.keyedStateBackend = keyedStateBackend;
             TypeSerializer frameworkKeySerializer = 
keyedStateBackend.getKeySerializer();
@@ -666,6 +670,7 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                         "Currently SimpleStateRequestHandler only support row 
key!");
             }
             this.keySerializer = keySerializer;
+            this.namespaceSerializer = namespaceSerializer;
             this.valueSerializer =
                     
PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.createSerializer(
                             new ExecutionConfig());
@@ -814,12 +819,20 @@ public abstract class BeamPythonFunctionRunner implements 
PythonFunctionRunner {
                                         + "'%s' is used both as LIST state and 
'%s' state at the same time.",
                                 stateName, cachedStateDescriptor.getType()));
             }
-
-            return (ListState<byte[]>)
-                    keyedStateBackend.getPartitionedState(
-                            VoidNamespace.INSTANCE,
-                            VoidNamespaceSerializer.INSTANCE,
-                            listStateDescriptor);
+            byte[] windowBytes = bagUserState.getWindow().toByteArray();
+            if (windowBytes.length != 0) {
+                bais.setBuffer(windowBytes, 0, windowBytes.length);
+                Object namespace = 
namespaceSerializer.deserialize(baisWrapper);
+                return (ListState<byte[]>)
+                        keyedStateBackend.getPartitionedState(
+                                namespace, namespaceSerializer, 
listStateDescriptor);
+            } else {
+                return (ListState<byte[]>)
+                        keyedStateBackend.getPartitionedState(
+                                VoidNamespace.INSTANCE,
+                                VoidNamespaceSerializer.INSTANCE,
+                                listStateDescriptor);
+            }
         }
 
         private CompletionStage<BeamFnApi.StateResponse.Builder> 
handleMapState(
@@ -1106,12 +1119,20 @@ public abstract class BeamPythonFunctionRunner 
implements PythonFunctionRunner {
                                         + "'%s' is used both as MAP state and 
'%s' state at the same time.",
                                 stateName, cachedStateDescriptor.getType()));
             }
-
-            return (MapState<ByteArrayWrapper, byte[]>)
-                    keyedStateBackend.getPartitionedState(
-                            VoidNamespace.INSTANCE,
-                            VoidNamespaceSerializer.INSTANCE,
-                            mapStateDescriptor);
+            byte[] windowBytes = mapUserState.getWindow().toByteArray();
+            if (windowBytes.length != 0) {
+                bais.setBuffer(windowBytes, 0, windowBytes.length);
+                Object namespace = 
namespaceSerializer.deserialize(baisWrapper);
+                return (MapState<ByteArrayWrapper, byte[]>)
+                        keyedStateBackend.getPartitionedState(
+                                namespace, namespaceSerializer, 
mapStateDescriptor);
+            } else {
+                return (MapState<ByteArrayWrapper, byte[]>)
+                        keyedStateBackend.getPartitionedState(
+                                VoidNamespace.INSTANCE,
+                                VoidNamespaceSerializer.INSTANCE,
+                                mapStateDescriptor);
+            }
         }
 
         private BeamFnApi.ProcessBundleRequest.CacheToken createCacheToken() {

Reply via email to