This is an automated email from the ASF dual-hosted git repository.

khannaekta pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 27f8ac96afc19d80ab4eb6a034d3b3ac29f1011f
Author: Ekta Khanna <[email protected]>
AuthorDate: Tue May 19 11:53:48 2020 -0700

    DL: Add object table info in load MST table utility
    
    This commit adds an optional param `object_table` (storing keras custom
    function objects) to `load_model_selection_table()`. This object table
    (if specified) is added to the summary table named
    `<model_selection_table>_summary`, which can be passed to the
    fit/evaluate functions.
---
 .../madlib_keras_custom_function.py_in             |   1 -
 .../madlib_keras_model_selection.py_in             |  23 ++-
 .../madlib_keras_model_selection.sql_in            |  13 +-
 .../deep_learning/madlib_keras_validator.py_in     |  38 +++-
 .../test/madlib_keras_custom_function.setup.sql_in |  41 +++++
 .../test/madlib_keras_custom_function.sql_in       |  25 +--
 .../test/madlib_keras_model_selection.sql_in       |  42 +++++
 .../test_madlib_keras_model_selection_table.py_in  | 194 +++++++++++++++++++++
 8 files changed, 349 insertions(+), 28 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 246c72d..9dcefed 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -18,7 +18,6 @@
 
 import dill
 import plpy
-from plpy import spiexceptions
 from utilities.control import MinWarning
 from utilities.utilities import _assert
 from utilities.utilities import get_col_name_type_sql_string
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
index f3b02c7..46267f0 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
@@ -28,6 +28,7 @@ class ModelSelectionSchema:
     MST_KEY = 'mst_key'
     MODEL_ID = ModelArchSchema.MODEL_ID
     MODEL_ARCH_TABLE = 'model_arch_table'
+    OBJECT_TABLE = 'object_table'
     COMPILE_PARAMS = 'compile_params'
     FIT_PARAMS = 'fit_params'
     col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR')
@@ -55,6 +56,7 @@ class MstLoader():
                  model_id_list,
                  compile_params_list,
                  fit_params_list,
+                 object_table=None,
                  **kwargs):
 
         self.model_arch_table = model_arch_table
@@ -62,13 +64,15 @@ class MstLoader():
         self.model_selection_summary_table = add_postfix(
             model_selection_table, "_summary")
         self.model_id_list = sorted(list(set(model_id_list)))
+        self.object_table = object_table
         MstLoaderInputValidator(
             model_arch_table=self.model_arch_table,
             model_selection_table=self.model_selection_table,
             model_selection_summary_table=self.model_selection_summary_table,
             model_id_list=self.model_id_list,
             compile_params_list=compile_params_list,
-            fit_params_list=fit_params_list
+            fit_params_list=fit_params_list,
+            object_table=object_table
         )
         self.compile_params_list = self.params_preprocessed(
             compile_params_list)
@@ -148,10 +152,12 @@ class MstLoader():
         """
         create_query = """
                         CREATE TABLE {self.model_selection_summary_table} (
-                            {model_arch_table} VARCHAR
+                            {model_arch_table} VARCHAR,
+                            {object_table} VARCHAR
                         );
                        """.format(self=self,
-                                  
model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE)
+                                  
model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE,
+                                  
object_table=ModelSelectionSchema.OBJECT_TABLE)
         with MinWarning('warning'):
             plpy.execute(create_query)
 
@@ -179,14 +185,21 @@ class MstLoader():
                                       
fit_params_col=ModelSelectionSchema.FIT_PARAMS,
                                       **locals())
             plpy.execute(insert_query)
+        if self.object_table is None:
+            object_table = 'NULL::VARCHAR'
+        else:
+            object_table = '$${0}$$'.format(self.object_table)
         insert_summary_query = """
                         INSERT INTO
                             {self.model_selection_summary_table}(
-                                {model_arch_table_name}
+                                {model_arch_table_name},
+                                {object_table_name}
                         )
                         VALUES (
-                            $${self.model_arch_table}$$
+                            $${self.model_arch_table}$$,
+                            {object_table}
                         )
                        
