This is an automated email from the ASF dual-hosted git repository. njayaram pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit f760e728533a87e7e19c4b157efebb7ef3441650 Author: Nandish Jayaram <[email protected]> AuthorDate: Mon Mar 18 17:31:42 2019 -0700 Deep Learning: Add documentation for helper function and refactor code. JIRA: MADLIB-1306 This commit adds user and online documentation (without examples) along with addressing other comments from the code review for PR #356. As part of code refactoring, this commit also makes deep learning a new module instead of creating new .py_in and .sql_in files in existing convex module. This will make it easier to control which databases we install it in. For instance, DL will not be supported on GPDB-4.3, so we can easily avoid installing that whole module on GPDB 4.3. This will also help in better maintenance by having future DL functions in this module rather than in convex. Closes #356 --- doc/mainpage.dox.in | 1 + src/config/Modules.yml | 2 + .../modules/convex/keras_model_arch_table.py_in | 133 ------------ .../modules/convex/keras_model_arch_table.sql_in | 39 ---- .../postgres/modules/deep_learning/__init__.py_in | 0 .../deep_learning/keras_model_arch_table.py_in | 233 +++++++++++++++++++++ .../deep_learning/keras_model_arch_table.sql_in | 166 +++++++++++++++ .../test/keras_model_arch_table.ic.sql_in | 0 .../test/keras_model_arch_table.sql_in | 23 +- .../postgres/modules/utilities/utilities.py_in | 5 + .../postgres/modules/utilities/utilities.sql_in | 16 ++ 11 files changed, 427 insertions(+), 191 deletions(-) diff --git a/doc/mainpage.dox.in b/doc/mainpage.dox.in index 67f19bc..826e8d7 100644 --- a/doc/mainpage.dox.in +++ b/doc/mainpage.dox.in @@ -291,6 +291,7 @@ Interface and implementation are subject to change. @details A collection of deep learning interfaces. @{ @defgroup grp_minibatch_preprocessing_dl Mini-Batch Preprocessor for Image Data + @defgroup grp_keras_model_arch Helper Function to Load Model Architectures to Table @} @defgroup grp_bayes Naive Bayes Classification @defgroup grp_sample Random Sampling diff --git a/src/config/Modules.yml b/src/config/Modules.yml index b70c8bd..f2da10f 100644 --- a/src/config/Modules.yml +++ b/src/config/Modules.yml @@ -11,6 +11,8 @@ modules: - name: convex depends: ['utilities'] - name: crf + - name: deep_learning + depends: ['utilities'] - name: elastic_net - name: glm depends: ['utilities'] diff --git a/src/ports/postgres/modules/convex/keras_model_arch_table.py_in b/src/ports/postgres/modules/convex/keras_model_arch_table.py_in deleted file mode 100644 index f6d8b41..0000000 --- a/src/ports/postgres/modules/convex/keras_model_arch_table.py_in +++ /dev/null @@ -1,133 +0,0 @@ -# coding=utf-8 -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -@file keras_model_arch_table.py_in - -@brief keras model arch table management helper functions - -@namespace keras_model_arch_table -""" - -from utilities.validate_args import table_exists -from utilities.validate_args import columns_missing_from_table -from utilities.validate_args import quote_ident -from utilities.control import MinWarning -from internal.db_utils import quote_literal -from utilities.utilities import unique_string -import plpy - -class Format: - """Expected format of keras_model_arch_table. - Example uses: - - from utilities.validate_args import columns_missing_from_table - from keras_model_arch_table import Format - - # Validate names in cols list against actual table - missing_cols = columns_missing_from_table('my_arch_table', Format.col_names) - - # Get model arch from keras model arch table, without hard coding column names - sql = "SELECT {arch} FROM {table} WHERE {id} = {my_id}" - .format(arch=Format.model_arch, - table='my_arch_table', - id=Format.model_id, - my_id=1) - arch = plpy.execute(sql)[0] - - """ - col_names = ('model_id', 'model_arch', 'model_weights', '__internal_madlib_id__') - col_types = ('SERIAL PRIMARY KEY', 'JSON', 'DOUBLE PRECISION[]', 'TEXT') - (model_id, model_arch, model_weights, __internal_madlib_id__) = col_names - -@MinWarning("warning") -def _execute(sql,max_rows=0): - return plpy.execute(sql,max_rows) - -def load_keras_model(keras_model_arch_table, model_arch,**kwargs): - model_arch_table = quote_ident(keras_model_arch_table) - if not table_exists(model_arch_table): - col_defs = ','.join(map(' '.join, - zip(Format.col_names, - Format.col_types))) - - sql = "CREATE TABLE {model_arch_table} ({col_defs})" \ - .format(**locals()) - - _execute(sql) - plpy.info("Created new keras model arch table {0}." \ - .format(model_arch_table)) - else: - missing_cols = columns_missing_from_table(model_arch_table, Format.col_names) - if len(missing_cols) > 0: - plpy.error("Invalid keras model arch table {0}," - " missing columns: {1}".format(model_arch_table, missing_cols)) - - unique_str = unique_string(prefix_has_temp=False) - - sql = """INSERT INTO {model_arch_table} ({model_arch_col}, {internal_id_col}) - VALUES({model_arch}, '{unique_str}'); - SELECT model_id, model_arch - FROM {model_arch_table} WHERE {internal_id_col} = '{unique_str}' - """.format(model_arch_table=model_arch_table, - model_arch_col=Format.model_arch, - unique_str=unique_str, - model_arch=quote_literal(model_arch), - internal_id_col=Format.__internal_madlib_id__) - -# This code works perfectly in postgres 8.3+, but fails in Greenplum 5 with: -# ERROR: The RETURNING clause of the INSERT statement is not supported in this version -# of Greenplum Database. -# sql = """INSERT INTO {model_arch_table} -# (model_arch) -# VALUES('{model_arch}') -# RETURNING *;""".format(model_arch_table=model_arch_table, -# model_arch=quote_literal(model_arch)) - res = _execute(sql,1) - - if len(res) != 1 or res[0][Format.model_arch] != model_arch: - raise Exception("Failed to insert new row in {0} table--try again?" - .format(model_arch_table)) - plpy.info("Added model id {0} to {1} table".format(res[0]['model_id'], model_arch_table)) - -def delete_keras_model(keras_model_arch_table, model_id, **kwargs): - model_arch_table = quote_ident(keras_model_arch_table) - if not table_exists(model_arch_table): - plpy.error("Table {0} does not exist.".format(model_arch_table)) - - missing_cols = columns_missing_from_table(model_arch_table, Format.col_names) - if len(missing_cols) > 0: - plpy.error("Invalid keras model arch table {0}," - " missing columns: {1}".format(model_arch_table, missing_cols)) - - sql = """ - DELETE FROM {model_arch_table} WHERE model_id={model_id} - """.format(model_arch_table=model_arch_table, model_id=model_id) - res = _execute(sql) - - if res.nrows() > 0: - plpy.info("Model id {0} has been deleted from {1}.".format(model_id, model_arch_table)) - else: - plpy.info("Model id {0} not found".format(model_id)) - - sql = "SELECT model_id FROM {model_arch_table}".format(model_arch_table=model_arch_table) - res = _execute(sql) - if not res or len(res) == 0: - plpy.info("Removing empty keras model arch table {model_arch_table}".format(model_arch_table=model_arch_table)) - sql = "DROP TABLE {model_arch_table}".format(model_arch_table=model_arch_table) - _execute(sql) diff --git a/src/ports/postgres/modules/convex/keras_model_arch_table.sql_in b/src/ports/postgres/modules/convex/keras_model_arch_table.sql_in deleted file mode 100644 index b27b979..0000000 --- a/src/ports/postgres/modules/convex/keras_model_arch_table.sql_in +++ /dev/null @@ -1,39 +0,0 @@ -/* ----------------------------------------------------------------------- *//** - * - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - * - * - * @file model_arch_table.sql_in - * - * @brief SQL functions for multilayer perceptron - * @date June 2012 - * - * - *//* ----------------------------------------------------------------------- */ - -m4_include(`SQLCommon.m4') - -CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model(keras_model_arch_table VARCHAR, model_arch JSON) -RETURNS VOID AS $$ - PythonFunction(`convex',`keras_model_arch_table',`load_keras_model') - $$ LANGUAGE plpythonu VOLATILE; - -CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_keras_model(keras_model_arch_table VARCHAR, model_id INTEGER) -RETURNS VOID AS $$ - PythonFunction(`convex',`keras_model_arch_table',`delete_keras_model') -$$ LANGUAGE plpythonu VOLATILE; diff --git a/src/ports/postgres/modules/deep_learning/__init__.py_in b/src/ports/postgres/modules/deep_learning/__init__.py_in new file mode 100644 index 0000000..e69de29 diff --git a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in new file mode 100644 index 0000000..a0521ee --- /dev/null +++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.py_in @@ -0,0 +1,233 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +@file keras_model_arch_table.py_in + +@brief keras model arch table management helper functions + +@namespace keras_model_arch_table +""" + +from internal.db_utils import quote_literal +import plpy +from utilities.control import MinWarning +from utilities.utilities import get_col_name_type_sql_string +from utilities.utilities import unique_string +from utilities.validate_args import columns_missing_from_table +from utilities.validate_args import input_tbl_valid +from utilities.validate_args import quote_ident +from utilities.validate_args import table_exists + +class Format: + """Expected format of keras_model_arch_table. + Example uses: + + from utilities.validate_args import columns_missing_from_table + from keras_model_arch_table import Format + + # Validate names in cols list against actual table + missing_cols = columns_missing_from_table('my_arch_table', Format.col_names) + + # Get model arch from keras model arch table, without hard coding column names + sql = "SELECT {arch} FROM {table} WHERE {id} = {my_id}" + .format(arch=Format.model_arch, + table='my_arch_table', + id=Format.model_id, + my_id=1) + arch = plpy.execute(sql)[0] + + """ + col_names = ('model_id', 'model_arch', 'model_weights', '__internal_madlib_id__') + col_types = ('SERIAL PRIMARY KEY', 'JSON', 'DOUBLE PRECISION[]', 'TEXT') + (model_id, model_arch, model_weights, __internal_madlib_id__) = col_names + +@MinWarning("warning") +def _execute(sql,max_rows=0): + return plpy.execute(sql,max_rows) + +def load_keras_model(schema_madlib, keras_model_arch_table, + model_arch, **kwargs): + model_arch_table = quote_ident(keras_model_arch_table) + if not table_exists(model_arch_table): + col_defs = get_col_name_type_sql_string(Format.col_names, + Format.col_types) + + sql = "CREATE TABLE {model_arch_table} ({col_defs})" \ + .format(**locals()) + + _execute(sql) + plpy.info("Keras Model Arch: Created new keras model arch table {0}." \ + .format(model_arch_table)) + else: + missing_cols = columns_missing_from_table(model_arch_table, + Format.col_names) + if len(missing_cols) > 0: + plpy.error("Keras Model Arch: Invalid keras model arch table {0}," + " missing columns: {1}".format(model_arch_table, + missing_cols)) + + unique_str = unique_string(prefix_has_temp=False) + + sql = """INSERT INTO {model_arch_table} ({model_arch_col}, {internal_id_col}) + VALUES({model_arch}, '{unique_str}'); + SELECT {model_id_col}, {model_arch_col} + FROM {model_arch_table} WHERE {internal_id_col} = '{unique_str}' + """.format(model_arch_table=model_arch_table, + model_arch_col=Format.model_arch, + unique_str=unique_str, + model_arch=quote_literal(model_arch), + model_id_col=Format.model_id, + internal_id_col=Format.__internal_madlib_id__) + res = _execute(sql,1) + + if len(res) != 1 or res[0][Format.model_arch] != model_arch: + raise Exception("Failed to insert new row in {0} table--try again?" + .format(model_arch_table)) + plpy.info("Keras Model Arch: Added model id {0} to {1} table". + format(res[0]['model_id'], model_arch_table)) + +def delete_keras_model(schema_madlib, keras_model_arch_table, + model_id, **kwargs): + model_arch_table = quote_ident(keras_model_arch_table) + input_tbl_valid(model_arch_table, "Keras Model Arch") + + missing_cols = columns_missing_from_table(model_arch_table, Format.col_names) + if len(missing_cols) > 0: + plpy.error("Keras Model Arch: Invalid keras model arch table {0}," + " missing columns: {1}".format(model_arch_table, missing_cols)) + + sql = """ + DELETE FROM {model_arch_table} WHERE model_id={model_id} + """.format(model_arch_table=model_arch_table, model_id=model_id) + res = _execute(sql) + + if res.nrows() > 0: + plpy.info("Keras Model Arch: Model id {0} has been deleted from {1}.". + format(model_id, model_arch_table)) + else: + plpy.error("Keras Model Arch: Model id {0} not found".format(model_id)) + + sql = "SELECT model_id FROM {0}".format(model_arch_table) + res = _execute(sql) + if not res: + plpy.info("Keras Model Arch: Dropping empty keras model arch "\ + "table {model_arch_table}".format(model_arch_table=model_arch_table)) + sql = "DROP TABLE {0}".format(model_arch_table) + try: + _execute(sql) + except plpy.SPIError, e: + plpy.warning("Keras Model Arch: Unable to drop empty keras model "\ + "arch table {0}".format(model_arch_table)) + +class KerasModelArchDocumentation: + @staticmethod + def _returnHelpMsg(schema_madlib, message, summary, usage, method): + if not message: + return summary + elif message.lower() in ('usage', 'help', '?'): + return usage + return """ + No such option. Use "SELECT {schema_madlib}.{method}()" + for help. + """.format(**locals()) + + @staticmethod + def load_keras_model_help(schema_madlib, message): + method = "load_keras_model" + summary = """ + ---------------------------------------------------------------- + SUMMARY + ---------------------------------------------------------------- + The architecture of the model to be used in madlib_keras_train() + function must be stored in a table, the details of which must be + provided as parameters to the madlib_keras_train module. This is + a helper function to help users insert JSON blobs of Keras model + architectures into a table. + If the output table already exists, the model_arch specified will + be added as a new row into the table. The output table could thus + act as a repository of Keras model architectures. + + For more details on function usage: + SELECT {schema_madlib}.{method}('usage') + """.format(**locals()) + + usage = """ + --------------------------------------------------------------------------- + USAGE + --------------------------------------------------------------------------- + SELECT {schema_madlib}.{method}( + keras_model_arch_table VARCHAR, -- Output table to load keras model arch. + model_arch JSON -- JSON of the model architecture to insert. + ); + + + --------------------------------------------------------------------------- + OUTPUT + --------------------------------------------------------------------------- + The output table produced by load_keras_model contains the following columns: + + 'model_id' -- SERIAL PRIMARY KEY. Model ID. + 'model_arch' -- JSON. JSON blob of the model architecture. + 'model_weights' -- DOUBLE PRECISION[]. weights of the model for warm start. + -- This is currently NULL. + '__internal_madlib_id__' -- TEXT. Unique id for model arch. + + """.format(**locals()) + + return KerasModelArchDocumentation._returnHelpMsg( + schema_madlib, message, summary, usage, method) +# --------------------------------------------------------------------- + + @staticmethod + def delete_keras_model_help(schema_madlib, message): + method = "delete_keras_model" + summary = """ + ---------------------------------------------------------------- + SUMMARY + ---------------------------------------------------------------- + Delete the model architecture corresponding to the provided model_id + from the model architecture repository table (keras_model_arch_table). + + For more details on function usage: + SELECT {schema_madlib}.{method}('usage') + """.format(**locals()) + + usage = """ + --------------------------------------------------------------------------- + USAGE + --------------------------------------------------------------------------- + SELECT {schema_madlib}.{method}( + keras_model_arch_table VARCHAR, -- Table containing Keras model architectures. + model_id INTEGER -- The id of the model arch to be deleted. + ); + + + --------------------------------------------------------------------------- + OUTPUT + --------------------------------------------------------------------------- + This method deletes the row corresponding to the given model_id in + keras_model_arch_table. This also tries to drop the table if the table is + empty after dropping the model_id. If there are any views depending on the + table, a warning message is displayed and the table is not dropped. + + --------------------------------------------------------------------------- + """.format(**locals()) + + return KerasModelArchDocumentation._returnHelpMsg( + schema_madlib, message, summary, usage, method) diff --git a/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in new file mode 100644 index 0000000..7626107 --- /dev/null +++ b/src/ports/postgres/modules/deep_learning/keras_model_arch_table.sql_in @@ -0,0 +1,166 @@ +/* ----------------------------------------------------------------------- *//** + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + * + * @file model_arch_table.sql_in + * + * @brief SQL functions for multilayer perceptron + * @date June 2012 + * + * + *//* ----------------------------------------------------------------------- */ + +m4_include(`SQLCommon.m4') +/** +@addtogroup grp_keras_model_arch + +<div class="toc"><b>Contents</b><ul> +<li class="level1"><a href="#load_keras_model">Helper Function to Load Model Architectures to Table</a></li> +<li class="level1"><a href="#delete_keras_model">Helper Function to Delete Model Architectures from Table</a></li> +<li class="level1"><a href="#example">Examples</a></li> +</ul></div> + +The architecture of the model to be used in madlib_keras_train() +function must be stored in a table, the details of which must be +provided as parameters to the madlib_keras_train module. load_keras_model is +a helper function to help users insert JSON blobs of Keras model +architectures into a table. If the output table already exists, the model_arch +specified will be added as a new row into the table. The output table could thus +act as a repository of Keras model architectures. + +delete_keras_model can be used to delete the model architecture corresponding +to the provided model_id from the model architecture repository table (keras_model_arch_table). + +<pre class="syntax"> +load_keras_model( + keras_model_arch_table + model_arch +) +</pre> +\b Arguments +<dl class="arglist"> + <dt>keras_model_arch_table</dt> + <dd>VARCHAR. Output table to load keras model arch. + </dd> + + <dt>model_arch</dt> + <dd>JSON. JSON of the model architecture to insert. + </dd> +</dl> + +<b>Output table</b> +<br> + The output table produced by load_keras_model contains the following columns: + <table class="output"> + <tr> + <th>model_id</th> + <td>SERIAL PRIMARY KEY. Model ID. + </td> + </tr> + <tr> + <th>model_arch</th> + <td>JSON. JSON blob of the model architecture. + </td> + </tr> + <tr> + <th>model_weights</th> + <td>DOUBLE PRECISION[]. Weights of the model for warm start. + </td> + </tr> + <tr> + <th>__internal_madlib_id__</th> + <td>TEXT. Unique id for model arch. + </td> + </tr> + </table> +</br> + + +<pre class="syntax"> +delete_keras_model( + keras_model_arch_table + model_id +) +</pre> +\b Arguments +<dl class="arglist"> + <dt>keras_model_arch_table</dt> + <dd>VARCHAR. Table containing Keras model architectures. + </dd> + + <dt>model_id</dt> + <dd>INTEGER. The id of the model arch to be deleted. + </dd> +</dl> + +@anchor example +@par Examples +-# TBD + +*/ + +-- Function to add a keras model to arch table +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model( + keras_model_arch_table VARCHAR, + model_arch JSON +) +RETURNS VOID AS $$ + PythonFunction(`deep_learning',`keras_model_arch_table',`load_keras_model') +$$ LANGUAGE plpythonu VOLATILE; + +-- Functions for online help +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model( + message VARCHAR +) RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(deep_learning, keras_model_arch_table) + return keras_model_arch_table.KerasModelArchDocumentation.load_keras_model_help(schema_madlib, message) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_keras_model() +RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(deep_learning, keras_model_arch_table) + return keras_model_arch_table.KerasModelArchDocumentation.load_keras_model_help(schema_madlib, '') +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +-- Function to delete a keras model from arch table +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_keras_model( + keras_model_arch_table VARCHAR, + model_id INTEGER +) +RETURNS VOID AS $$ + PythonFunction(`deep_learning',`keras_model_arch_table',`delete_keras_model') +$$ LANGUAGE plpythonu VOLATILE; + +-- Functions for online help +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_keras_model( + message VARCHAR +) RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(deep_learning, keras_model_arch_table) + return keras_model_arch_table.KerasModelArchDocumentation.delete_keras_model_help(schema_madlib, message) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); + +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.delete_keras_model() +RETURNS VARCHAR AS $$ + PythonFunctionBodyOnly(deep_learning, keras_model_arch_table) + return keras_model_arch_table.KerasModelArchDocumentation.delete_keras_model_help(schema_madlib, '') +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `'); diff --git a/src/ports/postgres/modules/convex/test/keras_model_arch_table.ic.sql_in b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.ic.sql_in similarity index 100% rename from src/ports/postgres/modules/convex/test/keras_model_arch_table.ic.sql_in rename to src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.ic.sql_in diff --git a/src/ports/postgres/modules/convex/test/keras_model_arch_table.sql_in b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in similarity index 90% rename from src/ports/postgres/modules/convex/test/keras_model_arch_table.sql_in rename to src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in index 64b34b0..e6a51cc 100644 --- a/src/ports/postgres/modules/convex/test/keras_model_arch_table.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/keras_model_arch_table.sql_in @@ -83,9 +83,6 @@ SELECT assert(COUNT(model_id) = 0, 'model id 3 should have been deleted!') /* Delete a second time, to make sure nothing weird happens. * It should archrt to the user that the model_id wasn't found but not * raise an exception or change anything. */ -SELECT delete_keras_model('test_keras_model_arch_table', 3); -SELECT assert(COUNT(model_id) = 0, 'model id 3 should have been deleted!') - FROM test_keras_model_arch_table WHERE model_id = 3; SELECT delete_keras_model('test_keras_model_arch_table', 1); SELECT assert(COUNT(relname) = 0, 'Table test_keras_model_arch_table should have been deleted.') FROM pg_class where relname = 'test_keras_model_arch_table'; @@ -93,25 +90,13 @@ SELECT assert(COUNT(relname) = 0, 'Table test_keras_model_arch_table should have SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}'); DELETE FROM test_keras_model_arch_table; -/* Test deletion where empty table exists */ -SELECT delete_keras_model('test_keras_model_arch_table', 3); -SELECT assert(COUNT(relname) = 0, 'Table test_keras_model_arch_table should have been deleted.') from pg_class where relname = 'test_keras_model_arch_table'; - /* Test deletion where invalid table exists */ - SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}'); ALTER TABLE test_keras_model_arch_table DROP COLUMN model_id; -CREATE FUNCTION trap_error(stmt TEXT) RETURNS INTEGER AS $$ -BEGIN - BEGIN - EXECUTE stmt; - EXCEPTION - WHEN OTHERS THEN - RETURN 1; - END; - RETURN 0; -END; -$$ LANGUAGE plpgsql; + +/* Test deletion where empty table exists */ +select assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1, + 'Deleting a model in an empty table should generate an exception.'); SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 1)$$) = 1, 'Deleting an invalid table should generate an exception.'); diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in index d2f14a5..f9f1fd0 100644 --- a/src/ports/postgres/modules/utilities/utilities.py_in +++ b/src/ports/postgres/modules/utilities/utilities.py_in @@ -338,6 +338,11 @@ def _string_to_array_with_quotes(s): return elm # ------------------------------------------------------------------------ +def get_col_name_type_sql_string(colnames, coltypes): + if colnames and coltypes and len(colnames)==len(coltypes): + return ','.join(map(' '.join, zip(colnames, coltypes))) + return None +# ------------------------------------------------------------------------ def py_list_to_sql_string(array, array_type=None, long_format=None): """Convert a list to SQL array string diff --git a/src/ports/postgres/modules/utilities/utilities.sql_in b/src/ports/postgres/modules/utilities/utilities.sql_in index 7035940..c028661 100644 --- a/src/ports/postgres/modules/utilities/utilities.sql_in +++ b/src/ports/postgres/modules/utilities/utilities.sql_in @@ -525,3 +525,19 @@ RETURNS BOOLEAN AS $$ PythonFunction(utilities, utilities, is_pg_major_version_less_than) $$ LANGUAGE plpythonu VOLATILE m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `'); + +-- Function to trap errors while running sql in plpy +CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.trap_error( + stmt TEXT +) +RETURNS INTEGER AS $$ +BEGIN + BEGIN + EXECUTE stmt; + EXCEPTION + WHEN OTHERS THEN + RETURN 1; + END; + RETURN 0; +END; +$$ LANGUAGE plpgsql;
