This is an automated email from the ASF dual-hosted git repository.
njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
The following commit(s) were added to refs/heads/master by this push:
new 36a5a9c DL: Add tests for madlib_keras_predict
36a5a9c is described below
commit 36a5a9c5b26a2a5a05f7d9708e32fa58ca7c82aa
Author: Domino Valdano <[email protected]>
AuthorDate: Tue Apr 2 16:33:05 2019 -0700
DL: Add tests for madlib_keras_predict
dev-check tests:
- Validate table format of predict output
- Test that predicted values are in class value set
- Test that predictions are accurate
- Test to make sure we error out on batched input
Unit tests:
- Only _get_class_label for now. (Tests that correct
class value returned, plus plpy.error on invalid index.)
Other changes:
- Add normalization of training data at beginning of dev-check
test file. (Currently in order for predict to work, image data
must be divided by 255 before training, to match what's done
inside the predict function.)
- Add additional error-checking in _get_class_label so that
unit test passes.
Closes #363
---
.../deep_learning/madlib_keras_predict.py_in | 17 +++++---
.../modules/deep_learning/test/madlib_keras.sql_in | 50 ++++++++++++++++++----
.../test/unit_tests/test_madlib_keras.py_in | 46 ++++++++++++++++++++
3 files changed, 98 insertions(+), 15 deletions(-)
diff --git
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index c484c09..6b1dbc6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -108,12 +108,15 @@ def _get_class_label(class_values, class_index):
scalar. If class_values is None, returns class_index, else returns
class_values[class_index].
"""
- if class_values:
- if class_index < len(class_values):
- return class_values[class_index]
- else:
- plpy.error("Invalid class index {0} returned from Keras predict. "\
- "Index value must be less than {1}".format(
+ if not class_values:
+ return class_index
+ elif class_index != int(class_index):
+ plpy.error("Invalid class index {0} returned from Keras predict. "\
+ "Index value must be an integer".format(
+ class_index))
+ elif class_index < 0 or class_index >= len(class_values):
+ plpy.error("Invalid class index {0} returned from Keras predict. "\
+ "Index value must be less than {1}".format(
class_index, len(class_values)))
else:
- return class_index
+ return class_values[class_index]
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 923d0b6..7865e55 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -33,7 +33,7 @@ copy cifar_10_sample from stdin delimiter '|';
DROP TABLE IF EXISTS cifar_10_sample_batched;
DROP TABLE IF EXISTS cifar_10_sample_batched_summary;
-SELECT
minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x',
2);
+SELECT
minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x',
2, 255);
DROP TABLE IF EXISTS model_arch;
SELECT load_keras_model('model_arch',
@@ -61,10 +61,10 @@ ALTER TABLE model_arch RENAME model_id TO id;
-- Please do not break up the compile_params string
-- It might break the assertion
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
- 'keras_out',
+ 'keras_saved_out',
'dependent_var',
'independent_var',
'model_arch',
@@ -83,7 +83,7 @@ SELECT assert(
end_training_time > start_training_time AND
source_table = 'cifar_10_sample_batched' AND
validation_table = 'cifar_10_sample_batched' AND
- model = 'keras_out' AND
+ model = 'keras_saved_out' AND
dependent_varname = 'dependent_var' AND
independent_varname = 'independent_var' AND
name is NULL AND
@@ -105,9 +105,13 @@ SELECT assert(
array_upper(accuracy_iter_validation, 1) = 3 AND
array_upper(loss_iter_validation, 1) = 3 ,
'Keras model output Summary Validation failed. Actual:' ||
__to_char(summary))
-FROM (SELECT * FROM keras_out_summary) summary;
+FROM (SELECT * FROM keras_saved_out_summary) summary;
-SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed')
FROM (SELECT * FROM keras_out) k;
+SELECT assert(accuracy_validation > 0.9999,
+ 'Validation accuracy after 3 iterations of training is only ' ||
__to_char(100*accuracy) || '%, should have reached 100%')
+ FROM keras_saved_out_summary;
+
+SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed')
FROM (SELECT * FROM keras_saved_out) k;
-- Test for
@@ -161,10 +165,9 @@ FROM (SELECT * FROM keras_out_summary) summary;
SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed')
FROM (SELECT * FROM keras_out) k;
--- Temporary predict test, to be updated as part of another jira
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
- 'keras_out',
+ 'keras_saved_out',
'cifar_10_sample',
'id',
'model_arch',
@@ -172,3 +175,34 @@ SELECT madlib_keras_predict(
'x',
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True),
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
'cifar10_predict');
+
+-- Validate that prediction output table exists and has correct schema
+SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be
INTEGER type')
+ FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+ AND attname = 'id';
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+ 'DOUBLE PRECISION', 'prediction column should be DOUBLE PRECISION type')
+ FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+ AND attname = 'prediction';
+
+-- Validate correct number of rows returned.
+SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have
two rows') FROM cifar10_predict;
+
+-- First test that all values are in set of class values; if this breaks, it's
definitely a problem.
+SELECT assert(prediction in (0,1),'Predicted value not in set of defined class
values for model') FROM cifar10_predict;
+
+-- Then test that each of the two images is correctly predicted. If this
breaks, it's likely a different problem.
+SELECT assert(prediction=0,'Incorrect prediction for first image. Predicted:
' || __to_char(prediction) || ', Expected: 0') FROM cifar10_predict WHERE id=1;
+SELECT assert(prediction=1,'Incorrect prediction for second image. Predicted:
' || __to_char(prediction) || ', Expected: 1') FROM cifar10_predict WHERE id=2;
+
+select assert(trap_error($TRAP$madlib_keras_predict(
+ 'keras_saved_out',
+ 'cifar_10_sample_batched',
+ 'id',
+ 'model_arch',
+ 1,
+ 'x',
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True),
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+ 'cifar10_predict');$TRAP$) = 1,
+ 'Passing batched image table to predict should error out.');
diff --git
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index a66a292..84bad08 100644
---
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -177,6 +177,52 @@ class MadlibKerasFitTestCase(unittest.TestCase):
[0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
'dummy_prev_state', **k))
+
+class MadlibKerasPredictTestCase(unittest.TestCase):
+ def setUp(self):
+ self.plpy_mock = Mock(spec='error')
+ patches = {
+ 'plpy': plpy
+ }
+
+ self.plpy_mock_execute = MagicMock()
+ plpy.execute = self.plpy_mock_execute
+
+ self.module_patcher = patch.dict('sys.modules', patches)
+ self.module_patcher.start()
+ import madlib_keras_predict
+ self.subject = madlib_keras_predict
+ self.classes = ['train', 'boat', 'car', 'airplane']
+
+ def test_get_class_label(self):
+ # test that index in range returns correct class value
+ self.assertEqual(
+ 'boat',
+ self.subject._get_class_label(self.classes,
+
self.classes.index('boat'))
+ )
+
+ # test that index is returned if class_values param is None
+ self.assertEqual(
+ 5,
+ self.subject._get_class_label(None,5)
+ )
+
+ # test that index too high generates plpy error
+ with self.assertRaises(plpy.PLPYException):
+ self.subject._get_class_label(self.classes, 9)
+
+ # test that index too low generates plpy error
+ with self.assertRaises(plpy.PLPYException):
+ self.subject._get_class_label(self.classes, -1)
+
+ # test that non-integer index generates plpy error
+ with self.assertRaises(plpy.PLPYException):
+ self.subject._get_class_label(self.classes, 4.5)
+
+ def tearDown(self):
+ self.module_patcher.stop()
+
if __name__ == '__main__':
unittest.main()
# ---------------------------------------------------------------------