This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch release-1.15
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.15 by this push:
new cdecc21cad9 [FLINK-30168][python] Fix DataStream.execute_and_collect
to support None data and ObjectArray
cdecc21cad9 is described below
commit cdecc21cad9f78b1555a0e2f5d7f1398949e7193
Author: Dian Fu <[email protected]>
AuthorDate: Fri Jan 13 16:56:51 2023 +0800
[FLINK-30168][python] Fix DataStream.execute_and_collect to support None
data and ObjectArray
This closes #21664.
---
.../pyflink/datastream/tests/test_data_stream.py | 22 ++++++++++++++++++++--
flink-python/pyflink/datastream/utils.py | 5 ++++-
.../flink/api/common/python/PythonBridgeUtils.java | 11 ++++++++---
3 files changed, 32 insertions(+), 6 deletions(-)
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index e3f05936ed1..3355cd64bcb 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -57,8 +57,10 @@ class DataStreamTests(object):
self.test_sink.clear()
def assert_equals_sorted(self, expected, actual):
- expected.sort()
- actual.sort()
+ # otherwise, it may thrown exceptions such as the following:
+ # TypeError: '<' not supported between instances of 'NoneType' and
'str'
+ expected.sort(key=lambda x: str(x))
+ actual.sort(key=lambda x: str(x))
self.assertEqual(expected, actual)
def test_basic_operations(self):
@@ -1144,6 +1146,22 @@ class StreamingModeDataStreamTests(DataStreamTests,
PyFlinkStreamingTestCase):
expected = ['+I[a, 0]', '+I[ab, 0]', '+I[c, 1]', '+I[cd, 1]', '+I[cde,
1]']
self.assert_equals_sorted(expected, results)
+ test_data = [
+ (["test", "test"], [0.0, 0.0]),
+ ([None, ], [0.0, 0.0])
+ ]
+
+ ds = self.env.from_collection(
+ test_data,
+ type_info=Types.TUPLE(
+ [Types.OBJECT_ARRAY(Types.STRING()),
Types.OBJECT_ARRAY(Types.DOUBLE())]
+ )
+ )
+ expected = test_data
+ with ds.execute_and_collect() as results:
+ actual = [result for result in results]
+ self.assert_equals_sorted(expected, actual)
+
def test_function_with_error(self):
ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1),
('e', 1)],
type_info=Types.ROW([Types.STRING(),
Types.INT()]))
diff --git a/flink-python/pyflink/datastream/utils.py
b/flink-python/pyflink/datastream/utils.py
index cddc3697f22..fe6b689fa44 100644
--- a/flink-python/pyflink/datastream/utils.py
+++ b/flink-python/pyflink/datastream/utils.py
@@ -36,7 +36,10 @@ def convert_to_python_obj(data, type_info):
pickle_bytes = gateway.jvm.PythonBridgeUtils. \
getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
if isinstance(type_info, RowTypeInfo) or isinstance(type_info,
TupleTypeInfo):
- field_data = zip(list(pickle_bytes[1:]),
type_info.get_field_types())
+ if isinstance(type_info, RowTypeInfo):
+ field_data = zip(list(pickle_bytes[1:]),
type_info.get_field_types())
+ else:
+ field_data = zip(pickle_bytes, type_info.get_field_types())
fields = []
for data, field_type in field_data:
if len(data) == 0:
diff --git
a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
index ebf4f346de3..b0bd0ab45ee 100644
---
a/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
+++
b/flink-python/src/main/java/org/apache/flink/api/common/python/PythonBridgeUtils.java
@@ -28,6 +28,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.typeutils.ListTypeInfo;
import org.apache.flink.api.java.typeutils.MapTypeInfo;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
@@ -241,7 +242,7 @@ public final class PythonBridgeUtils {
Pickler pickler = new Pickler();
initialize();
if (obj == null) {
- return new byte[0];
+ return pickler.dumps(null);
} else {
if (dataType instanceof SqlTimeTypeInfo) {
SqlTimeTypeInfo<?> sqlTimeTypeInfo =
@@ -270,15 +271,19 @@ public final class PythonBridgeUtils {
}
return fieldBytes;
} else if (dataType instanceof BasicArrayTypeInfo
- || dataType instanceof PrimitiveArrayTypeInfo) {
+ || dataType instanceof PrimitiveArrayTypeInfo
+ || dataType instanceof ObjectArrayTypeInfo) {
Object[] objects;
TypeInformation<?> elementType;
if (dataType instanceof BasicArrayTypeInfo) {
objects = (Object[]) obj;
elementType = ((BasicArrayTypeInfo<?, ?>)
dataType).getComponentInfo();
- } else {
+ } else if (dataType instanceof PrimitiveArrayTypeInfo) {
objects = primitiveArrayConverter(obj, dataType);
elementType = ((PrimitiveArrayTypeInfo<?>)
dataType).getComponentType();
+ } else {
+ objects = (Object[]) obj;
+ elementType = ((ObjectArrayTypeInfo<?, ?>)
dataType).getComponentInfo();
}
List<Object> serializedElements = new
ArrayList<>(objects.length);
for (Object object : objects) {