This is an automated email from the ASF dual-hosted git repository.

jingyimei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit bb2737425d764d29ba57d98a292d5589b8464900
Author: Nikhil Kak <[email protected]>
AuthorDate: Thu Jun 6 17:29:02 2019 -0700

    DL: refactor image count related code
    
    JIRA: MADLIB-1356
    
    This commit does the following:
    1. Removed check for total image count == 0 or > agg_image_count in
    fit_transition, because those cases will never happen.
    
    2. Refactored duplicated code into a new function 
get_image_count_per_seg_from_array
    to get the total image count in fit, predict and evaluate.
    
    3. Renamed and moved image count related functions to helper file and
    added more comments.
    
    Closes #407
    
    Co-authored-by: Jingyi Mei <[email protected]>
---
 .../modules/deep_learning/madlib_keras.py_in       |  72 ++--------
 .../deep_learning/madlib_keras_helper.py_in        |  81 +++++++++++
 .../deep_learning/madlib_keras_predict.py_in       |  44 +-----
 .../test/unit_tests/test_madlib_keras.py_in        | 151 +++------------------
 4 files changed, 112 insertions(+), 236 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 2b749ca..8187d8d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -36,15 +36,8 @@ from keras.models import *
 from keras.optimizers import *
 from keras.regularizers import *
 import madlib_keras_serializer
-from madlib_keras_helper import MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
-from madlib_keras_helper import MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
-from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
-from madlib_keras_helper import INDEPENDENT_VARNAME_COLNAME
-from madlib_keras_helper import CLASS_VALUES_COLNAME
-from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
-from madlib_keras_helper import NORMALIZING_CONST_COLNAME
-from madlib_keras_validator import FitInputValidator
-from madlib_keras_validator import EvaluateInputValidator
+from madlib_keras_helper import *
+from madlib_keras_validator import *
 from madlib_keras_wrapper import *
 from keras_model_arch_table import ModelArchSchema
 
@@ -115,10 +108,10 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     serialized_weights = get_initial_weights(model, model_arch_result,
                                              warm_start, gpus_per_host)
     # Compute total images on each segment
-    seg_ids_train, images_per_seg_train = get_images_per_seg(source_table)
+    seg_ids_train, images_per_seg_train = 
get_image_count_per_seg_for_minibatched_data_from_db(source_table)
 
     if validation_table:
