reductionista commented on a change in pull request #516:
URL: https://github.com/apache/madlib/pull/516#discussion_r489805349
##########
File path:
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
##########
@@ -375,6 +375,39 @@ SELECT assert(
'Keras Fit Multiple Output Summary Validation failed when user passes
in 1-hot encoded label vector. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary,
iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_one_hot_encoded_packed',
+ 'iris_multiple_model',
+ 'mst_table_4row',
+ 3,
+ FALSE, NULL, NULL, NULL, NULL, NULL,
+ TRUE
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table is NULL AND
+ model_info = 'iris_multiple_model_info' AND
+ source_table = 'iris_data_one_hot_encoded_packed' AND
+ model = 'iris_multiple_model' AND
+ model_selection_table = 'mst_table_4row' AND
+ object_table IS NULL AND
+ dependent_varname = 'class_one_hot_encoded' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ dependent_vartype = 'integer[]' AND
+ num_classes = NULL AND
+ class_values = NULL AND
+ normalizing_const = 1 AND
+ metrics_iters = ARRAY[3],
+ 'Keras Fit Multiple Output Summary Validation failed when user passes
in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
Review comment:
Many of these tests look like they have a lot in common with the
non-caching tests. It will be much easier to maintain if we make a common
function and call it once for caching=True and once for caching=False.
##########
File path:
src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
##########
@@ -580,6 +741,44 @@ FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy',
optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+-- Test with caching when number of configs(4) larger than number of
segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary,
iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table_4row',
+ 3,
+ FALSE, NULL, NULL, NULL, NULL, NULL,
+ TRUE
+);
Review comment:
Will these tests pass if the cluster has more than 3 segments (or less)?
We don't know what platform future MADlib contributors might want to test
this on, or what sort of pipelines it will need to pass in.
##########
File path:
src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
##########
@@ -1506,13 +1507,17 @@ CREATE OR REPLACE FUNCTION
MADLIB_SCHEMA.fit_transition_multiple_model(
segments_per_host INTEGER,
images_per_seg INTEGER[],
use_gpus BOOLEAN,
- accessible_gpus_for_seg INTEGER[],
+ accessible_gpus_for_seg INTEGER[],
prev_serialized_weights BYTEA,
- is_final_iteration BOOLEAN,
+ is_final_training_call BOOLEAN,
+ use_caching BOOLEAN,
custom_function_map BYTEA
) RETURNS BYTEA AS $$
PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
- return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+ if use_caching:
+ return madlib_keras.fit_multiple_transition(**globals())
Review comment:
I notice the SQL function `fit_transition_multiple_model` calls the
python function `fit_multiple_transition` if caching is enabled, otherwise
`fit_transition`. I think we could improve the names here to indicate what's
going on.
Maybe rename the SQL function to `fit_multiple_transition` and then the new
python function can be `fit_multiple_transition_caching`? (Then if we ever add
a caching version of regular fit it could be `fit_transition_caching`.)
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -523,14 +523,101 @@ def fit_transition(state, dependent_var,
independent_var, dependent_var_shape,
images_per_seg)
is_last_row = agg_image_count == total_images
return_state = get_state_to_return(segment_model, is_last_row,
is_multiple_model,
- agg_image_count, total_images)
+ agg_image_count, total_images)
if is_last_row:
if is_final_iteration or is_multiple_model:
SD_STORE.clear_SD(SD)
clear_keras_session(sess)
return return_state
+def fit_multiple_transition(state, dependent_var, independent_var,
dependent_var_shape,
+ independent_var_shape, model_architecture,
+ compile_params, fit_params, dist_key,
dist_key_mapping,
+ current_seg_id, segments_per_host,
images_per_seg, use_gpus,
+ accessible_gpus_for_seg, prev_serialized_weights,
+ is_final_training_call, custom_function_map=None,
**kwargs):
+ """
+ This transition function is called when caching is called for
+ madlib_keras_fit_multiple_model().
+ The input params: dependent_var, independent_var are passed in
+ as None and dependent_var_shape, independent_var_shape as [0]
+ for all hops except the very firt hop
+ Some things to note in this function are:
+ - prev_serialized_weights can be passed in as None for the
+ very first hop and the final training call
+ - x_train, y_train and cache_set is cleared from SD for
+ final_training_call = TRUE
+ """
+ if not state:
+ agg_image_count = 0
+ else:
+ agg_image_count = float(state)
+
+ SD = kwargs['SD']
+ is_cache_set = 'cache_set' in SD
+
+ # Prepare the data
+ if is_cache_set:
+ if 'x_train' not in SD or 'y_train' not in SD:
+ plpy.error("cache not populated properly.")
+ total_images = None
+ is_last_row = True
+ else:
+ if not independent_var or not dependent_var:
+ return state
+ if 'x_train' not in SD:
+ SD['x_train'] = list()
+ SD['y_train'] = list()
+ agg_image_count += independent_var_shape[0]
+ total_images =
get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+ images_per_seg)
+ is_last_row = agg_image_count == total_images
+ if is_last_row:
+ SD['cache_set'] = True
+ x_train_current = np_array_float32(independent_var,
independent_var_shape)
+ y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+ SD['x_train'].append(x_train_current)
+ SD['y_train'].append(y_train_current)
+
+ # Passed in weights can be None. Irrespective of the weights, we want to
populate the cache for the very first hop.
+ # But if the weights are None, we do not want to set any model. So early
return in that case
+ if prev_serialized_weights is None:
+ if is_final_training_call:
+ del SD['x_train']
+ del SD['y_train']
+ del SD['cache_set']
+ return float(agg_image_count)
Review comment:
I thought this function was supposed to return the state (a list), why
is it returning a float?
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -523,14 +523,101 @@ def fit_transition(state, dependent_var,
independent_var, dependent_var_shape,
images_per_seg)
is_last_row = agg_image_count == total_images
return_state = get_state_to_return(segment_model, is_last_row,
is_multiple_model,
- agg_image_count, total_images)
+ agg_image_count, total_images)
if is_last_row:
if is_final_iteration or is_multiple_model:
SD_STORE.clear_SD(SD)
clear_keras_session(sess)
return return_state
+def fit_multiple_transition(state, dependent_var, independent_var,
dependent_var_shape,
+ independent_var_shape, model_architecture,
+ compile_params, fit_params, dist_key,
dist_key_mapping,
+ current_seg_id, segments_per_host,
images_per_seg, use_gpus,
+ accessible_gpus_for_seg, prev_serialized_weights,
+ is_final_training_call, custom_function_map=None,
**kwargs):
+ """
+ This transition function is called when caching is called for
+ madlib_keras_fit_multiple_model().
+ The input params: dependent_var, independent_var are passed in
+ as None and dependent_var_shape, independent_var_shape as [0]
+ for all hops except the very firt hop
Review comment:
minor typo
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]