This is an automated email from the ASF dual-hosted git repository. njayaram pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 5a2cea7aaf0253d466f13e720f32b9533a2682ad Author: Nandish Jayaram <[email protected]> AuthorDate: Wed Apr 17 11:11:06 2019 -0700 DL: Add Postgres support JIRA: MADLIB-1311 This commit decouples GPDB specific code to enable Postgres support too for Deep Learning modules. There is a weird failure happenning when `import keras` and `array_scalar_mult` are called one after another. We have opened MADLIB-1326 to track this issue. Closes #371 Co-authored-by: Orhan Kislal <[email protected]> --- .../modules/deep_learning/madlib_keras.py_in | 90 +++++++++++++--------- .../modules/deep_learning/madlib_keras.sql_in | 2 +- .../deep_learning/madlib_keras_wrapper.py_in | 6 +- .../modules/deep_learning/test/madlib_keras.sql_in | 39 ++++++++-- 4 files changed, 92 insertions(+), 45 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in index bbbcfb4..8d4e384 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in @@ -46,6 +46,7 @@ from keras_model_arch_table import Format from utilities.model_arch_info import get_input_shape from utilities.model_arch_info import get_num_classes +from utilities.utilities import is_platform_pg from utilities.utilities import madlib_version from utilities.validate_args import get_col_value_and_type @@ -102,39 +103,46 @@ def fit(schema_madlib, source_table, model, dependent_varname, validation_set_provided = bool(validation_table) validation_aggregate_accuracy = []; validation_aggregate_loss = [] - # Compute total buffers on each segment - total_buffers_per_seg = plpy.execute( - """ SELECT gp_segment_id, count(*) AS total_buffers_per_seg - FROM {0} - GROUP BY gp_segment_id - """.format(source_table)) - seg_nums = [int(each_buffer["gp_segment_id"]) - for each_buffer in total_buffers_per_seg] + if is_platform_pg(): + total_buffers_per_seg = plpy.execute( + """ SELECT count(*) AS total_buffers_per_seg + FROM {0} + """.format(source_table)) + seg_nums = "[]::integer[]" + gp_segment_id_col = -1 + else: + # Compute total buffers on each segment + total_buffers_per_seg = plpy.execute( + """ SELECT gp_segment_id, count(*) AS total_buffers_per_seg + FROM {0} + GROUP BY gp_segment_id + """.format(source_table)) + seg_nums = [int(each_buffer["gp_segment_id"]) + for each_buffer in total_buffers_per_seg] + # gp_segment_id is an implicit column in GPDB tables. + gp_segment_id_col = "gp_segment_id" + total_buffers_per_seg = [int(each_buffer["total_buffers_per_seg"]) for each_buffer in total_buffers_per_seg] - # Prepare the SQL for running distributed training via UDA compile_params_to_pass = "$madlib$" + compile_params + "$madlib$" fit_params_to_pass = "$madlib$" + fit_params + "$madlib$" run_training_iteration = plpy.prepare(""" - SELECT {0}.fit_step( - {1}::REAL[], - {2}::SMALLINT[], - gp_segment_id, - {3}::INTEGER, - ARRAY{4}, - ARRAY{5}, - $MAD${6}$MAD$::TEXT, - {7}::TEXT, - {8}::TEXT, - {9}, + SELECT {schema_madlib}.fit_step( + {independent_varname}::REAL[], + {dependent_varname}::SMALLINT[], + {gp_segment_id_col}, + {num_classes}::INTEGER, + ARRAY{seg_nums}, + ARRAY{total_buffers_per_seg}, + $MAD${model_arch}$MAD$::TEXT, + {compile_params_to_pass}::TEXT, + {fit_params_to_pass}::TEXT, + {use_gpu}, $1 ) AS iteration_result - FROM {10} - """.format(schema_madlib, independent_varname, dependent_varname, - num_classes, seg_nums, total_buffers_per_seg, model_arch, - compile_params_to_pass, fit_params_to_pass, - use_gpu, source_table), ["bytea"]) + FROM {source_table} + """.format(**locals()), ["bytea"]) # Define the state for the model and loss/accuracy storage lists model_state = madlib_keras_serializer.serialize_weights( @@ -167,13 +175,10 @@ def fit(schema_madlib, source_table, model, dependent_varname, _, _, _, updated_weights = madlib_keras_serializer.deserialize_weights( model_state, model_shapes) master_model.set_weights(updated_weights) - evaluate_result = get_loss_acc_from_keras_eval(schema_madlib, - validation_table, - dependent_varname, - independent_varname, - compile_params_to_pass, - model_arch, model_state, - use_gpu) + evaluate_result = get_loss_acc_from_keras_eval( + schema_madlib, validation_table, dependent_varname, + independent_varname, compile_params_to_pass, model_arch, + model_state, use_gpu, gp_segment_id_col) if len(evaluate_result) < 2: plpy.error('Calling evaluate on validation data returned < 2 ' 'metrics. Expected metrics are loss and accuracy') @@ -279,7 +284,7 @@ def fit(schema_madlib, source_table, model, dependent_varname, def get_loss_acc_from_keras_eval(schema_madlib, table, dependent_varname, independent_varname, compile_params, model_arch, - model_data, use_gpu): + model_data, use_gpu, gp_segment_id_col): """ This function will call the internal keras evaluate function to get the loss and accuracy of each tuple which then gets averaged to get the final result. @@ -299,7 +304,7 @@ def get_loss_acc_from_keras_eval(schema_madlib, table, dependent_varname, {independent_varname}, $MAD${model_arch}$MAD$, $1, {compile_params}, - {use_gpu}, gp_segment_id)) as loss_acc + {use_gpu}, {gp_segment_id_col})) as loss_acc from {table} ) q""".format(**locals()), ["bytea"]) res = plpy.execute(evaluate_query, [model_data]) @@ -333,7 +338,16 @@ def fit_transition(state, ind_var, dep_var, current_seg_id, num_classes, start_transition = time.time() SD = kwargs['SD'] - + is_pg = False + if current_seg_id == -1: + is_pg = True + if is_pg: + # This is postgres + total_buffers = total_buffers_per_seg[0] + else: + # This is GPDB + total_buffers = total_buffers_per_seg[all_seg_ids.index( + current_seg_id)] # Configure GPUs/CPUs device_name = get_device_name_and_set_cuda_env(use_gpu, current_seg_id) @@ -378,14 +392,16 @@ def fit_transition(state, ind_var, dep_var, current_seg_id, num_classes, with K.tf.device(device_name): updated_weights = segment_model.get_weights() - total_buffers = total_buffers_per_seg[all_seg_ids.index(current_seg_id)] if SD['buffer_count'] == total_buffers: if total_buffers == 0: plpy.error('total buffers is 0') agg_loss /= total_buffers agg_accuracy /= total_buffers - clear_keras_session() + if not is_pg: + # In GPDB, each segment would have a keras session, so clear + # them after the last buffer is processed. + clear_keras_session() new_model_state = madlib_keras_serializer.serialize_weights( agg_loss, agg_accuracy, SD['buffer_count'], updated_weights) diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in index 34bf2c2..295276b 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in @@ -163,7 +163,7 @@ CREATE AGGREGATE MADLIB_SCHEMA.fit_step( )( STYPE=BYTEA, SFUNC=MADLIB_SCHEMA.fit_transition, - PREFUNC=MADLIB_SCHEMA.fit_merge, + m4_ifdef(`__POSTGRESQL__', `', `prefunc=MADLIB_SCHEMA.fit_merge,') FINALFUNC=MADLIB_SCHEMA.fit_final ); diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in index 4dd29d5..2e2250e 100644 --- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in +++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in @@ -40,7 +40,11 @@ def get_device_name_and_set_cuda_env(use_gpu, seg): gpus_per_host = 4 if use_gpu: device_name = '/gpu:0' - os.environ["CUDA_VISIBLE_DEVICES"] = str(seg % gpus_per_host) + if seg == -1: + cuda_visible_dev = ','.join([i for i in range(gpus_per_host)]) + else: + cuda_visible_dev = str(seg % gpus_per_host) + os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_dev else: # cpu only device_name = '/cpu:0' os.environ["CUDA_VISIBLE_DEVICES"] = '-1' diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in index 115af9d..8b68aa9 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in @@ -18,7 +18,6 @@ * under the License. * *//* ---------------------------------------------------------------------*/ - drop table if exists cifar_10_sample; create table cifar_10_sample(id INTEGER, y SMALLINT, imgpath TEXT, x REAL[]); copy cifar_10_sample from stdin delimiter '|'; @@ -36,11 +35,40 @@ copy cifar_10_sample_val from stdin delimiter '|'; -- TODO Calling this function makes keras.fit fail with the exception (investigate later) -- NOTICE: Releasing segworker groups to finish aborting the transaction. -- ERROR: could not connect to segment: initialization of segworker group failed (cdbgang.c:237) ---update cifar_10_sample_val SET independent_var = array_scalar_mult(independent_var::real[], (1/255.0)::real); +-- update cifar_10_sample_val SET independent_var = array_scalar_mult(independent_var::real[], (1/255.0)::real); + +-- Prepare the minibatched data manually instead of calling +-- minibatch_preprocessor_dl since it internally calls array_scalar_mult. +-- Please refer to MADLIB-1326 for more details on the issue. DROP TABLE IF EXISTS cifar_10_sample_batched; +CREATE TABLE cifar_10_sample_batched( + buffer_id smallint, + dependent_var integer[], + independent_var real[]); +copy cifar_10_sample_batched from stdin delimiter '|'; +0|{{0,1},{1,0}}|{{{{0.792157,0.8,0.780392},{0.792157,0.8,0.780392},{0.8,0.807843,0.788235},{0.807843,0.815686,0.796079},{0.815686,0.823529,0.803922},{0.819608,0.827451,0.807843},{0.823529,0.831373,0.811765},{0.831373,0.839216,0.823529},{0.835294,0.843137,0.831373},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.843137,0.85098,0.839216},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.843137},{0.847059,0.854902,0.839216},{0.85098,0.858824,0.839216 [...] +\. + DROP TABLE IF EXISTS cifar_10_sample_batched_summary; -SELECT minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x', 2, 255); +CREATE TABLE cifar_10_sample_batched_summary( + source_table text, + output_table text, + dependent_varname text, + independent_varname text, + dependent_vartype text, + class_values smallint[], + buffer_size integer, + normalizing_const numeric); +INSERT INTO cifar_10_sample_batched_summary values ( + 'cifar_10_sample', + 'cifar_10_sample_batched', + 'y', + 'x', + 'smallint', + ARRAY[0,1], + 2, + 255.0); DROP TABLE IF EXISTS model_arch; SELECT load_keras_model('model_arch', @@ -64,9 +92,8 @@ SELECT load_keras_model('model_arch', {"class_name": "Zeros", "config": {}}, "units": 2, "use_bias": true, "activity_regularizer": null} }], "backend": "tensorflow"}$$); --- Please do not break up the compile_params string --- It might break the assertion - +-- -- Please do not break up the compile_params string +-- -- It might break the assertion DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary; SELECT madlib_keras_fit( 'cifar_10_sample_batched',
