kaknikhil commented on a change in pull request #425: DL: Add training for multiple models URL: https://github.com/apache/madlib/pull/425#discussion_r310261399
########## File path: src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in ########## @@ -0,0 +1,424 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import plpy +import time +import sys +# Do not remove `import keras` although it's not directly used in this file. +# For ex if the user passes in the optimizer as keras.optimizers.SGD instead of just +# SGD, then without this import this python file won't find the SGD module +import keras + +# from keras import backend as K +# from keras import utils as keras_utils +from keras.layers import * +from keras.models import * +from keras.optimizers import * +from keras.regularizers import * +import madlib_keras_serializer +from madlib_keras import compute_loss_and_metrics +from madlib_keras import get_initial_weights +from madlib_keras import get_segments_and_gpus +from madlib_keras import get_source_summary_table_dict +from madlib_keras import reset_cuda_env +from madlib_keras_helper import * +from madlib_keras_validator import * +from madlib_keras_wrapper import * +from keras_model_arch_table import ModelArchSchema + +from utilities.control import MinWarning +from utilities.utilities import add_postfix +from utilities.utilities import rotate +from utilities.utilities import madlib_version +from utilities.utilities import is_platform_pg + +import json +from collections import defaultdict +import random +import datetime +mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL +mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL + + +class ModelSelectionSchema: + MST_KEY = 'mst_key' + MODEL_ARCH_ID = 'model_arch_id' + COMPILE_PARAMS = 'compile_params' + FIT_PARAMS = 'fit_params' + col_types = ('SERIAL', 'INTEGER', 'VARCHAR', 'VARCHAR') + +@MinWarning("warning") +class FitMultipleModel(): + def __init__(self, schema_madlib, source_table, model_output_table, + model_arch_table, model_selection_table, num_iterations, + gpus_per_host=0, **kwargs): + + if is_platform_pg(): + plpy.error("DL: Multiple model training is not supported on Postgresql.") + self.source_table = source_table + self.model_arch_table = model_arch_table + self.model_selection_table = model_selection_table + self.model_output_table = model_output_table + if self.model_output_table: + self.model_info_table = add_postfix(model_output_table, '_info') + self.model_summary_table = add_postfix( + model_output_table, '_summary') + self.num_iterations = num_iterations + self.module_name = 'madlib_keras_fit_multiple_model' + self.schema_madlib = schema_madlib + self.version = madlib_version(self.schema_madlib) + self.fit_validator = FitInputValidator( + self.source_table, None, self.model_output_table, + self.model_arch_table, mb_dep_var_col, mb_indep_var_col, + self.num_iterations, 1, False) + input_tbl_valid(self.model_selection_table, self.module_name) + output_tbl_valid(self.model_info_table, self.module_name) + self.msts = self.query_msts() + self.mst_key_col = ModelSelectionSchema.MST_KEY + self.model_arch_id_col = ModelSelectionSchema.MODEL_ARCH_ID + self.compile_params_col = ModelSelectionSchema.COMPILE_PARAMS + self.fit_params_col = ModelSelectionSchema.FIT_PARAMS + self.dist_keys = self.query_dist_keys() + self.grand_schedule = self.generate_schedule() + + self.seg_ids_train, self.images_per_seg_train = \ + get_image_count_per_seg_for_minibatched_data_from_db( + self.source_table) + self.segments_per_host, self.gpus_per_host = get_segments_and_gpus( + gpus_per_host) + self.create_model_output_table() + self.train_mst_metric_eval_time = defaultdict(list) + self.train_mst_loss = defaultdict(list) + self.train_mst_metric = defaultdict(list) + + def fit_multiple_model(self): + begin_time = time.time() + # WARNING: set orca off to prevent unwanted redistribution + plpy.execute('SET optimizer TO off') + original_cuda_env = None + if CUDA_VISIBLE_DEVICES_KEY in os.environ: + original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY] + self.start_training_time = datetime.datetime.now() + self.train_multiple_model() + self.end_training_time = datetime.datetime.now() + self.insert_info_table() + self.create_model_summary_table() + plpy.execute('RESET optimizer') + plpy.info( + "End to end execution time: {}".format(time.time() - begin_time)) + reset_cuda_env(original_cuda_env) + + def train_multiple_model(self): + self.weights_map = {} + for e in range(self.num_iterations): + plpy.info("Iteration: {}".format(e)) + is_final = (e == self.num_iterations - 1) + for i in range(len(self.msts)): + mst_row = [self.grand_schedule[dist_key][i] + for dist_key in self.dist_keys] + self.create_mst_schedule_table(mst_row) + self.run_training(is_final=True) + self.evaluate_train_output(e) + plpy.info(self.train_mst_metric_eval_time) + plpy.info(self.train_mst_loss) + plpy.info(self.train_mst_metric) + plpy.info(self.weights_map) + + def evaluate_train_output(self, epoch): + res = self.query_weights() + res_map = {x['mst_key']: x['weights'] for x in res if x['weights']} + weights_map_one_epoch = {} + for mst in self.msts: + model_arch = self.query_arch(mst[self.model_arch_id_col] + )[ModelArchSchema.MODEL_ARCH] + state = res_map[mst[self.mst_key_col]] + serialized_weights = \ + madlib_keras_serializer.get_serialized_1d_weights_from_state( + state) + loss_metric = compute_loss_and_metrics( + self.schema_madlib, self.source_table, "$madlib${}$madlib$".format( + mst[self.compile_params_col]), model_arch, + serialized_weights, 0, 3, self.seg_ids_train, + self.images_per_seg_train, [], [], epoch, True) + metric_eval_time, metric, loss = loss_metric + weights_map_one_epoch[mst[self.mst_key_col]] = loss_metric + self.train_mst_metric_eval_time[mst[self.mst_key_col]] \ + .append(metric_eval_time) + self.train_mst_loss[mst[self.mst_key_col]].append(loss) + self.train_mst_metric[mst[self.mst_key_col]].append(metric) + self.weights_map[epoch] = weights_map_one_epoch + + def generate_schedule(self): + """Summary + + Returns: + TYPE: Description + + Args: + dist_keys (TYPE): Description + msts (TYPE): Description + """ + grand_schedule = {} + for index, dist_key in enumerate(self.dist_keys): + grand_schedule[dist_key] = rotate(self.msts, index) + return grand_schedule + + def query_arch(self, model_arch_id): + model_arch_query = """ + SELECT {0}, {1} FROM {2} WHERE {3} = {4} + """.format(ModelArchSchema.MODEL_ARCH, + ModelArchSchema.MODEL_WEIGHTS, + self.model_arch_table, + ModelArchSchema.MODEL_ID, + model_arch_id) + model_arch_result = plpy.execute(model_arch_query)[0] + return model_arch_result + + def query_msts(self): + msts_query = """ + SELECT * FROM {} + ORDER BY mst_key + """.format(self.model_selection_table) + res = list(plpy.execute(msts_query)) + return res + + def query_dist_keys(self): + dist_key_query = """ + SELECT DISTINCT(dist_key) FROM {} + ORDER BY dist_key + """.format(self.source_table) + res = list(plpy.execute(dist_key_query)) + res = [x['dist_key'] for x in res] + return res + + def create_mst_schedule_table(self, mst_row): + mst_temp_query = """DROP TABLE IF EXISTS mst_current_schedule; + CREATE TABLE mst_current_schedule(model_arch_id INTEGER, + compile_params VARCHAR, + fit_params VARCHAR, + dist_key INTEGER, + mst_key INTEGER primary key) + """ + plpy.execute(mst_temp_query) + for mst, dist_key in zip(mst_row, self.dist_keys): + mst_insert_query = """ + INSERT INTO mst_current_schedule + VALUES ({}, + $madlib${}$madlib$, + $madlib${}$madlib$, + {}, + {}) + """.format(mst[self.model_arch_id_col], mst[self.compile_params_col], + mst[self.fit_params_col], dist_key, mst[self.mst_key_col]) + plpy.execute(mst_insert_query) + + def create_model_output_table(self): + output_table_create_query = """CREATE TABLE {} + (mst_key INTEGER PRIMARY KEY, + weights BYTEA, + model_arch JSON + ) + """.format(self.model_output_table) + info_table_create_query = """CREATE TABLE {} + (mst_key INTEGER PRIMARY KEY, + model_arch_id INTEGER, + compile_params TEXT, + fit_params TEXT, + model_type TEXT, + model_size DOUBLE PRECISION, + metrics_elapsed_time DOUBLE PRECISION[], + metrics_type TEXT[], + training_metrics_final DOUBLE PRECISION, + training_loss_final DOUBLE PRECISION, + training_metrics DOUBLE PRECISION[], + training_loss DOUBLE PRECISION[], + validation_metrics_final DOUBLE PRECISION, + validation_loss_final DOUBLE PRECISION, + validation_metrics DOUBLE PRECISION[], + validation_loss DOUBLE PRECISION[]) + """.format(self.model_info_table) + + plpy.execute(output_table_create_query) + plpy.execute(info_table_create_query) + for mst in self.msts: + model_arch_result = self.query_arch(mst[self.model_arch_id_col]) + model_arch = model_arch_result[ModelArchSchema.MODEL_ARCH] + random.seed(42) + serialized_weights = get_initial_weights(self.model_output_table, + model_arch_result, + False, + self.gpus_per_host + ) + model_size = sys.getsizeof(serialized_weights) / 1024.0 + model = model_from_json( + model_arch_result[ModelArchSchema.MODEL_ARCH]) + + serialized_state = \ + madlib_keras_serializer.serialize_state_with_nd_weights( + 0, model.get_weights()) + metrics_list = get_metrics_from_compile_param( + mst[self.compile_params_col]) + is_metrics_specified = True if metrics_list else False + metrics_type = 'ARRAY{0}'.format( + metrics_list) if is_metrics_specified else 'NULL' + + output_table_insert_query = """ + INSERT INTO {self.model_output_table}( + mst_key, weights, model_arch) + VALUES ({mst_key}, $1, $2) + """.format(self=self, + mst_key=mst[self.mst_key_col]) + output_table_insert_query_prepared = plpy.prepare( + output_table_insert_query, ["bytea", "json"]) + plpy.execute(output_table_insert_query_prepared, [ + serialized_state, json.dumps(model_arch)]) + info_table_insert_query = """ + INSERT INTO {}(mst_key, model_arch_id, compile_params, + fit_params, model_type, model_size, + metrics_type) + VALUES ({}, {}, $madlib${}$madlib$, + $madlib${}$madlib$, '{}', {}, {}) + """.format(self.model_info_table, + mst[self.mst_key_col], + mst[self.model_arch_id_col], + mst[self.compile_params_col], + mst[self.fit_params_col], + 'madlib_keras', + model_size, + metrics_type) + plpy.execute(info_table_insert_query) + + def create_model_summary_table(self): + src_summary_dict = get_source_summary_table_dict(self.fit_validator) + class_values = src_summary_dict['class_values'] + dep_vartype = src_summary_dict['dep_vartype'] + dependent_varname = \ + src_summary_dict['dependent_varname_in_source_table'] + independent_varname = \ + src_summary_dict['independent_varname_in_source_table'] + norm_const = src_summary_dict['norm_const'] + num_classes = len(class_values) + class_values_colname = CLASS_VALUES_COLNAME + dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME + normalizing_const_colname = NORMALIZING_CONST_COLNAME + float32_sql_type = FLOAT32_SQL_TYPE + update_query = """ + CREATE TABLE {self.model_summary_table} AS + SELECT + $MAD${self.source_table}$MAD$::TEXT AS source_table, + NULL::TEXT AS validation_table, + $MAD${self.model_output_table}$MAD$::TEXT AS model, + $MAD${self.model_info_table}$MAD$::TEXT AS model_info, + $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname, + $MAD${independent_varname}$MAD$::TEXT AS independent_varname, + $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table, + {self.num_iterations}::INTEGER AS num_iterations, + '{self.start_training_time}'::TIMESTAMP AS start_training_time, + '{self.end_training_time}'::TIMESTAMP AS end_training_time, + '{self.version}'::TEXT AS madlib_version, + {num_classes}::INTEGER AS num_classes, + ARRAY{class_values}::TEXT[] AS {class_values_colname}, + $MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname}, + {norm_const}::{float32_sql_type} AS {normalizing_const_colname} + """.format(**locals()) + plpy.execute(update_query) + + def insert_info_table(self): + for mst in self.msts: + mst_key = mst[self.mst_key_col] + training_metrics, training_metrics_final, metrics_elapsed_time = \ + "NULL", "NULL", "NULL" + training_loss_final = "NULL" + training_loss = "NULL" + if mst_key in self.train_mst_metric: + training_metrics = self.train_mst_metric[mst_key] + training_metrics_final = training_metrics[-1] + metrics_elapsed_time = self.train_mst_metric_eval_time[mst_key] + training_metrics = "ARRAY{}".format(training_metrics) + metrics_elapsed_time = "ARRAY{}".format(metrics_elapsed_time) + training_loss = self.train_mst_loss[mst_key] + training_loss_final = training_loss[-1] + training_loss = "ARRAY{}".format(training_loss) + update_query = """ + UPDATE {} SET training_metrics_final = {}, + training_loss_final = {}, + metrics_elapsed_time = {}, + training_metrics = {}, + training_loss = {} + WHERE mst_key = {} + """.format(self.model_info_table, + training_metrics_final, + training_loss_final, + metrics_elapsed_time, + training_metrics, + training_loss, + mst_key + ) + plpy.execute(update_query) + + def run_training(self, is_final): + mst_wgh = "mst_wgh" + # TODO: fix distributed by + mst_wgh_query = """DROP TABLE IF EXISTS mst_wgh; + CREATE TEMP TABLE mst_wgh AS + SELECT mst.*, wgh.weights, model_arch.model_arch + FROM mst_current_schedule mst JOIN {} wgh ON mst.mst_key = wgh.mst_key + JOIN {} model_arch ON mst.model_arch_id = model_arch.model_id + DISTRIBUTED BY (dist_key) + """.format(self.model_output_table, self.model_arch_table) + plpy.execute(mst_wgh_query) + # WARNING weights_one_schedule.weights is the state returned by + # the step function, which is img_count and weights concatnated + mlp_uda_query = """ + UPDATE {self.model_output_table} SET weights = weights_one_schedule.weights + FROM (SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col}, + {mb_indep_var_col}, + mst_wgh.model_arch::TEXT, + mst_wgh.compile_params::TEXT, + mst_wgh.fit_params::TEXT, + iris.gp_segment_id, + ARRAY{self.seg_ids_train}, + ARRAY{self.images_per_seg_train}, + {self.gpus_per_host}, + {self.segments_per_host}, + mst_wgh.weights::BYTEA, + {is_final}::BOOLEAN + )::BYTEA AS weights, + mst_wgh.mst_key AS mst_key + FROM {self.source_table} iris JOIN mst_wgh + USING (dist_key) + GROUP BY iris.dist_key, mst_wgh.mst_key + ) weights_one_schedule + WHERE {self.model_output_table}.mst_key = weights_one_schedule.mst_key + """.format(mb_dep_var_col=mb_dep_var_col, + mb_indep_var_col=mb_indep_var_col, + is_final=is_final, + self=self + ) + plpy.execute(mlp_uda_query) + plpy.execute("DROP TABLE IF EXISTS {0}".format(mst_wgh)) + return 0 Review comment: why do we need to return 0? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
