This is an automated email from the ASF dual-hosted git repository. bertty pushed a commit to branch python-platform in repository https://gitbox.apache.org/repos/asf/incubator-wayang.git
commit d60c1701cc6b9ba4554eb1bf72f8007a2e933261 Author: Bertty Contreras-Rojas <[email protected]> AuthorDate: Fri Apr 8 16:12:29 2022 +0200 [WAYANG-#8] add PyFlatMapOperator and correction of types capture Signed-off-by: bertty <[email protected]> --- python/src/pywy/dataquanta.py | 15 ++- python/src/pywy/operators/sink.py | 8 +- python/src/pywy/operators/unary.py | 8 -- python/src/pywy/platforms/python/execution.py | 8 ++ python/src/pywy/platforms/python/mappings.py | 1 + .../src/pywy/platforms/python/operator/__init__.py | 2 + .../platforms/python/operator/py_sink_textfile.py | 8 +- .../platforms/python/operator/py_unary_flatmap.py | 49 ++++++++ .../pywy/tests/integration/python_platform_test.py | 124 +++++++++++++++------ .../pywy/tests/unit/dataquanta/dataquanta_test.py | 26 +++-- python/src/pywy/types.py | 11 +- 11 files changed, 198 insertions(+), 62 deletions(-) diff --git a/python/src/pywy/dataquanta.py b/python/src/pywy/dataquanta.py index 956d815e..5fc1e2df 100644 --- a/python/src/pywy/dataquanta.py +++ b/python/src/pywy/dataquanta.py @@ -64,8 +64,19 @@ class DataQuanta(GenericTco): def flatmap(self: "DataQuanta[In]", f: FlatmapFunction) -> "DataQuanta[IterableOut]": return DataQuanta(self.context, self._connect(FlatmapOperator(f))) - def store_textfile(self: "DataQuanta[In]", path: str): - last: List[SinkOperator] = [cast(SinkOperator, self._connect(TextFileSink(path, self.operator.outputSlot[0])))] + def store_textfile(self: "DataQuanta[In]", path: str, end_line: str = None): + last: List[SinkOperator] = [ + cast( + SinkOperator, + self._connect( + TextFileSink( + path, + self.operator.outputSlot[0], + end_line + ) + ) + ) + ] plan = PywyPlan(self.context.plugins, last) plug = self.context.plugins.pop() diff --git a/python/src/pywy/operators/sink.py b/python/src/pywy/operators/sink.py index 1f78a63c..8a80ccfe 100644 --- a/python/src/pywy/operators/sink.py +++ b/python/src/pywy/operators/sink.py @@ -23,12 +23,16 @@ class SinkUnaryOperator(SinkOperator): class TextFileSink(SinkUnaryOperator): - path: str + end_line: str - def __init__(self, path: str, input_type: GenericTco): + def __init__(self, path: str, input_type: GenericTco, end_line: str = None): super().__init__('TextFile', input_type) self.path = path + if input_type != str and end_line is None: + self.end_line = '\n' + else: + self.end_line = end_line def __str__(self): return super().__str__() diff --git a/python/src/pywy/operators/unary.py b/python/src/pywy/operators/unary.py index 03d6f118..339480b0 100644 --- a/python/src/pywy/operators/unary.py +++ b/python/src/pywy/operators/unary.py @@ -1,4 +1,3 @@ -from itertools import chain from pywy.operators.base import PywyOperator from pywy.types import ( GenericTco, @@ -68,13 +67,6 @@ class FlatmapOperator(UnaryToUnaryOperator): super().__init__("Flatmap", types[0], types[1]) self.fm_function = fm_function - # TODO remove wrapper - def getWrapper(self): - udf = self.fm_function - def func(iterator): - return chain.from_iterable(map(udf, iterator)) - return func - def __str__(self): return super().__str__() diff --git a/python/src/pywy/platforms/python/execution.py b/python/src/pywy/platforms/python/execution.py index bab62ecc..4fb2ca2b 100644 --- a/python/src/pywy/platforms/python/execution.py +++ b/python/src/pywy/platforms/python/execution.py @@ -2,6 +2,7 @@ from pywy.graph.types import WGraphOfOperator, NodeOperator from pywy.core import ChannelDescriptor from pywy.core import Executor from pywy.core import PywyPlan +from pywy.operators import TextFileSource from pywy.platforms.python.channels import PY_ITERATOR_CHANNEL_DESCRIPTOR from pywy.platforms.python.operator.py_execution_operator import PyExecutionOperator @@ -17,6 +18,7 @@ class PyExecutor(Executor): # TODO get this information by a configuration and ideally by the context descriptor_default: ChannelDescriptor = PY_ITERATOR_CHANNEL_DESCRIPTOR + files_pool = [] def execute(op_current: NodeOperator, op_next: NodeOperator): if op_current is None: @@ -66,4 +68,10 @@ class PyExecutor(Executor): py_next.inputChannel = py_current.outputChannel + if isinstance(py_current, TextFileSource): + files_pool.append(py_current.outputChannel[0].provide_iterable()) + graph.traversal(graph.starting_nodes, execute) + # close the files used during the execution + for f in files_pool: + f.close() diff --git a/python/src/pywy/platforms/python/mappings.py b/python/src/pywy/platforms/python/mappings.py index 50a6ddeb..e46865cb 100644 --- a/python/src/pywy/platforms/python/mappings.py +++ b/python/src/pywy/platforms/python/mappings.py @@ -8,4 +8,5 @@ PYWY_OPERATOR_MAPPINGS.add_mapping(PyFilterOperator()) PYWY_OPERATOR_MAPPINGS.add_mapping(PyTextFileSourceOperator()) PYWY_OPERATOR_MAPPINGS.add_mapping(PyTextFileSinkOperator()) PYWY_OPERATOR_MAPPINGS.add_mapping(PyMapOperator()) +PYWY_OPERATOR_MAPPINGS.add_mapping(PyFlatmapOperator()) diff --git a/python/src/pywy/platforms/python/operator/__init__.py b/python/src/pywy/platforms/python/operator/__init__.py index 51b6f409..438a9692 100644 --- a/python/src/pywy/platforms/python/operator/__init__.py +++ b/python/src/pywy/platforms/python/operator/__init__.py @@ -1,6 +1,7 @@ from pywy.platforms.python.operator.py_execution_operator import PyExecutionOperator from pywy.platforms.python.operator.py_unary_filter import PyFilterOperator from pywy.platforms.python.operator.py_unary_map import PyMapOperator +from pywy.platforms.python.operator.py_unary_flatmap import PyFlatmapOperator from pywy.platforms.python.operator.py_source_textfile import PyTextFileSourceOperator from pywy.platforms.python.operator.py_sink_textfile import PyTextFileSinkOperator @@ -10,4 +11,5 @@ __ALL__ = [ PyTextFileSourceOperator, PyTextFileSinkOperator, PyMapOperator, + PyFlatmapOperator, ] diff --git a/python/src/pywy/platforms/python/operator/py_sink_textfile.py b/python/src/pywy/platforms/python/operator/py_sink_textfile.py index 801387cd..61528349 100644 --- a/python/src/pywy/platforms/python/operator/py_sink_textfile.py +++ b/python/src/pywy/platforms/python/operator/py_sink_textfile.py @@ -15,7 +15,8 @@ class PyTextFileSinkOperator(TextFileSink, PyExecutionOperator): def __init__(self, origin: TextFileSink = None): path = None if origin is None else origin.path type_class = None if origin is None else origin.inputSlot[0] - super().__init__(path, type_class) + end_line = None if origin is None else origin.end_line + super().__init__(path, type_class, end_line) def execute(self, inputs: List[Type[CH_T]], outputs: List[Type[CH_T]]): self.validate_channels(inputs, outputs) @@ -23,12 +24,13 @@ class PyTextFileSinkOperator(TextFileSink, PyExecutionOperator): file = open(self.path, 'w') py_in_iter_channel: PyIteratorChannel = inputs[0] iterable = py_in_iter_channel.provide_iterable() - if self.inputSlot[0] == str: + + if self.inputSlot[0] == str and self.end_line is None: for element in iterable: file.write(element) else: for element in iterable: - file.write("{}\n".format(str(element))) + file.write("{}{}".format(str(element), self.end_line)) file.close() else: diff --git a/python/src/pywy/platforms/python/operator/py_unary_flatmap.py b/python/src/pywy/platforms/python/operator/py_unary_flatmap.py new file mode 100644 index 00000000..d842b54a --- /dev/null +++ b/python/src/pywy/platforms/python/operator/py_unary_flatmap.py @@ -0,0 +1,49 @@ +from itertools import chain +from typing import Set, List, Type + +from pywy.core.channel import CH_T +from pywy.operators.unary import FlatmapOperator +from pywy.platforms.python.operator.py_execution_operator import PyExecutionOperator +from pywy.platforms.python.channels import ( + ChannelDescriptor, + PyIteratorChannel, + PY_ITERATOR_CHANNEL_DESCRIPTOR, + PY_CALLABLE_CHANNEL_DESCRIPTOR, + PyCallableChannel + ) + + +class PyFlatmapOperator(FlatmapOperator, PyExecutionOperator): + + def __init__(self, origin: FlatmapOperator = None): + fm_function = None if origin is None else origin.fm_function + super().__init__(fm_function) + + def execute(self, inputs: List[Type[CH_T]], outputs: List[Type[CH_T]]): + self.validate_channels(inputs, outputs) + udf = self.fm_function + if isinstance(inputs[0], PyIteratorChannel): + py_in_iter_channel: PyIteratorChannel = inputs[0] + py_out_iter_channel: PyIteratorChannel = outputs[0] + py_out_iter_channel.accept_iterable(chain.from_iterable(map(udf, py_in_iter_channel.provide_iterable()))) + elif isinstance(inputs[0], PyCallableChannel): + py_in_call_channel: PyCallableChannel = inputs[0] + py_out_call_channel: PyCallableChannel = outputs[0] + + def fm_func(iterator): + return chain.from_iterable(map(udf, iterator)) + + py_out_call_channel.accept_callable( + PyCallableChannel.concatenate( + fm_func, + py_in_call_channel.provide_callable() + ) + ) + else: + raise Exception("Channel Type does not supported") + + def get_input_channeldescriptors(self) -> Set[ChannelDescriptor]: + return {PY_ITERATOR_CHANNEL_DESCRIPTOR, PY_CALLABLE_CHANNEL_DESCRIPTOR} + + def get_output_channeldescriptors(self) -> Set[ChannelDescriptor]: + return {PY_ITERATOR_CHANNEL_DESCRIPTOR, PY_CALLABLE_CHANNEL_DESCRIPTOR} diff --git a/python/src/pywy/tests/integration/python_platform_test.py b/python/src/pywy/tests/integration/python_platform_test.py index d2ec4f22..48dc4558 100644 --- a/python/src/pywy/tests/integration/python_platform_test.py +++ b/python/src/pywy/tests/integration/python_platform_test.py @@ -1,12 +1,16 @@ +import logging import os import unittest import tempfile -from typing import List +from itertools import chain +from typing import List, Iterable from pywy.config import RC_TEST_DIR as ROOT from pywy.dataquanta import WayangContext from pywy.plugins import PYTHON +logger = logging.getLogger(__name__) + class TestIntegrationPythonPlatform(unittest.TestCase): @@ -14,60 +18,110 @@ class TestIntegrationPythonPlatform(unittest.TestCase): def setUp(self): self.file_10e0 = "{}/10e0MB.input".format(ROOT) - pass - def test_grep(self): + @staticmethod + def seed_small_grep(validation_file): def pre(a: str) -> bool: return 'six' in a fd, path_tmp = tempfile.mkstemp() - WayangContext() \ + dq = WayangContext() \ .register(PYTHON) \ - .textfile(self.file_10e0) \ - .filter(pre) \ - .store_textfile(path_tmp) - - lines_filter: List[str] - with open(self.file_10e0, 'r') as f: - lines_filter = list(filter(pre, f.readlines())) - selectivity = len(list(lines_filter)) + .textfile(validation_file) \ + .filter(pre) + + return dq, path_tmp, pre + + def validate_files(self, + validation_file, + outputed_file, + read_and_convert_validation, + read_and_convert_outputed, + delete_outputed=True, + print_variable=False): + lines_filter: List[int] + with open(validation_file, 'r') as f: + lines_filter = list(read_and_convert_validation(f)) + selectivity = len(lines_filter) - lines_platform: List[str] - with open(path_tmp, 'r') as fp: - lines_platform = fp.readlines() + lines_platform: List[int] + with open(outputed_file, 'r') as fp: + lines_platform = list(read_and_convert_outputed(fp)) elements = len(lines_platform) - os.remove(path_tmp) + + if delete_outputed: + os.remove(outputed_file) + + if print_variable: + logger.info(f"{lines_platform=}") + logger.info(f"{lines_filter=}") + logger.info(f"{elements=}") + logger.info(f"{selectivity=}") self.assertEqual(selectivity, elements) self.assertEqual(lines_filter, lines_platform) + def test_grep(self): + + dq, path_tmp, pre = self.seed_small_grep(self.file_10e0) + + dq.store_textfile(path_tmp) + + def convert_validation(file): + return filter(pre, file.readlines()) + + def convert_outputed(file): + return file.readlines() + + self.validate_files( + self.file_10e0, + path_tmp, + convert_validation, + convert_outputed + ) + def test_dummy_map(self): - def pre(a: str) -> bool: - return 'six' in a def convert(a: str) -> int: return len(a) - fd, path_tmp = tempfile.mkstemp() + dq, path_tmp, pre = self.seed_small_grep(self.file_10e0) - WayangContext() \ - .register(PYTHON) \ - .textfile(self.file_10e0) \ - .filter(pre) \ - .map(convert) \ + dq.map(convert) \ .store_textfile(path_tmp) - lines_filter: List[int] - with open(self.file_10e0, 'r') as f: - lines_filter = list(map(convert, filter(pre, f.readlines()))) - selectivity = len(list(lines_filter)) + def convert_validation(file): + return map(convert, filter(pre, file.readlines())) - lines_platform: List[int] - with open(path_tmp, 'r') as fp: - lines_platform = list(map(lambda x: int(x), fp.readlines())) - elements = len(lines_platform) - os.remove(path_tmp) + def convert_outputed(file): + return map(lambda x: int(x), file.read().splitlines()) - self.assertEqual(selectivity, elements) - self.assertEqual(lines_filter, lines_platform) + self.validate_files( + self.file_10e0, + path_tmp, + convert_validation, + convert_outputed + ) + + def test_dummy_flatmap(self): + def fm_func(string: str) -> Iterable[str]: + return string.strip().split(" ") + + dq, path_tmp, pre = self.seed_small_grep(self.file_10e0) + + dq.flatmap(fm_func) \ + .store_textfile(path_tmp, '\n') + + def convert_validation(file): + return chain.from_iterable(map(fm_func, filter(pre, file.readlines()))) + + def convert_outputed(file): + return file.read().splitlines() + + self.validate_files( + self.file_10e0, + path_tmp, + convert_validation, + convert_outputed + ) diff --git a/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py b/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py index 9739307b..2d0b8b30 100644 --- a/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py +++ b/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py @@ -1,10 +1,12 @@ import unittest -from typing import Tuple, Callable +from typing import Tuple, Callable, Iterable from unittest.mock import Mock from pywy.dataquanta import WayangContext from pywy.dataquanta import DataQuanta +from pywy.exception import PywyException from pywy.operators import * +from pywy.types import FlatmapFunction class TestUnitCoreTranslator(unittest.TestCase): @@ -110,13 +112,16 @@ class TestUnitCoreTranslator(unittest.TestCase): def test_flatmap_lambda(self): (operator, dq) = self.build_seed() func: Callable = lambda x: x.split(" ") - flatted = dq.flatmap(func) - self.validate_flatmap(flatted, operator) + try: + flatted = dq.flatmap(func) + self.validate_flatmap(flatted, operator) + except PywyException as e: + self.assertTrue("the return for the FlatmapFunction is not Iterable" in str(e)) def test_flatmap_func(self): (operator, dq) = self.build_seed() - def fmfunc(i: str) -> str: + def fmfunc(i: str) -> Iterable[str]: for x in range(len(i)): yield str(x) @@ -126,9 +131,10 @@ class TestUnitCoreTranslator(unittest.TestCase): def test_flatmap_func_lambda(self): (operator, dq) = self.build_seed() - def fmfunc(i): - for x in range(len(i)): - yield str(x) - - flatted = dq.flatmap(lambda x: fmfunc(x)) - self.validate_flatmap(flatted, operator) + try: + fm_func_lambda: Callable[[str], Iterable[str]] = lambda i: [str(x) for x in range(len(i))] + flatted = dq.flatmap(fm_func_lambda) + self.assertRaises("the current implementation does not support lambdas") + # self.validate_flatmap(flatted, operator) + except PywyException as e: + self.assertTrue("the return for the FlatmapFunction is not Iterable" in str(e)) diff --git a/python/src/pywy/types.py b/python/src/pywy/types.py index f39d4528..98131ca8 100644 --- a/python/src/pywy/types.py +++ b/python/src/pywy/types.py @@ -1,4 +1,4 @@ -from typing import (Generic, TypeVar, Callable, Hashable, Iterable) +from typing import (Generic, TypeVar, Callable, Hashable, Iterable, Type) from inspect import signature from pywy.exception import PywyException @@ -73,5 +73,12 @@ def get_type_flatmap_function(call: FlatmapFunction) -> (type, type): ) ) + if type(sig.return_annotation) != type(Iterable): + raise PywyException( + "the return for the FlatmapFunction is not Iterable, {}".format( + str(sig.return_annotation) + ) + ) + keys = list(sig.parameters.keys()) - return sig.parameters[keys[0]].annotation, sig.return_annotation + return sig.parameters[keys[0]].annotation, sig.return_annotation.__args__[0]
