This is an automated email from the ASF dual-hosted git repository.

nkak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 8570aa03f9f13e63b9bf7d078659b9d16c964744
Author: Nikhil Kak <[email protected]>
AuthorDate: Tue Apr 9 17:51:04 2019 -0700

    DL: Use FORMAT class for getting model_arch col names
    
    JIRA: MADLIB-1304
    
    Use FORMAT class for getting model_arch col names(previously it was
    hardcoded, which caused discrepancy between minibatch_preprocessor_dl()
    and madlib_keras_fit()).
    
    Closes #367
    
    Co-authored-by: Jingyi Mei <[email protected]>
---
 .../deep_learning/keras_model_arch_table.py_in        | 19 ++++++++++---------
 .../postgres/modules/deep_learning/madlib_keras.py_in | 14 ++++++++------
 .../modules/deep_learning/madlib_keras_predict.py_in  | 12 +++++++-----
 .../modules/deep_learning/test/madlib_keras.sql_in    |  1 -
 4 files changed, 25 insertions(+), 21 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in 
b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
index a0521ee..8ed3ad6 100644
--- a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
+++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in
@@ -55,7 +55,7 @@ class Format:
     """
     col_names = ('model_id', 'model_arch', 'model_weights', 
'__internal_madlib_id__')
     col_types = ('SERIAL PRIMARY KEY', 'JSON', 'DOUBLE PRECISION[]', 'TEXT')
-    (model_id, model_arch, model_weights, __internal_madlib_id__) = col_names
+    (MODEL_ID, MODEL_ARCH, MODEL_WEIGHTS, __INTERNAL_MADLIB_ID__) = col_names
 
 @MinWarning("warning")
 def _execute(sql,max_rows=0):
@@ -89,18 +89,18 @@ def load_keras_model(schema_madlib, keras_model_arch_table,
              SELECT {model_id_col}, {model_arch_col}
                  FROM {model_arch_table} WHERE {internal_id_col} = 
'{unique_str}'
     """.format(model_arch_table=model_arch_table,
-               model_arch_col=Format.model_arch,
+               model_arch_col=Format.MODEL_ARCH,
                unique_str=unique_str,
                model_arch=quote_literal(model_arch),
-               model_id_col=Format.model_id,
-               internal_id_col=Format.__internal_madlib_id__)
+               model_id_col=Format.MODEL_ID,
+               internal_id_col=Format.__INTERNAL_MADLIB_ID__)
     res = _execute(sql,1)
 
-    if len(res) != 1 or res[0][Format.model_arch] != model_arch:
+    if len(res) != 1 or res[0][Format.MODEL_ARCH] != model_arch:
         raise Exception("Failed to insert new row in {0} table--try again?"
                        .format(model_arch_table))
     plpy.info("Keras Model Arch: Added model id {0} to {1} table".
-        format(res[0]['model_id'], model_arch_table))
+        format(res[0][Format.MODEL_ID], model_arch_table))
 
 def delete_keras_model(schema_madlib, keras_model_arch_table,
                        model_id, **kwargs):
