khannaekta commented on a change in pull request #522: URL: https://github.com/apache/madlib/pull/522#discussion_r533011189
########## File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in ########## @@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase): def test_fit_transition_last_buffer_pass_gpdb(self): self._test_fit_transition_last_buffer_pass(False) + ############### GRAPH AND SESSION TESTS ################################ + def test_fit_eval_2_iterations_mcf_null_gpdb(self): + kwargs = {'GD': {}} + GD = kwargs['GD'] + + ######################### fit for 2 iterations ########## + # iteration 1 + first_iter_keras_sess = self._run_fit_iteration(**kwargs) + self._assert_keras_session_same_as_gd_session(GD) + + first_iter_tf_graph = self.subject.tf.get_default_graph() + + # iteration 2 (last iteration) + last_iter_keras_sess = self._run_fit_iteration(**kwargs) + self._assert_keras_session_same_as_gd_session(GD) + + last_iter_tf_graph = self.subject.tf.get_default_graph() + + self.assertEquals(first_iter_keras_sess, last_iter_keras_sess) + self.assertEquals(first_iter_tf_graph, last_iter_tf_graph) + + ###################### eval transition for last iteration ########### + self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs) + eval_last_iter_keras_sess = self.subject.K.get_session() + eval_last_iter_tf_graph = self.subject.tf.get_default_graph() + + self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess) + self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph) + self._assert_gd_cleared(GD) + + def test_fit_eval_2_iterations_mcf_1_gpdb(self): + kwargs = {'GD': {}} + GD = kwargs['GD'] + + ######################### fit + eval for 2 iterations ########## + # iteration 1 fit + first_iter_keras_sess = self._run_fit_iteration(**kwargs) + self._assert_keras_session_same_as_gd_session(GD) + + first_iter_tf_graph = self.subject.tf.get_default_graph() + + # iteration 1 eval + self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs) + self._assert_keras_session_same_as_gd_session(GD) + + eval_first_iter_keras_sess = self.subject.K.get_session() + eval_first_iter_tf_graph = self.subject.tf.get_default_graph() + + self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess) + self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph) + + # iteration 2 fit (last iteration) + last_iter_keras_sess = self._run_fit_iteration(**kwargs) + self._assert_keras_session_same_as_gd_session(GD) + + last_iter_tf_graph = self.subject.tf.get_default_graph() + + self.assertEquals(first_iter_keras_sess, last_iter_keras_sess) + self.assertEquals(first_iter_tf_graph, last_iter_tf_graph) + + # iteration 2 eval (last iteration) + self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs) + + eval_last_iter_keras_sess = self.subject.K.get_session() + eval_last_iter_tf_graph = self.subject.tf.get_default_graph() + + self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess) + self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph) + self._assert_gd_cleared(GD) + + def test_fit_multiple_2_iterations(self): + kwargs = {'GD': {}} + GD = kwargs['GD'] + + ############ fit multiple for 2 iterations ########## + # iteration 1 + # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session + # and graph at the last buffer. + # It is fetched prior to calling the fit_transition(from fit_multiple) as when we create + # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph. + # This enables us to do the not equals assert. + first_iter_tf_graph = self.subject.tf.get_default_graph() + first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs) + self._assert_gd_cleared(GD) + + # iteration 2 (last iteration) + last_iter_tf_graph = self.subject.tf.get_default_graph() + last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs) + self._assert_gd_cleared(GD) + + self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess) + self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph) + + def test_eval_multiple_any_iteration(self): + # This test tests 2 things: + # 1. Calling eval_transition from fit_multiple + # 2. Calling eval_transition from evaluate directly + kwargs = {'GD': {}} + GD = kwargs['GD'] + + # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and graph + # It is fetched prior to calling the eval_transition as when we create a session inside + # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1. + # This enables us to do the not equals assert. + eval_iter_tf_graph1 = self.subject.tf.get_default_graph() + eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs) + eval_iter_keras_sess2 = self.subject.K.get_session() + eval_iter_tf_graph2 = self.subject.tf.get_default_graph() + + self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2) + self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2) + self._assert_gd_cleared(GD) + + def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False, **kwargs): + self._test_internal_keras_eval_transition_first_buffer(final_iteration, + **kwargs) + self._assert_gd_is_valid(kwargs['GD']) + self._assert_keras_session_same_as_gd_session(kwargs['GD']) + + eval_first_buffer_keras_sess = kwargs['GD']['sess'] + self.assertFalse(eval_first_buffer_keras_sess._closed) + eval_first_buffer_tf_graph = self.subject.tf.get_default_graph() + + if not called_from_fit_multiple: + self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess) + self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph) + + self._test_internal_keras_eval_transition_middle_buffer(final_iteration, + **kwargs ) + self._assert_gd_is_valid(kwargs['GD']) + self._assert_keras_session_same_as_gd_session(kwargs['GD']) + self.assertFalse(eval_first_buffer_keras_sess._closed) + + self._test_internal_keras_eval_transition_last_buffer(final_iteration, + **kwargs) + if final_iteration: + self._assert_gd_cleared(kwargs['GD']) + self.assertTrue(eval_first_buffer_keras_sess._closed) + else: + self._assert_gd_is_valid(kwargs['GD']) + self.assertFalse(eval_first_buffer_keras_sess._closed) + return eval_first_buffer_keras_sess + + def _run_fit_iteration(self, **kwargs): + self._test_fit_transition_first_buffer_pass(**kwargs) + gd_first_buffer = kwargs['GD'] + self._assert_gd_is_valid(gd_first_buffer) + iter_sess = gd_first_buffer['sess'] + self.assertFalse(iter_sess._closed) + self._assert_keras_session_same_as_gd_session(gd_first_buffer) + + self._test_fit_transition_middle_buffer_pass(**kwargs) + gd_middle_buffer = kwargs['GD'] + self._assert_gd_is_valid(gd_middle_buffer) + self.assertFalse(iter_sess._closed) + + self._test_fit_transition_last_buffer_pass(**kwargs) + gd_last_buffer = kwargs['GD'] + self._assert_gd_is_valid(gd_last_buffer) + self.assertFalse(iter_sess._closed) + return iter_sess + + def _run_fit_multiple_iteration(self, **kwargs): + self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs) + self._assert_gd_is_valid(kwargs['GD']) + self._assert_keras_session_same_as_gd_session(kwargs['GD']) + iter_sess = kwargs['GD']['sess'] + self.assertFalse(iter_sess._closed) + + self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs) + self._assert_gd_is_valid(kwargs['GD']) + self._assert_keras_session_same_as_gd_session(kwargs['GD']) + self.assertFalse(iter_sess._closed) + + self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs) + self._assert_gd_cleared(kwargs['GD']) + self.assertTrue(iter_sess._closed) + return iter_sess Review comment: Agreed, we should re-evaluate this when we move to 2.x. Since we already hop the model and then call evaluate (this behavior will change as part of the https://github.com/apache/madlib/pull/525) and fetch the session/model from GD and read the weights from the output table, it might fail as the weights and the model will not be for the same model. We can probably come back to this after merging the Model Hopper Refactor PR(https://github.com/apache/madlib/pull/525) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org