""".format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE,
+                                  
object_table_name=ModelSelectionSchema.OBJECT_TABLE,
                                   **locals())
         plpy.execute(insert_summary_query)
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
index c15757c..7903e7f 100644
--- 
a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in
@@ -426,7 +426,8 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.load_model_selection_table(
     model_selection_table   VARCHAR,
     model_id_list           INTEGER[],
     compile_params_list     VARCHAR[],
-    fit_params_list         VARCHAR[]
+    fit_params_list         VARCHAR[],
+    object_table            VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_model_selection')
     with AOControl(False):
@@ -435,3 +436,13 @@ CREATE OR REPLACE FUNCTION 
MADLIB_SCHEMA.load_model_selection_table(
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table(
+    model_arch_table        VARCHAR,
+    model_selection_table   VARCHAR,
+    model_id_list           INTEGER[],
+    compile_params_list     VARCHAR[],
+    fit_params_list         VARCHAR[]
+) RETURNS VOID AS $$
+  SELECT MADLIB_SCHEMA.load_model_selection_table($1, $2, $3, $4, $5, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 11730cf..a364a9e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -20,6 +20,7 @@
 import plpy
 from keras_model_arch_table import ModelArchSchema
 from model_arch_info import get_num_classes
+from madlib_keras_custom_function import CustomFunctionSchema
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
@@ -45,6 +46,8 @@ from utilities.validate_args import input_tbl_valid
 from utilities.validate_args import output_tbl_valid
 from madlib_keras_wrapper import parse_and_validate_fit_params
 from madlib_keras_wrapper import parse_and_validate_compile_params
+import keras.losses as losses
+import keras.metrics as metrics
 
 class InputValidator:
     @staticmethod
@@ -443,7 +446,8 @@ class MstLoaderInputValidator():
                  model_selection_summary_table,
                  model_id_list,
                  compile_params_list,
-                 fit_params_list
+                 fit_params_list,
+                 object_table
                  ):
         self.model_arch_table = model_arch_table
         self.model_selection_table = model_selection_table
@@ -451,6 +455,7 @@ class MstLoaderInputValidator():
         self.model_id_list = model_id_list
         self.compile_params_list = compile_params_list
         self.fit_params_list = fit_params_list
+        self.object_table = object_table
         self.module_name = 'load_model_selection_table'
         self._validate_input_args()
 
@@ -489,9 +494,36 @@ class MstLoaderInputValidator():
                     """.format(fit_params, str(e)))
         if not self.compile_params_list:
             plpy.error( "compile_params_list cannot be NULL")
+        custom_fn_name = []
+        ## Initialize builtin loss/metrics functions
+        builtin_losses = dir(losses)
+        builtin_metrics = dir(metrics)
+        # Default metrics, since it is not part of the builtin metrics list
+        builtin_metrics.append('accuracy')
+        if self.object_table is not None:
+            res = plpy.execute("SELECT {0} from 
{1}".format(CustomFunctionSchema.FN_NAME,
+                                                            self.object_table))
+            for r in res:
+                custom_fn_name.append(r[CustomFunctionSchema.FN_NAME])
         for compile_params in self.compile_params_list:
             try:
