reductionista commented on a change in pull request #525:
URL: https://github.com/apache/madlib/pull/525#discussion_r558695581



##########
File path: 
src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
##########
@@ -0,0 +1,845 @@
+m4_include(`SQLCommon.m4')
+m4_changequote(<<<,>>>)
+m4_ifdef(<<<__POSTGRESQL__>>>, -- Skip all fit multiple tests for postgres
+,<<<
+m4_changequote(<!,!>)
+
+-- =================== Setup & Initialization for FitMultiple tests 
========================
+--
+--  For fit multiple, we test end-to-end functionality along with performance 
elsewhere.
+--  They take a long time to run.  Including similar tests here would probably 
not be worth
+--  the extra time added to dev-check.
+--
+--  Instead, we just want to unit test different python functions in the 
FitMultiple class.
+--  However, most of the important behavior we need to test requires access to 
an actual
+--  Greenplum database... mostly, we want to make sure that the models hop 
around to the
+--  right segments in the right order.  Therefore, the unit tests are here, as 
a part of
+--  dev-check. we mock fit_transition() and some validation functions in 
FitMultiple, but
+--  do NOT mock plpy, since most of the code we want to test is embedded SQL 
and needs to
+--  get through to gpdb. We also want to mock the number of segments, so we 
can test what
+--  the model hopping behavior will be for a large cluster, even though 
dev-check should be
+--  able to run on a single dev host.
+
+\i m4_regexp(MODULE_PATHNAME,
+             <!\(.*\)libmadlib\.so!>,
+            
<!\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in!>
+)
+
+-- Mock version() function to convince the InputValidator this is the real 
madlib schema
+CREATE OR REPLACE FUNCTION madlib_installcheck_deep_learning.version() RETURNS 
VARCHAR AS
+$$
+    SELECT MADLIB_SCHEMA.version();
+$$ LANGUAGE sql IMMUTABLE;
+
+-- Call this first to initialize the FitMultiple object, before anything else 
happens.
+-- Pass a real mst table and source table, rest of FitMultipleModel() 
constructor params
+--  are filled in.  They can be overriden later, before test functions are 
called, if necessary.
+CREATE OR REPLACE FUNCTION init_fit_mult(
+    source_table            VARCHAR,
+    model_selection_table   VARCHAR
+) RETURNS VOID AS
+$$
+    import sys
+    from mock import Mock, patch
+
+    
PythonFunctionBodyOnlyNoSchema(deep_learning,madlib_keras_fit_multiple_model)
+    schema_madlib = 'madlib_installcheck_deep_learning'
+
+    GD['fit_mult'] = madlib_keras_fit_multiple_model.FitMultipleModel(
+        schema_madlib,
+        source_table,
+        'orig_model_out',
+        model_selection_table,
+        1
+    )
+    
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_init_schedule(
+    schedule_table VARCHAR
+) RETURNS BOOLEAN AS
+$$
+    fit_mult = GD['fit_mult']
+    fit_mult.schedule_tbl = schedule_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(schedule_table))
+    if fit_mult.init_schedule_tbl():
+        err_msg = None
+    else:
+        err_msg = 'FitMultiple.init_schedule_tbl() returned False'
+
+    return err_msg
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+CREATE OR REPLACE FUNCTION test_rotate_schedule(
+    schedule_table          VARCHAR
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    if fit_mult.schedule_tbl != schedule_table:
+        fit_mult.init_schedule_tbl()
+
+    fit_mult.rotate_schedule_tbl()
+
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+-- Mock fit_transition function, for testing
+--  madlib_keras_fit_multiple_model() python code
+CREATE OR REPLACE FUNCTION 
madlib_installcheck_deep_learning.fit_transition_multiple_model(
+    dependent_var               BYTEA,
+    independent_var             BYTEA,
+    dependent_var_shape         INTEGER[],
+    independent_var_shape       INTEGER[],
+    model_architecture          TEXT,
+    compile_params              TEXT,
+    fit_params                  TEXT,
+    dist_key                    INTEGER,
+    dist_key_mapping            INTEGER[],
+    current_seg_id              INTEGER,
+    segments_per_host           INTEGER,
+    images_per_seg              INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
+    serialized_weights          BYTEA,
+    is_final_training_call      BOOLEAN,
+    use_caching                 BOOLEAN,
+    custom_function_map         BYTEA
+) RETURNS BYTEA AS
+$$
+    param_keys = [ 'compile_params', 'accessible_gpus_for_seg', 
'dependent_var_shape', 'dist_key_mapping',
+                   'current_seg_id', 'segments_per_host', 
'custom_function_map', 'is_final_training_call',
+                   'dist_key', 'serialized_weights', 'images_per_seg', 
'model_architecture', 'fit_params',
+                   'independent_var_shape', 'use_caching' ]
+
+    num_calls = 1
+    if 'transition_function_params' in GD:
+        if dist_key in GD['transition_function_params']:
+            if not 'reset' in GD['transition_function_params'][dist_key]:
+                num_calls = 
GD['transition_function_params'][dist_key]['num_calls']
+                num_calls += 1
+
+    g = globals()
+    params = dict()
+
+    for k in param_keys:
+        params[k] = g[k]
+
+    params['dependent_var'] = len(dependent_var) if dependent_var else 0
+    params['independent_var'] = len(independent_var) if independent_var else 0
+    params['num_calls'] = num_calls
+
+    if not 'transition_function_params' in GD:
+        GD['transition_function_params'] = dict()
+    GD['transition_function_params'][dist_key] = params
+
+    # compute simulated seg_id ( current_seg_id is the actual seg_id )
+    seg_id = dist_key_mapping.index( dist_key )
+
+    if dependent_var_shape and dependent_var_shape[0] * num_calls < 
images_per_seg [ seg_id ]:
+        return None
+    else:
+        GD['transition_function_params'][dist_key]['reset'] = True
+        return serialized_weights
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION validate_transition_function_params(
+    current_seg_id                       INTEGER,
+    segments_per_host                    INTEGER,
+    images_per_seg                       INTEGER[],
+    expected_num_calls                   INTEGER,
+    expected_dist_key                    INTEGER,
+    expected_is_final_training_call      BOOLEAN,
+    expected_dist_key_mapping            INTEGER[],
+    dependent_var_len                    INTEGER,
+    independent_var_len                  INTEGER,
+    use_caching                          BOOLEAN
+) RETURNS TEXT AS
+$$
+    err_msg = "transition function was not called on segment "
+
+    if 'transition_function_params' not in GD:
+        return err_msg.format(current_seg_id)
+    elif expected_dist_key not in GD['transition_function_params']:
+        return err_msg + " for __dist_key__ = {}".format(expected_dist_key)
+    actual = GD['transition_function_params'][expected_dist_key]
+
+    err_msg = """Incorrect value for {} param passed to 
fit_transition_multiple_model:
+       Actual={}, Expected={}"""
+
+    validation_map = {
+        'current_seg_id'         : current_seg_id,
+        'segments_per_host'      : segments_per_host,
+        'num_calls'              : expected_num_calls,
+        'is_final_training_call' : expected_is_final_training_call,
+        'dist_key'               : expected_dist_key,
+        'dependent_var'          : dependent_var_len,
+        'independent_var'        : independent_var_len,
+        'use_caching'            : use_caching
+    }
+
+    for param, expected in validation_map.items():
+        if actual[param] != expected:
+            return err_msg.format(
+                param,
+                actual[param],
+                expected
+            )
+
+    return 'PASS'  # actual params match expected params
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Helper to rotate an array of int's
+CREATE OR REPLACE FUNCTION rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[-1:] + keys[:-1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION reverse_rotate_keys(
+    keys    INTEGER[]
+) RETURNS INTEGER[]
+AS $$
+   return keys[1:] + keys[:1]
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION setup_model_tables(
+    input_table TEXT,
+    output_table TEXT,
+    cached_source_table TEXT
+) RETURNS TEXT AS
+$$ 
+    fit_mult = GD['fit_mult']
+
+    fit_mult.model_input_tbl = input_table
+    fit_mult.model_output_tbl = output_table
+    fit_mult.cached_source_table = cached_source_table
+
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(output_table))
+    plpy.execute('DROP TABLE IF EXISTS {}'.format(cached_source_table))
+    fit_mult.init_model_output_tbl()
+    q = """
+        UPDATE {model_out} -- Reduce size of model for faster tests
+            SET ( model_weights, model_arch, compile_params, fit_params )
+                  = ( mst_key::TEXT::BYTEA,
+                      ( '{{ "a" : ' || mst_key::TEXT || ' }}' )::JSON,
+                      'c' || mst_key::TEXT,
+                      'f' || mst_key::TEXT 
+                    )
+        WHERE mst_key IS NOT NULL;
+    """.format(model_out=fit_mult.model_output_tbl)
+    plpy.execute(q) 
+$$ LANGUAGE plpythonu VOLATILE;
+
+-- Updates dist keys in src table and internal fit_mult class variables
+--    num_data_segs can be larger than actual number of segments, since this
+--    is just for simulated testing.  This will also write to 
expected_distkey_mappings_tbl
+--    which can be used for validating dist key mappings and images per seg 
later.
+CREATE OR REPLACE FUNCTION update_dist_keys(
+    src_table TEXT,
+    num_data_segs INTEGER,
+    num_models INTEGER,
+    expected_distkey_mappings_tbl TEXT
+) RETURNS VOID AS
+$$ 
+    redist_cmd = """
+        UPDATE {src_table}
+            SET __dist_key__ = (buffer_id % {num_data_segs})
+    """.format(**globals())
+    plpy.execute(redist_cmd)
+
+    fit_mult = GD['fit_mult']
+
+    q = """
+        SELECT SUM(independent_var_shape[1]) AS image_count,
+            __dist_key__
+        FROM {src_table}
+        GROUP BY __dist_key__
+        ORDER BY __dist_key__
+    """.format(**globals())
+    res = plpy.execute(q)
+
+    images_per_seg = [ int(r['image_count']) for r in res ]
+    dist_keys = [ int(r['__dist_key__']) for r in res ]
+    num_dist_keys = len(dist_keys)
+
+    fit_mult.source_table = src_table
+    fit_mult.max_dist_key = sorted(dist_keys)[-1]
+    fit_mult.images_per_seg_train = images_per_seg
+    fit_mult.dist_key_mapping = fit_mult.dist_keys = dist_keys
+    fit_mult.accessible_gpus_per_seg = [0] * num_dist_keys
+    fit_mult.segments_per_host = num_data_segs
+
+    fit_mult.msts_for_schedule = fit_mult.msts[:num_models]
+    if num_models < num_dist_keys:
+        fit_mult.msts_for_schedule += [None] * \
+                                 (num_dist_keys - num_models)
+    fit_mult.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                              for mst in fit_mult.msts_for_schedule ]
+    fit_mult.num_msts = num_models
+
+    fit_mult.extra_dist_keys = []
+    for i in range(num_models - num_dist_keys):
+        fit_mult.extra_dist_keys.append(fit_mult.max_dist_key + 1 + i)
+    fit_mult.all_dist_keys = fit_mult.dist_key_mapping + 
fit_mult.extra_dist_keys
+
+    create_distkey_map_tbl_cmd = """
+        DROP TABLE IF EXISTS {exp_table};
+        CREATE TABLE {exp_table} AS
+        SELECT
+            ARRAY(  -- map of dist_keys to seg_ids from source table
+                SELECT __dist_key__
+                FROM {fm.source_table}
+                GROUP BY __dist_key__
+                ORDER BY __dist_key__  -- This would be gp_segment_id if it 
weren't a simulation
+            ) AS expected_dist_key_mapping,
+            ARRAY{fm.images_per_seg_train} AS expected_images_per_seg,
+            {num_data_segs} AS segments_per_host,
+            __dist_key__
+        FROM {fm.source_table}
+        GROUP BY __dist_key__
+        DISTRIBUTED BY (__dist_key__);
+    """.format(
+            fm=fit_mult,
+            num_data_segs=num_data_segs,
+            exp_table=expected_distkey_mappings_tbl
+        )
+    plpy.execute(create_distkey_map_tbl_cmd)
+$$ LANGUAGE plpythonu VOLATILE;
+
+CREATE OR REPLACE FUNCTION test_run_training(
+    source_table TEXT,
+    hop INTEGER,
+    is_very_first_hop BOOLEAN,
+    is_final_training_call BOOLEAN,
+    use_caching BOOLEAN
+) RETURNS VOID AS
+$$
+    fit_mult = GD['fit_mult']
+
+    # Each time we start a new test, clear out stats
+    #   like num_calls from GD so we don't end up validating
+    #   against old results
+    if 'transition_function_params' in GD:
+        del GD['transition_function_params']
+
+    fit_mult.source_tbl = source_table
+    fit_mult.is_very_first_hop = is_very_first_hop
+    fit_mult.is_final_training_call = is_final_training_call
+    if use_caching != fit_mult.use_caching:
+        fit_mult.udf_plan = None  # Otherwise it will execute the wrong
+                                  # query when use_caching changes!
+    fit_mult.use_caching = use_caching
+
+    fit_mult.run_training(hop=hop, is_very_first_hop=is_very_first_hop)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__',MODIFIES SQL DATA);
+
+m4_ifelse(m4_eval(__DBMS_VERSION_MAJOR__ <= 6),1,<!
+
+-- format() function for dynamic SQL was introduced in gpdb6,
+--   so for gpdb5 we need to define a simple version of it
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC TEXT[])
+RETURNS TEXT AS
+$$
+    import re
+    from internal.db_utils import quote_nullable
+    from utilities.utilities import quote_ident
+
+    seps = re.findall(r'%.', s)
+    pieces = re.split(r'%.', s)
+
+    res = pieces[0]
+    for i in range(len(seps)):
+        if vars[i] is None:
+            raise ValueError('Not enough params passed to 
format({})'.format(s))
+        else:
+            c = seps[i][1]
+            if c == '%':
+                quoted = '%'
+            elif c == 's':
+                quoted = vars[i]
+            elif c == 'L':
+                quoted = quote_nullable(vars[i])
+            elif c == 'I':
+                quoted = quote_ident(vars[i])
+            else:
+                raise ValueError('{} in format({}) not 
recognized'.format(seps[i],s))
+            res += quoted
+        res += pieces[i+1]
+    return res
+$$ LANGUAGE plpythonu IMMUTABLE;
+
+CREATE OR REPLACE FUNCTION format(s TEXT,
+                                  v1 INTEGER[],
+                                  v2 INTEGER[],
+                                  v3 TEXT,
+                                  v4 TEXT,
+                                  v5 INTEGER[],
+                                  v6 TEXT,
+                                  v7 INTEGER[]
+)
+RETURNS TEXT AS
+$$
+    SELECT format($1,
+               $2::TEXT,
+               $3::TEXT,
+               $4,
+               $5,
+               $6::TEXT,
+               $7,
+               $8::TEXT
+        );
+$$ LANGUAGE SQL;
+
+CREATE OR REPLACE FUNCTION format(s TEXT, vars VARIADIC ANYARRAY)
+RETURNS TEXT AS
+$$
+    --SELECT format($1, $2::TEXT[])
+    SELECT $2;
+$$ LANGUAGE sql IMMUTABLE;
+
+!>) -- m4_endif DBMS_VERSION_MAJOR <= 5
+
+CREATE OR REPLACE FUNCTION validate_mst_key_order(output_tbl TEXT, 
expected_tbl TEXT)
+RETURNS VOID AS
+$$
+DECLARE
+    actual INTEGER[];
+    expected INTEGER[];
+    res RECORD;
+BEGIN
+    EXECUTE format(
+        'SELECT ARRAY(' ||
+            'SELECT mst_key FROM %I ' ||
+                'ORDER BY __dist_key__)',
+        output_tbl
+    ) INTO actual;
+
+    EXECUTE format(
+        'SELECT mst_keys FROM %I',
+        expected_tbl
+    ) INTO expected;
+
+    EXECUTE format(

Review comment:
       You're right, this is simpler and more readable.
   
   I probably would have just done it this way initially, if I'd realized gpdb5 
didn't have format().  After a series of refactors, it ended up a lot more 
complicated than it needed to be.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to