This is an automated email from the ASF dual-hosted git repository.

jingyimei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 0eebc49e8e50c8487c26ee272d8a2d2a9e3070e6
Author: Nikhil Kak <[email protected]>
AuthorDate: Tue Jun 4 17:41:28 2019 -0700

    DL: Create one keras session per iteration for pg in fit and evaluate
    
    JIRA: MADLIB-1356
    
    Currently for postgres, we create one keras session in
    fit/evaluate/predict and keep it open for all the iterations and then
    close it at the end.  We did a few experiments and found out that if we
    create 1 keras session per iteration , it performs slightly better.
    So this commit modifies the code to make pg behave the same as gpdb(1
    session per iteration) in fit_transition, eval_transition and
    internal_predict functions.
    
    We also changed set_keras_session to use a device_name to be consistent
    with other keras operations.
    
    This commit also fixes the following bug: If gpus are available on the
    host but the user passed in 0 for the gpus_per_host param, we still end
    up using gpu memory. This is because in the madlib_keras.fit UDF we call
    model.get_weights() to initialize the model and this function ends up
    using gpu(s) if there are any available.
    
    Closes #407
    
    Co-authored-by: Jingyi Mei <[email protected]>
---
 .../modules/deep_learning/madlib_keras.py_in       | 49 ++++++++++------------
 .../deep_learning/madlib_keras_predict.py_in       | 18 ++------
 .../deep_learning/madlib_keras_wrapper.py_in       | 17 ++++----
 .../test/unit_tests/test_madlib_keras.py_in        |  9 ++--
 4 files changed, 38 insertions(+), 55 deletions(-)

diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 0f85d9e..2b749ca 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -110,20 +110,10 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
     num_classes = get_num_classes(model_arch)
     fit_validator.validate_input_shapes(input_shape)
 
-    serialized_weights = get_initial_weights(model, model_arch_result,
-                                             warm_start)
-
-    # TODO: Refactor the pg related logic in a future PR when we think
-    # about making the fit function easier to read and maintain.
-    if is_platform_pg():
-        gp_segment_id_col = '0'
-        set_keras_session(gpus_per_host, segments_per_host)
-    else:
-        # we want to disable gpu on gpdb's master node because GPUs will only 
be used
-        # for segment nodes.
-        set_cuda_env('-1')
-        gp_segment_id_col = 'gp_segment_id'
+    gp_segment_id_col = '0' if is_platform_pg() else 'gp_segment_id'
 
+    serialized_weights = get_initial_weights(model, model_arch_result,
+                                             warm_start, gpus_per_host)
     # Compute total images on each segment
     seg_ids_train, images_per_seg_train = get_images_per_seg(source_table)
 