-                res = parse_and_validate_compile_params(compile_params)
+                _, _, res = parse_and_validate_compile_params(compile_params)
+                # Validating if loss/metrics function called in compile_params
+                # is either defined in object table or is a built_in keras
+                # loss/metrics function
+                error_suffix = "but input object table missing!"
+                if self.object_table is not None:
+                    error_suffix = "is not defined in object table 
'{0}'!".format(self.object_table)
+
+                _assert(res['loss'] in custom_fn_name or res['loss'] in 
builtin_losses,
+                        "custom function '{0}' used in compile params "\
+                        "{1}".format(res['loss'], error_suffix))
+                if 'metrics' in res:
+                    
_assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0
+                            or 
len(set(res['metrics']).intersection(builtin_metrics)) > 0),
+                            "custom function '{0}' used in compile params " \
+                            "{1}".format(res['metrics'], error_suffix))
+
             except Exception as e:
                 plpy.error(
                     """Compile param check failed for: {0} \n
@@ -500,6 +532,8 @@ class MstLoaderInputValidator():
 
     def _validate_input_output_tables(self):
         input_tbl_valid(self.model_arch_table, self.module_name)
+        if self.object_table is not None:
+            input_tbl_valid(self.object_table, self.module_name)
         output_tbl_valid(self.model_selection_table, self.module_name)
         output_tbl_valid(self.model_selection_summary_table, self.module_name)
 
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
new file mode 100644
index 0000000..671cf07
--- /dev/null
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
@@ -0,0 +1,41 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ ** ---------------------------------------------------------------------*/
+
+---- utility for creating valid dill objects ----
+CREATE OR REPLACE FUNCTION custom_function_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_sum_fn(a, b):
+       return a+b
+
+pb=dill.dumps(test_sum_fn)
+return pb
+$$ language plpythonu;
+
+CREATE OR REPLACE FUNCTION read_custom_function(pb bytea, arg1 int, arg2 int)
+RETURNS INTEGER AS
+$$
+import dill
+obj=dill.loads(pb)
+res=obj(arg1, arg2)
+return res
+$$ language plpythonu;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index 74f6ba2..82d5e97 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -23,25 +23,12 @@
  * Test load custom function helper functions
  * -------------------------------------------------------------------------- 
*/
 
-CREATE OR REPLACE FUNCTION custom_function_object()
-RETURNS BYTEA AS
-$$
-import dill
-def test_sum_fn(a, b):
-       return a+b
-
-pb=dill.dumps(test_sum_fn)
-return pb
-$$ language plpythonu;
-
-CREATE OR REPLACE FUNCTION read_custom_function(pb bytea, arg1 int, arg2 int)
-RETURNS INTEGER AS
-$$
-import dill
-obj=dill.loads(pb)
-res=obj(arg1, arg2)
-return res
-$$ language plpythonu;
+m4_include(`SQLCommon.m4')
+
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
 
 /* Test successful table creation where no table exists */
 DROP TABLE IF EXISTS test_custom_function_table;
diff --git 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index e1dbe0c..fa90c86 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -26,6 +26,11 @@ m4_include(`SQLCommon.m4')
              
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
 )
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             
`\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
+
 -- MST table generation tests
 -- Valid inputs should pass and yield 6 msts in the table
 DROP TABLE IF EXISTS mst_table, mst_table_summary;
@@ -215,6 +220,43 @@ SELECT 
assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'M
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = 
false, 'Model summary output table is unlogged');
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = 
false, 'Model info output table is unlogged');
 
+-- Test for object table
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT assert(MADLIB_SCHEMA.trap_error($MAD$
+  SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_object_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', 
metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ],
+    'test_custom_function_table')
+$MAD$) = 1, 'Object table does not exist!');
+SELECT load_custom_function('test_custom_function_table', 
custom_function_object(), 'sum_fn', 'returns sum');
+
+DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_object_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', 
metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ],
+    'test_custom_function_table'
+);
+
+SELECT assert(
+        object_table = 'test_custom_function_table',
+        'Keras Fit Multiple Output Summary Validation failed when user passes 
in object_table. Actual:' || __to_char(summary))
+FROM (SELECT * FROM mst_object_table_summary) summary;
+
 -- Test when number of configs(3) equals number of segments(3)
 DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, 
iris_multiple_model_info;
 SELECT setseed(0);
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
index 57e08a5..b911992 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
@@ -53,6 +53,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
         self.subject = self.module.MstLoader
         self.model_selection_table = 'mst_table'
         self.model_arch_table = 'model_arch_library'
+        self.object_table = 'custom_function_table'
         self.model_id_list = [1]
         self.compile_params_list = [
             """
@@ -99,6 +100,20 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
                 self.fit_params_list
             )
 
