This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git


The following commit(s) were added to refs/heads/master by this push:
     new 36a5a9c  DL: Add tests for madlib_keras_predict
36a5a9c is described below

commit 36a5a9c5b26a2a5a05f7d9708e32fa58ca7c82aa
Author: Domino Valdano <[email protected]>
AuthorDate: Tue Apr 2 16:33:05 2019 -0700

    DL: Add tests for madlib_keras_predict
    
    dev-check tests:
    - Validate table format of predict output
    - Test that predicted values are in class value set
    - Test that predictions are accurate
    - Test to make sure we error out on batched input
    
    Unit tests:
    - Only _get_class_label for now.  (Tests that correct
      class value returned, plus plpy.error on invalid index.)
    
    Other changes:
    - Add normalization of training data at beginning of dev-check
      test file.  (Currently in order for predict to work, image data
      must be divided by 255 before training, to match what's done
      inside the predict function.)
    
    - Add additional error-checking in _get_class_label so that
      unit test passes.
    
    Closes #363
---
 .../deep_learning/madlib_keras_predict.py_in       | 17 +++++---
 .../modules/deep_learning/test/madlib_keras.sql_in | 50 ++++++++++++++++++----
 .../test/unit_tests/test_madlib_keras.py_in        | 46 ++++++++++++++++++++
 3 files changed, 98 insertions(+), 15 deletions(-)

