njayaram2 commented on a change in pull request #356: Keras model arch table helper functions for keras_fit() URL: https://github.com/apache/madlib/pull/356#discussion_r266102836
########## File path: src/ports/postgres/modules/convex/keras_model_arch_table.py_in ########## @@ -0,0 +1,132 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +@file keras_model_arch_table.py_in + +@brief keras model arch table management helper functions + +@namespace keras_model_arch_table +""" + +from utilities.validate_args import table_exists +from utilities.validate_args import columns_missing_from_table +from utilities.validate_args import quote_ident +from utilities.control import MinWarning +from internal.db_utils import quote_literal +from utilities.utilities import unique_string +import plpy + +class Format: + """Expected format of keras_model_arch_table. + Example uses: + + from utilities.validate_args import columns_missing_from_table + from keras_model_arch_table import Format + + # Validate names in cols list against actual table + missing_cols = columns_missing_from_table('my_arch_table', Format.col_names) + + # Get model arch from keras model arch table, without hard coding column names + sql = "SELECT {arch} FROM {table} WHERE {id} = {my_id}" + .format(arch=Format.model_arch, + table='my_arch_table', + id=Format.model_id, + my_id=1) + arch = plpy.execute(sql)[0] + + """ + col_names = ('model_id', 'model_arch', 'model_weights', '__internal_madlib_id__') + col_types = ('SERIAL PRIMARY KEY', 'JSON', 'DOUBLE PRECISION[]', 'TEXT') + (model_id, model_arch, model_weights, __internal_madlib_id__) = col_names + +@MinWarning("warning") +def _execute(sql,max_rows=0): + return plpy.execute(sql,max_rows) + +def load_keras_model(model_arch_table, model_arch,**kwargs): + model_arch_table = quote_ident(model_arch_table) + if not table_exists(model_arch_table): + col_defs = ','.join(map(' '.join, + zip(Format.col_names, + Format.col_types))) Review comment: This is probably a good candidate to refactor out as a generic helper function in `utilities.utilities.py_in` file. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services