@@ -286,23 +276,32 @@ def fit(schema_madlib, source_table, model, 
model_arch_table,
         $2 as {1}""".format(model, ModelArchSchema.MODEL_ARCH), ["bytea", 
"json"])
     plpy.execute(create_output_table, [serialized_weights, model_arch])
 
-    if is_platform_pg():
-        clear_keras_session()
-
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
-def get_initial_weights(model_table, model_arch_result, warm_start):
+def get_initial_weights(model_table, model_arch_result, warm_start, 
gpus_per_host):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
         table, if no weights are defined there, randomly initialize it using
         keras.
+        We also need to set the cuda environment variable based on the 
platform.
+        1. For postgres, if user specifies gpus_per_host=0 which means they 
want
+        to use CPU, then we have to set CUDA_VISIBLE_DEVICES to -1 to disable 
gpu.
+        Otherwise model.get_weights() will use gpu if available.
+
+        2. For gpdb, we want to disable gpu on gpdb's master node because GPUs
+        will only be used for segment nodes.
         @args:
             @param model_table: Output model table passed in to fit.
             @param model_arch_result: Dict containing model architecture info.
             @param warm_start: Boolean flag indicating warm start or not.
     """
+    if is_platform_pg():
+        _ = get_device_name_and_set_cuda_env(gpus_per_host, None)
+    else:
+        _ = get_device_name_and_set_cuda_env(0, None)
+
     if warm_start:
         serialized_weights = plpy.execute("""
             SELECT model_data FROM {0}
@@ -310,10 +309,10 @@ def get_initial_weights(model_table, model_arch_result, 
warm_start):
     else:
         serialized_weights = model_arch_result[ModelArchSchema.MODEL_WEIGHTS]
         if not serialized_weights:
-            master_model = model_from_json(
+            model = model_from_json(
                 model_arch_result[ModelArchSchema.MODEL_ARCH])
             serialized_weights = madlib_keras_serializer.serialize_nd_weights(
-                master_model.get_weights())
+                model.get_weights())
     return serialized_weights
 
 def get_source_summary_table_dict(fit_validator):
@@ -443,8 +442,7 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
                                                    current_seg_id)
     # Set up system if this is the first buffer on segment'
     if not state:
-        if not is_platform_pg():
-            set_keras_session(gpus_per_host, segments_per_host)
+        set_keras_session(device_name, gpus_per_host, segments_per_host)
         segment_model = model_from_json(model_architecture)
         compile_and_set_weights(segment_model, compile_params, device_name,
                                 prev_serialized_weights)
@@ -492,10 +490,9 @@ def fit_transition(state, dependent_var, independent_var, 
model_architecture,
         # with the total number of images here instead of the merge function.
         # The merge function only deals with aggregating them.
         updated_weights = [ total_images * w for w in updated_weights ]
-        if not is_platform_pg():
             # In GPDB, each segment would have a keras session, so clear
             # them after the last buffer is processed.
-            clear_keras_session()
+        clear_keras_session()
     elif agg_image_count > total_images:
         plpy.error('Processed {0} images, but there were supposed to be only 
{1}!'
                    .format(agg_image_count, total_images))
@@ -651,8 +648,7 @@ def internal_keras_eval_transition(state, dependent_var, 
independent_var,
     agg_loss, agg_metric, agg_image_count = state
 
     if not agg_image_count:
-        if not is_platform_pg():
-            set_keras_session(gpus_per_host, segments_per_host)
+        set_keras_session(device_name, gpus_per_host, segments_per_host)
         model = model_from_json(model_architecture)
         compile_and_set_weights(model, compile_params, device_name,
                                 serialized_weights)
@@ -692,8 +688,7 @@ def internal_keras_eval_transition(state, dependent_var, 
independent_var,
 
     if agg_image_count == total_images:
         SD.pop('segment_model', None)
-        if not is_platform_pg():
-            clear_keras_session()
+        clear_keras_session()
     elif agg_image_count > total_images:
         plpy.error("Evaluated too many images.")
 
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index deff6cf..4fe7430 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -81,13 +81,6 @@ def predict(schema_madlib, model_table, test_table, id_col,
     images_per_seg_test = 
get_images_per_seg_for_non_minibatched_data(test_table)
     segments_per_host = get_segments_per_host()
 
-    if is_platform_pg():
-        set_keras_session(gpus_per_host, segments_per_host)
-    else:
-        # we want to disable gpu on gpdb's master node because GPUs will only 
be used
-        # for segment nodes.
-        set_cuda_env('-1')
-
     predict_query = plpy.prepare("""
         CREATE TABLE {output_table} AS
         SELECT {id_col}, {prediction_select_clause}
@@ -110,8 +103,6 @@ def predict(schema_madlib, model_table, test_table, id_col,
         """.format(**locals()), ["text", "bytea"])
     plpy.execute(predict_query, [model_arch, model_data])
 
-    if is_platform_pg():
-        clear_keras_session()
 
 def get_images_per_seg_for_non_minibatched_data(table_name):
     """
@@ -156,8 +147,7 @@ def internal_keras_predict(independent_var, 
model_architecture, model_data,
         device_name = get_device_name_and_set_cuda_env(gpus_per_host,
                                                        current_seg_id)
         if model_key not in SD:
-            if not is_platform_pg():
-                set_keras_session(gpus_per_host, segments_per_host)
+            set_keras_session(device_name, gpus_per_host, segments_per_host)
             model = model_from_json(model_architecture)
             model_shapes = get_model_shapes(model)
             set_model_weights(model, device_name, model_data, model_shapes)
@@ -199,12 +189,10 @@ def internal_keras_predict(independent_var, 
model_architecture, model_data,
         if SD[row_count_key] == total_images:
             SD.pop(model_key, None)
             SD.pop(row_count_key, None)
-            if not is_platform_pg():
-                clear_keras_session()
+            clear_keras_session()
         return result
     except Exception as ex:
         SD.pop(model_key, None)
         SD.pop(row_count_key, None)
-        if not is_platform_pg():
-            clear_keras_session()
+        clear_keras_session()
         plpy.error(ex)
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 37c1b73..dcdf8a0 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -72,14 +72,15 @@ def get_device_name_and_set_cuda_env(gpus_per_host, seg):
         set_cuda_env('-1')
     return device_name
 
-def set_keras_session(gpus_per_host, segments_per_host):
-    config = K.tf.ConfigProto()
-    if gpus_per_host > 0:
-        memory_fraction = get_gpu_memory_fraction(gpus_per_host, 
segments_per_host)
-        config.gpu_options.allow_growth = False
-        config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
-    session = K.tf.Session(config=config)
-    K.set_session(session)
+def set_keras_session(device_name, gpus_per_host, segments_per_host):
+    with K.tf.device(device_name):
+        config = K.tf.ConfigProto()
+        if gpus_per_host > 0:
+            memory_fraction = get_gpu_memory_fraction(gpus_per_host, 
segments_per_host)
+            config.gpu_options.allow_growth = False
+            config.gpu_options.per_process_gpu_memory_fraction = 
memory_fraction
+        session = K.tf.Session(config=config)
+        K.set_session(session)
 
 def get_gpu_memory_fraction(gpus_per_host, segments_per_host):
     """
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 8253eaa..bf9da51 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -100,7 +100,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         # weights should not be modified yet
         self.assertTrue((self.model_weights == weights).all())
         # set_session must be not be called in transition func for PG
-        self.assertEqual(0, self.subject.K.set_session.call_count)
+        self.assertEqual(1, self.subject.K.set_session.call_count)
         # Clear session and sess.close must not get called for the first buffer
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
@@ -203,7 +203,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.K.set_session.call_count)
         # Clear session and sess.close must get called for the last buffer in 
gpdb,
         #  but not in postgres
-        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(1, self.subject.clear_keras_session.call_count)
 
     def test_fit_transition_last_buffer_pass_gpdb(self):
         #TODO should we mock tensorflow's close_session and keras'
@@ -1110,7 +1110,7 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
 
         self.assertEqual(ending_image_count, image_count)
         # Call set_session once for gpdb (but not for postgres)
-        self.assertEqual(0 if is_platform_pg else 1, 
self.subject.K.set_session.call_count)
+        self.assertEqual(1, self.subject.K.set_session.call_count)
         # loss and accuracy should be unchanged
         self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
         self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
@@ -1186,8 +1186,7 @@ class MadlibKerasEvaluationTestCase(unittest.TestCase):
         self.assertAlmostEqual(self.accuracy * ending_image_count, 
agg_accuracy, 4)
         # Clear session and sess.close must get called for the last buffer in 
gpdb,
         #  but not in postgres
-        self.assertEqual(0 if is_platform_pg else 1,
-                         self.subject.clear_keras_session.call_count)
+        self.assertEqual(1, self.subject.clear_keras_session.call_count)
 
     def test_internal_keras_eval_transition_first_buffer_pg(self):
         self._test_internal_keras_eval_transition_first_buffer(True)

Reply via email to