This is an automated email from the ASF dual-hosted git repository.

bertty pushed a commit to branch python-platform
in repository https://gitbox.apache.org/repos/asf/incubator-wayang.git

commit d60c1701cc6b9ba4554eb1bf72f8007a2e933261
Author: Bertty Contreras-Rojas <[email protected]>
AuthorDate: Fri Apr 8 16:12:29 2022 +0200

    [WAYANG-#8] add PyFlatMapOperator and correction of types capture
    
    Signed-off-by: bertty <[email protected]>
---
 python/src/pywy/dataquanta.py                      |  15 ++-
 python/src/pywy/operators/sink.py                  |   8 +-
 python/src/pywy/operators/unary.py                 |   8 --
 python/src/pywy/platforms/python/execution.py      |   8 ++
 python/src/pywy/platforms/python/mappings.py       |   1 +
 .../src/pywy/platforms/python/operator/__init__.py |   2 +
 .../platforms/python/operator/py_sink_textfile.py  |   8 +-
 .../platforms/python/operator/py_unary_flatmap.py  |  49 ++++++++
 .../pywy/tests/integration/python_platform_test.py | 124 +++++++++++++++------
 .../pywy/tests/unit/dataquanta/dataquanta_test.py  |  26 +++--
 python/src/pywy/types.py                           |  11 +-
 11 files changed, 198 insertions(+), 62 deletions(-)

diff --git a/python/src/pywy/dataquanta.py b/python/src/pywy/dataquanta.py
index 956d815e..5fc1e2df 100644
--- a/python/src/pywy/dataquanta.py
+++ b/python/src/pywy/dataquanta.py
@@ -64,8 +64,19 @@ class DataQuanta(GenericTco):
     def flatmap(self: "DataQuanta[In]", f: FlatmapFunction) -> 
"DataQuanta[IterableOut]":
         return DataQuanta(self.context, self._connect(FlatmapOperator(f)))
 
-    def store_textfile(self: "DataQuanta[In]", path: str):
-        last: List[SinkOperator] = [cast(SinkOperator, 
self._connect(TextFileSink(path, self.operator.outputSlot[0])))]
+    def store_textfile(self: "DataQuanta[In]", path: str, end_line: str = 
None):
+        last: List[SinkOperator] = [
+            cast(
+                SinkOperator,
+                self._connect(
+                    TextFileSink(
+                        path,
+                        self.operator.outputSlot[0],
+                        end_line
+                    )
+                )
+            )
+        ]
         plan = PywyPlan(self.context.plugins, last)
 
         plug = self.context.plugins.pop()
diff --git a/python/src/pywy/operators/sink.py 
b/python/src/pywy/operators/sink.py
index 1f78a63c..8a80ccfe 100644
--- a/python/src/pywy/operators/sink.py
+++ b/python/src/pywy/operators/sink.py
@@ -23,12 +23,16 @@ class SinkUnaryOperator(SinkOperator):
 
 
 class TextFileSink(SinkUnaryOperator):
-
     path: str
+    end_line: str
 
-    def __init__(self, path: str, input_type: GenericTco):
+    def __init__(self, path: str, input_type: GenericTco, end_line: str = 
None):
         super().__init__('TextFile', input_type)
         self.path = path
+        if input_type != str and end_line is None:
+            self.end_line = '\n'
+        else:
+            self.end_line = end_line
 
     def __str__(self):
         return super().__str__()
diff --git a/python/src/pywy/operators/unary.py 
b/python/src/pywy/operators/unary.py
index 03d6f118..339480b0 100644
--- a/python/src/pywy/operators/unary.py
+++ b/python/src/pywy/operators/unary.py
@@ -1,4 +1,3 @@
-from itertools import chain
 from pywy.operators.base import PywyOperator
 from pywy.types import (
                             GenericTco,
@@ -68,13 +67,6 @@ class FlatmapOperator(UnaryToUnaryOperator):
         super().__init__("Flatmap", types[0], types[1])
         self.fm_function = fm_function
 
-    # TODO remove wrapper
-    def getWrapper(self):
-        udf = self.fm_function
-        def func(iterator):
-            return chain.from_iterable(map(udf, iterator))
-        return func
-
     def __str__(self):
         return super().__str__()
 
diff --git a/python/src/pywy/platforms/python/execution.py 
b/python/src/pywy/platforms/python/execution.py
index bab62ecc..4fb2ca2b 100644
--- a/python/src/pywy/platforms/python/execution.py
+++ b/python/src/pywy/platforms/python/execution.py
@@ -2,6 +2,7 @@ from pywy.graph.types import WGraphOfOperator, NodeOperator
 from pywy.core import ChannelDescriptor
 from pywy.core import Executor
 from pywy.core import PywyPlan
+from pywy.operators import TextFileSource
 from pywy.platforms.python.channels import PY_ITERATOR_CHANNEL_DESCRIPTOR
 from pywy.platforms.python.operator.py_execution_operator import 
PyExecutionOperator
 
@@ -17,6 +18,7 @@ class PyExecutor(Executor):
 
         # TODO get this information by a configuration and ideally by the 
context
         descriptor_default: ChannelDescriptor = PY_ITERATOR_CHANNEL_DESCRIPTOR
+        files_pool = []
 
         def execute(op_current: NodeOperator, op_next: NodeOperator):
             if op_current is None:
@@ -66,4 +68,10 @@ class PyExecutor(Executor):
 
             py_next.inputChannel = py_current.outputChannel
 
+            if isinstance(py_current, TextFileSource):
+                
files_pool.append(py_current.outputChannel[0].provide_iterable())
+
         graph.traversal(graph.starting_nodes, execute)
+        # close the files used during the execution
+        for f in files_pool:
+            f.close()
diff --git a/python/src/pywy/platforms/python/mappings.py 
b/python/src/pywy/platforms/python/mappings.py
index 50a6ddeb..e46865cb 100644
--- a/python/src/pywy/platforms/python/mappings.py
+++ b/python/src/pywy/platforms/python/mappings.py
@@ -8,4 +8,5 @@ PYWY_OPERATOR_MAPPINGS.add_mapping(PyFilterOperator())
 PYWY_OPERATOR_MAPPINGS.add_mapping(PyTextFileSourceOperator())
 PYWY_OPERATOR_MAPPINGS.add_mapping(PyTextFileSinkOperator())
 PYWY_OPERATOR_MAPPINGS.add_mapping(PyMapOperator())
+PYWY_OPERATOR_MAPPINGS.add_mapping(PyFlatmapOperator())
 
diff --git a/python/src/pywy/platforms/python/operator/__init__.py 
b/python/src/pywy/platforms/python/operator/__init__.py
index 51b6f409..438a9692 100644
--- a/python/src/pywy/platforms/python/operator/__init__.py
+++ b/python/src/pywy/platforms/python/operator/__init__.py
@@ -1,6 +1,7 @@
 from pywy.platforms.python.operator.py_execution_operator import 
PyExecutionOperator
 from pywy.platforms.python.operator.py_unary_filter import PyFilterOperator
 from pywy.platforms.python.operator.py_unary_map import PyMapOperator
+from pywy.platforms.python.operator.py_unary_flatmap import PyFlatmapOperator
 from pywy.platforms.python.operator.py_source_textfile import 
PyTextFileSourceOperator
 from pywy.platforms.python.operator.py_sink_textfile import 
PyTextFileSinkOperator
 
@@ -10,4 +11,5 @@ __ALL__ = [
     PyTextFileSourceOperator,
     PyTextFileSinkOperator,
     PyMapOperator,
+    PyFlatmapOperator,
 ]
diff --git a/python/src/pywy/platforms/python/operator/py_sink_textfile.py 
b/python/src/pywy/platforms/python/operator/py_sink_textfile.py
index 801387cd..61528349 100644
--- a/python/src/pywy/platforms/python/operator/py_sink_textfile.py
+++ b/python/src/pywy/platforms/python/operator/py_sink_textfile.py
@@ -15,7 +15,8 @@ class PyTextFileSinkOperator(TextFileSink, 
PyExecutionOperator):
     def __init__(self, origin: TextFileSink = None):
         path = None if origin is None else origin.path
         type_class = None if origin is None else origin.inputSlot[0]
-        super().__init__(path, type_class)
+        end_line = None if origin is None else origin.end_line
+        super().__init__(path, type_class, end_line)
 
     def execute(self, inputs: List[Type[CH_T]], outputs: List[Type[CH_T]]):
         self.validate_channels(inputs, outputs)
@@ -23,12 +24,13 @@ class PyTextFileSinkOperator(TextFileSink, 
PyExecutionOperator):
             file = open(self.path, 'w')
             py_in_iter_channel: PyIteratorChannel = inputs[0]
             iterable = py_in_iter_channel.provide_iterable()
-            if self.inputSlot[0] == str:
+
+            if self.inputSlot[0] == str and self.end_line is None:
                 for element in iterable:
                     file.write(element)
             else:
                 for element in iterable:
-                    file.write("{}\n".format(str(element)))
+                    file.write("{}{}".format(str(element), self.end_line))
             file.close()
 
         else:
diff --git a/python/src/pywy/platforms/python/operator/py_unary_flatmap.py 
b/python/src/pywy/platforms/python/operator/py_unary_flatmap.py
new file mode 100644
index 00000000..d842b54a
--- /dev/null
+++ b/python/src/pywy/platforms/python/operator/py_unary_flatmap.py
@@ -0,0 +1,49 @@
+from itertools import chain
+from typing import Set, List, Type
+
+from pywy.core.channel import CH_T
+from pywy.operators.unary import FlatmapOperator
+from pywy.platforms.python.operator.py_execution_operator import 
PyExecutionOperator
+from pywy.platforms.python.channels import (
+                                                ChannelDescriptor,
+                                                PyIteratorChannel,
+                                                PY_ITERATOR_CHANNEL_DESCRIPTOR,
+                                                PY_CALLABLE_CHANNEL_DESCRIPTOR,
+                                                PyCallableChannel
+                                            )
+
+
+class PyFlatmapOperator(FlatmapOperator, PyExecutionOperator):
+
+    def __init__(self, origin: FlatmapOperator = None):
+        fm_function = None if origin is None else origin.fm_function
+        super().__init__(fm_function)
+
+    def execute(self, inputs: List[Type[CH_T]], outputs: List[Type[CH_T]]):
+        self.validate_channels(inputs, outputs)
+        udf = self.fm_function
+        if isinstance(inputs[0], PyIteratorChannel):
+            py_in_iter_channel: PyIteratorChannel = inputs[0]
+            py_out_iter_channel: PyIteratorChannel = outputs[0]
+            py_out_iter_channel.accept_iterable(chain.from_iterable(map(udf, 
py_in_iter_channel.provide_iterable())))
+        elif isinstance(inputs[0], PyCallableChannel):
+            py_in_call_channel: PyCallableChannel = inputs[0]
+            py_out_call_channel: PyCallableChannel = outputs[0]
+
+            def fm_func(iterator):
+                return chain.from_iterable(map(udf, iterator))
+
+            py_out_call_channel.accept_callable(
+                PyCallableChannel.concatenate(
+                    fm_func,
+                    py_in_call_channel.provide_callable()
+                )
+            )
+        else:
+            raise Exception("Channel Type does not supported")
+
+    def get_input_channeldescriptors(self) -> Set[ChannelDescriptor]:
+        return {PY_ITERATOR_CHANNEL_DESCRIPTOR, PY_CALLABLE_CHANNEL_DESCRIPTOR}
+
+    def get_output_channeldescriptors(self) -> Set[ChannelDescriptor]:
+        return {PY_ITERATOR_CHANNEL_DESCRIPTOR, PY_CALLABLE_CHANNEL_DESCRIPTOR}
diff --git a/python/src/pywy/tests/integration/python_platform_test.py 
b/python/src/pywy/tests/integration/python_platform_test.py
index d2ec4f22..48dc4558 100644
--- a/python/src/pywy/tests/integration/python_platform_test.py
+++ b/python/src/pywy/tests/integration/python_platform_test.py
@@ -1,12 +1,16 @@
+import logging
 import os
 import unittest
 import tempfile
-from typing import List
+from itertools import chain
+from typing import List, Iterable
 
 from pywy.config import RC_TEST_DIR as ROOT
 from pywy.dataquanta import WayangContext
 from pywy.plugins import PYTHON
 
+logger = logging.getLogger(__name__)
+
 
 class TestIntegrationPythonPlatform(unittest.TestCase):
 
@@ -14,60 +18,110 @@ class TestIntegrationPythonPlatform(unittest.TestCase):
 
     def setUp(self):
         self.file_10e0 = "{}/10e0MB.input".format(ROOT)
-        pass
 
-    def test_grep(self):
+    @staticmethod
+    def seed_small_grep(validation_file):
         def pre(a: str) -> bool:
             return 'six' in a
 
         fd, path_tmp = tempfile.mkstemp()
 
-        WayangContext() \
+        dq = WayangContext() \
             .register(PYTHON) \
-            .textfile(self.file_10e0) \
-            .filter(pre) \
-            .store_textfile(path_tmp)
-
-        lines_filter: List[str]
-        with open(self.file_10e0, 'r') as f:
-            lines_filter = list(filter(pre, f.readlines()))
-            selectivity = len(list(lines_filter))
+            .textfile(validation_file) \
+            .filter(pre)
+
+        return dq, path_tmp, pre
+
+    def validate_files(self,
+                       validation_file,
+                       outputed_file,
+                       read_and_convert_validation,
+                       read_and_convert_outputed,
+                       delete_outputed=True,
+                       print_variable=False):
+        lines_filter: List[int]
+        with open(validation_file, 'r') as f:
+            lines_filter = list(read_and_convert_validation(f))
+            selectivity = len(lines_filter)
 
-        lines_platform: List[str]
-        with open(path_tmp, 'r') as fp:
-            lines_platform = fp.readlines()
+        lines_platform: List[int]
+        with open(outputed_file, 'r') as fp:
+            lines_platform = list(read_and_convert_outputed(fp))
             elements = len(lines_platform)
-        os.remove(path_tmp)
+
+        if delete_outputed:
+            os.remove(outputed_file)
+
+        if print_variable:
+            logger.info(f"{lines_platform=}")
+            logger.info(f"{lines_filter=}")
+            logger.info(f"{elements=}")
+            logger.info(f"{selectivity=}")
 
         self.assertEqual(selectivity, elements)
         self.assertEqual(lines_filter, lines_platform)
 
+    def test_grep(self):
+
+        dq, path_tmp, pre = self.seed_small_grep(self.file_10e0)
+
+        dq.store_textfile(path_tmp)
+
+        def convert_validation(file):
+            return filter(pre, file.readlines())
+
+        def convert_outputed(file):
+            return file.readlines()
+
+        self.validate_files(
+            self.file_10e0,
+            path_tmp,
+            convert_validation,
+            convert_outputed
+        )
+
     def test_dummy_map(self):
-        def pre(a: str) -> bool:
-            return 'six' in a
 
         def convert(a: str) -> int:
             return len(a)
 
-        fd, path_tmp = tempfile.mkstemp()
+        dq, path_tmp, pre = self.seed_small_grep(self.file_10e0)
 
-        WayangContext() \
-            .register(PYTHON) \
-            .textfile(self.file_10e0) \
-            .filter(pre) \
-            .map(convert) \
+        dq.map(convert) \
             .store_textfile(path_tmp)
 
-        lines_filter: List[int]
-        with open(self.file_10e0, 'r') as f:
-            lines_filter = list(map(convert, filter(pre, f.readlines())))
-            selectivity = len(list(lines_filter))
+        def convert_validation(file):
+            return map(convert, filter(pre, file.readlines()))
 
-        lines_platform: List[int]
-        with open(path_tmp, 'r') as fp:
-            lines_platform = list(map(lambda x: int(x), fp.readlines()))
-            elements = len(lines_platform)
-        os.remove(path_tmp)
+        def convert_outputed(file):
+            return map(lambda x: int(x), file.read().splitlines())
 
-        self.assertEqual(selectivity, elements)
-        self.assertEqual(lines_filter, lines_platform)
+        self.validate_files(
+            self.file_10e0,
+            path_tmp,
+            convert_validation,
+            convert_outputed
+        )
+
+    def test_dummy_flatmap(self):
+        def fm_func(string: str) -> Iterable[str]:
+            return string.strip().split(" ")
+
+        dq, path_tmp, pre = self.seed_small_grep(self.file_10e0)
+
+        dq.flatmap(fm_func) \
+            .store_textfile(path_tmp, '\n')
+
+        def convert_validation(file):
+            return chain.from_iterable(map(fm_func, filter(pre, 
file.readlines())))
+
+        def convert_outputed(file):
+            return file.read().splitlines()
+
+        self.validate_files(
+            self.file_10e0,
+            path_tmp,
+            convert_validation,
+            convert_outputed
+        )
diff --git a/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py 
b/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py
index 9739307b..2d0b8b30 100644
--- a/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py
+++ b/python/src/pywy/tests/unit/dataquanta/dataquanta_test.py
@@ -1,10 +1,12 @@
 import unittest
-from typing import Tuple, Callable
+from typing import Tuple, Callable, Iterable
 from unittest.mock import Mock
 
 from pywy.dataquanta import WayangContext
 from pywy.dataquanta import DataQuanta
+from pywy.exception import PywyException
 from pywy.operators import *
+from pywy.types import FlatmapFunction
 
 
 class TestUnitCoreTranslator(unittest.TestCase):
@@ -110,13 +112,16 @@ class TestUnitCoreTranslator(unittest.TestCase):
     def test_flatmap_lambda(self):
         (operator, dq) = self.build_seed()
         func: Callable = lambda x: x.split(" ")
-        flatted = dq.flatmap(func)
-        self.validate_flatmap(flatted, operator)
+        try:
+            flatted = dq.flatmap(func)
+            self.validate_flatmap(flatted, operator)
+        except PywyException as e:
+            self.assertTrue("the return for the FlatmapFunction is not 
Iterable" in str(e))
 
     def test_flatmap_func(self):
         (operator, dq) = self.build_seed()
 
-        def fmfunc(i: str) -> str:
+        def fmfunc(i: str) -> Iterable[str]:
             for x in range(len(i)):
                 yield str(x)
 
@@ -126,9 +131,10 @@ class TestUnitCoreTranslator(unittest.TestCase):
     def test_flatmap_func_lambda(self):
         (operator, dq) = self.build_seed()
 
-        def fmfunc(i):
-            for x in range(len(i)):
-                yield str(x)
-
-        flatted = dq.flatmap(lambda x: fmfunc(x))
-        self.validate_flatmap(flatted, operator)
+        try:
+            fm_func_lambda: Callable[[str], Iterable[str]] = lambda i: [str(x) 
for x in range(len(i))]
+            flatted = dq.flatmap(fm_func_lambda)
+            self.assertRaises("the current implementation does not support 
lambdas")
+            # self.validate_flatmap(flatted, operator)
+        except PywyException as e:
+            self.assertTrue("the return for the FlatmapFunction is not 
Iterable" in str(e))
diff --git a/python/src/pywy/types.py b/python/src/pywy/types.py
index f39d4528..98131ca8 100644
--- a/python/src/pywy/types.py
+++ b/python/src/pywy/types.py
@@ -1,4 +1,4 @@
-from typing import (Generic, TypeVar, Callable, Hashable, Iterable)
+from typing import (Generic, TypeVar, Callable, Hashable, Iterable, Type)
 from inspect import signature
 from pywy.exception import PywyException
 
@@ -73,5 +73,12 @@ def get_type_flatmap_function(call: FlatmapFunction) -> 
(type, type):
             )
         )
 
+    if type(sig.return_annotation) != type(Iterable):
+        raise PywyException(
+            "the return for the FlatmapFunction is not Iterable, {}".format(
+                str(sig.return_annotation)
+            )
+        )
+
     keys = list(sig.parameters.keys())
-    return sig.parameters[keys[0]].annotation, sig.return_annotation
+    return sig.parameters[keys[0]].annotation, 
sig.return_annotation.__args__[0]

Reply via email to