kaknikhil commented on a change in pull request #443: DL: Add training for
multiple models
URL: https://github.com/apache/madlib/pull/443#discussion_r326738826
##########
File path:
src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
##########
@@ -421,3 +421,50 @@ class MstLoaderInputValidator():
output_tbl_valid(self.model_selection_table, self.module_name)
output_tbl_valid(self.model_selection_summary_table, self.module_name)
+
+class FitMultipleInputValidator(FitInputValidator):
+ def __init__(self, source_table, validation_table, output_model_table,
+ model_arch_table, dependent_varname,
+ independent_varname, num_iterations,
+ metrics_compute_frequency, warm_start):
+ super(FitMultipleInputValidator, self).__init__(source_table,
+ validation_table,
+ output_model_table,
+ model_arch_table,
+ None,
+ dependent_varname,
+ independent_varname,
+ num_iterations,
+
metrics_compute_frequency,
+ warm_start)
+
+ def _validate_input_args(self):
+ _assert(self.num_iterations > 0,
+ "{0}: Number of iterations cannot be <
1.".format(self.module_name))
+ _assert(self._is_valid_metrics_compute_frequency(),
+ "{0}: metrics_compute_frequency must be in the range (1 -
{1}).".format(
+ self.module_name, self.num_iterations))
+ input_tbl_valid(self.source_table, self.module_name)
+ input_tbl_valid(self.source_summary_table, self.module_name,
+ error_suffix_str="Please ensure that the source table
({0}) "
+ "has been preprocessed by "
+ "the image
preprocessor.".format(self.source_table))
+ cols_in_tbl_valid(self.source_summary_table, [CLASS_VALUES_COLNAME,
+
NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
+ 'dependent_varname',
'independent_varname'], self.module_name)
+
+ # Source table and validation tables must have the same schema
+ self._validate_input_table(self.source_table)
+ validate_bytea_var_for_minibatch(self.source_table,
+ self.dependent_varname)
+
+ self._validate_validation_table()
+ input_tbl_valid(self.model_arch_table, self.module_name)
+
+ if self.warm_start:
Review comment:
why do we have a variable for warm_start ? We don't really support warm
start yet. Is this a placeholder for future ?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services