khannaekta commented on a change in pull request #522:
URL: https://github.com/apache/madlib/pull/522#discussion_r533011189



##########
File path: 
src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, 
last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, 
first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, 
last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple 
clears the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) 
as when we create
+        # a session inside fit_transition, instead of creating a new graph it 
will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the 
tf session and graph
+        # It is fetched prior to calling the eval_transition as when we create 
a session inside
+        # eval_transition, instead of creating a new graph it will use 
eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, 
True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, 
prev_tf_graph, called_from_fit_multiple=False, **kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        
self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        
self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        
self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        
self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess

Review comment:
       Agreed, we should re-evaluate this when we move to 2.x.
   Since we already hop the model and then call evaluate (this behavior will 
change as part of the https://github.com/apache/madlib/pull/525) and fetch the 
session/model from GD and read the weights from the output table, it might fail 
as the weights and the model will not be for the same model. 
   We can probably come back to this after merging the Model Hopper Refactor 
PR(https://github.com/apache/madlib/pull/525)
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to