This is an automated email from the ASF dual-hosted git repository. nkak pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 58f34704f95dfabfdc09283ee3dbc094e914c6e4 Author: Nikhil Kak <[email protected]> AuthorDate: Wed Nov 11 17:49:02 2020 -0800 Remove pg tests from fit and eval transition JIRA: MADLIB-1438 We don't really need to test for pg because nothing in any of the transition functions care about postgres also our previous way of mocking is_platform_pg wasn't working correctly Also removed the code for postgres from get_image_count_per_seg_from_array since current_seg_id is always passed in as 0 for pg. So just indexing it from the array should be good enough For fit multiple, removed the call to is_platform_pg() while setting the gp_segment_id_col because we don't support pg for fit multiple. Co-authored-by: Ekta Khanna <[email protected]> --- .../madlib_keras_fit_multiple_model.py_in | 2 +- .../deep_learning/madlib_keras_helper.py_in | 8 ++--- .../test/unit_tests/test_madlib_keras.py_in | 34 ++++------------------ 3 files changed, 8 insertions(+), 36 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in index 9287524..8a5b2b3 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in @@ -208,7 +208,7 @@ class FitMultipleModel(): self.msts_for_schedule = self.msts random.shuffle(self.msts_for_schedule) self.grand_schedule = self.generate_schedule(self.msts_for_schedule) - self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME + self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else '' if self.warm_start: diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in index be9a1f9..cf030e1 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in @@ -110,14 +110,10 @@ def strip_trailing_nulls_from_class_values(class_values): def get_image_count_per_seg_from_array(current_seg_id, images_per_seg): """ Get the image count from the array containing all the images - per segment. Based on the platform, we find the index of the current segment. + per segment. This function is only called from inside the transition function. """ - if is_platform_pg(): - total_images = images_per_seg[0] - else: - total_images = images_per_seg[current_seg_id] - return total_images + return images_per_seg[current_seg_id] def get_image_count_per_seg_for_minibatched_data_from_db(table_name): """ diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in index 1b0ee8d..7cdd83c 100644 --- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in @@ -99,8 +99,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase): self.module_patcher.stop() self.subject.K.clear_session() - def _test_fit_transition_first_buffer_pass(self, is_platform_pg, **kwargs): - self.subject.is_platform_pg = Mock(return_value = is_platform_pg) + def _test_fit_transition_first_buffer_pass(self, **kwargs): ending_image_count = len(self.dependent_var_int) previous_state = np.array(self.model_weights, dtype=np.float32) @@ -152,8 +151,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase): self.assertTrue(k['GD']['x_train']) self.assertTrue(k['GD']['y_train']) - def _test_fit_transition_middle_buffer_pass(self, is_platform_pg, **kwargs): - self.subject.is_platform_pg = Mock(return_value = is_platform_pg) + def _test_fit_transition_middle_buffer_pass(self, **kwargs): starting_image_count = len(self.dependent_var_int) ending_image_count = starting_image_count + len(self.dependent_var_int) @@ -214,8 +212,7 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase): self.assertTrue(k['GD']['x_train']) self.assertTrue(k['GD']['y_train']) - def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs): - self.subject.is_platform_pg = Mock(return_value = is_platform_pg) + def _test_fit_transition_last_buffer_pass(self, **kwargs): starting_image_count = 2*len(self.dependent_var_int) ending_image_count = starting_image_count + len(self.dependent_var_int) @@ -237,10 +234,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase): self.assertTrue((weights == multiplied_weights).all()) self.assertEqual(ending_image_count, image_count) - def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg, + def _test_internal_keras_eval_transition_first_buffer(self, last_iteration = False, **kwargs): - self.subject.is_platform_pg = Mock(return_value = is_platform_pg) ending_image_count = len(self.dependent_var_int) state = [0,0,0] @@ -260,11 +256,9 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase): self.assertAlmostEqual(self.loss * image_count, agg_loss, 4) self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4) - def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg, + def _test_internal_keras_eval_transition_last_buffer(self, last_iteration = False, **kwargs): - self.subject.is_platform_pg = Mock(return_value = is_platform_pg) - starting_image_count = 2*len(self.dependent_var_int) ending_image_count = starting_image_count + len(self.dependent_var_int) @@ -448,24 +442,6 @@ class MadlibKerasFitEvalTransitionTestCase(unittest.TestCase): self.assertTrue('x_train' not in k['GD']) self.assertTrue('y_train' not in k['GD']) - def test_fit_transition_first_buffer_pass_pg(self): - self._test_fit_transition_first_buffer_pass(True) - - def test_fit_transition_first_buffer_pass_gpdb(self): - self._test_fit_transition_first_buffer_pass(False) - - def test_fit_transition_middle_buffer_pass_pg(self): - self._test_fit_transition_middle_buffer_pass(True) - - def test_fit_transition_middle_buffer_pass_gpdb(self): - self._test_fit_transition_middle_buffer_pass(False) - - def test_fit_transition_last_buffer_pass_pg(self): - self._test_fit_transition_last_buffer_pass(True) - - def test_fit_transition_last_buffer_pass_gpdb(self): - self._test_fit_transition_last_buffer_pass(False) - ############### GRAPH AND SESSION TESTS ################################ def test_fit_eval_2_iterations_mcf_null_gpdb(self): kwargs = {'GD': {}}
