This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 4dd5035145635847fcab30ac6a402878c76040d6
Author: Nandish Jayaram <[email protected]>
AuthorDate: Wed Mar 13 14:54:14 2019 -0700

    Deep Learning: Refactor code
    
    This commit refactors the madlib_keras code by pulling out all predict
    related code to a different file named madlib_keras_predict.py_in, and
    all keras and other required helper functions to a different file
    named madlib_keras_helper.py_in.
    
    madlib_keras.py_in still has all the code for fit(). We can change it
    based on community feedback, but we chose to have this instead of
    madlib_keras_train.py_in because we wanted our dev-check to do an
    end-to-end testing by calling both fit and predict in the same test sql
    file. All of this refactoring is to help better collaboration by various
    contributors.
    
    Closes #355
    
    Co-authored-by: Ekta Khanna <[email protected]>
    Co-authored-by: Nikhil Kak <[email protected]>
---
 .../{convex => deep_learning}/madlib_keras.py_in   | 519 ++++-----------------
 .../{convex => deep_learning}/madlib_keras.sql_in  |  22 +-
 .../deep_learning/madlib_keras_helper.py_in        | 194 ++++++++
 .../deep_learning/madlib_keras_predict.py_in       |  83 ++++
 .../deep_learning/madlib_keras_wrapper.py_in       |  91 ++++
 .../test/madlib_keras.sql_in                       |   0
 .../deep_learning/test/unit_tests/plpy_mock.py_in  |  43 ++
 .../test/unit_tests/test_madlib_keras.py_in        | 138 +-----
 .../test/unit_tests/test_madlib_keras_helper.py_in | 179 +++++++
 9 files changed, 720 insertions(+), 549 deletions(-)

diff --git a/src/ports/postgres/modules/convex/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
similarity index 52%
rename from src/ports/postgres/modules/convex/madlib_keras.py_in
rename to src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 8236f34..437211d 100644
--- a/src/ports/postgres/modules/convex/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -17,24 +17,32 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import plpy
-import keras
-import numpy as np
-import time
 import datetime
+import numpy as np
 import os
-from keras.models import *
+import plpy
+import sys
+import time
+
+from keras import backend as K
+from keras import utils as keras_utils
 from keras.layers import *
+from keras.models import *
 from keras.optimizers import *
-from keras import backend as K
 from keras.regularizers import *
-from utilities.validate_args import output_tbl_valid
-from utilities.validate_args import input_tbl_valid
-from utilities.utilities import add_postfix
-from utilities.utilities import madlib_version
+
+from madlib_keras_helper import KerasWeightsSerializer
+from madlib_keras_helper import get_data_as_np_array
+from madlib_keras_wrapper import *
+
 from utilities.model_arch_info import get_input_shape
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
 from utilities.utilities import _assert
+from utilities.utilities import add_postfix
 from utilities.utilities import is_var_valid
+from utilities.utilities import madlib_version
+
 
 def _validate_input_table(source_table, independent_varname,
                           dependent_varname):
@@ -105,11 +113,13 @@ def _validate_input_shapes(source_table, 
independent_varname, input_shape):
         if result[key_name] != input_shape[i]:
             # Construct the shape in independent varname to display meaningful
             # error msg.
