makemebitter commented on a change in pull request #425: DL: Add training for
multiple models
URL: https://github.com/apache/madlib/pull/425#discussion_r311304915
##########
File path: src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
##########
@@ -1179,3 +1180,93 @@ SELECT assert(
abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
'Transfer learning test failed because training loss and metrics don''t
match the expected value.')
FROM iris_model_first_run AS first, iris_model_transfer_summary AS second;
+
+-- Multiple models test
+DROP TABLE if exists iris_data_packed_dist, iris_data_packed_dist_summary;
+CREATE TABLE iris_data_packed_dist AS
+ SELECT *, (row_number() over())%3 AS dist_key FROM iris_data_packed;
+CREATE TABLE iris_data_packed_dist_summary AS SELECT * FROM
iris_data_packed_summary;
+
+DROP TABLE IF EXISTS mst_table;
+CREATE TABLE mst_table (mst_key INTEGER,
+ model_arch_id INTEGER,
+ compile_params VARCHAR,
+ fit_params VARCHAR,
+ unique (model_arch_id, compile_params,
fit_params));
+INSERT INTO mst_table(mst_key,
+ model_arch_id,
+ compile_params,
+ fit_params)
+ VALUES (1, 1,
+ 'loss=''categorical_crossentropy'', optimizer=''Adam(lr=0.01)'',
metrics=[''accuracy'']',
+ 'batch_size=16, epochs=1'),
+ (2, 1,
+ 'loss=''categorical_crossentropy'',
optimizer=''Adam(lr=0.001)'', metrics=[''accuracy'']',
+ 'batch_size=16, epochs=1'),
+ (3, 1,
+ 'loss=''categorical_crossentropy'',
optimizer=''Adam(lr=0.0001)'', metrics=[''accuracy'']',
+ 'batch_size=16, epochs=1');
+
+CREATE FUNCTION test_mult_models()
+RETURNS VOID AS $$
+begin
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary,
iris_multiple_model_info;
+
+PERFORM madlib_keras_fit_multiple_model(
+ 'iris_data_packed_dist',
+ 'iris_multiple_model',
+ 'iris_model_arch',
+ 'mst_table',
+ 2,
+ 0
+);
+
+PERFORM assert(
+ model_arch_table = 'iris_model_arch' AND
+ model_info = 'iris_multiple_model_info' AND
+ source_table = 'iris_data_packed_dist' AND
+ model = 'iris_multiple_model' AND
+ dependent_varname = 'class_text' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 2 AND
+ num_classes = 3 AND
+ class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+ normalizing_const = 1,
+ 'Keras Fit Multiple Output Summary Validation failed. Actual:' ||
__to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+PERFORM assert(
+ model_arch_id = 1 AND
+ model_type = 'madlib_keras' AND
+ model_size > 0 AND
+ fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
+ metrics_type = '{accuracy}' AND
+ training_metrics_final >= 0 AND
+ training_loss_final >= 0 AND
+ array_upper(training_metrics, 1) = 2 AND
+ array_upper(training_loss, 1) = 2 AND
+ array_upper(metrics_elapsed_time, 1) = 2,
+ 'Keras Fit Multiple Output Info Validation failed. Actual:' ||
__to_char(info))
+FROM (SELECT * FROM iris_multiple_model_info) info;
Review comment:
As assert() is a function, I'm assuming it's executed on all three rows?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services