This is an automated email from the ASF dual-hosted git repository.

jbonofre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-liminal.git

commit 249aa74d2e038e0e6617df7503a559d7fa1a6111
Author: roei <[email protected]>
AuthorDate: Thu Jul 2 11:36:52 2020 +0300

    fix split list function
---
 .../kubernetes_pod_operator_with_input_output.py   | 41 ++++++++-------------
 ...st_kubernetes_pod_operator_with_input_output.py | 43 ++++++++++++++++++++++
 2 files changed, 59 insertions(+), 25 deletions(-)

diff --git 
a/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py
 
b/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py
index 5833550..267010e 100644
--- 
a/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py
+++ 
b/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py
@@ -3,16 +3,11 @@ import json
 from airflow.contrib.operators.kubernetes_pod_operator import 
KubernetesPodOperator
 
 
-def split_list(seq, num):
-    avg = len(seq) / float(num)
-    out = []
-    last = 0.0
-
-    while last < len(seq):
-        out.append(seq[int(last):int(last + avg)])
-        last += avg
-
-    return out
+def _split_list(seq, num):
+    k, m = divmod(len(seq), num)
+    return list(
+        (seq[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in 
range(num))
+    )
 
 
 _IS_SPLIT_KEY = 'is_split'
@@ -27,13 +22,9 @@ class PrepareInputOperator(KubernetesPodOperator):
                  executors=1,
                  *args,
                  **kwargs):
-        namespace = kwargs['namespace']
-        image = kwargs['image']
-        name = kwargs['name']
-
-        del kwargs['namespace']
-        del kwargs['image']
-        del kwargs['name']
+        namespace = kwargs.pop('namespace')
+        image = kwargs.pop('image')
+        name = kwargs.pop('name')
 
         super().__init__(
             namespace=namespace,
@@ -74,7 +65,11 @@ class PrepareInputOperator(KubernetesPodOperator):
             self.log.info(f'Generated input: {input_dict}')
 
             if self.split_input:
-                input_splits = split_list(input_dict, self.executors)
+                input_splits = _split_list(input_dict, self.executors)
+                numbered_splits = list(
+                    zip(range(len(input_splits)), input_splits)
+                )
+                self.log.info(numbered_splits)
 
                 ti.xcom_push(key=_IS_SPLIT_KEY, value=True)
 
@@ -100,13 +95,9 @@ class 
KubernetesPodOperatorWithInputAndOutput(KubernetesPodOperator):
                  input_task_id=None,
                  *args,
                  **kwargs):
-        namespace = kwargs['namespace']
-        image = kwargs['image']
-        name = kwargs['name']
-
-        del kwargs['namespace']
-        del kwargs['image']
-        del kwargs['name']
+        namespace = kwargs.pop('namespace')
+        image = kwargs.pop('image')
+        name = kwargs.pop('name')
 
         super().__init__(
             namespace=namespace,
diff --git 
a/tests/runners/airflow/operators/test_kubernetes_pod_operator_with_input_output.py
 
b/tests/runners/airflow/operators/test_kubernetes_pod_operator_with_input_output.py
new file mode 100644
index 0000000..6895382
--- /dev/null
+++ 
b/tests/runners/airflow/operators/test_kubernetes_pod_operator_with_input_output.py
@@ -0,0 +1,43 @@
+import unittest
+from unittest import TestCase
+import itertools
+
+from liminal.runners.airflow.operators.\
+    kubernetes_pod_operator_with_input_output import _split_list
+
+
+class TestSplitList(TestCase):
+    def setUp(self) -> None:
+        self.short_seq = [{f'task_{i}': f'value_{i}'} for i in range(3)]
+        self.long_seq = [{f'task_{i}': f'value_{i}'} for i in range(10)]
+
+    def test_seq_equal_num(self):
+        num = len(self.short_seq)
+        result = _split_list(self.short_seq, num)
+        expected = [[{'task_0': 'value_0'}], [{'task_1': 'value_1'}],
+                    [{'task_2': 'value_2'}]]
+        self.assertListEqual(expected, result)
+
+    def test_seq_grater_than_num(self):
+        num = 3
+        result = _split_list(self.long_seq, num)
+        n_tasks = len(self.long_seq)
+
+        min_length = min([len(i) for i in result])
+        max_length = max([len(i) for i in result])
+        flat_results = list(itertools.chain(*result))
+
+        self.assertGreaterEqual(max_length - min_length, 1)
+        self.assertEqual(n_tasks, len(flat_results))
+        self.assertTrue(all([{f'task_{i}': f'value_{i}'} in flat_results
+                             for i in range(n_tasks)]))
+
+    def test_seq_smaller_than_num(self):
+        test_num_range = [8, 9, 10, 11, 12]
+        for num in test_num_range:
+            result = _split_list(self.short_seq, num)
+            self.assertEqual(len(result), num)
+            self.assertTrue(all([[i] in result for i in self.short_seq]))
+            self.assertEqual([[]] * (num - len(self.short_seq)),
+                             [i for i in result if i == []])
+

Reply via email to