kaknikhil commented on a change in pull request #451: DL: Update evaluate and
predict for multi model outputs
URL: https://github.com/apache/madlib/pull/451#discussion_r336617343
##########
File path:
src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
##########
@@ -306,6 +308,54 @@ SELECT assert(trap_error($TRAP$madlib_keras_predict(
0);$TRAP$) = 1,
'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should
have failed.');
+-- Test multi model
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+ 'model_arch',
+ 'mst_table',
+ ARRAY[1],
+ ARRAY[
+
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
+ $$loss='categorical_crossentropy',
optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
+
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=5,epochs=1$$,
+ $$batch_size=10,epochs=1$$
+ ]
+);
+
+DROP TABLE if exists cifar_10_multiple_model, cifar_10_multiple_model_summary,
+ cifar_10_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+ 'cifar_10_sample_batched',
+ 'cifar_10_multiple_model',
+ 'mst_table',
+ 6,
+ 0
+);
+
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_evaluate('cifar_10_multiple_model', 'cifar_10_sample_val',
'evaluate_out', 0, 1);
Review comment:
we don't have to call evaluate here
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services