@@ -113,8 +113,9 @@ def delete_keras_model(schema_madlib, 
keras_model_arch_table,
                    " missing columns: {1}".format(model_arch_table, 
missing_cols))
 
     sql = """
-           DELETE FROM {model_arch_table} WHERE model_id={model_id}
-          """.format(model_arch_table=model_arch_table, model_id=model_id)
+           DELETE FROM {model_arch_table} WHERE {model_id_col}={model_id}
+          """.format(model_arch_table=model_arch_table, 
model_id_col=Format.MODEL_ID,
+                     model_id=model_id)
     res = _execute(sql)
 
     if res.nrows() > 0:
@@ -123,7 +124,7 @@ def delete_keras_model(schema_madlib, 
keras_model_arch_table,
     else:
         plpy.error("Keras Model Arch: Model id {0} not found".format(model_id))
 
-    sql = "SELECT model_id FROM {0}".format(model_arch_table)
+    sql = "SELECT {0} FROM {1}".format(Format.MODEL_ID, model_arch_table)
     res = _execute(sql)
     if not res:
         plpy.info("Keras Model Arch: Dropping empty keras model arch "\
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index c3f36ab..b883cac 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -41,6 +41,7 @@ from madlib_keras_helper import DEPENDENT_VARTYPE
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
 from madlib_keras_helper import FitInputValidator
 from madlib_keras_wrapper import *
+from keras_model_arch_table import Format
 
 from utilities.model_arch_info import get_input_shape
 from utilities.model_arch_info import get_num_classes
@@ -63,22 +64,23 @@ def fit(schema_madlib, source_table, model, 
dependent_varname,
     use_gpu = bool(use_gpu)
 
     # Get the serialized master model
-    #TODO fix hardcoding of col names
     start_deserialization = time.time()
-    model_arch_query = "SELECT model_arch, model_weights FROM {0} WHERE "\
-        "id = {1}".format(model_arch_table, model_arch_id)
+    model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
+                                        Format.MODEL_ARCH, 
Format.MODEL_WEIGHTS,
+                                        model_arch_table, Format.MODEL_ID,
+                                        model_arch_id)
     query_result = plpy.execute(model_arch_query)
     if not  query_result:
         plpy.error("no model arch found in table {0} with id {1}".format(
             model_arch_table, model_arch_id))
     query_result = query_result[0]
-    model_arch = query_result['model_arch']
+    model_arch = query_result[Format.MODEL_ARCH]
     input_shape = get_input_shape(model_arch)
     num_classes = get_num_classes(model_arch)
     fit_validator.validate_input_shapes(source_table, input_shape, 2)
     if validation_table:
         fit_validator.validate_input_shapes(validation_table, input_shape, 1)
-    model_weights_serialized = query_result['model_weights']
+    model_weights_serialized = query_result[Format.MODEL_WEIGHTS]
 
     # Convert model from json and initialize weights
     master_model = model_from_json(model_arch)
@@ -510,7 +512,7 @@ def evaluate1(schema_madlib, model_table, test_table, 
id_col, model_arch_table,
         plpy.error("no model arch found in table {0} with id {1}".format(
             model_arch_table, model_arch_id))
     query_result = query_result[0]
-    model_arch = query_result['model_arch']
+    model_arch = query_result[Format.MODEL_ARCH]
     compile_params = "$madlib$" + compile_params + "$madlib$"
 
     loss_acc = get_loss_acc_from_keras_eval(schema_madlib, test_table, 
dependent_varname,
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index d89d703..1475a0f 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -33,6 +33,7 @@ from utilities.validate_args import get_col_value_and_type
 from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 from madlib_keras_helper import CLASS_VALUES_COLNAME
+from keras_model_arch_table import Format
 
 from madlib_keras_wrapper import compile_and_set_weights
 import madlib_keras_serializer
@@ -56,16 +57,17 @@ def predict(schema_madlib, model_table, test_table, id_col,
     model_data = plpy.execute(model_data_query)[0]['model_data']
 
     model_arch_query = """
-        SELECT model_arch, model_weights
-        FROM {0}
-        WHERE id = {1}
-        """.format(model_arch_table, model_arch_id)
+        SELECT {0}, {1}
+        FROM {2}
+        WHERE {3} = {4}
+        """.format(Format.MODEL_ARCH, Format.MODEL_WEIGHTS,model_arch_table,
+                   Format.MODEL_ID, model_arch_id)
     query_result = plpy.execute(model_arch_query)
     if not  query_result or len(query_result) == 0:
         plpy.error("{0}: No model arch found in table {1} with id {2}".format(
             MODULE_NAME, model_arch_table, model_arch_id))
     query_result = query_result[0]
-    model_arch = query_result['model_arch']
+    model_arch = query_result[Format.MODEL_ARCH]
     input_shape = get_input_shape(model_arch)
     compile_params = "$madlib$" + compile_params + "$madlib$"
     model_summary_table = add_postfix(model_table, "_summary")
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index ceb3d67..7902db4 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -63,7 +63,6 @@ SELECT load_keras_model('model_arch',
        {"class_name": "Dense", "config": {"kernel_initializer": {"class_name": 
"VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0, "seed": 
null, "mode": "fan_avg"}}, "name": "dense_1", "kernel_constraint": null, 
"bias_regularizer": null, "bias_constraint": null, "activation": "softmax", 
"trainable": true, "kernel_regularizer": null, "bias_initializer":
        {"class_name": "Zeros", "config": {}}, "units": 2, "use_bias": true, 
"activity_regularizer": null}
        }], "backend": "tensorflow"}$$);
-ALTER TABLE model_arch RENAME model_id TO id;
 
 -- Please do not break up the compile_params string
 -- It might break the assertion

Reply via email to