-            input_shape_from_table = [result["n_{0}".format(i)] for i in range(
-                1, len(input_shape))]
+            input_shape_from_table = [result["n_{0}".format(i)]
+                for i in range(len(input_shape))]
             plpy.error("model_keras error: Input shape {0} in the model" \
                 " architecture does not match the input shape {1} of column" \
-                " {2} in table {3}.".format(input_shape, 
input_shape_from_table, independent_varname, source_table))
+                " {2} in table {3}.".format(
+                    input_shape, input_shape_from_table, independent_varname,
+                    source_table))
 
 def fit(schema_madlib, source_table, model, dependent_varname,
         independent_varname, model_arch_table, model_arch_id, compile_params,
@@ -128,10 +138,12 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
 
     # Get the serialized master model
     start_deserialization = time.time()
-    model_arch_query = "SELECT model_arch, model_weights FROM {0} WHERE id = 
{1}".format(model_arch_table, model_arch_id)
+    model_arch_query = "SELECT model_arch, model_weights FROM {0} WHERE "\
+        "id = {1}".format(model_arch_table, model_arch_id)
     query_result = plpy.execute(model_arch_query)
-    if not  query_result or len(query_result) == 0:
-        plpy.error("no model arch found in table {0} with id 
{1}".format(model_arch_table, model_arch_id))
+    if not  query_result:
+        plpy.error("no model arch found in table {0} with id {1}".format(
+            model_arch_table, model_arch_id))
     query_result = query_result[0]
     model_arch = query_result['model_arch']
     input_shape = get_input_shape(model_arch)
@@ -152,22 +164,18 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
 
     if model_weights_serialized:
         # If warm start from previously trained model, set weights
-        model_weights = deserialize_weights_orig(model_weights_serialized, 
model_shapes)
+        model_weights = KerasWeightsSerializer.deserialize_weights_orig(
+            model_weights_serialized, model_shapes)
         master_model.set_weights(model_weights)
 
-    end_deserialization = time.time()
-    # plpy.info("Model deserialization time: {} 
sec".format(end_deserialization - start_deserialization))
-
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
     validation_aggregate_accuracy = []; validation_aggregate_loss = []
     x_validation = None; y_validation = None
     if validation_set_provided:
-        x_validation,  y_validation = get_data_as_np_array(validation_table,
-                                                           dependent_varname,
-                                                           independent_varname,
-                                                           input_shape,
-                                                           num_classes)
+        x_validation,  y_validation = get_data_as_np_array(
+            validation_table, dependent_varname, independent_varname,
+            input_shape, num_classes)
 
     # Compute total buffers on each segment
     total_buffers_per_seg = plpy.execute(
@@ -175,8 +183,10 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
             FROM {0}
             GROUP BY gp_segment_id
         """.format(source_table))
-    seg_nums = [int(each_buffer["gp_segment_id"]) for each_buffer in 
total_buffers_per_seg]
-    total_buffers_per_seg = [int(each_buffer["total_buffers_per_seg"]) for 
each_buffer in total_buffers_per_seg]
+    seg_nums = [int(each_buffer["gp_segment_id"])
+        for each_buffer in total_buffers_per_seg]
+    total_buffers_per_seg = [int(each_buffer["total_buffers_per_seg"])
+        for each_buffer in total_buffers_per_seg]
 
     # Prepare the SQL for running distributed training via UDA
     compile_params_to_pass = "$madlib$" + compile_params + "$madlib$"
@@ -202,19 +212,20 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
                    use_gpu, source_table), ["bytea"])
 
     # Define the state for the model and loss/accuracy storage lists
-    model_state = serialize_weights(0, 0, 0, model_weights)
+    model_state = KerasWeightsSerializer.serialize_weights(
+        0, 0, 0, model_weights)
     aggregate_loss, aggregate_accuracy, aggregate_runtime = [], [], []
 
     plpy.info("Model architecture size: {}KB".format(len(model_arch)/1024))
-    plpy.info("Model state (serialized) size: 
{}MB".format(len(model_state)/1024/1024))
+    plpy.info("Model state (serialized) size: {}MB".format(
+        len(model_state)/1024/1024))
 
     # Run distributed training for specified number of iterations
     for i in range(num_iterations):
-        # prev_state = model_state
         start_iteration = time.time()
         try:
-            iteration_result = plpy.execute(run_training_iteration,
-                                            
[model_state])[0]['iteration_result']
+            iteration_result = plpy.execute(
+                run_training_iteration, [model_state])[0]['iteration_result']
         except plpy.SPIError as e:
             plpy.error('A plpy error occurred in the step function: {0}'.
                        format(str(e)))
@@ -222,11 +233,13 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
         plpy.info("Time for iteration {0}: {1} sec".
                   format(i + 1, end_iteration - start_iteration))
         aggregate_runtime.append(datetime.datetime.now())
-        avg_loss, avg_accuracy, model_state = 
deserialize_iteration_state(iteration_result)
-        plpy.info("Average loss after training iteration {0}: {1}".format(i + 
1, avg_loss))
-        plpy.info("Average accuracy after training iteration {0}: 
{1}".format(i + 1, avg_accuracy))
+        avg_loss, avg_accuracy, model_state = 
KerasWeightsSerializer.deserialize_iteration_state(iteration_result)
+        plpy.info("Average loss after training iteration {0}: {1}".format(
+            i + 1, avg_loss))
+        plpy.info("Average accuracy after training iteration {0}: {1}".format(
+            i + 1, avg_accuracy))
         if validation_set_provided:
-            _, _, _, updated_weights = deserialize_weights(model_state, 
model_shapes)
+            _, _, _, updated_weights = 
KerasWeightsSerializer.deserialize_weights(model_state, model_shapes)
             master_model.set_weights(updated_weights)
             compile_params_args = 
convert_string_of_args_to_dict(compile_params)
             master_model.compile(**compile_params_args)
@@ -254,8 +267,6 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
     if validation_aggregate_loss and len(validation_aggregate_loss) > 0:
         final_validation_loss = validation_aggregate_loss[-1]
     version = madlib_version(schema_madlib)
-    # accuracy = aggregate_accuracy[-1]
-    # loss = aggregate_loss[-1]
     create_output_summary_table = plpy.prepare("""
         CREATE TABLE {0}_summary AS
         SELECT
@@ -295,64 +306,32 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
                                  "DOUBLE PRECISION[]", "TIMESTAMP[]",
                                  "DOUBLE PRECISION", "DOUBLE PRECISION",
                                  "DOUBLE PRECISION[]", "DOUBLE PRECISION[]"])
-    plpy.execute(create_output_summary_table, [model_arch_table, model_arch_id,
-                                               "madlib_keras",
-                                               start_training_time, 
end_training_time,
-                                               source_table, validation_table,
-                                               model, dependent_varname,
-                                               independent_varname, name, 
description,
-                                               None, version, compile_params,
-                                               fit_params, num_iterations, 
num_classes,
-                                               aggregate_accuracy[-1],
-                                               aggregate_loss[-1],
-                                               aggregate_accuracy, 
aggregate_loss,
-                                               aggregate_runtime, 
final_validation_acc,
-                                               final_validation_loss,
-                                               validation_aggregate_accuracy,
-                                               validation_aggregate_loss])
+    plpy.execute(
+        create_output_summary_table,
+        [
+            model_arch_table, model_arch_id,
+            "madlib_keras",
+            start_training_time, end_training_time,
+            source_table, validation_table,
+            model, dependent_varname,
+            independent_varname, name, description,
+            sys.getsizeof(model), version, compile_params,
+            fit_params, num_iterations, num_classes,
+            aggregate_accuracy[-1],
+            aggregate_loss[-1],
+            aggregate_accuracy, aggregate_loss,
+            aggregate_runtime, final_validation_acc,
+            final_validation_loss,
+            validation_aggregate_accuracy,
+            validation_aggregate_loss
+        ]
+        )
 
     create_output_table = plpy.prepare("""
         CREATE TABLE {0} AS
         SELECT $1 as model_data""".format(model), ["bytea"])
     plpy.execute(create_output_table, [model_state])
 
-def get_device_name_for_keras(use_gpu, seg, gpus_per_host):
-    if use_gpu:
-        device_name = '/gpu:0'
-        os.environ["CUDA_VISIBLE_DEVICES"] = str(seg % gpus_per_host)
-    else: # cpu only
-        device_name = '/cpu:0'
-        os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
-
-    return device_name
-
-def set_keras_session(use_gpu):
-    config = K.tf.ConfigProto()
-    if use_gpu:
-        config.gpu_options.allow_growth = False
-        config.gpu_options.per_process_gpu_memory_fraction = 0.9
-    session = K.tf.Session(config=config)
-    K.set_session(session)
-
-def clear_keras_session():
-    sess = K.get_session()
-    K.clear_session()
-    sess.close()
-
-def compile_and_set_weights(segment_model, compile_params, device_name,
-                            previous_state):
-    model_shapes = []
-    with K.tf.device(device_name):
-        compile_params = convert_string_of_args_to_dict(compile_params)
-        segment_model.compile(**compile_params)
-        # prev_segment_model.compile(**compile_params)
-        for a in segment_model.get_weights():
-            model_shapes.append(a.shape)
-
-        agg_loss, agg_accuracy, _, model_weights = deserialize_weights(
-            previous_state, model_shapes)
-        segment_model.set_weights(model_weights)
-    # prev_model.set_weights(model_weights)
 
 def fit_transition(state, ind_var, dep_var, current_seg_id, num_classes,
                    all_seg_ids, total_buffers_per_seg, architecture,
@@ -388,32 +367,37 @@ def fit_transition(state, ind_var, dep_var, 
current_seg_id, num_classes,
         use_gpu, current_seg_id, gpus_per_host)
 
     # Set up system if this is the first buffer on segment'
+
     if not state:
         set_keras_session(use_gpu)
         segment_model = model_from_json(architecture)
+        SD['model_shapes'] = 
KerasWeightsSerializer.get_model_shapes(segment_model)
         compile_and_set_weights(segment_model, compile_params, device_name,
-                                previous_state)
+                                previous_state, SD['model_shapes'])
         SD['segment_model'] = segment_model
         SD['buffer_count'] = 0
+        agg_loss = 0
+        agg_accuracy = 0
     else:
         segment_model = SD['segment_model']
+        # Since we deserialize everytime, the transition function might be 
slightly
+        # slower
+        agg_loss, agg_accuracy, _, _ = 
KerasWeightsSerializer.deserialize_weights(
+            state, SD['model_shapes'])
 
-    agg_loss = 0
-    agg_accuracy = 0
     input_shape = get_input_shape(architecture)
 
     # Prepare the data
     x_train = np.array(ind_var, dtype='float64').reshape(
         len(ind_var), *input_shape)
     y_train = np.array(dep_var)
-    y_train = keras.utils.to_categorical(y_train, num_classes)
+    y_train = keras_utils.to_categorical(y_train, num_classes)
 
     # Fit segment model on data
     start_fit = time.time()
     with K.tf.device(device_name):
         fit_params = convert_string_of_args_to_dict(fit_params)
         history = segment_model.fit(x_train, y_train, **fit_params)
-        # loss, accuracy = prev_model.evaluate(x_train, y_train)
         loss = history.history['loss'][0]
         accuracy = history.history['acc'][0]
     end_fit = time.time()
@@ -421,8 +405,8 @@ def fit_transition(state, ind_var, dep_var, current_seg_id, 
num_classes,
     # Re-serialize the weights
     # Update buffer count, check if we are done
     SD['buffer_count'] += 1
-    updated_loss = agg_loss + loss
-    updated_accuracy = agg_accuracy + accuracy
+    agg_loss += loss
+    agg_accuracy += accuracy
 
     with K.tf.device(device_name):
         updated_weights = segment_model.get_weights()
@@ -432,14 +416,12 @@ def fit_transition(state, ind_var, dep_var, 
current_seg_id, num_classes,
         if total_buffers == 0:
             plpy.error('total buffers is 0')
 
-        updated_loss /= total_buffers
-        updated_accuracy /= total_buffers
-        # plpy.info('final buffer loss {}, accuracy {}, buffer count 
{}'.format(loss, accuracy, SD['buffer_count']))
+        agg_loss /= total_buffers
+        agg_accuracy /= total_buffers
         clear_keras_session()
 
-    new_model_state = serialize_weights(updated_loss, updated_accuracy,
-                                        SD['buffer_count'], updated_weights)
-    # new_model_state[2] += len(x_train)
+    new_model_state = KerasWeightsSerializer.serialize_weights(
+        agg_loss, agg_accuracy, SD['buffer_count'], updated_weights)
 
     del x_train
     del y_train
@@ -456,8 +438,8 @@ def fit_merge(state1, state2, **kwargs):
         return state1 or state2
 
     # Deserialize states
-    loss1, accuracy1, buffer_count1, weights1 = 
deserialize_weights_merge(state1)
-    loss2, accuracy2, buffer_count2, weights2 = 
deserialize_weights_merge(state2)
+    loss1, accuracy1, buffer_count1, weights1 = 
KerasWeightsSerializer.deserialize_weights_merge(state1)
+    loss2, accuracy2, buffer_count2, weights2 = 
KerasWeightsSerializer.deserialize_weights_merge(state2)
         # plpy.info('merge buffer loss1 {}, accuracy1 {}, buffer count1 
{}'.format(loss1, accuracy1, buffer_count1))
     # plpy.info('merge buffer loss2 {}, accuracy2 {}, buffer count2 
{}'.format(loss2, accuracy2, buffer_count2))
 
@@ -483,50 +465,16 @@ def fit_merge(state1, state2, **kwargs):
     # avg_weights = [(merge_weight1 * e1) + (merge_weight2 * e2) for e1, e2 in 
zip(weights1, weights2)]
 
     # Return the merged state
-    return serialize_weights_merge(avg_loss, avg_accuracy, total_buffers, 
avg_weights)
+    return KerasWeightsSerializer.serialize_weights_merge(
+        avg_loss, avg_accuracy, total_buffers, avg_weights)
 
 def fit_final(state, **kwargs):
     return state
 
-
-def get_data_as_np_array(table_name, y, x, input_shape, num_classes):
-    """
-
-    :param table_name: Table containing the batch of images per row
-    :param y: Column name for y
-    :param x: Column name for x
-    :param input_shape: input_shape of data in array format [L , W , C]
-    :param num_classes: num of distinct classes in y
-    :return:
-    """
-    val_data_qry = "SELECT {0}, {1} FROM {2}".format(y, x, table_name)
-    input_shape = map(int, input_shape)
-    val_data = plpy.execute(val_data_qry)
-    indep_len = len(val_data[0][x])
-    pixels_per_image = int(input_shape[0] * input_shape[1] * input_shape[2])
-    x_validation = np.ndarray((0,indep_len, pixels_per_image))
-    y_validation = np.ndarray((0,indep_len))
-    for i in range(len(val_data)):
-        x_test = np.asarray((val_data[i][x],))
-        x_test = x_test.reshape(1, indep_len, pixels_per_image)
-        y_test = np.asarray((val_data[i][y],))
-        y_test = y_test.reshape(1, indep_len)
-        x_validation=np.concatenate((x_validation, x_test))
-        y_validation=np.concatenate((y_validation, y_test))
-    num_test_examples = x_validation.shape[0]
-    x_validation = x_validation.reshape(indep_len * num_test_examples, 
*input_shape)
-    x_validation = x_validation.astype('float64')
-    y_validation = y_validation.reshape(indep_len * num_test_examples)
-
-    x_validation = x_validation.astype('float64')
-    #x_validation /= 255.0
-    y_validation = keras.utils.to_categorical(y_validation, num_classes)
-
-    return x_validation, y_validation
-
-def evaluate(schema_madlib, model_table, source_table, id_col, 
model_arch_table,
-             model_arch_id, dependent_varname, independent_varname, 
compile_params,
-             output_table, **kwargs):
+def evaluate(schema_madlib, model_table, source_table, id_col,
+             model_arch_table, model_arch_id, dependent_varname,
+             independent_varname, compile_params, output_table,
+             **kwargs):
     module_name = 'madlib_keras_evaluate'
     input_tbl_valid(source_table, module_name)
     input_tbl_valid(model_arch_table, module_name)
@@ -551,7 +499,8 @@ def evaluate(schema_madlib, model_table, source_table, 
id_col, model_arch_table,
     model_shapes = []
     for weight_arr in model.get_weights():
         model_shapes.append(weight_arr.shape)
-    _, updated_weights = deserialize_weights(model_data, model_shapes)
+    _, updated_weights = KerasWeightsSerializer.deserialize_weights(
+        model_data, model_shapes)
     model.set_weights(updated_weights)
     compile_params_args = convert_string_of_args_to_dict(compile_params)
     with K.tf.device(device_name):
@@ -574,8 +523,8 @@ def evaluate(schema_madlib, model_table, source_table, 
id_col, model_arch_table,
 
 
 def evaluate1(schema_madlib, model_table, test_table, id_col, model_arch_table,
-            model_arch_id, dependent_varname, independent_varname, 
compile_params, output_table,
-            **kwargs):
+            model_arch_id, dependent_varname, independent_varname,
+            compile_params, output_table, **kwargs):
     # module_name = 'madlib_keras_evaluate'
     # input_tbl_valid(test_table, module_name)
     # input_tbl_valid(model_arch_table, module_name)
@@ -590,7 +539,8 @@ def evaluate1(schema_madlib, model_table, test_table, 
id_col, model_arch_table,
                        "WHERE id = {1}".format(model_arch_table, model_arch_id)
     query_result = plpy.execute(model_arch_query)
     if not  query_result or len(query_result) == 0:
-        plpy.error("no model arch found in table {0} with id 
{1}".format(model_arch_table, model_arch_id))
+        plpy.error("no model arch found in table {0} with id {1}".format(
+            model_arch_table, model_arch_id))
     query_result = query_result[0]
     model_arch = query_result['model_arch']
     input_shape = get_input_shape(model_arch)
@@ -617,7 +567,8 @@ def internal_keras_evaluate(x_test, y_test, model_arch, 
model_data, input_shape,
     model_shapes = []
     for weight_arr in model.get_weights():
         model_shapes.append(weight_arr.shape)
-    _, model_weights = deserialize_weights(model_data, model_shapes)
+    _, model_weights = KerasWeightsSerializer.deserialize_weights(
+        model_data, model_shapes)
     model.set_weights(model_weights)
     with K.tf.device(device_name):
         model.compile(**compile_params)
@@ -625,274 +576,8 @@ def internal_keras_evaluate(x_test, y_test, model_arch, 
model_data, input_shape,
     x_test = np.array(x_test).reshape(len(x_test), input_shape[0], 
input_shape[1],
                                       input_shape[2])
     x_test = x_test.astype('float32')
-    y_test = keras.utils.to_categorical(np.array(y_test), 10)
+    y_test = keras_utils.to_categorical(np.array(y_test), 10)
     with K.tf.device(device_name):
         res = model.evaluate(x_test, y_test)
     plpy.info('evaluate result from internal_keras_evaluate is {}'.format(res))
     return res
-
-def print_md5_sum(obj, name):
-    import hashlib
-    m = hashlib.md5()
-    m.update(obj)
-    plpy.info('md5 sum for {} is {}'.format(name, m.hexdigest()))
-
-
-def predict(schema_madlib, model_table, test_table, id_col, model_arch_table,
-            model_arch_id, independent_varname, compile_params, output_table,
-            **kwargs):
-    module_name = 'madlib_keras_predict'
-    input_tbl_valid(test_table, module_name)
-    input_tbl_valid(model_arch_table, module_name)
-    output_tbl_valid(output_table, module_name)
-
-    # _validate_input_args(test_table, model_arch_table, output_table)
-
-    model_data_query = "SELECT model_data from {0}".format(model_table)
-    model_data = plpy.execute(model_data_query)[0]['model_data']
-
-    model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
-                       "WHERE id = {1}".format(model_arch_table, model_arch_id)
-    query_result = plpy.execute(model_arch_query)
-    if not  query_result or len(query_result) == 0:
-        plpy.error("no model arch found in table {0} with id 
{1}".format(model_arch_table, model_arch_id))
-    query_result = query_result[0]
-    model_arch = query_result['model_arch']
-    input_shape = get_input_shape(model_arch)
-    compile_params = "$madlib$" + compile_params + "$madlib$"
-    predict_query = plpy.prepare("""create table {output_table} as
-        select {id_col}, (madlib.internal_keras_predict({independent_varname},
-                                             $MAD${model_arch}$MAD$,
-                                             $1,ARRAY{input_shape},
-                                             {compile_params}))[1] as 
prediction
-        from {test_table}""".format(**locals()), ["bytea"])
-    plpy.execute(predict_query, [model_data])
-
-
-def internal_keras_predict(x_test, model_arch, model_data, input_shape, 
compile_params):
-    model = model_from_json(model_arch)
-    compile_params = convert_string_of_args_to_dict(compile_params)
-    device_name = '/cpu:0'
-    os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
-
-    with K.tf.device(device_name):
-        model.compile(**compile_params)
-
-    model_shapes = []
-    for weight_arr in model.get_weights():
-        model_shapes.append(weight_arr.shape)
-    _,_,_, model_weights = deserialize_weights(model_data, model_shapes)
-    model.set_weights(model_weights)
-    x_test = np.array(x_test).reshape(1, *input_shape)
-    x_test /= 255
-    res = model.predict_classes(x_test)
-    return res
-
-#### FUNCTIONS TO CHANGE ####
-
-# def construct_init_state(model_weights_serialized):
-#     return str([0, 0, 0] + model_weights_serialized) # format: [loss, 
accuracy, buffer_count, weights...]
-
-# def deserialize_iteration_state(iteration_result):
-#     iteration_result = eval(iteration_result)
-#     avg_loss, avg_accuracy, updated_model_state = iteration_result[0], 
iteration_result[1], iteration_result[3:]
-#     return avg_loss, avg_accuracy, str([0, 0, 0] + updated_model_state)
-
-# def deserialize_weights(model_weights_serialized, model_shapes):
-#     model_state = eval(model_weights_serialized)
-#     model_weights_serialized = model_state[3:]
-#     i, j, model_weights = 0, 0, []
-#     while j < len(model_shapes):
-#         next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
-#         weight_arr_portion = model_weights_serialized[i:next_pointer]
-#         
model_weights.append(np.array(weight_arr_portion).reshape(model_shapes[j]))
-#         i, j = next_pointer, j + 1
-#     return model_state[2], model_weights
-
-# def serialize_weights(loss, accuracy, buffer_count, model_weights):
-#     flattened_weights = [list(w.flatten()) for w in model_weights]
-#     model_weights_serialized = sum(flattened_weights, [])
-#     return str([loss, accuracy, buffer_count] + model_weights_serialized)
-
-# def deserialize_weights_merge(model_weights_serialized):
-#     model_state = eval(model_weights_serialized)
-#     return model_state[0], model_state[1], model_state[2], model_state[3:]
-
-# def serialize_weights_merge(avg_loss, avg_accuracy, total_buffers, 
avg_weights):
-#     return str([avg_loss, avg_accuracy, total_buffers] + avg_weights)
-
-# def reset_buffers_final(state):
-#     state = eval(state)
-#     state[2] = 0
-#     return str(state)
-
-# SPLITTER SERIALIZATION: WORKS
-# def deserialize_iteration_state(iteration_result):
-#     split_state = filter(None, iteration_result.split('splitter'))
-#     new_model_string = "0splitter0splitter0splitter"
-#     for a in split_state[3:]:
-#         new_model_string += a
-#         new_model_string += 'splitter'
-#     avg_loss, avg_accuracy = split_state[0], split_state[1]
-#     return float(avg_loss), float(avg_accuracy), new_model_string
-
-# def deserialize_weights(model_state, model_shapes):
-#     split_state = filter(None, model_state.split('splitter'))
-#     j, model_weights = 0, []
-#     for a in split_state[3:]:
-#         arr = np.fromstring(a, dtype=np.float32)
-#         model_weights.append(arr.reshape(model_shapes[j]))
-#         j += 1
-#     '''For the buffer count, we first cast to float and then int because 
Python
-#     cannot cast directly from string like '3.0' to int 3'''
-#     return int(float(split_state[2])), model_weights
-
-# def serialize_weights(loss, accuracy, buffer_count, model_weights):
-#     new_model_string = str(loss) + "splitter" + str(accuracy) + "splitter" + 
str(buffer_count) + "splitter"
-#     for a in model_weights:
-#         a = np.float32(a)
-#         new_model_string += a.tostring()
-#         new_model_string += 'splitter'
-#     return new_model_string
-
-# def deserialize_weights_merge(state):
-#     split_state = filter(None, state.split('splitter'))
-#     model_weights = []
-#     for a in split_state[3:]:
-#         model_weights.append(np.fromstring(a, dtype=np.float32))
-#     return float(split_state[0]), float(split_state[1]), 
int(float(split_state[2])), model_weights
-# END SPLITTER SERIALIZATION
-
-"""
-Parameters:
-    iteration_result: the output of the step function
-Returns:
-    loss: the averaged loss from that iteration of training
-    accuracy: the averaged accuracy from that iteration of training
-    new_model_state: the stringified (serialized) state to pass in to next 
iteration
-        of step function training, represents the averaged weights from the 
last
-        iteration of training; zeros out loss, accuracy, buffer_count in this 
state
-        because the new iteration must start with fresh values
-"""
-def deserialize_iteration_state(iteration_result):
-    if not iteration_result:
-        return None
-    state = np.fromstring(iteration_result, dtype=np.float32)
-    new_model_string = np.array(state)
-    new_model_string[0], new_model_string[1], new_model_string[2] = 0, 0, 0
-    new_model_string = np.float32(new_model_string)
-    return float(state[0]), float(state[1]), new_model_string.tostring()
-
-"""
-Parameters:
-    model_state: a stringified (serialized) state containing loss, accuracy, 
buffer_count,
-        and model_weights, passed from postgres
-    model_shapes: a list of tuples containing the shapes of each element in 
keras.get_weights()
-Returns:
-    buffer_count: the buffer count from state
-    model_weights: a list of numpy arrays that can be inputted into 
keras.set_weights()
-"""
-def deserialize_weights(model_state, model_shapes):
-    if not model_state or not model_shapes:
-        return None
-    state = np.fromstring(model_state, dtype=np.float32)
-    model_weights_serialized = state[3:]
-    i, j, model_weights = 0, 0, []
-    while j < len(model_shapes):
-        next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
-        weight_arr_portion = model_weights_serialized[i:next_pointer]
-        model_weights.append(weight_arr_portion.reshape(model_shapes[j]))
-        i, j = next_pointer, j + 1
-    return int(float(state[0])), int(float(state[1])), int(float(state[2])), 
model_weights
-
-"""
-Parameters:
-    loss, accuracy, buffer_count: float values
-    model_weights: a list of numpy arrays, what you get from 
keras.get_weights()
-Returns:
-    A stringified (serialized) state containing all these values, to be passed 
to postgres
-"""
-def serialize_weights(loss, accuracy, buffer_count, model_weights):
-    if model_weights is None:
-        return None
-    flattened_weights = [w.flatten() for w in model_weights]
-    model_weights_serialized = np.concatenate(flattened_weights)
-    new_model_string = np.array([loss, accuracy, buffer_count])
-    new_model_string = np.concatenate((new_model_string, 
model_weights_serialized))
-    new_model_string = np.float32(new_model_string)
-    return new_model_string.tostring()
-
-"""
-Parameters:
-    state: the stringified (serialized) state containing loss, accuracy, 
buffer_count, and
-        model_weights, passed from postgres to merge function
-Returns:
-    loss: the averaged loss from that iteration of training
-    accuracy: the averaged accuracy from that iteration of training
-    buffer_count: total buffer counts processed
-    model_weights: a single flattened numpy array containing all of the 
weights, flattened
-        because all we have to do is average them (so don't have to reshape)
-"""
-def deserialize_weights_merge(state):
-    if not state:
-        return None
-    state = np.fromstring(state, dtype=np.float32)
-    return float(state[0]), float(state[1]), int(float(state[2])), state[3:]
-
-"""
-Parameters:
-    loss, accuracy, buffer_count: float values
-    model_weights: a single flattened numpy array containing all of the 
weights, averaged
-        in merge function over the 2 states
-Returns:
-    A stringified (serialized) state containing all these values, to be passed 
to postgres
-"""
-def serialize_weights_merge(loss, accuracy, buffer_count, model_weights):
-    if model_weights is None:
-        return None
-    new_model_string = np.array([loss, accuracy, buffer_count])
-    new_model_string = np.concatenate((new_model_string, model_weights))
-    new_model_string = np.float32(new_model_string)
-    return new_model_string.tostring()
-
-#### OTHER FUNCTIONS ####
-
-"""
-Original deserialization for warm-start, used only to parse model received
-from query at the top of this file
-"""
-def deserialize_weights_orig(model_weights_serialized, model_shapes):
-    i, j, model_weights = 0, 0, []
-    while j < len(model_shapes):
-        next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
-        weight_arr_portion = model_weights_serialized[i:next_pointer]
-        
model_weights.append(np.array(weight_arr_portion).reshape(model_shapes[j]))
-        i, j = next_pointer, j + 1
-    return model_weights
-
-"""
-Used to convert compile_params and fit_params to actual argument dictionaries
-"""
-def convert_string_of_args_to_dict(str_of_args):
-    """Uses parenthases matching algorithm to intelligently convert
-    a string with valid python code into an argument dictionary"""
-    stack = []
-    dual = {
-        '(' : ')',
-        '[' : ']',
-        '{' : '}',
-    }
-    result_str = ""
-    for char in str_of_args:
-        if char in dual.keys():
-            stack.append(char)
-            result_str += char
-        elif char in dual.values() and stack:
-            if dual[stack[-1]] == char:
-                stack.pop(-1)
-            result_str += char
-        elif not stack and char == "=":
-            result_str += ":"
-        else:
-            result_str += char
-    return eval('{' + result_str + '}')
diff --git a/src/ports/postgres/modules/convex/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
similarity index 93%
rename from src/ports/postgres/modules/convex/madlib_keras.sql_in
rename to src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index f65014d..6b2daf7 100644
--- a/src/ports/postgres/modules/convex/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -44,7 +44,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
     name                    VARCHAR,
     description             VARCHAR
 ) RETURNS VOID AS $$
-    PythonFunctionBodyOnly(`convex', `madlib_keras')
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
     with AOControl(False):
         madlib_keras.fit(**globals())
 $$ LANGUAGE plpythonu VOLATILE
@@ -116,7 +116,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     use_gpu                    BOOLEAN,
     previous_state             BYTEA
 ) RETURNS BYTEA AS $$
-PythonFunctionBodyOnlyNoSchema(`convex', `madlib_keras')
+PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(**globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
@@ -125,7 +125,7 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_merge(
     state1          BYTEA,
     state2          BYTEA
 ) RETURNS BYTEA AS $$
-PythonFunctionBodyOnlyNoSchema(`convex', `madlib_keras')
+PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_merge(**globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
@@ -133,7 +133,7 @@ m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_final(
     state BYTEA
 ) RETURNS BYTEA AS $$
-PythonFunctionBodyOnlyNoSchema(`convex', `madlib_keras')
+PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_final(**globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
@@ -180,9 +180,9 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_predict(
     compile_params          VARCHAR,
     output_table            VARCHAR
 ) RETURNS VOID AS $$
-    PythonFunctionBodyOnly(`convex', `madlib_keras')
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
-        madlib_keras.predict(schema_madlib,
+        madlib_keras_predict.predict(schema_madlib,
                model_table,
                test_table,
                id_col,
@@ -201,9 +201,9 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_predict(
    input_shape integer[],
    compile_params TEXT
 ) RETURNS DOUBLE PRECISION[] AS $$
-    PythonFunctionBodyOnly(`convex', `madlib_keras')
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     with AOControl(False):
-        return madlib_keras.internal_keras_predict(
+        return madlib_keras_predict.internal_keras_predict(
                independent_var,
                model_architecture,
                model_data,
@@ -224,7 +224,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_evaluate(
     compile_params          VARCHAR,
     output_table            VARCHAR
 ) RETURNS VOID AS $$
-    PythonFunctionBodyOnly(`convex', `madlib_keras')
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
     with AOControl(False):
         madlib_keras.evaluate(schema_madlib,
                model_table,
@@ -250,7 +250,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.madlib_keras_evaluate1(
     compile_params          VARCHAR,
     output_table            VARCHAR
 ) RETURNS VOID AS $$
-    PythonFunctionBodyOnly(`convex', `madlib_keras')
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
     with AOControl(False):
         madlib_keras.evaluate1(schema_madlib,
                model_table,
@@ -273,7 +273,7 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.internal_keras_evaluate(
    input_shape integer[],
    compile_params TEXT
 ) RETURNS DOUBLE PRECISION[] AS $$
-    PythonFunctionBodyOnly(`convex', `madlib_keras')
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
     with AOControl(False):
         return madlib_keras.internal_keras_evaluate(
                dependent_var,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
new file mode 100644
index 0000000..a198b7a
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -0,0 +1,194 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import os
+import plpy
+
+# Import needed for get_data_as_np_array()
+from keras import utils as keras_utils
+
+#######################################################################
+########### Helper functions to serialize and deserialize weights #####
+#######################################################################
+class KerasWeightsSerializer:
+
+    @staticmethod
+    def get_model_shapes(model):
+        model_shapes = []
+        for a in model.get_weights():
+            model_shapes.append(a.shape)
+        return model_shapes
+
+    @staticmethod
+    def deserialize_weights(model_state, model_shapes):
+        """
+        Parameters:
+            model_state: a stringified (serialized) state containing loss,
+            accuracy, buffer_count, and model_weights, passed from postgres
+            model_shapes: a list of tuples containing the shapes of each 
element
+            in keras.get_weights()
+        Returns:
+            buffer_count: the buffer count from state
+            model_weights: a list of numpy arrays that can be inputted into 
keras.set_weights()
+        """
+        if not model_state or not model_shapes:
+            return None
+        state = np.fromstring(model_state, dtype=np.float32)
+        model_weights_serialized = state[3:]
+        i, j, model_weights = 0, 0, []
+        while j < len(model_shapes):
+            next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
+            weight_arr_portion = model_weights_serialized[i:next_pointer]
+            model_weights.append(weight_arr_portion.reshape(model_shapes[j]))
+            i, j = next_pointer, j + 1
+        return int(float(state[0])), int(float(state[1])), 
int(float(state[2])), model_weights
+
+    @staticmethod
+    def serialize_weights(loss, accuracy, buffer_count, model_weights):
+        """
+        Parameters:
+            loss, accuracy, buffer_count: float values
+            model_weights: a list of numpy arrays, what you get from
+            keras.get_weights()
+        Returns:
+            A stringified (serialized) state containing all these values, to be
+            passed to postgres
+        """
+        if model_weights is None:
+            return None
+        flattened_weights = [w.flatten() for w in model_weights]
+        model_weights_serialized = np.concatenate(flattened_weights)
+        new_model_string = np.array([loss, accuracy, buffer_count])
+        new_model_string = np.concatenate((new_model_string, 
model_weights_serialized))
+        new_model_string = np.float32(new_model_string)
+        return new_model_string.tostring()
+
+    @staticmethod
+    def deserialize_iteration_state(iteration_result):
+        """
+        Parameters:
+            iteration_result: the output of the step function
+        Returns:
+            loss: the averaged loss from that iteration of training
+            accuracy: the averaged accuracy from that iteration of training
+            new_model_state: the stringified (serialized) state to pass in to 
next
+            iteration of step function training, represents the averaged 
weights
+            from the last iteration of training; zeros out loss, accuracy,
+            buffer_count in this state because the new iteration must start 
with
+            fresh values
+        """
+        if not iteration_result:
+            return None
+        state = np.fromstring(iteration_result, dtype=np.float32)
+        new_model_string = np.array(state)
+        new_model_string[0], new_model_string[1], new_model_string[2] = 0, 0, 0
+        new_model_string = np.float32(new_model_string)
+        return float(state[0]), float(state[1]), new_model_string.tostring()
+
+    @staticmethod
+    def deserialize_weights_merge(state):
+        """
+        Parameters:
+            state: the stringified (serialized) state containing loss, 
accuracy, buffer_count, and
+                model_weights, passed from postgres to merge function
+        Returns:
+            loss: the averaged loss from that iteration of training
+            accuracy: the averaged accuracy from that iteration of training
+            buffer_count: total buffer counts processed
+            model_weights: a single flattened numpy array containing all of the
+            weights, flattened because all we have to do is average them (so 
don't
+            have to reshape)
+        """
+        if not state:
+            return None
+        state = np.fromstring(state, dtype=np.float32)
+        return float(state[0]), float(state[1]), int(float(state[2])), 
state[3:]
+
+    @staticmethod
+    def serialize_weights_merge(loss, accuracy, buffer_count, model_weights):
+        """
+        Parameters:
+            loss, accuracy, buffer_count: float values
+            model_weights: a single flattened numpy array containing all of the
+            weights, averaged in merge function over the 2 states
+        Returns:
+            A stringified (serialized) state containing all these values, to be
+            passed to postgres
+        """
+        if model_weights is None:
+            return None
+        new_model_string = np.array([loss, accuracy, buffer_count])
+        new_model_string = np.concatenate((new_model_string, model_weights))
+        new_model_string = np.float32(new_model_string)
+        return new_model_string.tostring()
+
+    @staticmethod
+    def deserialize_weights_orig(model_weights_serialized, model_shapes):
+        """
+        Original deserialization for warm-start, used only to parse model 
received
+        from query at the top of this file
+        """
+        i, j, model_weights = 0, 0, []
+        while j < len(model_shapes):
+            next_pointer = i + reduce(lambda x, y: x * y, model_shapes[j])
+            weight_arr_portion = model_weights_serialized[i:next_pointer]
+            
model_weights.append(np.array(weight_arr_portion).reshape(model_shapes[j]))
+            i, j = next_pointer, j + 1
+        return model_weights
+
+
+#######################################################################
+########### General Helper functions  #######
+#######################################################################
+
+def get_data_as_np_array(table_name, y, x, input_shape, num_classes):
+    """
+
+    :param table_name: Table containing the batch of images per row
+    :param y: Column name for y
+    :param x: Column name for x
+    :param input_shape: input_shape of data in array format [L , W , C]
+    :param num_classes: num of distinct classes in y
+    :return:
+    """
+    val_data_qry = "SELECT {0}, {1} FROM {2}".format(y, x, table_name)
+    input_shape = map(int, input_shape)
+    val_data = plpy.execute(val_data_qry)
+    indep_len = len(val_data[0][x])
+    pixels_per_image = int(input_shape[0] * input_shape[1] * input_shape[2])
+    x_validation = np.ndarray((0,indep_len, pixels_per_image))
+    y_validation = np.ndarray((0,indep_len))
+    for i in range(len(val_data)):
+        x_test = np.asarray((val_data[i][x],))
+        x_test = x_test.reshape(1, indep_len, pixels_per_image)
+        y_test = np.asarray((val_data[i][y],))
+        y_test = y_test.reshape(1, indep_len)
+        x_validation=np.concatenate((x_validation, x_test))
+        y_validation=np.concatenate((y_validation, y_test))
+    num_test_examples = x_validation.shape[0]
+    x_validation = x_validation.reshape(indep_len * num_test_examples, 
*input_shape)
+    x_validation = x_validation.astype('float64')
+    y_validation = y_validation.reshape(indep_len * num_test_examples)
+
+    x_validation = x_validation.astype('float64')
+    #x_validation /= 255.0
+    y_validation = keras_utils.to_categorical(y_validation, num_classes)
+
+    return x_validation, y_validation
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
new file mode 100644
index 0000000..e2ed883
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -0,0 +1,83 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+import os
+
+import keras
+from keras import backend as K
+from keras.layers import *
+from keras.models import *
+from keras.optimizers import *
+import numpy as np
+
+from utilities.model_arch_info import get_input_shape
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
+
+from madlib_keras_helper import convert_string_of_args_to_dict
+
+def predict(schema_madlib, model_table, test_table, id_col, model_arch_table,
+            model_arch_id, independent_varname, compile_params, output_table,
+            **kwargs):
+    module_name = 'madlib_keras_predict'
+    input_tbl_valid(test_table, module_name)
+    input_tbl_valid(model_arch_table, module_name)
+    output_tbl_valid(output_table, module_name)
+
+    # _validate_input_args(test_table, model_arch_table, output_table)
+
+    model_data_query = "SELECT model_data from {0}".format(model_table)
+    model_data = plpy.execute(model_data_query)[0]['model_data']
+
+    model_arch_query = "SELECT model_arch, model_weights FROM {0} " \
+                       "WHERE id = {1}".format(model_arch_table, model_arch_id)
+    query_result = plpy.execute(model_arch_query)
+    if not  query_result or len(query_result) == 0:
+        plpy.error("no model arch found in table {0} with id 
{1}".format(model_arch_table, model_arch_id))
+    query_result = query_result[0]
+    model_arch = query_result['model_arch']
+    input_shape = get_input_shape(model_arch)
+    compile_params = "$madlib$" + compile_params + "$madlib$"
+    predict_query = plpy.prepare("""create table {output_table} as
+        select {id_col}, (madlib.internal_keras_predict({independent_varname},
+                                             $MAD${model_arch}$MAD$,
+                                             $1,ARRAY{input_shape},
+                                             {compile_params}))[1] as 
prediction
+        from {test_table}""".format(**locals()), ["bytea"])
+    plpy.execute(predict_query, [model_data])
+
+def internal_keras_predict(x_test, model_arch, model_data, input_shape, 
compile_params):
+    model = model_from_json(model_arch)
+    compile_params = convert_string_of_args_to_dict(compile_params)
+    device_name = '/cpu:0'
+    os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
+
+    with K.tf.device(device_name):
+        model.compile(**compile_params)
+
+    model_shapes = []
+    for weight_arr in model.get_weights():
+        model_shapes.append(weight_arr.shape)
+    _,_,_, model_weights = deserialize_weights(model_data, model_shapes)
+    model.set_weights(model_weights)
+    x_test = np.array(x_test).reshape(1, *input_shape)
+    x_test /= 255
+    res = model.predict_classes(x_test)
+    return res
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
new file mode 100644
index 0000000..63d7e86
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -0,0 +1,91 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import os
+import plpy
+
+from keras import backend as K
+from keras import utils as keras_utils
+from keras.optimizers import *
+
+from madlib_keras_helper import KerasWeightsSerializer
+
+#######################################################################
+########### Keras specific functions #####
+#######################################################################
+def get_device_name_for_keras(use_gpu, seg, gpus_per_host):
+    if use_gpu:
+        device_name = '/gpu:0'
+        os.environ["CUDA_VISIBLE_DEVICES"] = str(seg % gpus_per_host)
+    else: # cpu only
+        device_name = '/cpu:0'
+        os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
+
+    return device_name
+
+def set_keras_session(use_gpu):
+    config = K.tf.ConfigProto()
+    if use_gpu:
+        config.gpu_options.allow_growth = False
+        config.gpu_options.per_process_gpu_memory_fraction = 0.9
+    session = K.tf.Session(config=config)
+    K.set_session(session)
+
+def clear_keras_session():
+    sess = K.get_session()
+    K.clear_session()
+    sess.close()
+
+def compile_and_set_weights(segment_model, compile_params, device_name,
+                            previous_state, model_shapes):
+    with K.tf.device(device_name):
+        compile_params = convert_string_of_args_to_dict(compile_params)
+        segment_model.compile(**compile_params)
+        _, _, _, model_weights = KerasWeightsSerializer.deserialize_weights(
+            previous_state, model_shapes)
+        segment_model.set_weights(model_weights)
+
+
+"""
+Used to convert compile_params and fit_params to actual argument dictionaries
+"""
+def convert_string_of_args_to_dict(str_of_args):
+    """Uses parenthases matching algorithm to intelligently convert
+    a string with valid python code into an argument dictionary"""
+    stack = []
+    dual = {
+        '(' : ')',
+        '[' : ']',
+        '{' : '}',
+    }
+    result_str = ""
+    for char in str_of_args:
+        if char in dual.keys():
+            stack.append(char)
+            result_str += char
+        elif char in dual.values() and stack:
+            if dual[stack[-1]] == char:
+                stack.pop(-1)
+            result_str += char
+        elif not stack and char == "=":
+            result_str += ":"
+        else:
+            result_str += char
+    return eval('{' + result_str + '}')
diff --git a/src/ports/postgres/modules/convex/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
similarity index 100%
rename from src/ports/postgres/modules/convex/test/madlib_keras.sql_in
rename to src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/plpy_mock.py_in 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/plpy_mock.py_in
new file mode 100644
index 0000000..dd18649
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/plpy_mock.py_in
@@ -0,0 +1,43 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+m4_changequote(`<!', `!>')
+def __init__(self):
+    pass
+
+def error(message):
+    raise PLPYException(message)
+
+def execute(query):
+    pass
+
+def warning(query):
+    pass
+
+def info(query):
+    print query
+
+
+class PLPYException(Exception):
+    def __init__(self, message):
+        super(PLPYException, self).__init__()
+        self.message = message
+
+    def __str__(self):
+        return repr(self.message)
diff --git 
a/src/ports/postgres/modules/convex/test/unit_tests/test_madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
similarity index 57%
rename from 
src/ports/postgres/modules/convex/test/unit_tests/test_madlib_keras.py_in
rename to 
src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 4a2691d..fe6a1b8 100644
--- a/src/ports/postgres/modules/convex/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -33,7 +33,6 @@ import plpy_mock as plpy
 from keras.models import *
 from keras.layers import *
 
-
 m4_changequote(`<!', `!>')
 
 class MadlibKerasFitTestCase(unittest.TestCase):
@@ -59,6 +58,10 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.compile_params = "'optimizer'=SGD(lr=0.01, decay=1e-6, 
nesterov=True), 'loss'='categorical_crossentropy', 'metrics'=['accuracy']"
         self.fit_params = "'batch_size'=1, 'epochs'=1"
         self.model_weights = [3,4,5,6]
+        self.model_shapes = []
+        for a in self.model.get_weights():
+            self.model_shapes.append(a.shape)
+
         self.loss = 1.3
         self.accuracy = 0.34
         self.all_seg_ids = [0,1,2]
@@ -90,6 +93,7 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
         self.assertEqual(1, k['SD']['buffer_count'])
         self.assertTrue(k['SD']['segment_model'])
+        self.assertTrue(k['SD']['model_shapes'])
 
     def test_fit_transition_last_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
@@ -104,8 +108,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         state = np.array(state, dtype=np.float32)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', state.tostring())
-        k = {'SD': {'buffer_count': buffer_count}}
+                                             '/cpu:0', state.tostring(), 
self.model_shapes)
+        k = {'SD': {'buffer_count': buffer_count,
+                   'model_shapes': self.model_shapes}}
         k['SD']['segment_model'] = self.model
         new_model_state = self.subject.fit_transition(
             state.tostring(), [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
@@ -132,8 +137,9 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         state = np.array(state, dtype=np.float32)
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', state.tostring())
-        k = {'SD': {'buffer_count': buffer_count}}
+                                             '/cpu:0', state.tostring(), 
self.model_shapes)
+        k = {'SD': {'buffer_count': buffer_count,
+                   'model_shapes': self.model_shapes}}
         k['SD']['segment_model'] = self.model
         new_model_state = self.subject.fit_transition(
             state.tostring(), [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
@@ -147,124 +153,14 @@ class MadlibKerasFitTestCase(unittest.TestCase):
         self.assertEqual(0, self.subject.clear_keras_session.call_count)
         self.assertEqual(2, k['SD']['buffer_count'])
 
-    def test_deserialize_weights_merge_null_state_returns_none(self):
-        self.assertEqual(None, self.subject.deserialize_weights_merge(None))
-
-    def test_deserialize_weights_merge_returns_not_none(self):
-        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
-        res = 
self.subject.deserialize_weights_merge(dummy_model_state.tostring())
-        self.assertEqual(0, res[0])
-        self.assertEqual(1, res[1])
-        self.assertEqual(2, res[2])
-        self.assertEqual([3,4,5,6], res[3].tolist())
-
-    def test_deserialize_weights_null_input_returns_none(self):
-        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
-        self.assertEqual(None, 
self.subject.deserialize_weights(dummy_model_state.tostring(), None))
-        self.assertEqual(None, self.subject.deserialize_weights(None, [1,2,3]))
-        self.assertEqual(None, self.subject.deserialize_weights(None, None))
-
-    def test_deserialize_weights_valid_input_returns_not_none(self):
-        dummy_model_state = np.array([0,1,2,3,4,5], dtype=np.float32)
-        dummy_model_shape = [(2, 1, 1, 1), (1,)]
-        res = self.subject.deserialize_weights(dummy_model_state.tostring(), 
dummy_model_shape)
-        self.assertEqual(0, res[0])
-        self.assertEqual(1, res[1])
-        self.assertEqual(2, res[2])
-        self.assertEqual([[[[3.0]]], [[[4.0]]]], res[3][0].tolist())
-        self.assertEqual([5], res[3][1].tolist())
-
-    def test_deserialize_weights_invalid_input_fails(self):
-        # pass an invalid state with missing model weights
-        invalid_model_state = np.array([0,1,2], dtype=np.float32)
-        dummy_model_shape = [(2, 1, 1, 1), (1,)]
-
-        # we except keras failure(ValueError) because we cannot reshape model 
weights of size 0 into shape (2,2,3,1)
-        with self.assertRaises(ValueError):
-            self.subject.deserialize_weights(invalid_model_state.tostring(), 
dummy_model_shape)
-
-        invalid_model_state = np.array([0,1,2,3,4], dtype=np.float32)
-        dummy_model_shape = [(2, 2, 3, 1), (1,)]
-        # we except keras failure(ValueError) because we cannot reshape model 
weights of size 2 into shape (2,2,3,1)
-        with self.assertRaises(ValueError):
-            self.subject.deserialize_weights(invalid_model_state.tostring(), 
dummy_model_shape)
-
-    def test_deserialize_iteration_state_none_input_returns_none(self):
-        self.assertEqual(None, self.subject.deserialize_iteration_state(None))
-
-    def test_deserialize_iteration_state_returns_valid_output(self):
-        dummy_iteration_state = np.array([0,1,2,3,4,5], dtype=np.float32)
-        res = self.subject.deserialize_iteration_state(
-            dummy_iteration_state.tostring())
-        self.assertEqual(0, res[0])
-        self.assertEqual(1, res[1])
-        self.assertEqual(res[2],
-                         np.array([0,0,0,3,4,5], dtype=np.float32).tostring())
-
-    def test_serialize_weights_none_weights_returns_none(self):
-        res = self.subject.serialize_weights(0,1,2,None)
-        self.assertEqual(None , res)
-
-    def test_serialize_weights_valid_output(self):
-        res = self.subject.serialize_weights(0,1,2,[np.array([1,3]),
-                                                    np.array([4,5])])
-        self.assertEqual(np.array([0,1,2,1,3,4,5], 
dtype=np.float32).tostring(),
-                         res)
-
-    def test_serialize_weights_merge_none_weights_returns_none(self):
-        res = self.subject.serialize_weights_merge(0,1,2,None)
-        self.assertEqual(None , res)
-
-    def test_serialize_weights_merge_valid_output(self):
-        res = self.subject.serialize_weights_merge(0,1,2,np.array([1,3,4,5]))
-        self.assertEqual(np.array([0,1,2,1,3,4,5], 
dtype=np.float32).tostring(),
-                         res)
-
-    def test_get_data_as_np_array_one_image_per_row(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
-                                               {'x': [[5,6]], 'y': 1}]
-        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
-                                                         3)
-        self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
-                         x_res.tolist())
-        self.assertEqual(np.array([[1, 0, 0], [0, 1, 0]]).tolist(),
-                         y_res.tolist())
-
-    def test_get_data_as_np_array_multiple_images_per_row(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2], [3,4]], 'y': 
[0,2]},
-                                               {'x': [[5,6], [7,8]], 'y': 
[1,0]}]
-        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
-                                                         3)
-        self.assertEqual(np.array([[[[1,2]]], [[[3,4]]],
-                                   [[[5,6]]], [[[7,8]]]]).tolist(),
-                         x_res.tolist())
-        self.assertEqual(np.array([[1,0,0], [0,0,1] ,
-                                   [0,1,0], [1,0,0]]).tolist(),
-                         y_res.tolist())
-
-    def test_get_data_as_np_array_float_input_shape(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
-                                               {'x': [[5,6]], 'y': 1}]
-        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x',
-                                                         [1.5,1.9,2.3], 3)
-        self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
-                         x_res.tolist())
-        self.assertEqual(np.array([[1, 0, 0], [0, 1, 0]]).tolist(),
-                         y_res.tolist())
-
-    def test_get_data_as_np_array_invalid_input_shape(self):
-        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
-                                               {'x': [[5,6]], 'y': 1}]
-        # we except keras failure(ValueError) because we cannot reshape
-        # the input which is of size 2 to input shape of 1,1,3
-        with self.assertRaises(ValueError):
-            self.subject.get_data_as_np_array('foo','y','x', [1,1,3], 3)
-
     def test_get_device_name_for_keras(self):
         import os
-        self.assertEqual('/gpu:0', 
self.subject.get_device_name_for_keras(True, 1, 3))
-        self.assertEqual('/cpu:0', 
self.subject.get_device_name_for_keras(False, 1, 3))
-        self.assertEqual('-1', os.environ["CUDA_VISIBLE_DEVICES"] )
+        self.assertEqual('/gpu:0', self.subject.get_device_name_for_keras(
+            True, 1, 3))
+        self.assertEqual('1', os.environ["CUDA_VISIBLE_DEVICES"])
+        self.assertEqual('/cpu:0', self.subject.get_device_name_for_keras(
+            False, 1, 3))
+        self.assertEqual('-1', os.environ["CUDA_VISIBLE_DEVICES"])
 
     def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
         k = {}
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
new file mode 100644
index 0000000..b9feef8
--- /dev/null
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_helper.py_in
@@ -0,0 +1,179 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import numpy as np
+from os import path
+# Add convex module to the pythonpath.
+sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
+sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
+
+from keras import utils as keras_utils
+
+import unittest
+from mock import *
+import plpy_mock as plpy
+
+m4_changequote(`<!', `!>')
+
+class MadlibKerasHelperTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import madlib_keras_helper
+        self.subject = madlib_keras_helper
+
+        self.model = Sequential()
+        self.model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
+                         input_shape=(1,1,1,), padding='same'))
+        self.model.add(Flatten())
+
+        self.compile_params = "'optimizer'=SGD(lr=0.01, decay=1e-6, 
nesterov=True), 'loss'='categorical_crossentropy', 'metrics'=['accuracy']"
+        self.fit_params = "'batch_size'=1, 'epochs'=1"
+        self.model_weights = [3,4,5,6]
+        self.loss = 1.3
+        self.accuracy = 0.34
+        self.all_seg_ids = [0,1,2]
+        self.total_buffers_per_seg = [3,3,3]
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_deserialize_weights_merge_null_state_returns_none(self):
+        self.assertEqual(None, 
self.subject.KerasWeightsSerializer.deserialize_weights_merge(None))
+
+    def test_deserialize_weights_merge_returns_not_none(self):
+        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
+        res = 
self.subject.KerasWeightsSerializer.deserialize_weights_merge(dummy_model_state.tostring())
+        self.assertEqual(0, res[0])
+        self.assertEqual(1, res[1])
+        self.assertEqual(2, res[2])
+        self.assertEqual([3,4,5,6], res[3].tolist())
+
+    def test_deserialize_weights_null_input_returns_none(self):
+        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
+        self.assertEqual(None, 
self.subject.KerasWeightsSerializer.deserialize_weights(dummy_model_state.tostring(),
 None))
+        self.assertEqual(None, 
self.subject.KerasWeightsSerializer.deserialize_weights(None, [1,2,3]))
+        self.assertEqual(None, 
self.subject.KerasWeightsSerializer.deserialize_weights(None, None))
+
+    def test_deserialize_weights_valid_input_returns_not_none(self):
+        dummy_model_state = np.array([0,1,2,3,4,5], dtype=np.float32)
+        dummy_model_shape = [(2, 1, 1, 1), (1,)]
+        res = 
self.subject.KerasWeightsSerializer.deserialize_weights(dummy_model_state.tostring(),
 dummy_model_shape)
+        self.assertEqual(0, res[0])
+        self.assertEqual(1, res[1])
+        self.assertEqual(2, res[2])
+        self.assertEqual([[[[3.0]]], [[[4.0]]]], res[3][0].tolist())
+        self.assertEqual([5], res[3][1].tolist())
+
+    def test_deserialize_weights_invalid_input_fails(self):
+        # pass an invalid state with missing model weights
+        invalid_model_state = np.array([0,1,2], dtype=np.float32)
+        dummy_model_shape = [(2, 1, 1, 1), (1,)]
+
+        # we expect keras failure(ValueError) because we cannot reshape
+        # model weights of size 0 into shape (2,2,3,1)
+        with self.assertRaises(ValueError):
+            
self.subject.KerasWeightsSerializer.deserialize_weights(invalid_model_state.tostring(),
 dummy_model_shape)
+
+        invalid_model_state = np.array([0,1,2,3,4], dtype=np.float32)
+        dummy_model_shape = [(2, 2, 3, 1), (1,)]
+        # we expect keras failure(ValueError) because we cannot reshape
+        # model weights of size 2 into shape (2,2,3,1)
+        with self.assertRaises(ValueError):
+            
self.subject.KerasWeightsSerializer.deserialize_weights(invalid_model_state.tostring(),
 dummy_model_shape)
+
+    def test_deserialize_iteration_state_none_input_returns_none(self):
+        self.assertEqual(None, 
self.subject.KerasWeightsSerializer.deserialize_iteration_state(None))
+
+    def test_deserialize_iteration_state_returns_valid_output(self):
+        dummy_iteration_state = np.array([0,1,2,3,4,5], dtype=np.float32)
+        res = self.subject.KerasWeightsSerializer.deserialize_iteration_state(
+            dummy_iteration_state.tostring())
+        self.assertEqual(0, res[0])
+        self.assertEqual(1, res[1])
+        self.assertEqual(res[2],
+                         np.array([0,0,0,3,4,5], dtype=np.float32).tostring())
+
+    def test_serialize_weights_none_weights_returns_none(self):
+        res = self.subject.KerasWeightsSerializer.serialize_weights(0,1,2,None)
+        self.assertEqual(None , res)
+
+    def test_serialize_weights_valid_output(self):
+        res = 
self.subject.KerasWeightsSerializer.serialize_weights(0,1,2,[np.array([1,3]),
+                                                    np.array([4,5])])
+        self.assertEqual(np.array([0,1,2,1,3,4,5], 
dtype=np.float32).tostring(),
+                         res)
+
+    def test_serialize_weights_merge_none_weights_returns_none(self):
+        res = 
self.subject.KerasWeightsSerializer.serialize_weights_merge(0,1,2,None)
+        self.assertEqual(None , res)
+
+    def test_serialize_weights_merge_valid_output(self):
+        res = 
self.subject.KerasWeightsSerializer.serialize_weights_merge(0,1,2,np.array([1,3,4,5]))
+        self.assertEqual(np.array([0,1,2,1,3,4,5], 
dtype=np.float32).tostring(),
+                         res)
+
+    def test_get_data_as_np_array_one_image_per_row(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
+                                               {'x': [[5,6]], 'y': 1}]
+        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
+                                                         3)
+        self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
+                         x_res.tolist())
+        self.assertEqual(np.array([[1, 0, 0], [0, 1, 0]]).tolist(),
+                         y_res.tolist())
+
+    def test_get_data_as_np_array_multiple_images_per_row(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2], [3,4]], 'y': 
[0,2]},
+                                               {'x': [[5,6], [7,8]], 'y': 
[1,0]}]
+        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
+                                                         3)
+        self.assertEqual(np.array([[[[1,2]]], [[[3,4]]],
+                                   [[[5,6]]], [[[7,8]]]]).tolist(),
+                         x_res.tolist())
+        self.assertEqual(np.array([[1,0,0], [0,0,1] ,
+                                   [0,1,0], [1,0,0]]).tolist(),
+                         y_res.tolist())
+
+    def test_get_data_as_np_array_float_input_shape(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
+                                               {'x': [[5,6]], 'y': 1}]
+        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x',
+                                                         [1.5,1.9,2.3], 3)
+        self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
+                         x_res.tolist())
+        self.assertEqual(np.array([[1, 0, 0], [0, 1, 0]]).tolist(),
+                         y_res.tolist())
+
+    def test_get_data_as_np_array_invalid_input_shape(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
+                                               {'x': [[5,6]], 'y': 1}]
+        # we expect keras failure(ValueError) because we cannot reshape
+        # the input which is of size 2 to input shape of 1,1,3
+        with self.assertRaises(ValueError):
+            self.subject.get_data_as_np_array('foo','y','x', [1,1,3], 3)

Reply via email to