diff --git 
a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in 
b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index c484c09..6b1dbc6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -108,12 +108,15 @@ def _get_class_label(class_values, class_index):
         scalar. If class_values is None, returns class_index, else returns
         class_values[class_index].
     """
-    if class_values:
-        if class_index < len(class_values):
-            return class_values[class_index]
-        else:
-            plpy.error("Invalid class index {0} returned from Keras predict. "\
-                "Index value must be less than {1}".format(
+    if not class_values:
+        return class_index
+    elif class_index != int(class_index):
+        plpy.error("Invalid class index {0} returned from Keras predict. "\
+                   "Index value must be an integer".format(
+                    class_index))
+    elif class_index < 0 or class_index >= len(class_values):
+        plpy.error("Invalid class index {0} returned from Keras predict. "\
+                   "Index value must be less than {1}".format(
                     class_index, len(class_values)))
     else:
-        return class_index
+        return class_values[class_index]
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in 
b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
index 923d0b6..7865e55 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras.sql_in
@@ -33,7 +33,7 @@ copy cifar_10_sample from stdin delimiter '|';
 
 DROP TABLE IF EXISTS cifar_10_sample_batched;
 DROP TABLE IF EXISTS cifar_10_sample_batched_summary;
-SELECT 
minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x', 
2);
+SELECT 
minibatch_preprocessor_dl('cifar_10_sample','cifar_10_sample_batched','y','x', 
2, 255);
 
 DROP TABLE IF EXISTS model_arch;
 SELECT load_keras_model('model_arch',
@@ -61,10 +61,10 @@ ALTER TABLE model_arch RENAME model_id TO id;
 -- Please do not break up the compile_params string
 -- It might break the assertion
 
-DROP TABLE IF EXISTS keras_out, keras_out_summary;
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
 SELECT madlib_keras_fit(
     'cifar_10_sample_batched',
-    'keras_out',
+    'keras_saved_out',
     'dependent_var',
     'independent_var',
     'model_arch',
@@ -83,7 +83,7 @@ SELECT assert(
         end_training_time > start_training_time AND
         source_table = 'cifar_10_sample_batched' AND
         validation_table = 'cifar_10_sample_batched' AND
-        model = 'keras_out' AND
+        model = 'keras_saved_out' AND
         dependent_varname = 'dependent_var' AND
         independent_varname = 'independent_var' AND
         name is NULL AND
@@ -105,9 +105,13 @@ SELECT assert(
         array_upper(accuracy_iter_validation, 1) = 3 AND
         array_upper(loss_iter_validation, 1) = 3 ,
         'Keras model output Summary Validation failed. Actual:' || 
__to_char(summary))
-FROM (SELECT * FROM keras_out_summary) summary;
+FROM (SELECT * FROM keras_saved_out_summary) summary;
 
-SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
+SELECT assert(accuracy_validation > 0.9999,
+    'Validation accuracy after 3 iterations of training is only ' || 
__to_char(100*accuracy) || '%, should have reached 100%')
+    FROM keras_saved_out_summary;
+
+SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_saved_out) k;
 
 
 -- Test for
@@ -161,10 +165,9 @@ FROM (SELECT * FROM keras_out_summary) summary;
 
 SELECT assert(model_data IS NOT NULL , 'Keras model output validation failed') 
FROM (SELECT * FROM keras_out) k;
 
--- Temporary predict test, to be updated as part of another jira
 DROP TABLE IF EXISTS cifar10_predict;
 SELECT madlib_keras_predict(
-    'keras_out',
+    'keras_saved_out',
     'cifar_10_sample',
     'id',
     'model_arch',
@@ -172,3 +175,34 @@ SELECT madlib_keras_predict(
     'x',
     $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     'cifar10_predict');
+
+-- Validate that prediction output table exists and has correct schema
+SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be 
INTEGER type')
+    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+        AND attname = 'id';
+
+SELECT assert(UPPER(atttypid::regtype::TEXT) =
+    'DOUBLE PRECISION', 'prediction column should be DOUBLE PRECISION type')
+    FROM pg_attribute WHERE attrelid = 'cifar10_predict'::regclass
+        AND attname = 'prediction';
+
+-- Validate correct number of rows returned.
+SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have 
two rows') FROM cifar10_predict;
+
+-- First test that all values are in set of class values; if this breaks, it's 
definitely a problem.
+SELECT assert(prediction in (0,1),'Predicted value not in set of defined class 
values for model') FROM cifar10_predict;
+
+-- Then test that each of the two images is correctly predicted.  If this 
breaks, it's likely a different problem.
+SELECT assert(prediction=0,'Incorrect prediction for first image.  Predicted: 
' || __to_char(prediction) || ', Expected: 0') FROM cifar10_predict WHERE id=1;
+SELECT assert(prediction=1,'Incorrect prediction for second image.  Predicted: 
' || __to_char(prediction) || ', Expected: 1') FROM cifar10_predict WHERE id=2;
+
+select assert(trap_error($TRAP$madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample_batched',
+    'id',
+    'model_arch',
+    1,
+    'x',
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), 
loss='categorical_crossentropy', metrics=['accuracy']$$::text,
+    'cifar10_predict');$TRAP$) = 1,
+    'Passing batched image table to predict should error out.');
diff --git 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index a66a292..84bad08 100644
--- 
a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ 
b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -177,6 +177,52 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
             'dummy_prev_state', **k))
 
+
+class MadlibKerasPredictTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import madlib_keras_predict
+        self.subject = madlib_keras_predict
+        self.classes = ['train', 'boat', 'car', 'airplane']
+
+    def test_get_class_label(self):
+        # test that index in range returns correct class value
+        self.assertEqual(
+                           'boat',
+                            self.subject._get_class_label(self.classes,
+                                                          
self.classes.index('boat'))
+                        )
+
+        # test that index is returned if class_values param is None
+        self.assertEqual(
+                            5,
+                            self.subject._get_class_label(None,5)
+                        )
+
+        # test that index too high generates plpy error
+        with self.assertRaises(plpy.PLPYException):
+            self.subject._get_class_label(self.classes, 9)
+
+        # test that index too low generates plpy error
+        with self.assertRaises(plpy.PLPYException):
+            self.subject._get_class_label(self.classes, -1)
+
+        # test that non-integer index generates plpy error
+        with self.assertRaises(plpy.PLPYException):
+            self.subject._get_class_label(self.classes, 4.5)
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
 if __name__ == '__main__':
     unittest.main()
 # ---------------------------------------------------------------------

Reply via email to