khannaekta commented on a change in pull request #490: DL: Don't include
weights as part of state except for the last row.
URL: https://github.com/apache/madlib/pull/490#discussion_r393870599
##########
File path: src/ports/postgres/modules/deep_learning/madlib_keras.py_in
##########
@@ -507,51 +509,50 @@ def fit_transition(state, dependent_var,
independent_var, dependent_var_shape,
#TODO consider not doing this every time
fit_params = parse_and_validate_fit_params(fit_params)
segment_model.fit(x_train, y_train, **fit_params)
- updated_model_weights = segment_model.get_weights()
# Aggregating number of images, loss and accuracy
agg_image_count += len(x_train)
total_images =
get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
images_per_seg)
is_last_row = agg_image_count == total_images
+ return_state = get_state_to_return(segment_model, is_last_row,
is_multiple_model,
+ agg_image_count, total_images)
if is_last_row:
if is_final_iteration or is_multiple_model:
SD_STORE.clear_SD(SD)
clear_keras_session(sess)
- return get_state_to_return(is_last_row, is_multiple_model, agg_image_count,
- total_images, updated_model_weights)
+ return return_state
-def get_state_to_return(is_last_row, is_multiple_model, agg_image_count,
- total_images, updated_model_weights):
+def get_state_to_return(segment_model, is_last_row, is_multiple_model,
agg_image_count,
+ total_images):
"""
- 1. For model averaging fit_transition, the state always contains the image
count
- as well as the model weights
- 2. For fit multiple transition,
- a. The state that gets passed from one row/buffer (within the same hop)
- to the next needs to have the image_count and model weights.
image_count
- is needed to keep track of the last image for that hop.
- b. Once we get to the last row, the state only needs the model
- weights. This state is the output of the UDA for that hop. We don't
need
- the image_count here because unlike model averaging, model hopper does
- not have a merge function and there is no need to average the weights
- based on the image count.
+ 1. For both model averaging fit_transition and fit multiple transition,
the state
+ only needs to have the image count except for the last row.
+ 1. For model averaging fit_transition, the last row state must always
contains the
Review comment:
numbers out of order
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services