This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 40713de  DL: Add support for empty optimizer params
40713de is described below

commit 40713debea650398795fe36f36463024ecd2d012
Author: Orhan Kislal <okis...@pivotal.io>
AuthorDate: Fri Apr 19 13:17:15 2019 -0700

    DL: Add support for empty optimizer params
    
    Calling compile with an optimizer that has no params caused an error
    during parsing. This commit fixes the issue and adds tests to ensure it
    works as expected.
---
 .../deep_learning/madlib_keras_wrapper.py_in       | 18 ++++++++++--
 .../modules/deep_learning/test/madlib_keras.sql_in | 32 +++++++++++++++++++++-
 2 files changed, 46 insertions(+), 4 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 211488c..4dd29d5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -140,16 +140,28 @@ def parse_optimizer(compile_dict):
     optimizers = get_optimizers()
     _assert(opt_name in optimizers,
             "model_keras error: invalid optimizer name: {0}".format(opt_name))
+
+    # If we use only the optimizer name
     if len(opt_split) == 1:
         final_args = None
+    # If we use optimizer object with no params
+    elif opt_split[1] == ')':
+        final_args = None
+    # If we give parameters to the optimizer
     else:
         opt_params = opt_split[1][:-1]
         opt_params_array = opt_params.split(',')
         opt_params_clean = map(split_and_strip, opt_params_array)
         key_value_params = { x[0] : x[1] for x in opt_params_clean}
-        final_args = { key: bool(value) if value == 'True' or value == 'False'
-                       else float(value)
-                       for key,value in key_value_params.iteritems() }
+
+        final_args = {}
+        for key,value in key_value_params.iteritems():
+            if value == 'None':
+                final_args[key] = None
+            elif value == 'True' or value == 'False':
+                final_args[key] = bool(value)
+            else:
+                final_args[key] = float(value)
     return (opt_name,final_args)
 
 # Parse the loss function.
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 08ac9cb..115af9d 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -219,7 +219,37 @@ SELECT madlib_keras_fit(
     NULL,
     'model name', 'model desc');
 
--- negative test case for passing non numeric y to fit
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer='Adam()', loss=losses.categorical_crossentropy, 
metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1,
+    FALSE,
+    NULL,
+    'model name', 'model desc');
+
+DROP TABLE IF EXISTS keras_out, keras_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_out',
+    'dependent_var',
+    'independent_var',
+    'model_arch',
+    1,
+    $$ optimizer=Adam(epsilon=None), loss=losses.categorical_crossentropy, 
metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    1,
+    FALSE,
+    NULL,
+    'model name', 'model desc');
+
+-- -- negative test case for passing non numeric y to fit
 -- induce failure by passing a non numeric column
 create table cifar_10_sample_val_failure as select * from cifar_10_sample_val;
 alter table cifar_10_sample_val_failure rename dependent_var to 
dependent_var_original;

Reply via email to