-        seg_ids_val, images_per_seg_val = get_images_per_seg(validation_table)
+        seg_ids_val, images_per_seg_val = 
get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
 
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
@@ -397,38 +390,6 @@ def should_compute_metrics_this_iter(curr_iter, 
metrics_compute_frequency,
     return (curr_iter)%metrics_compute_frequency == 0 or \
            curr_iter == num_iterations
 
-def get_images_per_seg(source_table):
-    """
-    Compute total images in each segment, by querying source_table.  For
-    postgres, this is just the total number of images in the db.
-    :param source_table:
-    :return: Returns a string and two arrays
-    1. An array containing all the segment numbers in ascending order
-    1. An array containing the total images on each of the segments in the
-    segment array.
-    """
-
-    mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
-
-    if is_platform_pg():
-        res = plpy.execute(
-            """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS images_per_seg
-                FROM {1}
-            """.format(mb_dep_var_col, source_table))
-        images_per_seg = [int(res[0]['images_per_seg'])]
-        seg_ids = [0]
-    else:
-        images_per_seg = plpy.execute(
-            """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS 
images_per_seg
-                FROM {1}
-                GROUP BY gp_segment_id
-            """.format(mb_dep_var_col, source_table))
-        seg_ids = [int(each_segment["gp_segment_id"])
-                   for each_segment in images_per_seg]
-        images_per_seg = [int(each_segment["images_per_seg"])
-                          for each_segment in images_per_seg]
-    return seg_ids, images_per_seg
-
 def fit_transition(state, dependent_var, independent_var, model_architecture,
                    compile_params, fit_params, current_seg_id, seg_ids,
                    images_per_seg, gpus_per_host, segments_per_host,
@@ -471,17 +432,9 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
 
     with K.tf.device(device_name):
         updated_weights = segment_model.get_weights()
-    if is_platform_pg():
-        total_images = images_per_seg[0]
-    else:
-        total_images = images_per_seg[seg_ids.index(current_seg_id)]
 
-    if total_images == 0:
-        if is_platform_pg():
-            plpy.error('Total images is 0 in fit_transition')
-
-        else:
-            plpy.error('Total images is 0 in fit_transition on segment 
{0}'.format(current_seg_id))
+    total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
+                                                      images_per_seg)
 
     # Re-serialize the weights
     # Update image count, check if we are done
@@ -493,9 +446,6 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
             # In GPDB, each segment would have a keras session, so clear
             # them after the last buffer is processed.
         clear_keras_session()
-    elif agg_image_count > total_images:
-        plpy.error('Processed {0} images, but there were supposed to be only 
{1}!'
-                   .format(agg_image_count, total_images))
 
     new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
         agg_image_count, updated_weights)
@@ -576,7 +526,7 @@ def evaluate(schema_madlib, model_table, test_table, 
output_table, gpus_per_host
     metrics_type = res['metrics_type']
     compile_params = "$madlib$" + res['compile_params'] + "$madlib$"
 
-    seg_ids, images_per_seg = get_images_per_seg(test_table)
+    seg_ids, images_per_seg = 
get_image_count_per_seg_for_minibatched_data_from_db(test_table)
 
     res = plpy.execute("""
         SELECT {dependent_varname_col}, {independent_varname_col}
@@ -681,16 +631,12 @@ def internal_keras_eval_transition(state, dependent_var, 
independent_var,
     agg_loss += (image_count * loss)
     agg_metric += (image_count * metric)
 
-    if is_platform_pg():
-        total_images = images_per_seg[0]
-    else:
-        total_images = images_per_seg[seg_ids.index(current_seg_id)]
+    total_images = get_image_count_per_seg_from_array(current_seg_id, seg_ids,
+                                                      images_per_seg)
 
     if agg_image_count == total_images:
         SD.pop('segment_model', None)
         clear_keras_session()
-    elif agg_image_count > total_images:
-        plpy.error("Evaluated too many images.")
 
     state[0] = agg_loss
     state[1] = agg_metric
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 03a8399..948f2ad 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -18,6 +18,8 @@
 # under the License.
 
 import numpy as np
+from utilities.utilities import is_platform_pg
+import plpy
 
 # Prepend a dimension to np arrays using expand_dims.
 def expand_input_dims(input_data, target_type=None):
@@ -62,6 +64,85 @@ def strip_trailing_nulls_from_class_values(class_values):
         class_values = class_values[:num_of_valid_class_values]
     return class_values
 
+def get_image_count_per_seg_from_array(current_seg_id, seg_ids, 
images_per_seg):
+    """
+    Get the image count from the array containing all the images
+    per segment. Based on the platform, we find the index of the current 
segment.
+    This function is only called from inside the transition function. 
+    """
+    if is_platform_pg():
+        total_images = images_per_seg[0]
+    else:
+        total_images = images_per_seg[seg_ids.index(current_seg_id)]
+    return total_images
+
+def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
+    """
+    Query the given minibatch formatted table and return the total rows per 
segment.
+    Since we cannot pass a dictionary to the keras fit step function we create 
+    arrays out of the segment numbers and the rows per segment values.
+    This function assumes that the table is not empty.
+    :param table_name:
+    :return: Returns two arrays
+    1. An array containing all the segment numbers in ascending order
+    1. An array containing the total images on each of the segments in the
+    segment array.
+    """
+
+    mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
+
+    if is_platform_pg():
+        res = plpy.execute(
+            """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS images_per_seg
+                FROM {1}
+            """.format(mb_dep_var_col, table_name))
+        images_per_seg = [int(res[0]['images_per_seg'])]
+        seg_ids = [0]
+    else:
+        images_per_seg = plpy.execute(
+            """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS 
images_per_seg
+                FROM {1}
+                GROUP BY gp_segment_id
+            """.format(mb_dep_var_col, table_name))
+        seg_ids = [int(each_segment["gp_segment_id"])
+                   for each_segment in images_per_seg]
+        images_per_seg = [int(each_segment["images_per_seg"])
+                          for each_segment in images_per_seg]
+    return seg_ids, images_per_seg
+
+def get_image_count_per_seg_for_non_minibatched_data_from_db(table_name):
+    """
+    Query the given non minibatch formatted table and return the total rows 
per segment.
+    Since we cannot pass a dictionary to the keras fit step function we create 
arrays
+    out of the segment numbers and the rows per segment values.
+    This function assumes that the table is not empty.
+    :param table_name:
+    :return: gp segment id col name and two arrays
+    1. An array containing all the segment numbers in ascending order
+    2. An array containing the total rows for each of the segments in the
+    segment array
+    """
+    if is_platform_pg():
+        images_per_seg = plpy.execute(
+            """ SELECT count(*) AS images_per_seg
+                FROM {0}
+            """.format(table_name))
+        seg_ids = [0]
+        gp_segment_id_col = '0'
+    else:
+        # Compute total buffers on each segment
+        images_per_seg = plpy.execute(
+            """ SELECT gp_segment_id, count(*) AS images_per_seg
+                FROM {0}
+                GROUP BY gp_segment_id
+            """.format(table_name))
+        seg_ids = [int(image["gp_segment_id"]) for image in images_per_seg]
+        gp_segment_id_col = '{0}.gp_segment_id'.format(table_name)
+
+    images_per_seg = [int(image["images_per_seg"]) for image in images_per_seg]
+    return gp_segment_id_col, seg_ids, images_per_seg
+
+
 # Name of columns in model summary table.
 CLASS_VALUES_COLNAME = "class_values"
 NORMALIZING_CONST_COLNAME = "normalizing_const"
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 4fe7430..f333f04 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -26,8 +26,7 @@ from keras.layers import *
 from keras.models import *
 from keras.optimizers import *
 
-from madlib_keras_helper import expand_input_dims
-from madlib_keras_helper import strip_trailing_nulls_from_class_values
+from madlib_keras_helper import *
 from madlib_keras_validator import PredictInputValidator
 from predict_input_params import PredictParamsProcessor
 from utilities.control import MinWarning
@@ -78,7 +77,7 @@ def predict(schema_madlib, model_table, test_table, id_col,
         pred_col_type, is_response, MODULE_NAME)
 
     gp_segment_id_col, seg_ids_test, \
-    images_per_seg_test = 
get_images_per_seg_for_non_minibatched_data(test_table)
+    images_per_seg_test = 
get_image_count_per_seg_for_non_minibatched_data_from_db(test_table)
     segments_per_host = get_segments_per_host()
 
     predict_query = plpy.prepare("""
@@ -104,38 +103,6 @@ def predict(schema_madlib, model_table, test_table, id_col,
     plpy.execute(predict_query, [model_arch, model_data])
 
 
-def get_images_per_seg_for_non_minibatched_data(table_name):
-    """
-    This function queries the given table and returns the total rows per 
segment.
-    Since we cannot pass a dictionary to the keras fit step function we create 
arrays
-    out of the segment numbers and the rows per segment values.
-    This function assumes that the table is not empty.
-    :param table_name:
-    :return: gp segment id col name and two arrays
-    1. An array containing all the segment numbers in ascending order
-    2. An array containing the total rows for each of the segments in the
-    segment array
-    """
-    if is_platform_pg():
-        images_per_seg = plpy.execute(
-            """ SELECT count(*) AS images_per_seg
-                FROM {0}
-            """.format(table_name))
-        seg_ids = [0]
-        gp_segment_id_col = '0'
-    else:
-        # Compute total buffers on each segment
-        images_per_seg = plpy.execute(
-            """ SELECT gp_segment_id, count(*) AS images_per_seg
-                FROM {0}
-                GROUP BY gp_segment_id
-            """.format(table_name))
-        seg_ids = [int(image["gp_segment_id"]) for image in images_per_seg]
-        gp_segment_id_col = '{0}.gp_segment_id'.format(table_name)
-
-    images_per_seg = [int(image["images_per_seg"]) for image in images_per_seg]
-    return gp_segment_id_col, seg_ids, images_per_seg
-
 def internal_keras_predict(independent_var, model_architecture, model_data,
                            is_response, normalizing_const, current_seg_id, 
seg_ids,
                            images_per_seg, gpus_per_host, segments_per_host,
@@ -180,11 +147,8 @@ def internal_keras_predict(independent_var, 
model_architecture, model_data,
             # and not mini-batched, this list contains exactly one list in it,
             # so return back the first list in probs.
             result = probs[0]
-
-        if is_platform_pg():
-            total_images = images_per_seg[0]
-        else:
-            total_images = images_per_seg[seg_ids.index(current_seg_id)]
+        total_images = get_image_count_per_seg_from_array(current_seg_id, 
seg_ids,
+                                                          images_per_seg)
 
         if SD[row_count_key] == total_images:
             SD.pop(model_key, None)
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index bf9da51..b3e3fdd 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -75,44 +75,12 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
-    def test_fit_transition_first_buffer_pass_pg(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
-
-        #postgres
-        self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = True)
-        starting_image_count = 0
-        ending_image_count = len(self.dependent_var)
-        previous_state = np.array(self.model_weights, dtype=np.float32)
-
-        k = {'SD' : {}}
-
-        new_state = self.subject.fit_transition(
-            None, self.dependent_var, self.independent_var , 
self.model.to_json(),
-            self.compile_params, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, previous_state.tostring(), **k)
-        state = np.fromstring(new_state, dtype=np.float32)
-        image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
-        self.assertEqual(ending_image_count, image_count)
-        # weights should not be modified yet
-        self.assertTrue((self.model_weights == weights).all())
-        # set_session must be not be called in transition func for PG
-        self.assertEqual(1, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the first buffer
-        self.assertEqual(0, self.subject.clear_keras_session.call_count)
-        self.assertTrue(k['SD']['segment_model'])
-
-    def test_fit_transition_first_buffer_pass_gpdb(self):
+    def _test_fit_transition_first_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `clear_keras_session`
-
-        #postgres
         self.subject.K.set_session = Mock()
         self.subject.clear_keras_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 0
         ending_image_count = len(self.dependent_var)
         previous_state = np.array(self.model_weights, dtype=np.float32)
@@ -135,12 +103,12 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
-    def test_fit_transition_middle_buffer_pass(self):
+    def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `clear_keras_session`
         self.subject.K.set_session = Mock()
         self.subject.clear_keras_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = False)
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = len(self.dependent_var)
         ending_image_count = starting_image_count + len(self.dependent_var)
@@ -169,12 +137,12 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # Clear session and sess.close must not get called for the middle 
buffer
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
 
-    def test_fit_transition_last_buffer_pass_pg(self):
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `clear_keras_session`
         self.subject.K.set_session = Mock()
         self.subject.clear_keras_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = True)
+        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var)
         ending_image_count = starting_image_count + len(self.dependent_var)
@@ -205,80 +173,23 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         #  but not in postgres
         self.assertEqual(1, self.subject.clear_keras_session.call_count)
 
-    def test_fit_transition_last_buffer_pass_gpdb(self):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `clear_keras_session`
-        self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
-        self.subject.is_platform_pg = Mock(return_value = False)
-
-        starting_image_count = 2*len(self.dependent_var)
-        ending_image_count = starting_image_count + len(self.dependent_var)
-
-        state = [starting_image_count]
-        state.extend(self.model_weights)
-        state = np.array(state, dtype=np.float32)
-
-        multiplied_weights = 
mult(self.total_images_per_seg[0],self.model_weights)
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model}}
-        new_state = self.subject.fit_transition(
-            state.tostring(), self.dependent_var, self.independent_var,
-            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
-
-        state = np.fromstring(new_state, dtype=np.float32)
-        image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
-        self.assertEqual(ending_image_count, image_count)
-        # weights should be multiplied by final image count
-        self.assertTrue((multiplied_weights == weights).all())
-        # set_session must get called ONLY once, when its the first buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must get called for the last buffer in 
gpdb,
-        #  but not in postgres
-        self.assertEqual(1, self.subject.clear_keras_session.call_count)
-
-    def test_fit_transition_ending_image_count_zero(self):
-        self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
-        starting_image_count = 0
-        previous_state = [starting_image_count]
-        previous_state.extend(self.model_weights)
-        previous_state = np.array(previous_state, dtype=np.float32)
-
-        k = {'SD' : {}}
-
-        total_images_per_seg = [0,1,1]
-
-        with self.assertRaises(plpy.PLPYException) as error:
-            new_state = self.subject.fit_transition(
-                None, self.dependent_var, self.independent_var , 
self.model.to_json(),
-                self.compile_params, self.fit_params, 0, self.all_seg_ids,
-                total_images_per_seg, 0, 4, previous_state.tostring(), **k)
-        self.assertIn('Total images is 0', str(error.exception))
+    def test_fit_transition_first_buffer_pass_pg(self):
+        self._test_fit_transition_first_buffer_pass(True)
 
-    def test_fit_transition_too_many_images(self):
-        self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
-        starting_image_count = 0
-        previous_state = [ starting_image_count]
-        previous_state.extend(self.model_weights)
-        previous_state = np.array(previous_state, dtype=np.float32)
+    def test_fit_transition_first_buffer_pass_gpdb(self):
+        self._test_fit_transition_first_buffer_pass(False)
 
-        k = {'SD' : {}}
+    def test_fit_transition_middle_buffer_pass_pg(self):
+        self._test_fit_transition_middle_buffer_pass(True)
 
-        total_images_per_seg = [1,1,1]
+    def test_fit_transition_middle_buffer_pass_gpdb(self):
+        self._test_fit_transition_middle_buffer_pass(False)
 
-        with self.assertRaises(plpy.PLPYException) as error:
-            new_state = self.subject.fit_transition(
-                None, self.dependent_var, self.independent_var , 
self.model.to_json(),
-                self.compile_params, self.fit_params, 0, self.all_seg_ids,
-                total_images_per_seg, 0, 4, previous_state.tostring(), **k)
+    def test_fit_transition_last_buffer_pass_pg(self):
+        self._test_fit_transition_last_buffer_pass(True)
 
-        self.assertIn('only 1', str(error.exception))
+    def test_fit_transition_last_buffer_pass_gpdb(self):
+        self._test_fit_transition_last_buffer_pass(False)
 
     def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
         k = {}
@@ -1265,32 +1176,6 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         result = self.subject.internal_keras_eval_final(None)
         self.assertEqual(result, None)
 
-    def test_internal_keras_eval_transition_too_many_images(self):
-        self.subject.K.set_session = Mock()
-        self.subject.clear_keras_session = Mock()
-
-        starting_image_count = 5
-
-        k = {'SD' : {}}
-        state = [self.loss, self.accuracy, starting_image_count]
-        state.extend(self.model_weights)
-        state = np.array(state, dtype=np.float32)
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', state.tostring())
-
-        state = [self.loss * starting_image_count, self.accuracy * 
starting_image_count, starting_image_count]
-
-        k['SD']['segment_model'] = self.model
-
-        total_images_per_seg = [10, 10, 10]
-
-        with self.assertRaises(plpy.PLPYException):
-            self.subject.internal_keras_eval_transition(
-                state, self.dependent_var , self.independent_var, 
self.model.to_json(),
-                'dummy_model_data', None, 0, self.all_seg_ids,
-                total_images_per_seg, 0, 3, **k)
-
     def test_internal_keras_eval_final_image_count_zero(self):
         input_state = [0, 0, 0]
 

Reply via email to