This is an automated email from the ASF dual-hosted git repository. domino pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 885104d62738877ae0e98c4eccdf2e3c22f5fe68 Author: Domino Valdano <[email protected]> AuthorDate: Wed Nov 18 15:37:04 2020 -0800 DL: Add unit tests for model hopping in FitMultiple This is a set of "unit tests" implemented as a dev-check test, since they need DB access to work. Functions unit tested in deep_learning/tests/madlib_keras_fit_multiple.sql_in: FitMultiple.init_schedule() FitMultiple.rotate_schedule() FitMultiple.run_training() - Tests schedule creation and rotation - Tests that models are hopping to right segments in right order - Tests caching, # msts > segs, = , and # mst < segs Uses simulated segments instead of actual segments so the same test can be run on any size cluster. - 6 models on 3 segments - 3 models on 3 segments - 3 models on 5 segments --- .../test/madlib_keras_fit_multiple.sql_in | 786 +++++++++++++++++++++ .../test/madlib_keras_iris.setup.sql_in | 8 + src/ports/postgres/modules/internal/db_utils.py_in | 8 + .../postgres/modules/utilities/validate_args.py_in | 5 +- 4 files changed, 804 insertions(+), 3 deletions(-) diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in new file mode 100644 index 0000000..a7bd3fc --- /dev/null +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in @@ -0,0 +1,786 @@ +m4_include(`SQLCommon.m4') +m4_changequote(<<<,>>>) +m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres +,<<< +m4_changequote(<!,!>) + +-- =================== Setup & Initialization for FitMultiple tests ======================== +-- +-- For fit multiple, we test end-to-end functionality along with performance elsewhere. +-- They take a long time to run. Including similar tests here would probably not be worth +-- the extra time added to dev-check. +-- +-- Instead, we just want to unit test different python functions in the FitMultiple class. +-- However, most of the important behavior we need to test requires access to an actual +-- Greenplum database... mostly, we want to make sure that the models hop around to the +-- right segments in the right order. Therefore, the unit tests are here, as a part of +-- dev-check. we mock fit_transition() and some validation functions in FitMultiple, but +-- do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to +-- get through to gpdb. We also want to mock the number of segments, so we can test what +-- the model hopping behavior will be for a large cluster, even though dev-check should be +-- able to run on a single dev host. + +\i m4_regexp(MODULE_PATHNAME, + <!\(.*\)libmadlib\.so!>, + <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!> +) + +-- Mock version() function to convince the InputValidator this is the real madlib schema +CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS +$$ + SELECT MADLIB_SCHEMA.version(); +$$ LANGUAGE sql IMMUTABLE; + +-- Call this first to initialize the FitMultiple object, before anything else happens. +-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params +-- are filled in. They can be overriden later, before test functions are called, if necessary. +CREATE OR REPLACE FUNCTION init_fit_mult( + source_table VARCHAR, + model_selection_table VARCHAR +) RETURNS VOID AS +$$ + import sys + from mock import Mock, patch + + PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model) + schema_madlib = 'madlib_installcheck_deep_learning' + + GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel( + schema_madlib, + source_table, + 'orig_model_out', + model_selection_table, + 1 + ) + +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA); + +CREATE OR REPLACE FUNCTION test_init_schedule( + schedule_table VARCHAR +) RETURNS BOOLEAN AS +$$ + fit_mult = GD['fit_mult'] + fit_mult.schedule_tbl = schedule_table + + plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table)) + if fit_mult.init_schedule_tbl(): + err_msg = None + else: + err_msg = 'FitMultiple.init_schedule_tbl() returned False' + + return err_msg +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); + +CREATE OR REPLACE FUNCTION test_rotate_schedule( + schedule_table VARCHAR +) RETURNS VOID AS +$$ + fit_mult = GD['fit_mult'] + + if fit_mult.schedule_tbl != schedule_table: + fit_mult.init_schedule_tbl() + + fit_mult.rotate_schedule_tbl() + +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); + +-- Mock fit_transition function, for testing +-- madlib_keras_fit_multiple_model() python code +CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model( + dependent_var BYTEA, + independent_var BYTEA, + dependent_var_shape INTEGER[], + independent_var_shape INTEGER[], + model_architecture TEXT, + compile_params TEXT, + fit_params TEXT, + dist_key INTEGER, + dist_key_mapping INTEGER[], + current_seg_id INTEGER, + segments_per_host INTEGER, + images_per_seg INTEGER[], + accessible_gpus_for_seg INTEGER[], + serialized_weights BYTEA, + is_final_training_call BOOLEAN, + use_caching BOOLEAN, + custom_function_map BYTEA +) RETURNS BYTEA AS +$$ + param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping', + 'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call', + 'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params', + 'independent_var_shape', 'use_caching' ] + + num_calls = 1 + if 'transition_function_params' in GD: + if dist_key in GD['transition_function_params']: + if not 'reset' in GD['transition_function_params'][dist_key]: + num_calls = GD['transition_function_params'][dist_key]['num_calls'] + num_calls += 1 + + g = globals() + params = dict() + + for k in param_keys: + params[k] = g[k] + + params['dependent_var'] = len(dependent_var) if dependent_var else 0 + params['independent_var'] = len(independent_var) if independent_var else 0 + params['num_calls'] = num_calls + + if not 'transition_function_params' in GD: + GD['transition_function_params'] = dict() + GD['transition_function_params'][dist_key] = params + + # compute simulated seg_id ( current_seg_id is the actual seg_id ) + seg_id = dist_key_mapping.index( dist_key ) + + if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]: + return None + else: + GD['transition_function_params'][dist_key]['reset'] = True + return serialized_weights +$$ LANGUAGE plpythonu VOLATILE; + +CREATE OR REPLACE FUNCTION validate_transition_function_params( + current_seg_id INTEGER, + segments_per_host INTEGER, + images_per_seg INTEGER[], + expected_num_calls INTEGER, + expected_dist_key INTEGER, + expected_is_final_training_call BOOLEAN, + expected_dist_key_mapping INTEGER[], + dependent_var_len INTEGER, + independent_var_len INTEGER, + use_caching BOOLEAN +) RETURNS TEXT AS +$$ + err_msg = "transition function was not called on segment {}".format(current_seg_id) + + if 'transition_function_params' not in GD: + return err_msg + elif expected_dist_key not in GD['transition_function_params']: + return err_msg + " for __dist_key__ = {}".format(expected_dist_key) + actual = GD['transition_function_params'][expected_dist_key] + + err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model: + Actual={}, Expected={}""" + + validation_map = { + 'current_seg_id' : current_seg_id, + 'segments_per_host' : segments_per_host, + 'num_calls' : expected_num_calls, + 'is_final_training_call' : expected_is_final_training_call, + 'dist_key' : expected_dist_key, + 'dependent_var' : dependent_var_len, + 'independent_var' : independent_var_len, + 'use_caching' : use_caching + } + + for param, expected in validation_map.items(): + if actual[param] != expected: + return err_msg.format( + param, + actual[param], + expected + ) + + return 'PASS' # actual params match expected params +$$ LANGUAGE plpythonu VOLATILE; + +-- Helper to rotate an array of int's +CREATE OR REPLACE FUNCTION rotate_keys( + keys INTEGER[] +) RETURNS INTEGER[] +AS $$ + return keys[-1:] + keys[:-1] +$$ LANGUAGE plpythonu IMMUTABLE; + +CREATE OR REPLACE FUNCTION reverse_rotate_keys( + keys INTEGER[] +) RETURNS INTEGER[] +AS $$ + return keys[1:] + keys[:1] +$$ LANGUAGE plpythonu IMMUTABLE; + +CREATE OR REPLACE FUNCTION setup_model_tables( + input_table TEXT, + output_table TEXT, + cached_source_table TEXT +) RETURNS TEXT AS +$$ + fit_mult = GD['fit_mult'] + + fit_mult.model_input_tbl = input_table + fit_mult.model_output_tbl = output_table + fit_mult.cached_source_table = cached_source_table + + plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table)) + plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table)) + fit_mult.init_model_output_tbl() + q = """ + UPDATE {model_out} -- Reduce size of model for faster tests + SET ( model_weights, model_arch, compile_params, fit_params ) + = ( mst_key::TEXT::BYTEA, + ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON, + 'c' || mst_key::TEXT, + 'f' || mst_key::TEXT + ) + WHERE mst_key IS NOT NULL; + """.format(model_out=fit_mult.model_output_tbl) + plpy.execute(q) +$$ LANGUAGE plpythonu VOLATILE; + +-- Updates dist keys in src table and internal fit_mult class variables +-- num_data_segs can be larger than actual number of segments, since this +-- is just for simulated testing. This will also write to expected_distkey_mappings_tbl +-- which can be used for validating dist key mappings and images per seg later. +CREATE OR REPLACE FUNCTION update_dist_keys( + src_table TEXT, + num_data_segs INTEGER, + num_models INTEGER, + expected_distkey_mappings_tbl TEXT +) RETURNS VOID AS +$$ + redist_cmd = """ + UPDATE {src_table} + SET __dist_key__ = (buffer_id % {num_data_segs}) + """.format(**globals()) + plpy.execute(redist_cmd) + + fit_mult = GD['fit_mult'] + + q = """ + SELECT SUM(independent_var_shape[1]) AS image_count, + __dist_key__ + FROM {src_table} + GROUP BY __dist_key__ + ORDER BY __dist_key__ + """.format(**globals()) + res = plpy.execute(q) + + images_per_seg = [ int(r['image_count']) for r in res ] + dist_keys = [ int(r['__dist_key__']) for r in res ] + num_dist_keys = len(dist_keys) + + fit_mult.source_table = src_table + fit_mult.max_dist_key = sorted(dist_keys)[-1] + fit_mult.images_per_seg_train = images_per_seg + fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys + fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys + fit_mult.segments_per_host = num_data_segs + + fit_mult.msts_for_schedule = fit_mult.msts[:num_models] + if num_models < num_dist_keys: + fit_mult.msts_for_schedule += [None] * \ + (num_dist_keys - num_models) + fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\ + for mst in fit_mult.msts_for_schedule ] + fit_mult.num_msts = num_models + + fit_mult.extra_dist_keys = [] + for i in range(num_models - num_dist_keys): + fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i) + fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys + + create_distkey_map_tbl_cmd = """ + DROP TABLE IF EXISTS {exp_table}; + CREATE TABLE {exp_table} AS + SELECT + ARRAY( -- map of dist_keys to seg_ids from source table + SELECT __dist_key__ + FROM {fm.source_table} + GROUP BY __dist_key__ + ORDER BY __dist_key__ -- This would be gp_segment_id if it weren't a simulation + ) AS expected_dist_key_mapping, + ARRAY{fm.images_per_seg_train} AS expected_images_per_seg, + {num_data_segs} AS segments_per_host, + __dist_key__ + FROM {fm.source_table} + GROUP BY __dist_key__ + DISTRIBUTED BY (__dist_key__); + """.format( + fm=fit_mult, + num_data_segs=num_data_segs, + exp_table=expected_distkey_mappings_tbl + ) + plpy.execute(create_distkey_map_tbl_cmd) +$$ LANGUAGE plpythonu VOLATILE; + +CREATE OR REPLACE FUNCTION test_run_training( + source_table TEXT, + hop INTEGER, + is_very_first_hop BOOLEAN, + is_final_training_call BOOLEAN, + use_caching BOOLEAN +) RETURNS VOID AS +$$ + fit_mult = GD['fit_mult'] + + # Each time we start a new test, clear out stats + # like num_calls from GD so we don't end up validating + # against old results + if 'transition_function_params' in GD: + del GD['transition_function_params'] + + fit_mult.source_tbl = source_table + fit_mult.is_very_first_hop = is_very_first_hop + fit_mult.is_final_training_call = is_final_training_call + if use_caching != fit_mult.use_caching: + fit_mult.udf_plan = None # Otherwise it will execute the wrong + # query when use_caching changes! + fit_mult.use_caching = use_caching + + fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); + +CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT) +RETURNS VOID AS +$$ +DECLARE + actual INTEGER[]; + expected INTEGER[]; +BEGIN + EXECUTE 'SELECT ARRAY(' || + 'SELECT mst_key FROM ' || output_tbl || ' ORDER BY __dist_key__)' + INTO actual; + + EXECUTE 'SELECT mst_keys FROM ' || expected_tbl + INTO expected; + + PERFORM assert( + actual = expected, + 'mst keys found in wrong order / wrong segments!' || + E'\nActual: ' || actual::text || + E'\nExpected: ' || expected::text + ); +END +$$ LANGUAGE PLpgSQL VOLATILE; + +-- Create mst table +DROP TABLE IF EXISTS iris_mst_table, iris_mst_table_summary; +SELECT load_model_selection_table( + 'iris_model_arch', + 'iris_mst_table', + ARRAY[1], + ARRAY[ + $$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$, + $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$, + $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$ + ], + ARRAY[ + $$batch_size=5,epochs=1$$, + $$batch_size=10,epochs=1$$ + ] +); + +-- Create FitMultiple object for running test functions +SELECT init_fit_mult('iris_data_15buf_packed', 'iris_mst_table'); + +CREATE TABLE src_3segs AS + SELECT * FROM iris_data_15buf_packed + DISTRIBUTED BY (__dist_key__); + +-- Simulate 6 models on 3 segments -- +SELECT update_dist_keys('src_3segs', 3, 6, 'expected_dist_key_mappings'); + +--=== Test init_schedule_tbl() ===-- +-- ==================================================================== +-- =========== Enough setup, now for the actual tests! =============== +-- ==================================================================== + +SELECT test_init_schedule('current_schedule'); +SELECT assert( + s.mst_key IS NOT NULL AND m.mst_key IS NOT NULL, + 'mst_keys in schedule table created by test_init_schedule() does not match keys in mst_table' +) FROM current_schedule s FULL JOIN iris_mst_table m USING (mst_key); + +-- Save order of mst keys in schedule for tracking +DROP TABLE IF EXISTS expected_order; +CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys; + +--=== Test rotate_schedule() ===-- +SELECT test_rotate_schedule('current_schedule'); +UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); +SELECT validate_mst_key_order('current_schedule', 'expected_order'); +UPDATE expected_order SET mst_keys=reverse_rotate_keys(mst_keys); -- Undo for later + +-- Initialize model_output table, and set model_input & cached_src table names +SELECT setup_model_tables('model_input', 'model_output', 'cached_src'); +SELECT validate_mst_key_order('model_output', 'expected_order'); + +-- Order of params in run_training test function (for reference below): +-- +-- test_run_training(src_tbl, hop, is_v_first_hop, is_final_call, use_caching) + +--=== Test first hop of an iteration - no caching (# msts > # segs) ===-- +SELECT test_run_training('src_3segs', 0, False, False, False); + + -- mst_keys should not have moved + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + 3, + s.expected_images_per_seg, + 5, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 12, -- dependent_var length + 32, -- independent_var length + False -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +--=== Test an ordinary hop - no caching (# msts > # segs) ===-- +SELECT test_run_training('src_3segs', 1, False, False, False); +SELECT test_rotate_schedule('current_schedule'); + + -- check that mst keys rotated onto correct segments + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 5, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 12, -- dependent_var length + 32, -- independent_var length + False -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +--=== Test final training hop - no caching (# msts > # segs) ===-- +SELECT test_run_training('src_3segs', 8, False, True, False); + + -- check that mst keys rotated onto correct segments + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 5, -- expected num_calls (per dist_key) + s.__dist_key__, + True, -- expected is_final_training_call + s.expected_dist_key_mapping, + 12, -- dependent_var length + 32, -- independent_var length + False -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +--=== Test very first hop - caching enabled ( # msts > # segs ) ===-- +SELECT test_run_training('src_3segs', 0, True, False, True); + + -- mst_keys should not have moved + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 5, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 12, -- dependent_var length + 32, -- independent_var length + True -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + + -- validate that cached source table was created with proper dist keys + SELECT assert( + c.__dist_key__ IS NOT NULL AND s.__dist_key__ IS NOT NULL, + 'cached src table was not created or dist keys do not match original src table') + FROM cached_src c FULL JOIN (SELECT __dist_key__ FROM src_3segs GROUP BY __dist_key__) s USING(__dist_key__); + +-- Test ordinary hop - caching enabled ( # msts > # segs ) +SELECT test_run_training('src_3segs', 7, False, False, True); + + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 1, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 0, -- dependent_var length + 0, -- independent_var length + True -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +-- Test final training hop - caching enabled ( # msts > # segs ) +SELECT test_run_training('src_3segs', 2, False, True, True); + + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- independent_var & dependent_var should have both been passed as NULL + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 1, -- expected num_calls (per dist_key) + s.__dist_key__, + True, -- expected is_final_training_call + s.expected_dist_key_mapping, + 0, -- dependent_var length + 0, -- independent_var length + True -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +--=== Simulate 3 models on 3 segments ===-- +SELECT update_dist_keys('src_3segs', 3, 3, 'expected_dist_key_mappings'); +DELETE FROM iris_mst_table WHERE ARRAY[mst_key] <@ (SELECT mst_keys FROM expected_order); +SELECT test_init_schedule('current_schedule'); +SELECT assert( + COUNT(*) = 3, + 'Wrong number of mst_keys in schedule table created by test_init_schedule()\n' || + 'Expected: 3\nActual: ' || COUNT(*)::TEXT +) FROM current_schedule; +-- Make sure none of the entries in the schedule table are NULL +-- ( this should only happen for # msts < # segs case ) +SELECT assert( + COUNT(*) = 0, + 'NULL mst_key found in schedule table created by test_init_schedule, even though # msts = # segs' +) FROM current_schedule WHERE mst_key IS NULL; + +-- Save new order of mst keys in schedule for tracking +DROP TABLE IF EXISTS expected_order; +CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys; + +SELECT setup_model_tables('model_input', 'model_output', 'cached_src'); + +SELECT validate_mst_key_order('model_output', 'expected_order'); +SELECT test_rotate_schedule('current_schedule'); + +-- Test ordinary hop - no caching ( # msts = # segs ) +SELECT test_run_training('src_3segs', 2, False, False, False); + + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + SELECT validate_mst_key_order('model_output', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 5, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 12, -- dependent_var length + 32, -- independent_var length + False -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +--=== Simulate 3 models on 5 segments ( # msts < # segs ) ===-- +-- ( by updating dist keys in source table ) +CREATE TABLE src_5segs AS + SELECT * FROM iris_data_15buf_packed + DISTRIBUTED BY (__dist_key__); + +SELECT update_dist_keys('src_5segs', 5, 3, 'expected_dist_key_mappings'); +SELECT test_init_schedule('current_schedule'); +SELECT assert( + COUNT(*) = 2, + 'Wrong number NULL entries in schedule table created by test_init_schedule()\n' || + 'Expected: 2\nActual: ' || COUNT(*)::TEXT +) FROM current_schedule WHERE mst_key IS NULL; + +SELECT assert( + COUNT(*) = 3, + 'Wrong number of non-NULL entries in schedule table created by test_init_schedule()\n' || + 'Expected: 3\nActual: ' || COUNT(*)::TEXT +) FROM current_schedule WHERE mst_key IS NOT NULL; + +-- Save expected mst_key order +DROP TABLE IF EXISTS expected_order; +CREATE TABLE expected_order AS SELECT ARRAY(SELECT mst_key FROM current_schedule ORDER BY __dist_key__) mst_keys; + +-- Initialize model_output table, and set model_input & cached_src table names +SELECT setup_model_tables('model_input', 'model_output', 'cached_src'); + +DROP TABLE IF EXISTS model_output_ext; +CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); +-- Make sure model_output was created with correct mst_key order +SELECT validate_mst_key_order('model_output_ext', 'expected_order'); + +--=== Test very first hop - caching enabled ( # msts < # segs ) ===-- +SELECT test_run_training('src_5segs', 0, True, False, True); +SELECT test_rotate_schedule('current_schedule'); + + -- Verify mst keys did not move + DROP TABLE IF EXISTS model_output_ext; + CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); + SELECT validate_mst_key_order('model_output_ext', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + -- This should generate an Assertion failure if the transition function was not called for + -- any __dist_key__, even if there is no model on that segment. + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 3, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 12, -- dependent_var length + 32, -- independent_var length + True -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +-- Test ordinary hop - caching enabled ( # msts < # segs ) + -- This should generate an Assertion failure if the transition function is + -- called for any __dist_key__ where mst_key is NULL +SELECT test_run_training('src_5segs', 1, False, False, True); + + -- verify mst_keys moved to correct segments + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + DROP TABLE IF EXISTS model_output_ext; + CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); + SELECT validate_mst_key_order('model_output_ext', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 1, -- expected num_calls (per dist_key) + s.__dist_key__, + False, -- expected is_final_training_call + s.expected_dist_key_mapping, + 0, -- dependent_var length + 0, -- independent_var length + True -- use_caching + ) AS res, + s.__dist_key__ -- WHERE clause restricts this check to segments with models + FROM expected_dist_key_mappings s WHERE ARRAY[__dist_key__] <@ ARRAY(SELECT __dist_key__ FROM current_schedule WHERE mst_key IS NOT NULL) + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + + SELECT test_rotate_schedule('current_schedule'); + +-- Test final training hop - caching enabled ( # msts < # segs ) +SELECT test_run_training('src_5segs', 5, False, True, True); + + -- verify mst_keys moved to correct segments + UPDATE expected_order SET mst_keys=rotate_keys(mst_keys); + DROP TABLE IF EXISTS model_output_ext; + CREATE TABLE model_output_ext AS SELECT c.__dist_key__, o.mst_key FROM current_schedule c LEFT JOIN model_output o USING (__dist_key__); + SELECT validate_mst_key_order('model_output_ext', 'expected_order'); + + -- verify transition func was called correct # of times with correct params + -- This should generate an Assertion failure if the transition function was not + -- called for any __dist_key__, even if there is no model on that segment. + DROP TABLE IF EXISTS validate_params_results; + CREATE TABLE validate_params_results AS + SELECT validate_transition_function_params( + s.gp_segment_id, + s.segments_per_host, + s.expected_images_per_seg, + 1, -- expected num_calls (per dist_key) + s.__dist_key__, + True, -- expected is_final_training_call + s.expected_dist_key_mapping, + 0, -- dependent_var length + 0, -- independent_var length + True -- use_caching + ) AS res, + s.__dist_key__ + FROM expected_dist_key_mappings s + DISTRIBUTED BY (__dist_key__); + SELECT assert(res = 'PASS', res) FROM validate_params_results; + +-- We don't want to hide the madlib versions of these for any other +-- test files that run afterwards +DROP FUNCTION madlib_installcheck_deep_learning.version(); +DROP FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model( + dependent_var BYTEA, + independent_var BYTEA, + dependent_var_shape INTEGER[], + independent_var_shape INTEGER[], + model_architecture TEXT, + compile_params TEXT, + fit_params TEXT, + dist_key INTEGER, + dist_key_mapping INTEGER[], + current_seg_id INTEGER, + segments_per_host INTEGER, + images_per_seg INTEGER[], + accessible_gpus_for_seg INTEGER[], + serialized_weights BYTEA, + is_final_training_call BOOLEAN, + use_caching BOOLEAN, + custom_function_map BYTEA +) + +>>> ) -- m4_endif postgres diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in index 389efa3..88011ef 100644 --- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in +++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in @@ -293,3 +293,11 @@ SELECT train_test_split('iris_data', -- Source table NULL, -- Sample without replacement TRUE -- Separate output tables ); + +DROP TABLE IF EXISTS iris_data_15buf_packed, iris_data_15buf_packed_summary; +SELECT training_preprocessor_dl('iris_test', -- Source table + 'iris_data_15buf_packed', -- Output table + 'class_text', -- Dependent variable + 'attributes', -- Independent variable + 2 -- buffer_size (15 buffers) + ); diff --git a/src/ports/postgres/modules/internal/db_utils.py_in b/src/ports/postgres/modules/internal/db_utils.py_in index 90c09ca..d39c845 100644 --- a/src/ports/postgres/modules/internal/db_utils.py_in +++ b/src/ports/postgres/modules/internal/db_utils.py_in @@ -85,6 +85,14 @@ def quote_literal(input_str): input_str=input_str) # ------------------------------------------------------------------------------ +def quote_nullable(input_str): + if input_str is not None: + return quote_literal(input_str) + else: + return 'NULL' + +# ------------------------------------------------------------------------------ + def is_col_1d_array(source_table, col_name): query = """ SELECT array_upper({0}, 2) IS NULL AS n_y diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in index f3d30ea..20a11c2 100644 --- a/src/ports/postgres/modules/utilities/validate_args.py_in +++ b/src/ports/postgres/modules/utilities/validate_args.py_in @@ -51,13 +51,12 @@ def unquote_ident(input_str): return input_str # ------------------------------------------------------------------------- - def quote_ident(input_str): """ Returns input_str with quotes added per Postgres identifier rules. - This function is available via plpy.quote_ident in PG > 9.1. We add this - function for compatibility with Greenplum. + This function is available via plpy.quote_ident in PG > 9.1 and GPDB >= 6.0. + We add this function for compatibility with older versions of Greenplum. If the input_str is a lower case string with characters in [a-z0-9_] then the string is returned as is, else a double quote is added in front and back of the string.
