reductionista commented on a change in pull request #356: Keras model arch table helper functions for keras_fit() URL: https://github.com/apache/madlib/pull/356#discussion_r266154748
########## File path: src/ports/postgres/modules/convex/keras_model_arch_table.py_in ########## @@ -0,0 +1,133 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +@file keras_model_arch_table.py_in + +@brief keras model arch table management helper functions + +@namespace keras_model_arch_table +""" + +from utilities.validate_args import table_exists +from utilities.validate_args import columns_missing_from_table +from utilities.validate_args import quote_ident +from utilities.control import MinWarning +from internal.db_utils import quote_literal +from utilities.utilities import unique_string +import plpy + +class Format: + """Expected format of keras_model_arch_table. + Example uses: + + from utilities.validate_args import columns_missing_from_table + from keras_model_arch_table import Format + + # Validate names in cols list against actual table + missing_cols = columns_missing_from_table('my_arch_table', Format.col_names) + + # Get model arch from keras model arch table, without hard coding column names + sql = "SELECT {arch} FROM {table} WHERE {id} = {my_id}" + .format(arch=Format.model_arch, + table='my_arch_table', + id=Format.model_id, + my_id=1) + arch = plpy.execute(sql)[0] + + """ + col_names = ('model_id', 'model_arch', 'model_weights', '__internal_madlib_id__') + col_types = ('SERIAL PRIMARY KEY', 'JSON', 'DOUBLE PRECISION[]', 'TEXT') + (model_id, model_arch, model_weights, __internal_madlib_id__) = col_names + +@MinWarning("warning") +def _execute(sql,max_rows=0): + return plpy.execute(sql,max_rows) + +def load_keras_model(keras_model_arch_table, model_arch,**kwargs): + model_arch_table = quote_ident(keras_model_arch_table) + if not table_exists(model_arch_table): + col_defs = ','.join(map(' '.join, + zip(Format.col_names, + Format.col_types))) + + sql = "CREATE TABLE {model_arch_table} ({col_defs})" \ + .format(**locals()) + + _execute(sql) + plpy.info("Created new keras model arch table {0}." \ + .format(model_arch_table)) + else: + missing_cols = columns_missing_from_table(model_arch_table, Format.col_names) + if len(missing_cols) > 0: + plpy.error("Invalid keras model arch table {0}," + " missing columns: {1}".format(model_arch_table, missing_cols)) + + unique_str = unique_string(prefix_has_temp=False) + + sql = """INSERT INTO {model_arch_table} ({model_arch_col}, {internal_id_col}) + VALUES({model_arch}, '{unique_str}'); + SELECT model_id, model_arch + FROM {model_arch_table} WHERE {internal_id_col} = '{unique_str}' + """.format(model_arch_table=model_arch_table, + model_arch_col=Format.model_arch, + unique_str=unique_str, + model_arch=quote_literal(model_arch), + internal_id_col=Format.__internal_madlib_id__) + +# This code works perfectly in postgres 8.3+, but fails in Greenplum 5 with: +# ERROR: The RETURNING clause of the INSERT statement is not supported in this version +# of Greenplum Database. +# sql = """INSERT INTO {model_arch_table} +# (model_arch) +# VALUES('{model_arch}') +# RETURNING *;""".format(model_arch_table=model_arch_table, +# model_arch=quote_literal(model_arch)) + res = _execute(sql,1) + + if len(res) != 1 or res[0][Format.model_arch] != model_arch: + raise Exception("Failed to insert new row in {0} table--try again?" + .format(model_arch_table)) + plpy.info("Added model id {0} to {1} table".format(res[0]['model_id'], model_arch_table)) + +def delete_keras_model(keras_model_arch_table, model_id, **kwargs): + model_arch_table = quote_ident(keras_model_arch_table) + if not table_exists(model_arch_table): + plpy.error("Table {0} does not exist.".format(model_arch_table)) + + missing_cols = columns_missing_from_table(model_arch_table, Format.col_names) + if len(missing_cols) > 0: + plpy.error("Invalid keras model arch table {0}," + " missing columns: {1}".format(model_arch_table, missing_cols)) + + sql = """ + DELETE FROM {model_arch_table} WHERE model_id={model_id} + """.format(model_arch_table=model_arch_table, model_id=model_id) + res = _execute(sql) + + if res.nrows() > 0: + plpy.info("Model id {0} has been deleted from {1}.".format(model_id, model_arch_table)) + else: + plpy.info("Model id {0} not found".format(model_id)) Review comment: Oh, actually maybe plpy.warn()? I think I like that more than error or info... warn the user, but don't throw an exception. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services