This is an automated email from the ASF dual-hosted git repository. jbonofre pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-liminal.git
commit 249aa74d2e038e0e6617df7503a559d7fa1a6111 Author: roei <[email protected]> AuthorDate: Thu Jul 2 11:36:52 2020 +0300 fix split list function --- .../kubernetes_pod_operator_with_input_output.py | 41 ++++++++------------- ...st_kubernetes_pod_operator_with_input_output.py | 43 ++++++++++++++++++++++ 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py b/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py index 5833550..267010e 100644 --- a/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py +++ b/liminal/runners/airflow/operators/kubernetes_pod_operator_with_input_output.py @@ -3,16 +3,11 @@ import json from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator -def split_list(seq, num): - avg = len(seq) / float(num) - out = [] - last = 0.0 - - while last < len(seq): - out.append(seq[int(last):int(last + avg)]) - last += avg - - return out +def _split_list(seq, num): + k, m = divmod(len(seq), num) + return list( + (seq[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(num)) + ) _IS_SPLIT_KEY = 'is_split' @@ -27,13 +22,9 @@ class PrepareInputOperator(KubernetesPodOperator): executors=1, *args, **kwargs): - namespace = kwargs['namespace'] - image = kwargs['image'] - name = kwargs['name'] - - del kwargs['namespace'] - del kwargs['image'] - del kwargs['name'] + namespace = kwargs.pop('namespace') + image = kwargs.pop('image') + name = kwargs.pop('name') super().__init__( namespace=namespace, @@ -74,7 +65,11 @@ class PrepareInputOperator(KubernetesPodOperator): self.log.info(f'Generated input: {input_dict}') if self.split_input: - input_splits = split_list(input_dict, self.executors) + input_splits = _split_list(input_dict, self.executors) + numbered_splits = list( + zip(range(len(input_splits)), input_splits) + ) + self.log.info(numbered_splits) ti.xcom_push(key=_IS_SPLIT_KEY, value=True) @@ -100,13 +95,9 @@ class KubernetesPodOperatorWithInputAndOutput(KubernetesPodOperator): input_task_id=None, *args, **kwargs): - namespace = kwargs['namespace'] - image = kwargs['image'] - name = kwargs['name'] - - del kwargs['namespace'] - del kwargs['image'] - del kwargs['name'] + namespace = kwargs.pop('namespace') + image = kwargs.pop('image') + name = kwargs.pop('name') super().__init__( namespace=namespace, diff --git a/tests/runners/airflow/operators/test_kubernetes_pod_operator_with_input_output.py b/tests/runners/airflow/operators/test_kubernetes_pod_operator_with_input_output.py new file mode 100644 index 0000000..6895382 --- /dev/null +++ b/tests/runners/airflow/operators/test_kubernetes_pod_operator_with_input_output.py @@ -0,0 +1,43 @@ +import unittest +from unittest import TestCase +import itertools + +from liminal.runners.airflow.operators.\ + kubernetes_pod_operator_with_input_output import _split_list + + +class TestSplitList(TestCase): + def setUp(self) -> None: + self.short_seq = [{f'task_{i}': f'value_{i}'} for i in range(3)] + self.long_seq = [{f'task_{i}': f'value_{i}'} for i in range(10)] + + def test_seq_equal_num(self): + num = len(self.short_seq) + result = _split_list(self.short_seq, num) + expected = [[{'task_0': 'value_0'}], [{'task_1': 'value_1'}], + [{'task_2': 'value_2'}]] + self.assertListEqual(expected, result) + + def test_seq_grater_than_num(self): + num = 3 + result = _split_list(self.long_seq, num) + n_tasks = len(self.long_seq) + + min_length = min([len(i) for i in result]) + max_length = max([len(i) for i in result]) + flat_results = list(itertools.chain(*result)) + + self.assertGreaterEqual(max_length - min_length, 1) + self.assertEqual(n_tasks, len(flat_results)) + self.assertTrue(all([{f'task_{i}': f'value_{i}'} in flat_results + for i in range(n_tasks)])) + + def test_seq_smaller_than_num(self): + test_num_range = [8, 9, 10, 11, 12] + for num in test_num_range: + result = _split_list(self.short_seq, num) + self.assertEqual(len(result), num) + self.assertTrue(all([[i] in result for i in self.short_seq])) + self.assertEqual([[]] * (num - len(self.short_seq)), + [i for i in result if i == []]) +