+    def test_invalid_input_args_optional_param(self):
+        self.module.MstLoaderInputValidator \
+            ._validate_input_args \
+            .side_effect = plpy.PLPYException('Invalid input args')
+        with self.assertRaises(plpy.PLPYException):
+            generate_mst = self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_id_list,
+                self.compile_params_list,
+                self.fit_params_list,
+                "invalid_table"
+            )
+
     def test_duplicate_params(self):
         self.model_id_list = [1, 1, 2]
         self.compile_params_list = [
@@ -135,6 +150,185 @@ class LoadModelSelectionTableTestCase(unittest.TestCase):
     def tearDown(self):
         self.module_patcher.stop()
 
+class MstLoaderInputValidatorTestCase(unittest.TestCase):
+    def setUp(self):
+        # The side effects of this class(writing to the output table) are not
+        # tested here. They are tested in dev-check.
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import deep_learning.madlib_keras_validator
+        self.module = deep_learning.madlib_keras_validator
+
+        self.subject = self.module.MstLoaderInputValidator
+        self.model_selection_table = 'mst_table'
+        self.model_arch_table = 'model_arch_library'
+        self.model_arch_summary_table = 'model_arch_library_summary'
+        self.object_table = 'custom_function_table'
+        self.model_id_list = [1]
+        self.compile_params_list = [
+            """
+                loss='categorical_crossentropy',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """,
+            """
+                loss='categorical_crossentropy',
+                optimizer='Adam(lr=0.01)',
+                metrics=['accuracy']
+            """,
+            """
+                loss='categorical_crossentropy',
+                optimizer='Adam(lr=0.001)',
+                metrics=['accuracy']
+            """
+        ]
+        self.fit_params_list = [
+            "batch_size=5,epochs=1",
+            "batch_size=10,epochs=1"
+        ]
+
+    def test_validate_compile_params_no_custom_fn_table(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+
+        self.subject(
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_arch_summary_table,
+            self.model_id_list,
+            self.compile_params_list,
+            self.fit_params_list,
+            None
+        )
+
+    def test_test_validate_compile_params_custom_fn_table(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                              {'name': 'custom_fn2'}]]
+        self.subject(
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_arch_summary_table,
+            self.model_id_list,
+            self.compile_params_list,
+            self.fit_params_list,
+            self.object_table
+        )
+
+    def test_test_validate_compile_params_valid_custom_fn(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+        self.compile_params_list_valid_custom_fn = [
+            """
+                loss='custom_fn1',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """
+        ]
+        self.subject(
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_arch_summary_table,
+            self.model_id_list,
+            self.compile_params_list_valid_custom_fn,
+            self.fit_params_list,
+            self.object_table
+        )
+
+    def 
test_test_validate_compile_params_valid_custom_fn_missing_obj_tbl(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+        self.compile_params_list_valid_custom_fn = [
+            """
+                loss='custom_fn1',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """
+        ]
+
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_arch_summary_table,
+                self.model_id_list,
+                self.compile_params_list_valid_custom_fn,
+                self.fit_params_list,
+                None
+            )
+        self.assertIn("object table missing", str(error.exception).lower())
+
+    def test_test_validate_compile_params_missing_loss_fn(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+        self.compile_params_list_invalid_loss_fn = [
+            """
+                loss='invalid_loss',
+                optimizer='Adam(lr=0.1)',
+                metrics=['accuracy']
+            """
+        ]
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_arch_summary_table,
+                self.model_id_list,
+                self.compile_params_list_invalid_loss_fn,
+                self.fit_params_list,
+                self.object_table
+            )
+        self.assertIn("invalid_loss", str(error.exception).lower())
+
+    def test_test_validate_compile_params_missing_metric_fn(self):
+        self.subject._validate_input_output_tables = Mock()
+        self.subject._validate_model_ids = Mock()
+        self.subject.parse_and_validate_fit_params = Mock()
+        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
+                                               {'name': 'custom_fn2'}]]
+
+        self.compile_params_list_invalid_metric_fn = [
+            """
+                loss='custom_fn1',
+                optimizer='Adam(lr=0.1)',
+                metrics=['invalid_metrics']
+            """
+        ]
+        with self.assertRaises(plpy.PLPYException) as error:
+            self.subject(
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_arch_summary_table,
+                self.model_id_list,
+                self.compile_params_list_invalid_metric_fn,
+                self.fit_params_list,
+                self.object_table
+            )
+        self.assertIn("invalid_metrics", str(error.exception).lower())
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
 if __name__ == '__main__':
     unittest.main()
 # ---------------------------------------------------------------------

Reply via email to