reductionista commented on a change in pull request #525: URL: https://github.com/apache/madlib/pull/525#discussion_r558695581
########## File path: src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in ########## @@ -0,0 +1,845 @@ +m4_include(`SQLCommon.m4') +m4_changequote(<<<,>>>) +m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres +,<<< +m4_changequote(<!,!>) + +-- =================== Setup & Initialization for FitMultiple tests ======================== +-- +-- For fit multiple, we test end-to-end functionality along with performance elsewhere. +-- They take a long time to run. Including similar tests here would probably not be worth +-- the extra time added to dev-check. +-- +-- Instead, we just want to unit test different python functions in the FitMultiple class. +-- However, most of the important behavior we need to test requires access to an actual +-- Greenplum database... mostly, we want to make sure that the models hop around to the +-- right segments in the right order. Therefore, the unit tests are here, as a part of +-- dev-check. we mock fit_transition() and some validation functions in FitMultiple, but +-- do NOT mock plpy, since most of the code we want to test is embedded SQL and needs to +-- get through to gpdb. We also want to mock the number of segments, so we can test what +-- the model hopping behavior will be for a large cluster, even though dev-check should be +-- able to run on a single dev host. + +\i m4_regexp(MODULE_PATHNAME, + <!\(.*\)libmadlib\.so!>, + <!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!> +) + +-- Mock version() function to convince the InputValidator this is the real madlib schema +CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS VARCHAR AS +$$ + SELECT MADLIB_SCHEMA.version(); +$$ LANGUAGE sql IMMUTABLE; + +-- Call this first to initialize the FitMultiple object, before anything else happens. +-- Pass a real mst table and source table, rest of FitMultipleModel() constructor params +-- are filled in. They can be overriden later, before test functions are called, if necessary. +CREATE OR REPLACE FUNCTION init_fit_mult( + source_table VARCHAR, + model_selection_table VARCHAR +) RETURNS VOID AS +$$ + import sys + from mock import Mock, patch + + PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model) + schema_madlib = 'madlib_installcheck_deep_learning' + + GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel( + schema_madlib, + source_table, + 'orig_model_out', + model_selection_table, + 1 + ) + +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA); + +CREATE OR REPLACE FUNCTION test_init_schedule( + schedule_table VARCHAR +) RETURNS BOOLEAN AS +$$ + fit_mult = GD['fit_mult'] + fit_mult.schedule_tbl = schedule_table + + plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table)) + if fit_mult.init_schedule_tbl(): + err_msg = None + else: + err_msg = 'FitMultiple.init_schedule_tbl() returned False' + + return err_msg +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); + +CREATE OR REPLACE FUNCTION test_rotate_schedule( + schedule_table VARCHAR +) RETURNS VOID AS +$$ + fit_mult = GD['fit_mult'] + + if fit_mult.schedule_tbl != schedule_table: + fit_mult.init_schedule_tbl() + + fit_mult.rotate_schedule_tbl() + +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); + +-- Mock fit_transition function, for testing +-- madlib_keras_fit_multiple_model() python code +CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.fit_transition_multiple_model( + dependent_var BYTEA, + independent_var BYTEA, + dependent_var_shape INTEGER[], + independent_var_shape INTEGER[], + model_architecture TEXT, + compile_params TEXT, + fit_params TEXT, + dist_key INTEGER, + dist_key_mapping INTEGER[], + current_seg_id INTEGER, + segments_per_host INTEGER, + images_per_seg INTEGER[], + accessible_gpus_for_seg INTEGER[], + serialized_weights BYTEA, + is_final_training_call BOOLEAN, + use_caching BOOLEAN, + custom_function_map BYTEA +) RETURNS BYTEA AS +$$ + param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 'dependent_var_shape', 'dist_key_mapping', + 'current_seg_id', 'segments_per_host', 'custom_function_map', 'is_final_training_call', + 'dist_key', 'serialized_weights', 'images_per_seg', 'model_architecture', 'fit_params', + 'independent_var_shape', 'use_caching' ] + + num_calls = 1 + if 'transition_function_params' in GD: + if dist_key in GD['transition_function_params']: + if not 'reset' in GD['transition_function_params'][dist_key]: + num_calls = GD['transition_function_params'][dist_key]['num_calls'] + num_calls += 1 + + g = globals() + params = dict() + + for k in param_keys: + params[k] = g[k] + + params['dependent_var'] = len(dependent_var) if dependent_var else 0 + params['independent_var'] = len(independent_var) if independent_var else 0 + params['num_calls'] = num_calls + + if not 'transition_function_params' in GD: + GD['transition_function_params'] = dict() + GD['transition_function_params'][dist_key] = params + + # compute simulated seg_id ( current_seg_id is the actual seg_id ) + seg_id = dist_key_mapping.index( dist_key ) + + if dependent_var_shape and dependent_var_shape[0] * num_calls < images_per_seg [ seg_id ]: + return None + else: + GD['transition_function_params'][dist_key]['reset'] = True + return serialized_weights +$$ LANGUAGE plpythonu VOLATILE; + +CREATE OR REPLACE FUNCTION validate_transition_function_params( + current_seg_id INTEGER, + segments_per_host INTEGER, + images_per_seg INTEGER[], + expected_num_calls INTEGER, + expected_dist_key INTEGER, + expected_is_final_training_call BOOLEAN, + expected_dist_key_mapping INTEGER[], + dependent_var_len INTEGER, + independent_var_len INTEGER, + use_caching BOOLEAN +) RETURNS TEXT AS +$$ + err_msg = "transition function was not called on segment " + + if 'transition_function_params' not in GD: + return err_msg.format(current_seg_id) + elif expected_dist_key not in GD['transition_function_params']: + return err_msg + " for __dist_key__ = {}".format(expected_dist_key) + actual = GD['transition_function_params'][expected_dist_key] + + err_msg = """Incorrect value for {} param passed to fit_transition_multiple_model: + Actual={}, Expected={}""" + + validation_map = { + 'current_seg_id' : current_seg_id, + 'segments_per_host' : segments_per_host, + 'num_calls' : expected_num_calls, + 'is_final_training_call' : expected_is_final_training_call, + 'dist_key' : expected_dist_key, + 'dependent_var' : dependent_var_len, + 'independent_var' : independent_var_len, + 'use_caching' : use_caching + } + + for param, expected in validation_map.items(): + if actual[param] != expected: + return err_msg.format( + param, + actual[param], + expected + ) + + return 'PASS' # actual params match expected params +$$ LANGUAGE plpythonu VOLATILE; + +-- Helper to rotate an array of int's +CREATE OR REPLACE FUNCTION rotate_keys( + keys INTEGER[] +) RETURNS INTEGER[] +AS $$ + return keys[-1:] + keys[:-1] +$$ LANGUAGE plpythonu IMMUTABLE; + +CREATE OR REPLACE FUNCTION reverse_rotate_keys( + keys INTEGER[] +) RETURNS INTEGER[] +AS $$ + return keys[1:] + keys[:1] +$$ LANGUAGE plpythonu IMMUTABLE; + +CREATE OR REPLACE FUNCTION setup_model_tables( + input_table TEXT, + output_table TEXT, + cached_source_table TEXT +) RETURNS TEXT AS +$$ + fit_mult = GD['fit_mult'] + + fit_mult.model_input_tbl = input_table + fit_mult.model_output_tbl = output_table + fit_mult.cached_source_table = cached_source_table + + plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table)) + plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table)) + fit_mult.init_model_output_tbl() + q = """ + UPDATE {model_out} -- Reduce size of model for faster tests + SET ( model_weights, model_arch, compile_params, fit_params ) + = ( mst_key::TEXT::BYTEA, + ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON, + 'c' || mst_key::TEXT, + 'f' || mst_key::TEXT + ) + WHERE mst_key IS NOT NULL; + """.format(model_out=fit_mult.model_output_tbl) + plpy.execute(q) +$$ LANGUAGE plpythonu VOLATILE; + +-- Updates dist keys in src table and internal fit_mult class variables +-- num_data_segs can be larger than actual number of segments, since this +-- is just for simulated testing. This will also write to expected_distkey_mappings_tbl +-- which can be used for validating dist key mappings and images per seg later. +CREATE OR REPLACE FUNCTION update_dist_keys( + src_table TEXT, + num_data_segs INTEGER, + num_models INTEGER, + expected_distkey_mappings_tbl TEXT +) RETURNS VOID AS +$$ + redist_cmd = """ + UPDATE {src_table} + SET __dist_key__ = (buffer_id % {num_data_segs}) + """.format(**globals()) + plpy.execute(redist_cmd) + + fit_mult = GD['fit_mult'] + + q = """ + SELECT SUM(independent_var_shape[1]) AS image_count, + __dist_key__ + FROM {src_table} + GROUP BY __dist_key__ + ORDER BY __dist_key__ + """.format(**globals()) + res = plpy.execute(q) + + images_per_seg = [ int(r['image_count']) for r in res ] + dist_keys = [ int(r['__dist_key__']) for r in res ] + num_dist_keys = len(dist_keys) + + fit_mult.source_table = src_table + fit_mult.max_dist_key = sorted(dist_keys)[-1] + fit_mult.images_per_seg_train = images_per_seg + fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys + fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys + fit_mult.segments_per_host = num_data_segs + + fit_mult.msts_for_schedule = fit_mult.msts[:num_models] + if num_models < num_dist_keys: + fit_mult.msts_for_schedule += [None] * \ + (num_dist_keys - num_models) + fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\ + for mst in fit_mult.msts_for_schedule ] + fit_mult.num_msts = num_models + + fit_mult.extra_dist_keys = [] + for i in range(num_models - num_dist_keys): + fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i) + fit_mult.all_dist_keys = fit_mult.dist_key_mapping + fit_mult.extra_dist_keys + + create_distkey_map_tbl_cmd = """ + DROP TABLE IF EXISTS {exp_table}; + CREATE TABLE {exp_table} AS + SELECT + ARRAY( -- map of dist_keys to seg_ids from source table + SELECT __dist_key__ + FROM {fm.source_table} + GROUP BY __dist_key__ + ORDER BY __dist_key__ -- This would be gp_segment_id if it weren't a simulation + ) AS expected_dist_key_mapping, + ARRAY{fm.images_per_seg_train} AS expected_images_per_seg, + {num_data_segs} AS segments_per_host, + __dist_key__ + FROM {fm.source_table} + GROUP BY __dist_key__ + DISTRIBUTED BY (__dist_key__); + """.format( + fm=fit_mult, + num_data_segs=num_data_segs, + exp_table=expected_distkey_mappings_tbl + ) + plpy.execute(create_distkey_map_tbl_cmd) +$$ LANGUAGE plpythonu VOLATILE; + +CREATE OR REPLACE FUNCTION test_run_training( + source_table TEXT, + hop INTEGER, + is_very_first_hop BOOLEAN, + is_final_training_call BOOLEAN, + use_caching BOOLEAN +) RETURNS VOID AS +$$ + fit_mult = GD['fit_mult'] + + # Each time we start a new test, clear out stats + # like num_calls from GD so we don't end up validating + # against old results + if 'transition_function_params' in GD: + del GD['transition_function_params'] + + fit_mult.source_tbl = source_table + fit_mult.is_very_first_hop = is_very_first_hop + fit_mult.is_final_training_call = is_final_training_call + if use_caching != fit_mult.use_caching: + fit_mult.udf_plan = None # Otherwise it will execute the wrong + # query when use_caching changes! + fit_mult.use_caching = use_caching + + fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop) +$$ LANGUAGE plpythonu VOLATILE +m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA); + +m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<! + +-- format() function for dynamic SQL was introduced in gpdb6, +-- so for gpdb5 we need to define a simple version of it +CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[]) +RETURNS TEXT AS +$$ + import re + from internal.db_utils import quote_nullable + from utilities.utilities import quote_ident + + seps = re.findall(r'%.', s) + pieces = re.split(r'%.', s) + + res = pieces[0] + for i in range(len(seps)): + if vars[i] is None: + raise ValueError('Not enough params passed to format({})'.format(s)) + else: + c = seps[i][1] + if c == '%': + quoted = '%' + elif c == 's': + quoted = vars[i] + elif c == 'L': + quoted = quote_nullable(vars[i]) + elif c == 'I': + quoted = quote_ident(vars[i]) + else: + raise ValueError('{} in format({}) not recognized'.format(seps[i],s)) + res += quoted + res += pieces[i+1] + return res +$$ LANGUAGE plpythonu IMMUTABLE; + +CREATE OR REPLACE FUNCTION format(s TEXT, + v1 INTEGER[], + v2 INTEGER[], + v3 TEXT, + v4 TEXT, + v5 INTEGER[], + v6 TEXT, + v7 INTEGER[] +) +RETURNS TEXT AS +$$ + SELECT format($1, + $2::TEXT, + $3::TEXT, + $4, + $5, + $6::TEXT, + $7, + $8::TEXT + ); +$$ LANGUAGE SQL; + +CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY) +RETURNS TEXT AS +$$ + --SELECT format($1, $2::TEXT[]) + SELECT $2; +$$ LANGUAGE sql IMMUTABLE; + +!>) -- m4_endif DBMS_VERSION_MAJOR <= 5 + +CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, expected_tbl TEXT) +RETURNS VOID AS +$$ +DECLARE + actual INTEGER[]; + expected INTEGER[]; + res RECORD; +BEGIN + EXECUTE format( + 'SELECT ARRAY(' || + 'SELECT mst_key FROM %I ' || + 'ORDER BY __dist_key__)', + output_tbl + ) INTO actual; + + EXECUTE format( + 'SELECT mst_keys FROM %I', + expected_tbl + ) INTO expected; + + EXECUTE format( Review comment: You're right, this is simpler and more readable. I probably would have just done it this way initially, if I'd realized gpdb5 didn't have format(). After a series of refactors, it ended up a lot more complicated than it needed to be. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org