Repository: flink Updated Branches: refs/heads/master 946e8f648 -> 30647a2e6
[FLINK-2432] Custom serializer support This closes #962 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/30647a2e Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/30647a2e Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/30647a2e Branch: refs/heads/master Commit: 30647a2e6b74563e441947ad2f5726d9627251c2 Parents: 946e8f6 Author: zentol <ches...@apache.org> Authored: Fri Nov 13 14:38:05 2015 +0100 Committer: zentol <ches...@apache.org> Committed: Fri Nov 13 15:53:34 2015 +0100 ---------------------------------------------------------------------- .../flink/python/api/streaming/Receiver.java | 20 ++- .../flink/python/api/streaming/Sender.java | 18 ++- .../python/api/types/CustomTypeWrapper.java | 34 +++++ .../python/api/flink/connection/Collector.py | 38 ++++-- .../python/api/flink/connection/Iterator.py | 123 +++++++------------ .../api/flink/functions/CoGroupFunction.py | 8 +- .../python/api/flink/functions/Function.py | 6 +- .../api/flink/functions/GroupReduceFunction.py | 10 +- .../api/flink/functions/ReduceFunction.py | 10 +- .../flink/python/api/flink/plan/Constants.py | 7 ++ .../flink/python/api/flink/plan/Environment.py | 21 +++- .../org/apache/flink/python/api/test_main.py | 28 ++++- 12 files changed, 215 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Receiver.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Receiver.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Receiver.java index 07698d3..a706053 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Receiver.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Receiver.java @@ -34,6 +34,7 @@ import static org.apache.flink.python.api.streaming.Sender.TYPE_NULL; import static org.apache.flink.python.api.streaming.Sender.TYPE_SHORT; import static org.apache.flink.python.api.streaming.Sender.TYPE_STRING; import static org.apache.flink.python.api.streaming.Sender.TYPE_TUPLE; +import org.apache.flink.python.api.types.CustomTypeWrapper; import org.apache.flink.util.Collector; /** @@ -192,7 +193,7 @@ public class Receiver implements Serializable { case TYPE_NULL: return null; default: - throw new IllegalArgumentException("Unknown TypeID encountered: " + type); + return new CustomTypeDeserializer(type).deserialize(); } } @@ -245,14 +246,29 @@ public class Receiver implements Serializable { case TYPE_NULL: return new NullDeserializer(); default: - throw new IllegalArgumentException("Unknown TypeID encountered: " + type); + return new CustomTypeDeserializer(type); } } private interface Deserializer<T> { public T deserialize(); + } + + private class CustomTypeDeserializer implements Deserializer<CustomTypeWrapper> { + private final byte type; + + public CustomTypeDeserializer(byte type) { + this.type = type; + } + @Override + public CustomTypeWrapper deserialize() { + int size = fileBuffer.getInt(); + byte[] data = new byte[size]; + fileBuffer.get(data); + return new CustomTypeWrapper(type, data); + } } private class BooleanDeserializer implements Deserializer<Boolean> { http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Sender.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Sender.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Sender.java index b897f12..2db1441 100644 --- a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Sender.java +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/streaming/Sender.java @@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.AbstractRichFunction; import org.apache.flink.api.java.tuple.Tuple; import static org.apache.flink.python.api.PythonPlanBinder.FLINK_TMP_DATA_DIR; import static org.apache.flink.python.api.PythonPlanBinder.MAPPED_FILE_SIZE; +import org.apache.flink.python.api.types.CustomTypeWrapper; /** * General-purpose class to write data to memory-mapped files. @@ -180,7 +181,7 @@ public class Sender implements Serializable { } private enum SupportedTypes { - TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL + TUPLE, BOOLEAN, BYTE, BYTES, CHARACTER, SHORT, INTEGER, LONG, FLOAT, DOUBLE, STRING, OTHER, NULL, CUSTOMTYPEWRAPPER } //=====Serializer=================================================================================================== @@ -231,6 +232,9 @@ public class Sender implements Serializable { case NULL: fileBuffer.put(TYPE_NULL); return new NullSerializer(); + case CUSTOMTYPEWRAPPER: + fileBuffer.put(((CustomTypeWrapper) value).getType()); + return new CustomTypeSerializer(); default: throw new IllegalArgumentException("Unknown Type encountered: " + type); } @@ -253,6 +257,18 @@ public class Sender implements Serializable { public abstract void serializeInternal(T value); } + private class CustomTypeSerializer extends Serializer<CustomTypeWrapper> { + public CustomTypeSerializer() { + super(0); + } + @Override + public void serializeInternal(CustomTypeWrapper value) { + byte[] bytes = value.getData(); + buffer = ByteBuffer.wrap(bytes); + buffer.position(bytes.length); + } + } + private class ByteSerializer extends Serializer<Byte> { public ByteSerializer() { super(1); http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/types/CustomTypeWrapper.java ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/types/CustomTypeWrapper.java b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/types/CustomTypeWrapper.java new file mode 100644 index 0000000..e16c3eb --- /dev/null +++ b/flink-libraries/flink-python/src/main/java/org/apache/flink/python/api/types/CustomTypeWrapper.java @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.apache.flink.python.api.types; + +/** + * Container for serialized python objects, generally assumed to be custom objects. + */ +public class CustomTypeWrapper { + private final byte typeID; + private final byte[] data; + + public CustomTypeWrapper(byte typeID, byte[] data) { + this.typeID = typeID; + this.data = data; + } + + public byte getType() { + return typeID; + } + + public byte[] getData() { + return data; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py index bf35756..b5674b9 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Collector.py @@ -19,6 +19,7 @@ from struct import pack import sys from flink.connection.Constants import Types +from flink.plan.Constants import _Dummy PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 @@ -30,15 +31,16 @@ else: class Collector(object): - def __init__(self, con): + def __init__(self, con, env): self._connection = con self._serializer = None + self._env = env def _close(self): self._connection.send_end_signal() def collect(self, value): - self._serializer = _get_serializer(self._connection.write, value) + self._serializer = _get_serializer(self._connection.write, value, self._env._types) self.collect = self._collect self.collect(value) @@ -46,11 +48,11 @@ class Collector(object): self._connection.write(self._serializer.serialize(value)) -def _get_serializer(write, value): +def _get_serializer(write, value, custom_types): if isinstance(value, (list, tuple)): write(Types.TYPE_TUPLE) write(pack(">I", len(value))) - return TupleSerializer(write, value) + return TupleSerializer(write, value, custom_types) elif value is None: write(Types.TYPE_NULL) return NullSerializer() @@ -70,12 +72,25 @@ def _get_serializer(write, value): write(Types.TYPE_DOUBLE) return FloatSerializer() else: + for entry in custom_types: + if isinstance(value, entry[1]): + write(entry[0]) + return CustomTypeSerializer(entry[2]) raise Exception("Unsupported Type encountered.") +class CustomTypeSerializer(object): + def __init__(self, serializer): + self._serializer = serializer + + def serialize(self, value): + msg = self._serializer.serialize(value) + return b"".join([pack(">i",len(msg)), msg]) + + class TupleSerializer(object): - def __init__(self, write, value): - self.serializer = [_get_serializer(write, field) for field in value] + def __init__(self, write, value, custom_types): + self.serializer = [_get_serializer(write, field, custom_types) for field in value] def serialize(self, value): bits = [] @@ -117,8 +132,9 @@ class NullSerializer(object): class TypedCollector(object): - def __init__(self, con): + def __init__(self, con, env): self._connection = con + self._env = env def collect(self, value): if not isinstance(value, (list, tuple)): @@ -153,5 +169,13 @@ class TypedCollector(object): value = bytes(value) size = pack(">I", len(value)) self._connection.write(b"".join([Types.TYPE_BYTES, size, value])) + elif isinstance(value, _Dummy): + self._connection.write(pack(">i", 127)[3:]) + self._connection.write(pack(">i", 0)) else: + for entry in self._env._types: + if isinstance(value, entry[1]): + self._connection.write(entry[0]) + self._connection.write(CustomTypeSerializer(entry[2]).serialize(value)) + return raise Exception("Unsupported Type encountered.") \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py index fb0e26d..0e740cf 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/connection/Iterator.py @@ -168,21 +168,25 @@ class CoGroupIterator(object): class Iterator(defIter.Iterator): - def __init__(self, con, group=0): + def __init__(self, con, env, group=0): super(Iterator, self).__init__() self._connection = con self._init = True self._group = group self._deserializer = None + self._env = env def __next__(self): return self.next() + def _read(self, des_size): + return self._connection.read(des_size, self._group) + def next(self): if self.has_next(): if self._deserializer is None: - self._deserializer = _get_deserializer(self._group, self._connection.read) - return self._deserializer.deserialize() + self._deserializer = _get_deserializer(self._group, self._connection.read, self._env._types) + return self._deserializer.deserialize(self._read) else: raise StopIteration @@ -207,121 +211,88 @@ class DummyIterator(Iterator): return False -def _get_deserializer(group, read, type=None): +def _get_deserializer(group, read, custom_types, type=None): if type is None: type = read(1, group) - return _get_deserializer(group, read, type) + return _get_deserializer(group, read, custom_types, type) elif type == Types.TYPE_TUPLE: - return TupleDeserializer(read, group) + return TupleDeserializer(read, group, custom_types) elif type == Types.TYPE_BYTE: - return ByteDeserializer(read, group) + return ByteDeserializer() elif type == Types.TYPE_BYTES: - return ByteArrayDeserializer(read, group) + return ByteArrayDeserializer() elif type == Types.TYPE_BOOLEAN: - return BooleanDeserializer(read, group) + return BooleanDeserializer() elif type == Types.TYPE_FLOAT: - return FloatDeserializer(read, group) + return FloatDeserializer() elif type == Types.TYPE_DOUBLE: - return DoubleDeserializer(read, group) + return DoubleDeserializer() elif type == Types.TYPE_INTEGER: - return IntegerDeserializer(read, group) + return IntegerDeserializer() elif type == Types.TYPE_LONG: - return LongDeserializer(read, group) + return LongDeserializer() elif type == Types.TYPE_STRING: - return StringDeserializer(read, group) + return StringDeserializer() elif type == Types.TYPE_NULL: - return NullDeserializer(read, group) + return NullDeserializer() + else: + for entry in custom_types: + if type == entry[0]: + return entry[3] + raise Exception("Unable to find deserializer for type ID " + str(type)) class TupleDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - size = unpack(">I", self.read(4, self._group))[0] - self.deserializer = [_get_deserializer(self._group, self.read) for _ in range(size)] + def __init__(self, read, group, custom_types): + size = unpack(">I", read(4, group))[0] + self.deserializer = [_get_deserializer(group, read, custom_types) for _ in range(size)] - def deserialize(self): - return tuple([s.deserialize() for s in self.deserializer]) + def deserialize(self, read): + return tuple([s.deserialize(read) for s in self.deserializer]) class ByteDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - return unpack(">c", self.read(1, self._group))[0] + def deserialize(self, read): + return unpack(">c", read(1))[0] class ByteArrayDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - size = unpack(">i", self.read(4, self._group))[0] - return bytearray(self.read(size, self._group)) if size else bytearray(b"") + def deserialize(self, read): + size = unpack(">i", read(4))[0] + return bytearray(read(size)) if size else bytearray(b"") class BooleanDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - return unpack(">?", self.read(1, self._group))[0] + def deserialize(self, read): + return unpack(">?", read(1))[0] class FloatDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - return unpack(">f", self.read(4, self._group))[0] + def deserialize(self, read): + return unpack(">f", read(4))[0] class DoubleDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - return unpack(">d", self.read(8, self._group))[0] + def deserialize(self, read): + return unpack(">d", read(8))[0] class IntegerDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - return unpack(">i", self.read(4, self._group))[0] + def deserialize(self, read): + return unpack(">i", read(4))[0] class LongDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - return unpack(">q", self.read(8, self._group))[0] + def deserialize(self, read): + return unpack(">q", read(8))[0] class StringDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - - def deserialize(self): - length = unpack(">i", self.read(4, self._group))[0] - return self.read(length, self._group).decode("utf-8") if length else "" + def deserialize(self, read): + length = unpack(">i", read(4))[0] + return read(length).decode("utf-8") if length else "" class NullDeserializer(object): - def __init__(self, read, group): - self.read = read - self._group = group - def deserialize(self): return None http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py index db951fe..9c55787 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/CoGroupFunction.py @@ -25,13 +25,13 @@ class CoGroupFunction(Function.Function): self._keys1 = None self._keys2 = None - def _configure(self, input_file, output_file, port): + def _configure(self, input_file, output_file, port, env): self._connection = Connection.TwinBufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection, 0) - self._iterator2 = Iterator.Iterator(self._connection, 1) + self._iterator = Iterator.Iterator(self._connection, env, 0) + self._iterator2 = Iterator.Iterator(self._connection, env, 1) self._cgiter = Iterator.CoGroupIterator(self._iterator, self._iterator2, self._keys1, self._keys2) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - self._configure_chain(Collector.Collector(self._connection)) + self._configure_chain(Collector.Collector(self._connection, env)) def _run(self): collector = self._collector http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py index 5323462..4bf8b3a 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/Function.py @@ -32,11 +32,11 @@ class Function(object): self.context = None self._chain_operator = None - def _configure(self, input_file, output_file, port): + def _configure(self, input_file, output_file, port, env): self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection) + self._iterator = Iterator.Iterator(self._connection, env) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - self._configure_chain(Collector.Collector(self._connection)) + self._configure_chain(Collector.Collector(self._connection, env)) def _configure_chain(self, collector): if self._chain_operator is not None: http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py index 11bba30..23e39ab 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/GroupReduceFunction.py @@ -29,19 +29,19 @@ class GroupReduceFunction(Function.Function): self._combine = False self._values = [] - def _configure(self, input_file, output_file, port): + def _configure(self, input_file, output_file, port, env): if self._combine: self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection) - self._collector = Collector.Collector(self._connection) + self._iterator = Iterator.Iterator(self._connection, env) + self._collector = Collector.Collector(self._connection, env) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) self._run = self._run_combine else: self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection) + self._iterator = Iterator.Iterator(self._connection, env) self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) - self._configure_chain(Collector.Collector(self._connection)) + self._configure_chain(Collector.Collector(self._connection, env)) self._open() def _open(self): http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py index ffa6de0..4d19c13 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/functions/ReduceFunction.py @@ -27,21 +27,21 @@ class ReduceFunction(Function.Function): self._combine = False self._values = [] - def _configure(self, input_file, output_file, port): + def _configure(self, input_file, output_file, port, env): if self._combine: self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection) - self._collector = Collector.Collector(self._connection) + self._iterator = Iterator.Iterator(self._connection, env) + self._collector = Collector.Collector(self._connection, env) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) self._run = self._run_combine else: self._connection = Connection.BufferingTCPMappedFileConnection(input_file, output_file, port) - self._iterator = Iterator.Iterator(self._connection) + self._iterator = Iterator.Iterator(self._connection, env) if self._keys is None: self._run = self._run_allreduce else: self._group_iterator = Iterator.GroupIterator(self._iterator, self._keys) - self._configure_chain(Collector.Collector(self._connection)) + self._configure_chain(Collector.Collector(self._connection, env)) self.context = RuntimeContext.RuntimeContext(self._iterator, self._collector) def _set_grouping_keys(self, keys): http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py index f60273f..b0d79e8 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Constants.py @@ -91,6 +91,11 @@ import sys PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 + +class _Dummy(object): + pass + + if PY2: BOOL = True INT = 1 @@ -98,9 +103,11 @@ if PY2: FLOAT = 2.5 STRING = "type" BYTES = bytearray(b"byte") + CUSTOM = _Dummy() elif PY3: BOOL = True INT = 1 FLOAT = 2.5 STRING = "type" BYTES = bytearray(b"byte") + CUSTOM = _Dummy() http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py index 236eda4..8647686 100644 --- a/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py +++ b/flink-libraries/flink-python/src/main/python/org/apache/flink/python/api/flink/plan/Environment.py @@ -22,7 +22,7 @@ from flink.plan.Constants import _Fields, _Identifier from flink.utilities import Switch import copy import sys - +from struct import pack def get_environment(): """ @@ -49,6 +49,19 @@ class Environment(object): #specials self._broadcast = [] + self._types = [] + + def register_type(self, type, serializer, deserializer): + """ + Registers the given type with this environment, allowing all operators within to + (de-)serialize objects of the given type. + + :param type: class of the objects to be (de-)serialized + :param serializer: instance of the serializer + :param deserializer: instance of the deserializer + """ + self._types.append((pack(">i",126 - len(self._types))[3:], type, serializer, deserializer)) + def read_csv(self, path, types, line_delimiter="\n", field_delimiter=','): """ Create a DataSet that represents the tuples produced by reading the given CSV file. @@ -127,7 +140,7 @@ class Environment(object): if plan_mode: output_path = sys.stdin.readline().rstrip('\n') self._connection = Connection.OneWayBusyBufferingMappedFileConnection(output_path) - self._collector = Collector.TypedCollector(self._connection) + self._collector = Collector.TypedCollector(self._connection, self) self._send_plan() self._connection._write_buffer() else: @@ -146,7 +159,7 @@ class Environment(object): operator = set[_Fields.OPERATOR] if set[_Fields.ID] == -id: operator = set[_Fields.COMBINEOP] - operator._configure(input_path, output_path, port) + operator._configure(input_path, output_path, port, self) operator._go() sys.stdout.flush() sys.stderr.flush() @@ -342,4 +355,4 @@ class Environment(object): collect(_Identifier.BROADCAST) collect(entry[_Fields.PARENT][_Fields.ID]) collect(entry[_Fields.OTHER][_Fields.ID]) - collect(entry[_Fields.NAME]) \ No newline at end of file + collect(entry[_Fields.NAME]) http://git-wip-us.apache.org/repos/asf/flink/blob/30647a2e/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py index 2116d1f..2666945 100644 --- a/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py +++ b/flink-libraries/flink-python/src/test/python/org/apache/flink/python/api/test_main.py @@ -25,7 +25,8 @@ from flink.functions.CrossFunction import CrossFunction from flink.functions.JoinFunction import JoinFunction from flink.functions.GroupReduceFunction import GroupReduceFunction from flink.functions.CoGroupFunction import CoGroupFunction -from flink.plan.Constants import INT, STRING, FLOAT, BOOL, Order +from flink.plan.Constants import INT, STRING, FLOAT, BOOL, CUSTOM, Order +import struct class Mapper(MapFunction): @@ -259,6 +260,31 @@ if __name__ == "__main__": .co_group(d5).where(0).equal_to(2).using(CoGroup(), ((INT, FLOAT, STRING, BOOL), (FLOAT, FLOAT, INT))) \ .map_partition(Verify([((1, 0.5, "hello", True), (4.4, 4.3, 1)), ((1, 0.4, "hello", False), (4.3, 4.4, 1))], "CoGroup"), STRING).output() + #Custom Serialization + class Ext(MapPartitionFunction): + def map_partition(self, iterator, collector): + for value in iterator: + collector.collect(value.value) + + class MyObj(object): + def __init__(self, i): + self.value = i + + class MySerializer(object): + def serialize(self, value): + return struct.pack(">i", value.value) + + class MyDeserializer(object): + def deserialize(self, read): + i = struct.unpack(">i", read(4))[0] + return MyObj(i) + + env.register_type(MyObj, MySerializer(), MyDeserializer()) + + env.from_elements(MyObj(2), MyObj(4)) \ + .map(Id(), CUSTOM).map_partition(Ext(), INT) \ + .map_partition(Verify([2, 4], "CustomTypeSerialization"), STRING).output() + env.set_degree_of_parallelism(1) env.execute(local=True)