This is an automated email from the ASF dual-hosted git repository. khannaekta pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 918256db8da56e7de690e3770c02fbf8afafb9ad Author: Orhan Kislal <[email protected]> AuthorDate: Wed Sep 23 19:53:55 2020 +0300 DL: Add a helper function to load custom top n accuracy functions JIRA: MADLIB-1452 This commit enables the top_n_accuracy metric. The current parser cannot use top_n_accuracy(k=3) format because we don't want to run eval for security reasons. Instead, we add a helper function so that the user can easily create a custom top_n_accuracy function. --- .../madlib_keras_custom_function.py_in | 85 +++++++++++- .../madlib_keras_custom_function.sql_in | 149 ++++++++++++++++++--- .../deep_learning/madlib_keras_wrapper.py_in | 10 +- .../test/madlib_keras_custom_function.sql_in | 105 +++++++++------ .../test/madlib_keras_model_averaging_e2e.sql_in | 10 +- .../test/madlib_keras_model_selection_e2e.sql_in | 10 +- .../test/unit_tests/test_madlib_keras.py_in | 7 + 7 files changed, 297 insertions(+), 79 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in index 23e16f6..e500970 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in @@ -60,7 +60,6 @@ def _validate_object(object, **kwargs): except Exception as e: plpy.error("{0}: Invalid function object".format(module_name, e)) -@MinWarning("error") def load_custom_function(object_table, object, name, description=None, **kwargs): object_table = quote_ident(object_table) _validate_object(object) @@ -74,7 +73,19 @@ def load_custom_function(object_table, object, name, description=None, **kwargs) .format(object_table, col_defs, CustomFunctionSchema.FN_NAME) plpy.execute(sql, 0) - plpy.info("{0}: Created new custom function table {1}." \ + # Using plpy.notice here as this function can be called: + # 1. Directly by the user, we do want to display to the user + # if we create a new table or later the function name that + # is added to the table + # 2. From load_top_k_accuracy_function, since plpy.info + # displays the query context when called from the function + # there is a very verbose output and cannot be suppressed with + # MinWarning decorator as INFO is always displayed irrespective + # of what the decorator sets the client_min_messages to. + # Therefore, instead we print this information as a NOTICE + # when called directly by the user and suppress it by setting + # MinWarning decorator to 'error' level in the calling function. + plpy.notice("{0}: Created new custom function table {1}." \ .format(module_name, object_table)) else: missing_cols = columns_missing_from_table(object_table, @@ -98,10 +109,9 @@ def load_custom_function(object_table, object, name, description=None, **kwargs) plpy.error("Function '{0}' already exists in {1}".format(name, object_table)) plpy.error(e) - plpy.info("{0}: Added function {1} to {2} table". + plpy.notice("{0}: Added function {1} to {2} table". format(module_name, name, object_table)) -@MinWarning("error") def delete_custom_function(object_table, id=None, name=None, **kwargs): object_table = quote_ident(object_table) input_tbl_valid(object_table, "Keras Custom Funtion") @@ -126,7 +136,7 @@ def delete_custom_function(object_table, id=None, name=None, **kwargs): res = plpy.execute(sql, 0) if res.nrows() > 0: - plpy.info("{0}: Object id {1} has been deleted from {2}.". + plpy.notice("{0}: Object id {1} has been deleted from {2}.". format(module_name, id, object_table)) else: plpy.error("{0}: Object id {1} not found".format(module_name, id)) @@ -134,7 +144,7 @@ def delete_custom_function(object_table, id=None, name=None, **kwargs): sql = "SELECT {0} FROM {1}".format(CustomFunctionSchema.FN_ID, object_table) res = plpy.execute(sql, 0) if not res: - plpy.info("{0}: Dropping empty custom keras function table " \ + plpy.notice("{0}: Dropping empty custom keras function table " \ "table {1}".format(module_name, object_table)) sql = "DROP TABLE {0}".format(object_table) plpy.execute(sql, 0) @@ -146,6 +156,27 @@ def update_builtin_metrics(builtin_metrics): builtin_metrics.append('ce') return builtin_metrics +@MinWarning("error") +def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs): + + object_table = quote_ident(object_table) + _assert(k > 0, + "{0}: For top k accuracy functions k has to be a positive integer.".format(module_name)) + fn_name = "top_{k}_accuracy".format(**locals()) + + sql = """ + SELECT {schema_madlib}.load_custom_function(\'{object_table}\', + {schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'), + \'{fn_name}\', + \'returns {fn_name}\'); + """.format(**locals()) + plpy.execute(sql) + # As this function allocates the name for the top_k_accuracy function, + # printing it out here so the user doesn't need to lookup for the + # newly added custom function name in the object_table + plpy.info("{0}: Added function \'{1}\' to \'{2}\' table". + format(module_name, fn_name, object_table)) + return class KerasCustomFunctionDocumentation: @staticmethod @@ -250,3 +281,45 @@ class KerasCustomFunctionDocumentation: return KerasCustomFunctionDocumentation._returnHelpMsg( schema_madlib, message, summary, usage, method) + + @staticmethod + def load_top_k_accuracy_function_help(schema_madlib, message): + method = "load_top_k_accuracy_function" + summary = """ + ---------------------------------------------------------------- + SUMMARY + ---------------------------------------------------------------- + The user can specify a custom n value for top_n_accuracy metric. + If the output table already exists, the custom function specified + will be added as a new row into the table. The output table could + thus act as a repository of Keras custom functions. + + For more details on function usage: + SELECT {schema_madlib}.{method}('usage') + """.format(**locals()) + + usage = """ + --------------------------------------------------------------------------- + USAGE + --------------------------------------------------------------------------- + SELECT {schema_madlib}.{method}( + object_table, -- VARCHAR. Output table to load custom function. + k -- INTEGER. The number of samples for top n accuracy + ); + + + --------------------------------------------------------------------------- + OUTPUT + --------------------------------------------------------------------------- + The output table produced by load_top_k_accuracy_function contains the following columns: + + 'id' -- SERIAL. Function ID. + 'name' -- TEXT PRIMARY KEY. unique function name. + 'description' -- TEXT. function description. + 'object' -- BYTEA. dill pickled function object. + + """.format(**locals()) + + return KerasCustomFunctionDocumentation._returnHelpMsg( + schema_madlib, message, summary, usage, method) + # --------------------------------------------------------------------- diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in index 01523f3..acdaa28 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in @@ -38,6 +38,7 @@ Interface and implementation are subject to change. </em> <div class="toc"><b>Contents</b><ul> <li class="level1"><a href="#load_function">Load Function</a></li> <li class="level1"><a href="#delete_function">Delete Function</a></li> +<li class="level1"><a href="#top_n_function">Top n Function</a></li> <li class="level1"><a href="#example">Examples</a></li> <li class="level1"><a href="#literature">Literature</a></li> <li class="level1"><a href="#related">Related Topics</a></li> @@ -45,10 +46,10 @@ Interface and implementation are subject to change. </em> This utility function loads custom Python functions into a table for use by deep learning algorithms. -Custom functions can be useful if, for example, you need loss functions +Custom functions can be useful if, for example, you need loss functions or metrics that are not built into the standard libraries. -The functions to be loaded must be in the form of serialized Python objects -created using Dill, which extends Python's pickle module to the majority +The functions to be loaded must be in the form of serialized Python objects +created using Dill, which extends Python's pickle module to the majority of the built-in Python types [1]. There is also a utility function to delete a function @@ -69,8 +70,8 @@ load_custom_function( <dl class="arglist"> <dt>object table</dt> <dd>VARCHAR. Table to load serialized Python objects. If this table - does not exist, it will be created. If this table already - exists, a new row is inserted into the existing table. + does not exist, it will be created. If this table already + exists, a new row is inserted into the existing table. </dd> <dt>object</dt> @@ -149,10 +150,63 @@ delete_custom_function( </dd> </dl> +@anchor top_n_function +@par Top n Function + +Load a top n function with a specific n to the custom functions table. + +<pre class="syntax"> +load_top_k_accuracy_function( + object table, + k + ) +</pre> +\b Arguments +<dl class="arglist"> + <dt>object table</dt> + <dd>VARCHAR. Table to load serialized Python objects. If this table + does not exist, it will be created. If this table already + exists, a new row is inserted into the existing table. + </dd> + + <dt>k</dt> + <dd>INTEGER. k value for the top k accuracy function. + </dd> + +</dl> + +<b>Output table</b> +<br> + The output table contains the following columns: + <table class="output"> + <tr> + <th>id</th> + <td>SERIAL. Object ID. + </td> + </tr> + <tr> + <th>name</th> + <td>TEXT PRIMARY KEY. Name of the object. + Generated with the following pattern: (sparse_,)top_(n)_accuracy. + </td> + </tr> + <tr> + <th>description</th> + <td>TEXT. Description of the object (free text). + </td> + </tr> + <tr> + <th>object</th> + <td>BYTEA. Serialized Python object stored as a PostgreSQL binary data type. + </td> + </tr> + </table> +</br> + @anchor example @par Examples --# Load object using psycopg2. Psycopg is a PostgreSQL database -adapter for the Python programming language. Note need to use the +-# Load object using psycopg2. Psycopg is a PostgreSQL database +adapter for the Python programming language. Note need to use the psycopg2.Binary() method to pass as bytes. <pre class="example"> \# import database connector psycopg2 and create connection cursor @@ -163,12 +217,12 @@ cur = conn.cursor() import dill \# custom loss def squared_error(y_true, y_pred): - import keras.backend as K + import keras.backend as K return K.square(y_pred - y_true) pb_squared_error=dill.dumps(squared_error) \# custom metric def rmse(y_true, y_pred): - import keras.backend as K + import keras.backend as K return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1)) pb_rmse=dill.dumps(rmse) \# call load function @@ -182,7 +236,7 @@ List table to see objects: SELECT id, name, description FROM test_custom_function_table ORDER BY id; </pre> <pre class="result"> - id | name | description + id | name | description ----+---------------+------------------------ 1 | squared_error | squared error 2 | rmse | root mean square error @@ -194,7 +248,7 @@ RETURNS BYTEA AS $$ import dill def squared_error(y_true, y_pred): - import keras.backend as K + import keras.backend as K return K.square(y_pred - y_true) pb_squared_error=dill.dumps(squared_error) return pb_squared_error @@ -204,7 +258,7 @@ RETURNS BYTEA AS $$ import dill def rmse(y_true, y_pred): - import keras.backend as K + import keras.backend as K return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1)) pb_rmse=dill.dumps(rmse) return pb_rmse @@ -213,13 +267,13 @@ $$ language plpythonu; Now call loader: <pre class="result"> DROP TABLE IF EXISTS custom_function_table; -SELECT madlib.load_custom_function('custom_function_table', - custom_function_squared_error(), - 'squared_error', +SELECT madlib.load_custom_function('custom_function_table', + custom_function_squared_error(), + 'squared_error', 'squared error'); -SELECT madlib.load_custom_function('custom_function_table', - custom_function_rmse(), - 'rmse', +SELECT madlib.load_custom_function('custom_function_table', + custom_function_rmse(), + 'rmse', 'root mean square error'); </pre> -# Delete an object by id: @@ -228,7 +282,7 @@ SELECT madlib.delete_custom_function( 'custom_function_table', 1); SELECT id, name, description FROM custom_function_table ORDER BY id; </pre> <pre class="result"> - id | name | description + id | name | description ----+------+------------------------ 2 | rmse | root mean square error </pre> @@ -237,7 +291,19 @@ Delete an object by name: SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse'); </pre> If all objects are deleted from the table using this function, the table itself will be dropped. - +</pre> +Load top 3 accuracy function: +<pre class="example"> +DROP TABLE IF EXISTS custom_function_table; +SELECT madlib.load_top_k_accuracy_function('custom_function_table', + 3); +SELECT id, name, description FROM custom_function_table ORDER BY id; +</pre> +<pre class="result"> + id | name | description +----+----------------+------------------------ + 1 | top_3_accuracy | returns top_3_accuracy +</pre> @anchor literature @literature @@ -323,3 +389,46 @@ RETURNS VARCHAR AS $$ return madlib_keras_custom_function.KerasCustomFunctionDocumentation.delete_custom_function_help(schema_madlib, '') $$ LANGUAGE plpythonu VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +-- Top n accuracy function +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function( + object_table VARCHAR, + k INTEGER +) RETURNS VOID AS $$ + PythonFunctionBodyOnly(`deep_learning', `madlib_keras_custom_function') + with AOControl(False): + madlib_keras_custom_function.load_top_k_accuracy_function(**globals()) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function( + message VARCHAR +) RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function) + return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, message) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function() +RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function) + return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, '') +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.top_k_categorical_acc_pickled( +n INTEGER, +fn_name VARCHAR +) RETURNS BYTEA AS $$ + import dill + from keras.metrics import top_k_categorical_accuracy + + def fn(Y_true, Y_pred): + return top_k_categorical_accuracy(Y_true, + Y_pred, + k = n) + fn.__name__= fn_name + pb=dill.dumps(fn) + return pb +$$ language plpythonu +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in index 780de8a..57827c5 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in @@ -217,6 +217,9 @@ def parse_and_validate_compile_params(str_of_args, additional_params=[]): opt_name, opt_args = None, None _assert('loss' in compile_dict, "loss is a required parameter for compile") + unsupported_loss_list = ['sparse_categorical_crossentropy'] + _assert(compile_dict['loss'] not in unsupported_loss_list, + "Loss function {0} is not supported.".format(compile_dict['loss'])) validate_compile_param_types(compile_dict) _validate_metrics(compile_dict) return (opt_name, opt_args, compile_dict) @@ -226,10 +229,10 @@ def _validate_metrics(compile_dict): compile_dict['metrics'] is None or type(compile_dict['metrics']) is list, "wrong input type for compile parameter metrics: multi-output model" - "and user defined metrics are not supported yet, please pass a list") + "are not supported yet, please pass a list") if 'metrics' in compile_dict and compile_dict['metrics']: unsupported_metrics_list = ['sparse_categorical_accuracy', - 'sparse_categorical_crossentropy', 'top_k_categorical_accuracy', + 'sparse_categorical_crossentropy', 'sparse_top_k_categorical_accuracy'] _assert(len(compile_dict['metrics']) == 1, "Only one metric at a time is supported.") @@ -436,6 +439,7 @@ def get_custom_functions_list(compile_params): if local_loss and (local_loss not in [a.lower() for a in builtin_losses]): custom_fn_list.append(local_loss) if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]): - custom_fn_list.append(local_metric) + if 'top_k_categorical_accuracy' not in local_metric: + custom_fn_list.append(local_metric) return custom_fn_list diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in index ddfcc8d..520b9c9 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in @@ -31,107 +31,130 @@ m4_include(`SQLCommon.m4') ) /* Test successful table creation where no table exists */ -DROP TABLE IF EXISTS test_custom_function_table; -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum'); +DROP TABLE IF EXISTS __test_custom_function_table__; +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum'); SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be INTEGER type') - FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass + FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass AND attname = 'id'; SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA', 'object column should be BYTEA type' ) - FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass + FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass AND attname = 'object'; SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT', 'name column should be TEXT type') - FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass + FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass AND attname = 'name'; SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT', 'description column should be TEXT type') - FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass + FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass AND attname = 'description'; /* id should be 1 */ SELECT assert(id = 1, 'Wrong id written by load_custom_function') - FROM test_custom_function_table; + FROM __test_custom_function_table__; /* Validate function object created */ SELECT assert(read_custom_function(object, 2, 3) = 5, 'Custom function should return sum of args.') - FROM test_custom_function_table; + FROM __test_custom_function_table__; /* Test custom function insertion where valid table exists */ -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1'); +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1'); SELECT assert(name = 'sum_fn', 'Custom function sum_fn found in table.') - FROM test_custom_function_table WHERE id = 1; + FROM __test_custom_function_table__ WHERE id = 1; SELECT assert(name = 'sum_fn1', 'Custom function sum_fn1 found in table.') - FROM test_custom_function_table WHERE id = 2; + FROM __test_custom_function_table__ WHERE id = 2; /* Test adding an existing function name should error out */ SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$ - SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1'); + SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1'); $TRAP$) = 1, 'Should error out for duplicate function name'); /* Test deletion by id where valid table exists */ /* Assert id exists before deleting */ SELECT assert(COUNT(id) = 1, 'id 2 should exist before deletion!') - FROM test_custom_function_table WHERE id = 2; -SELECT delete_custom_function('test_custom_function_table', 2); + FROM __test_custom_function_table__ WHERE id = 2; +SELECT delete_custom_function('__test_custom_function_table__', 2); SELECT assert(COUNT(id) = 0, 'id 2 should have been deleted!') - FROM test_custom_function_table WHERE id = 2; + FROM __test_custom_function_table__ WHERE id = 2; /* Test deletion by name where valid table exists */ -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1'); +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1'); /* Assert id exists before deleting */ SELECT assert(COUNT(id) = 1, 'function name sum_fn1 should exist before deletion!') - FROM test_custom_function_table WHERE name = 'sum_fn1'; -SELECT delete_custom_function('test_custom_function_table', 'sum_fn1'); + FROM __test_custom_function_table__ WHERE name = 'sum_fn1'; +SELECT delete_custom_function('__test_custom_function_table__', 'sum_fn1'); SELECT assert(COUNT(id) = 0, 'function name sum_fn1 should have been deleted!') - FROM test_custom_function_table WHERE name = 'sum_fn1'; + FROM __test_custom_function_table__ WHERE name = 'sum_fn1'; /* Test deleting an already deleted entry should error out */ SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$ - SELECT delete_custom_function('test_custom_function_table', 2); + SELECT delete_custom_function('__test_custom_function_table__', 2); $TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist'); /* Test delete drops the table after deleting last entry*/ -DROP TABLE IF EXISTS test_custom_function_table; -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum'); -SELECT delete_custom_function('test_custom_function_table', 1); -SELECT assert(COUNT(relname) = 0, 'Table test_custom_function_table should have been deleted.') - FROM pg_class where relname='test_custom_function_table'; +DROP TABLE IF EXISTS __test_custom_function_table__; +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum'); +SELECT delete_custom_function('__test_custom_function_table__', 1); +SELECT assert(COUNT(relname) = 0, 'Table __test_custom_function_table__ should have been deleted.') + FROM pg_class where relname='__test_custom_function_table__'; /* Test deletion where empty table exists */ -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum'); -DELETE FROM test_custom_function_table; -SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1, +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum'); +DELETE FROM __test_custom_function_table__; +SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1, 'Deleting function in an empty table should generate an exception.'); /* Test deletion where no table exists */ -DROP TABLE IF EXISTS test_custom_function_table; -SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1, +DROP TABLE IF EXISTS __test_custom_function_table__; +SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1, 'Deleting a non-existent table should raise exception.'); /* Test where invalid table exists */ -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum'); -ALTER TABLE test_custom_function_table DROP COLUMN id; -SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 2)$$) = 1, +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum'); +ALTER TABLE __test_custom_function_table__ DROP COLUMN id; +SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 2)$$) = 1, 'Deleting an invalid table should generate an exception.'); -SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1, +SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1, 'Passing an invalid table to load_custom_function() should raise exception.'); /* Test input validation */ -DROP TABLE IF EXISTS test_custom_function_table; +DROP TABLE IF EXISTS __test_custom_function_table__; SELECT assert(MADLIB_SCHEMA.trap_error($$ - SELECT load_custom_function('test_custom_function_table', custom_function_object(), NULL, NULL); + SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), NULL, NULL); $$) = 1, 'Name cannot be NULL'); SELECT assert(MADLIB_SCHEMA.trap_error($$ - SELECT load_custom_function('test_custom_function_table', NULL, 'sum_fn', NULL); + SELECT load_custom_function('__test_custom_function_table__', NULL, 'sum_fn', NULL); $$) = 1, 'Function object cannot be NULL'); SELECT assert(MADLIB_SCHEMA.trap_error($$ - SELECT load_custom_function('test_custom_function_table', 'invalid_obj'::bytea, 'sum_fn', NULL); + SELECT load_custom_function('__test_custom_function_table__', 'invalid_obj'::bytea, 'sum_fn', NULL); $$) = 1, 'Invalid custom function object'); -SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', NULL); +SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', NULL); SELECT assert(name IS NOT NULL AND description IS NULL, 'validate name is not NULL.') - FROM test_custom_function_table; + FROM __test_custom_function_table__; SELECT assert(MADLIB_SCHEMA.trap_error($$ - SELECT delete_custom_function('test_custom_function_table', NULL); + SELECT delete_custom_function('__test_custom_function_table__', NULL); $$) = 1, 'id/name cannot be NULL!'); + +/* Test top n accuracy */ + +DROP TABLE IF EXISTS __test_custom_function_table__; +SELECT load_top_k_accuracy_function('__test_custom_function_table__', 3); +SELECT load_top_k_accuracy_function('__test_custom_function_table__', 7); +SELECT load_top_k_accuracy_function('__test_custom_function_table__', 4); +SELECT load_top_k_accuracy_function('__test_custom_function_table__', 8); + +SELECT assert(count(*) = 4, 'Table __test_custom_function_table__ should have 4 entries') +FROM __test_custom_function_table__; + +SELECT assert(name = 'top_3_accuracy', 'Top 3 accuracy name is incorrect') +FROM __test_custom_function_table__ WHERE id = 1; + +SELECT assert(name = 'top_7_accuracy', 'Top 7 accuracy name is incorrect') +FROM __test_custom_function_table__ WHERE id = 2; + +SELECT assert(name = 'top_4_accuracy', 'Top 4 accuracy name is incorrect') +FROM __test_custom_function_table__ WHERE id = 3; + +SELECT assert(name = 'top_8_accuracy', 'Top 8 accuracy name is incorrect') +FROM __test_custom_function_table__ WHERE id = 4; diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in index fecd19f..b002550 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in @@ -175,12 +175,14 @@ SELECT madlib_keras_fit( 'test_custom_function_table' ); DROP TABLE if exists iris_model, iris_model_summary, iris_model_info; +-- Test for load_top_k_accuracy with a custom k value +SELECT load_top_k_accuracy_function('test_custom_function_table', 3); SELECT madlib_keras_fit( 'iris_data_packed', 'iris_model', 'iris_model_arch', 1, - $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text, + $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text, $$ batch_size=2, epochs=1, verbose=0 $$::text, 3, FALSE, NULL, 1, NULL, NULL, NULL, @@ -203,13 +205,13 @@ SELECT assert( object_table = 'test_custom_function_table' AND model_size > 0 AND madlib_version is NOT NULL AND - compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text AND + compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text AND fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND num_iterations = 3 AND metrics_compute_frequency = 1 AND num_classes = 3 AND class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND - metrics_type = '{test_custom_fn1}' AND + metrics_type = '{top_3_accuracy}' AND array_upper(training_metrics, 1) = 3 AND training_loss = '{0,0,0}' AND array_upper(metrics_elapsed_time, 1) = 3 , @@ -230,7 +232,7 @@ SELECT madlib_keras_evaluate( SELECT assert(loss >= 0 AND metric >= 0 AND - metrics_type = '{test_custom_fn1}' AND + metrics_type = '{top_3_accuracy}' AND loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out)) FROM evaluate_out; SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END; diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in index c4c0315..b9b775c 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in @@ -166,21 +166,21 @@ SELECT assert(loss >= 0 AND metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out)) FROM evaluate_out; --- TEST custom loss function +-- TEST custom loss function and DROP TABLE IF EXISTS test_custom_function_table; SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn'); -SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1'); -- Prepare model selection table with four rows DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary; +SELECT load_top_k_accuracy_function('test_custom_function_table', 4); SELECT load_model_selection_table( 'iris_model_arch', 'mst_object_table', ARRAY[1], ARRAY[ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$, - $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['test_custom_fn1']$$ + $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['top_4_accuracy']$$ ], ARRAY[ $$batch_size=16, epochs=1$$ @@ -222,7 +222,7 @@ SELECT assert( model_type = 'madlib_keras' AND model_size > 0 AND fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND - metrics_type = '{test_custom_fn1}' AND + metrics_type = '{top_4_accuracy}' AND training_metrics_final >= 0 AND training_loss_final = 0 AND training_loss = '{0,0,0}' AND @@ -259,7 +259,7 @@ SELECT madlib_keras_evaluate( SELECT assert(loss = 0 AND metric >= 0 AND - metrics_type = '{test_custom_fn1}' AND + metrics_type = '{top_4_accuracy}' AND loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out)) FROM evaluate_out; diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in index 4ccf2bd..e69bab4 100644 --- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in @@ -1092,6 +1092,13 @@ class MadlibKerasWrapperTestCase(unittest.TestCase): with self.assertRaises(plpy.PLPYException): self.subject.parse_and_validate_compile_params(test_str) + def test_parse_and_validate_compile_params_unsupported_loss_fail(self): + test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \ + "metrics=['accuracy'], loss='sparse_categorical_crossentropy'" + + with self.assertRaises(plpy.PLPYException): + self.subject.parse_and_validate_compile_params(test_str) + def test_parse_and_validate_compile_params_dict_metrics_fail(self): test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \ "loss='categorical_crossentropy', metrics={'0':'accuracy'}"
