This is an automated email from the ASF dual-hosted git repository. khannaekta pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 27f8ac96afc19d80ab4eb6a034d3b3ac29f1011f Author: Ekta Khanna <[email protected]> AuthorDate: Tue May 19 11:53:48 2020 -0700 DL: Add object table info in load MST table utility This commit adds an optional param `object_table` (storing keras custom function objects) to `load_model_selection_table()`. This object table (if specified) is added to the summary table named `<model_selection_table>_summary`, which can be passed to the fit/evaluate functions. --- .../madlib_keras_custom_function.py_in | 1 - .../madlib_keras_model_selection.py_in | 23 ++- .../madlib_keras_model_selection.sql_in | 13 +- .../deep_learning/madlib_keras_validator.py_in | 38 +++- .../test/madlib_keras_custom_function.setup.sql_in | 41 +++++ .../test/madlib_keras_custom_function.sql_in | 25 +-- .../test/madlib_keras_model_selection.sql_in | 42 +++++ .../test_madlib_keras_model_selection_table.py_in | 194 +++++++++++++++++++++ 8 files changed, 349 insertions(+), 28 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in index 246c72d..9dcefed 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in @@ -18,7 +18,6 @@ import dill import plpy -from plpy import spiexceptions from utilities.control import MinWarning from utilities.utilities import _assert from utilities.utilities import get_col_name_type_sql_string diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in index f3b02c7..46267f0 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in @@ -28,6 +28,7 @@ class ModelSelectionSchema: MST_KEY = 'mst_key' MODEL_ID = ModelArchSchema.MODEL_ID MODEL_ARCH_TABLE = 'model_arch_table' + OBJECT_TABLE = 'object_table' COMPILE_PARAMS = 'compile_params' FIT_PARAMS = 'fit_params' col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR') @@ -55,6 +56,7 @@ class MstLoader(): model_id_list, compile_params_list, fit_params_list, + object_table=None, **kwargs): self.model_arch_table = model_arch_table @@ -62,13 +64,15 @@ class MstLoader(): self.model_selection_summary_table = add_postfix( model_selection_table, "_summary") self.model_id_list = sorted(list(set(model_id_list))) + self.object_table = object_table MstLoaderInputValidator( model_arch_table=self.model_arch_table, model_selection_table=self.model_selection_table, model_selection_summary_table=self.model_selection_summary_table, model_id_list=self.model_id_list, compile_params_list=compile_params_list, - fit_params_list=fit_params_list + fit_params_list=fit_params_list, + object_table=object_table ) self.compile_params_list = self.params_preprocessed( compile_params_list) @@ -148,10 +152,12 @@ class MstLoader(): """ create_query = """ CREATE TABLE {self.model_selection_summary_table} ( - {model_arch_table} VARCHAR + {model_arch_table} VARCHAR, + {object_table} VARCHAR ); """.format(self=self, - model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE) + model_arch_table=ModelSelectionSchema.MODEL_ARCH_TABLE, + object_table=ModelSelectionSchema.OBJECT_TABLE) with MinWarning('warning'): plpy.execute(create_query) @@ -179,14 +185,21 @@ class MstLoader(): fit_params_col=ModelSelectionSchema.FIT_PARAMS, **locals()) plpy.execute(insert_query) + if self.object_table is None: + object_table = 'NULL::VARCHAR' + else: + object_table = '$${0}$$'.format(self.object_table) insert_summary_query = """ INSERT INTO {self.model_selection_summary_table}( - {model_arch_table_name} + {model_arch_table_name}, + {object_table_name} ) VALUES ( - $${self.model_arch_table}$$ + $${self.model_arch_table}$$, + {object_table} ) """.format(model_arch_table_name=ModelSelectionSchema.MODEL_ARCH_TABLE, + object_table_name=ModelSelectionSchema.OBJECT_TABLE, **locals()) plpy.execute(insert_summary_query) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in index c15757c..7903e7f 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.sql_in @@ -426,7 +426,8 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table( model_selection_table VARCHAR, model_id_list INTEGER[], compile_params_list VARCHAR[], - fit_params_list VARCHAR[] + fit_params_list VARCHAR[], + object_table VARCHAR ) RETURNS VOID AS $$ PythonFunctionBodyOnly(`deep_learning', `madlib_keras_model_selection') with AOControl(False): @@ -435,3 +436,13 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table( $$ LANGUAGE plpythonu VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_model_selection_table( + model_arch_table VARCHAR, + model_selection_table VARCHAR, + model_id_list INTEGER[], + compile_params_list VARCHAR[], + fit_params_list VARCHAR[] +) RETURNS VOID AS $$ + SELECT MADLIB_SCHEMA.load_model_selection_table($1, $2, $3, $4, $5, NULL); +$$ LANGUAGE sql VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in index 11730cf..a364a9e 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in @@ -20,6 +20,7 @@ import plpy from keras_model_arch_table import ModelArchSchema from model_arch_info import get_num_classes +from madlib_keras_custom_function import CustomFunctionSchema from madlib_keras_helper import CLASS_VALUES_COLNAME from madlib_keras_helper import COMPILE_PARAMS_COLNAME from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME @@ -45,6 +46,8 @@ from utilities.validate_args import input_tbl_valid from utilities.validate_args import output_tbl_valid from madlib_keras_wrapper import parse_and_validate_fit_params from madlib_keras_wrapper import parse_and_validate_compile_params +import keras.losses as losses +import keras.metrics as metrics class InputValidator: @staticmethod @@ -443,7 +446,8 @@ class MstLoaderInputValidator(): model_selection_summary_table, model_id_list, compile_params_list, - fit_params_list + fit_params_list, + object_table ): self.model_arch_table = model_arch_table self.model_selection_table = model_selection_table @@ -451,6 +455,7 @@ class MstLoaderInputValidator(): self.model_id_list = model_id_list self.compile_params_list = compile_params_list self.fit_params_list = fit_params_list + self.object_table = object_table self.module_name = 'load_model_selection_table' self._validate_input_args() @@ -489,9 +494,36 @@ class MstLoaderInputValidator(): """.format(fit_params, str(e))) if not self.compile_params_list: plpy.error( "compile_params_list cannot be NULL") + custom_fn_name = [] + ## Initialize builtin loss/metrics functions + builtin_losses = dir(losses) + builtin_metrics = dir(metrics) + # Default metrics, since it is not part of the builtin metrics list + builtin_metrics.append('accuracy') + if self.object_table is not None: + res = plpy.execute("SELECT {0} from {1}".format(CustomFunctionSchema.FN_NAME, + self.object_table)) + for r in res: + custom_fn_name.append(r[CustomFunctionSchema.FN_NAME]) for compile_params in self.compile_params_list: try: - res = parse_and_validate_compile_params(compile_params) + _, _, res = parse_and_validate_compile_params(compile_params) + # Validating if loss/metrics function called in compile_params + # is either defined in object table or is a built_in keras + # loss/metrics function + error_suffix = "but input object table missing!" + if self.object_table is not None: + error_suffix = "is not defined in object table '{0}'!".format(self.object_table) + + _assert(res['loss'] in custom_fn_name or res['loss'] in builtin_losses, + "custom function '{0}' used in compile params "\ + "{1}".format(res['loss'], error_suffix)) + if 'metrics' in res: + _assert((len(set(res['metrics']).intersection(custom_fn_name)) > 0 + or len(set(res['metrics']).intersection(builtin_metrics)) > 0), + "custom function '{0}' used in compile params " \ + "{1}".format(res['metrics'], error_suffix)) + except Exception as e: plpy.error( """Compile param check failed for: {0} \n @@ -500,6 +532,8 @@ class MstLoaderInputValidator(): def _validate_input_output_tables(self): input_tbl_valid(self.model_arch_table, self.module_name) + if self.object_table is not None: + input_tbl_valid(self.object_table, self.module_name) output_tbl_valid(self.model_selection_table, self.module_name) output_tbl_valid(self.model_selection_summary_table, self.module_name) diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in new file mode 100644 index 0000000..671cf07 --- /dev/null +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in @@ -0,0 +1,41 @@ +/* ---------------------------------------------------------------------*//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + ** ---------------------------------------------------------------------*/ + +---- utility for creating valid dill objects ---- +CREATE OR REPLACE FUNCTION custom_function_object() +RETURNS BYTEA AS +$$ +import dill +def test_sum_fn(a, b): + return a+b + +pb=dill.dumps(test_sum_fn) +return pb +$$ language plpythonu; + +CREATE OR REPLACE FUNCTION read_custom_function(pb bytea, arg1 int, arg2 int) +RETURNS INTEGER AS +$$ +import dill +obj=dill.loads(pb) +res=obj(arg1, arg2) +return res +$$ language plpythonu; diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in index 74f6ba2..82d5e97 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in @@ -23,25 +23,12 @@ * Test load custom function helper functions * -------------------------------------------------------------------------- */ -CREATE OR REPLACE FUNCTION custom_function_object() -RETURNS BYTEA AS -$$ -import dill -def test_sum_fn(a, b): - return a+b - -pb=dill.dumps(test_sum_fn) -return pb -$$ language plpythonu; - -CREATE OR REPLACE FUNCTION read_custom_function(pb bytea, arg1 int, arg2 int) -RETURNS INTEGER AS -$$ -import dill -obj=dill.loads(pb) -res=obj(arg1, arg2) -return res -$$ language plpythonu; +m4_include(`SQLCommon.m4') + +\i m4_regexp(MODULE_PATHNAME, + `\(.*\)libmadlib\.so', + `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in' +) /* Test successful table creation where no table exists */ DROP TABLE IF EXISTS test_custom_function_table; diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in index e1dbe0c..fa90c86 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in @@ -26,6 +26,11 @@ m4_include(`SQLCommon.m4') `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in' ) +\i m4_regexp(MODULE_PATHNAME, + `\(.*\)libmadlib\.so', + `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in' +) + -- MST table generation tests -- Valid inputs should pass and yield 6 msts in the table DROP TABLE IF EXISTS mst_table, mst_table_summary; @@ -215,6 +220,43 @@ SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'M SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged'); SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged'); +-- Test for object table + +DROP TABLE IF EXISTS test_custom_function_table; +SELECT assert(MADLIB_SCHEMA.trap_error($MAD$ + SELECT load_model_selection_table( + 'iris_model_arch', + 'mst_object_table', + ARRAY[1], + ARRAY[ + $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$ + ], + ARRAY[ + $$batch_size=16, epochs=1$$ + ], + 'test_custom_function_table') +$MAD$) = 1, 'Object table does not exist!'); +SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum'); + +DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary; +SELECT load_model_selection_table( + 'iris_model_arch', + 'mst_object_table', + ARRAY[1], + ARRAY[ + $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$ + ], + ARRAY[ + $$batch_size=16, epochs=1$$ + ], + 'test_custom_function_table' +); + +SELECT assert( + object_table = 'test_custom_function_table', + 'Keras Fit Multiple Output Summary Validation failed when user passes in object_table. Actual:' || __to_char(summary)) +FROM (SELECT * FROM mst_object_table_summary) summary; + -- Test when number of configs(3) equals number of segments(3) DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info; SELECT setseed(0); diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in index 57e08a5..b911992 100644 --- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in +++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in @@ -53,6 +53,7 @@ class LoadModelSelectionTableTestCase(unittest.TestCase): self.subject = self.module.MstLoader self.model_selection_table = 'mst_table' self.model_arch_table = 'model_arch_library' + self.object_table = 'custom_function_table' self.model_id_list = [1] self.compile_params_list = [ """ @@ -99,6 +100,20 @@ class LoadModelSelectionTableTestCase(unittest.TestCase): self.fit_params_list ) + def test_invalid_input_args_optional_param(self): + self.module.MstLoaderInputValidator \ + ._validate_input_args \ + .side_effect = plpy.PLPYException('Invalid input args') + with self.assertRaises(plpy.PLPYException): + generate_mst = self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_id_list, + self.compile_params_list, + self.fit_params_list, + "invalid_table" + ) + def test_duplicate_params(self): self.model_id_list = [1, 1, 2] self.compile_params_list = [ @@ -135,6 +150,185 @@ class LoadModelSelectionTableTestCase(unittest.TestCase): def tearDown(self): self.module_patcher.stop() +class MstLoaderInputValidatorTestCase(unittest.TestCase): + def setUp(self): + # The side effects of this class(writing to the output table) are not + # tested here. They are tested in dev-check. + self.plpy_mock = Mock(spec='error') + patches = { + 'plpy': plpy + } + + self.plpy_mock_execute = MagicMock() + plpy.execute = self.plpy_mock_execute + + self.module_patcher = patch.dict('sys.modules', patches) + self.module_patcher.start() + import deep_learning.madlib_keras_validator + self.module = deep_learning.madlib_keras_validator + + self.subject = self.module.MstLoaderInputValidator + self.model_selection_table = 'mst_table' + self.model_arch_table = 'model_arch_library' + self.model_arch_summary_table = 'model_arch_library_summary' + self.object_table = 'custom_function_table' + self.model_id_list = [1] + self.compile_params_list = [ + """ + loss='categorical_crossentropy', + optimizer='Adam(lr=0.1)', + metrics=['accuracy'] + """, + """ + loss='categorical_crossentropy', + optimizer='Adam(lr=0.01)', + metrics=['accuracy'] + """, + """ + loss='categorical_crossentropy', + optimizer='Adam(lr=0.001)', + metrics=['accuracy'] + """ + ] + self.fit_params_list = [ + "batch_size=5,epochs=1", + "batch_size=10,epochs=1" + ] + + def test_validate_compile_params_no_custom_fn_table(self): + self.subject._validate_input_output_tables = Mock() + self.subject._validate_model_ids = Mock() + self.subject.parse_and_validate_fit_params = Mock() + + self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_arch_summary_table, + self.model_id_list, + self.compile_params_list, + self.fit_params_list, + None + ) + + def test_test_validate_compile_params_custom_fn_table(self): + self.subject._validate_input_output_tables = Mock() + self.subject._validate_model_ids = Mock() + self.subject.parse_and_validate_fit_params = Mock() + self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, + {'name': 'custom_fn2'}]] + self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_arch_summary_table, + self.model_id_list, + self.compile_params_list, + self.fit_params_list, + self.object_table + ) + + def test_test_validate_compile_params_valid_custom_fn(self): + self.subject._validate_input_output_tables = Mock() + self.subject._validate_model_ids = Mock() + self.subject.parse_and_validate_fit_params = Mock() + self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, + {'name': 'custom_fn2'}]] + self.compile_params_list_valid_custom_fn = [ + """ + loss='custom_fn1', + optimizer='Adam(lr=0.1)', + metrics=['accuracy'] + """ + ] + self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_arch_summary_table, + self.model_id_list, + self.compile_params_list_valid_custom_fn, + self.fit_params_list, + self.object_table + ) + + def test_test_validate_compile_params_valid_custom_fn_missing_obj_tbl(self): + self.subject._validate_input_output_tables = Mock() + self.subject._validate_model_ids = Mock() + self.subject.parse_and_validate_fit_params = Mock() + self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, + {'name': 'custom_fn2'}]] + self.compile_params_list_valid_custom_fn = [ + """ + loss='custom_fn1', + optimizer='Adam(lr=0.1)', + metrics=['accuracy'] + """ + ] + + with self.assertRaises(plpy.PLPYException) as error: + self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_arch_summary_table, + self.model_id_list, + self.compile_params_list_valid_custom_fn, + self.fit_params_list, + None + ) + self.assertIn("object table missing", str(error.exception).lower()) + + def test_test_validate_compile_params_missing_loss_fn(self): + self.subject._validate_input_output_tables = Mock() + self.subject._validate_model_ids = Mock() + self.subject.parse_and_validate_fit_params = Mock() + self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, + {'name': 'custom_fn2'}]] + self.compile_params_list_invalid_loss_fn = [ + """ + loss='invalid_loss', + optimizer='Adam(lr=0.1)', + metrics=['accuracy'] + """ + ] + with self.assertRaises(plpy.PLPYException) as error: + self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_arch_summary_table, + self.model_id_list, + self.compile_params_list_invalid_loss_fn, + self.fit_params_list, + self.object_table + ) + self.assertIn("invalid_loss", str(error.exception).lower()) + + def test_test_validate_compile_params_missing_metric_fn(self): + self.subject._validate_input_output_tables = Mock() + self.subject._validate_model_ids = Mock() + self.subject.parse_and_validate_fit_params = Mock() + self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, + {'name': 'custom_fn2'}]] + + self.compile_params_list_invalid_metric_fn = [ + """ + loss='custom_fn1', + optimizer='Adam(lr=0.1)', + metrics=['invalid_metrics'] + """ + ] + with self.assertRaises(plpy.PLPYException) as error: + self.subject( + self.model_selection_table, + self.model_arch_table, + self.model_arch_summary_table, + self.model_id_list, + self.compile_params_list_invalid_metric_fn, + self.fit_params_list, + self.object_table + ) + self.assertIn("invalid_metrics", str(error.exception).lower()) + + def tearDown(self): + self.module_patcher.stop() + if __name__ == '__main__': unittest.main() # ---------------------------------------------------------------------
