This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 0f03fb43a0649814fd93cc29aa4d993c3bb635a7
Author: huangxingbo <[email protected]>
AuthorDate: Tue Jul 27 16:03:40 2021 +0800

    [FLINK-22911][python] Optimize the output format of RowTypeInfo and 
ExternalTypeInfo
    
    This closes #16611.
---
 flink-python/pyflink/common/types.py                   |  6 ++++++
 flink-python/pyflink/datastream/data_stream.py         | 17 +++++++++++++++++
 .../pyflink/datastream/tests/test_data_stream.py       |  2 +-
 flink-python/pyflink/datastream/utils.py               | 18 +++++++++++++-----
 4 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/flink-python/pyflink/common/types.py 
b/flink-python/pyflink/common/types.py
index 19187af..0542fb0 100644
--- a/flink-python/pyflink/common/types.py
+++ b/flink-python/pyflink/common/types.py
@@ -157,6 +157,12 @@ class Row(object):
     def _is_accumulate_msg(self):
         return self._row_kind == RowKind.UPDATE_AFTER or self._row_kind == 
RowKind.INSERT
 
+    @staticmethod
+    def of_kind(row_kind: RowKind, *args, **kwargs):
+        row = Row(*args, **kwargs)
+        row.set_row_kind(row_kind)
+        return row
+
     def __contains__(self, item):
         return item in self._values
 
diff --git a/flink-python/pyflink/datastream/data_stream.py 
b/flink-python/pyflink/datastream/data_stream.py
index 04bbf7e..63d8891 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -762,6 +762,13 @@ class DataStream(object):
         """
         Transform the pickled python object into String if the output type is 
PickledByteArrayInfo.
         """
+        from py4j.java_gateway import get_java_class
+
+        gateway = get_gateway()
+        ExternalTypeInfo_CLASS = get_java_class(
+            
gateway.jvm.org.apache.flink.table.runtime.typeutils.ExternalTypeInfo)
+        RowTypeInfo_CLASS = get_java_class(
+            gateway.jvm.org.apache.flink.api.java.typeutils.RowTypeInfo)
         output_type_info_class = 
self._j_data_stream.getTransformation().getOutputType().getClass()
         if output_type_info_class.isAssignableFrom(
                 Types.PICKLED_BYTE_ARRAY().get_java_type_info()
@@ -775,6 +782,16 @@ class DataStream(object):
                 self.map(python_obj_to_str_map_func,
                          output_type=Types.STRING())._j_data_stream)
             return transformed_data_stream
+        elif (output_type_info_class.isAssignableFrom(ExternalTypeInfo_CLASS) 
or
+              output_type_info_class.isAssignableFrom(RowTypeInfo_CLASS)):
+            def python_obj_to_str_map_func(value):
+                assert isinstance(value, Row)
+                return '{}[{}]'.format(value.get_row_kind(),
+                                       ','.join([str(item) for item in 
value._values]))
+            transformed_data_stream = DataStream(
+                self.map(python_obj_to_str_map_func,
+                         output_type=Types.STRING())._j_data_stream)
+            return transformed_data_stream
         else:
             return self
 
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 384b397..98efbe0 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -561,7 +561,7 @@ class DataStreamTests(object):
                                       type_info=Types.ROW([Types.STRING(), 
Types.INT()]))
         ds.print()
         plan = eval(str(self.env.get_execution_plan()))
-        self.assertEqual("Sink: Print to Std. Out", plan['nodes'][1]['type'])
+        self.assertEqual("Sink: Print to Std. Out", plan['nodes'][2]['type'])
 
     def test_print_with_align_output(self):
         # need to align output type before print, therefore the plan will 
contain three nodes
diff --git a/flink-python/pyflink/datastream/utils.py 
b/flink-python/pyflink/datastream/utils.py
index 10bc752..cddc369 100644
--- a/flink-python/pyflink/datastream/utils.py
+++ b/flink-python/pyflink/datastream/utils.py
@@ -19,15 +19,18 @@ import ast
 import datetime
 import pickle
 
-from pyflink.common.typeinfo import RowTypeInfo, TupleTypeInfo, Types, \
-    BasicArrayTypeInfo, \
-    PrimitiveArrayTypeInfo, MapTypeInfo, ListTypeInfo, ObjectArrayTypeInfo
+from pyflink.common import Row, RowKind
+from pyflink.common.typeinfo import (RowTypeInfo, TupleTypeInfo, Types,  
BasicArrayTypeInfo,
+                                     PrimitiveArrayTypeInfo, MapTypeInfo, 
ListTypeInfo,
+                                     ObjectArrayTypeInfo, ExternalTypeInfo)
 from pyflink.java_gateway import get_gateway
 
 
 def convert_to_python_obj(data, type_info):
     if type_info == Types.PICKLED_BYTE_ARRAY():
         return pickle.loads(data)
+    elif isinstance(type_info, ExternalTypeInfo):
+        return convert_to_python_obj(data, type_info._type_info)
     else:
         gateway = get_gateway()
         pickle_bytes = gateway.jvm.PythonBridgeUtils. \
@@ -40,18 +43,23 @@ def convert_to_python_obj(data, type_info):
                     fields.append(None)
                 else:
                     fields.append(pickled_bytes_to_python_converter(data, 
field_type))
-            return tuple(fields)
+            if isinstance(type_info, RowTypeInfo):
+                return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 
'little')), *fields)
+            else:
+                return tuple(fields)
         else:
             return pickled_bytes_to_python_converter(pickle_bytes, type_info)
 
 
 def pickled_bytes_to_python_converter(data, field_type):
     if isinstance(field_type, RowTypeInfo):
+        row_kind = RowKind(int.from_bytes(data[0], 'little'))
         data = zip(list(data[1:]), field_type.get_field_types())
         fields = []
         for d, d_type in data:
             fields.append(pickled_bytes_to_python_converter(d, d_type))
-        return tuple(fields)
+        row = Row.of_kind(row_kind, *fields)
+        return row
     else:
         data = pickle.loads(data)
         if field_type == Types.SQL_TIME():

Reply via email to