This is an automated email from the ASF dual-hosted git repository.

njayaram pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 3a4a8e3a4f5688f3af18e09735822edd53376a31
Author: Nandish Jayaram <[email protected]>
AuthorDate: Wed Mar 13 12:04:58 2019 -0700

    Deep Learning: Add unit test cases
    
    This commit adds unit tests for fit transition in the deep learning
    module.
    
    Co-authored-by: Nikhil Kak <[email protected]>
---
 .../convex/test/unit_tests/test_madlib_keras.py_in | 303 +++++++++++++++++++++
 1 file changed, 303 insertions(+)

diff --git 
a/src/ports/postgres/modules/convex/test/unit_tests/test_madlib_keras.py_in 
b/src/ports/postgres/modules/convex/test/unit_tests/test_madlib_keras.py_in
new file mode 100644
index 0000000..4a2691d
--- /dev/null
+++ b/src/ports/postgres/modules/convex/test/unit_tests/test_madlib_keras.py_in
@@ -0,0 +1,303 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import numpy as np
+from os import path
+# Add convex module to the pythonpath.
+sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
+sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
+
+import keras
+from keras.models import *
+from keras.layers import *
+import unittest
+from mock import *
+import plpy_mock as plpy
+from keras.models import *
+from keras.layers import *
+
+
+m4_changequote(`<!', `!>')
+
+class MadlibKerasFitTestCase(unittest.TestCase):
+    def setUp(self):
+        self.plpy_mock = Mock(spec='error')
+        patches = {
+            'plpy': plpy
+        }
+
+        self.plpy_mock_execute = MagicMock()
+        plpy.execute = self.plpy_mock_execute
+
+        self.module_patcher = patch.dict('sys.modules', patches)
+        self.module_patcher.start()
+        import madlib_keras
+        self.subject = madlib_keras
+
+        self.model = Sequential()
+        self.model.add(Conv2D(2, kernel_size=(1, 1), activation='relu',
+                         input_shape=(1,1,1,), padding='same'))
+        self.model.add(Flatten())
+
+        self.compile_params = "'optimizer'=SGD(lr=0.01, decay=1e-6, 
nesterov=True), 'loss'='categorical_crossentropy', 'metrics'=['accuracy']"
+        self.fit_params = "'batch_size'=1, 'epochs'=1"
+        self.model_weights = [3,4,5,6]
+        self.loss = 1.3
+        self.accuracy = 0.34
+        self.all_seg_ids = [0,1,2]
+        self.total_buffers_per_seg = [3,3,3]
+
+    def tearDown(self):
+        self.module_patcher.stop()
+
+    def test_fit_transition_first_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `clear_keras_session`
+        self.subject.K.set_session = Mock()
+        self.subject.clear_keras_session = Mock()
+        buffer_count = 0
+        previous_state = [self.loss, self.accuracy, buffer_count]
+        previous_state.extend(self.model_weights)
+        previous_state = np.array(previous_state, dtype=np.float32)
+
+        k = {'SD': {'buffer_count': buffer_count}}
+        new_model_state = self.subject.fit_transition(
+            None, [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
+            self.model.to_json(), self.compile_params, self.fit_params, False,
+            previous_state.tostring(), **k)
+        buffer_count = np.fromstring(new_model_state, dtype=np.float32)[2]
+        self.assertEqual(1, buffer_count)
+        # set_session must get called ONLY once, when its the first buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the first buffer
+        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(1, k['SD']['buffer_count'])
+        self.assertTrue(k['SD']['segment_model'])
+
+    def test_fit_transition_last_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `clear_keras_session`
+        self.subject.K.set_session = Mock()
+        self.subject.clear_keras_session = Mock()
+
+        buffer_count = 2
+
+        state = [self.loss, self.accuracy, buffer_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', state.tostring())
+        k = {'SD': {'buffer_count': buffer_count}}
+        k['SD']['segment_model'] = self.model
+        new_model_state = self.subject.fit_transition(
+            state.tostring(), [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
+            self.model.to_json(), None, self.fit_params, False, 
'dummy_previous_state', **k)
+
+        buffer_count = np.fromstring(new_model_state, dtype=np.float32)[2]
+        self.assertEqual(3, buffer_count)
+        # set_session must get called ONLY once, when its the first buffer
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the first buffer
+        self.assertEqual(1, self.subject.clear_keras_session.call_count)
+        self.assertEqual(3, k['SD']['buffer_count'])
+
+    def test_fit_transition_middle_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `clear_keras_session`
+        self.subject.K.set_session = Mock()
+        self.subject.clear_keras_session = Mock()
+
+        buffer_count = 1
+
+        state = [self.loss, self.accuracy, buffer_count]
+        state.extend(self.model_weights)
+        state = np.array(state, dtype=np.float32)
+
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', state.tostring())
+        k = {'SD': {'buffer_count': buffer_count}}
+        k['SD']['segment_model'] = self.model
+        new_model_state = self.subject.fit_transition(
+            state.tostring(), [[0.5]] , [0], 1, 2, self.all_seg_ids, 
self.total_buffers_per_seg,
+            self.model.to_json(), None, self.fit_params, False, 
'dummy_previous_state', **k)
+
+        buffer_count = np.fromstring(new_model_state, dtype=np.float32)[2]
+        self.assertEqual(2, buffer_count)
+        # set_session must get called ONLY once, when its the first buffer
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the first buffer
+        self.assertEqual(0, self.subject.clear_keras_session.call_count)
+        self.assertEqual(2, k['SD']['buffer_count'])
+
+    def test_deserialize_weights_merge_null_state_returns_none(self):
+        self.assertEqual(None, self.subject.deserialize_weights_merge(None))
+
+    def test_deserialize_weights_merge_returns_not_none(self):
+        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
+        res = 
self.subject.deserialize_weights_merge(dummy_model_state.tostring())
+        self.assertEqual(0, res[0])
+        self.assertEqual(1, res[1])
+        self.assertEqual(2, res[2])
+        self.assertEqual([3,4,5,6], res[3].tolist())
+
+    def test_deserialize_weights_null_input_returns_none(self):
+        dummy_model_state = np.array([0,1,2,3,4,5,6], dtype=np.float32)
+        self.assertEqual(None, 
self.subject.deserialize_weights(dummy_model_state.tostring(), None))
+        self.assertEqual(None, self.subject.deserialize_weights(None, [1,2,3]))
+        self.assertEqual(None, self.subject.deserialize_weights(None, None))
+
+    def test_deserialize_weights_valid_input_returns_not_none(self):
+        dummy_model_state = np.array([0,1,2,3,4,5], dtype=np.float32)
+        dummy_model_shape = [(2, 1, 1, 1), (1,)]
+        res = self.subject.deserialize_weights(dummy_model_state.tostring(), 
dummy_model_shape)
+        self.assertEqual(0, res[0])
+        self.assertEqual(1, res[1])
+        self.assertEqual(2, res[2])
+        self.assertEqual([[[[3.0]]], [[[4.0]]]], res[3][0].tolist())
+        self.assertEqual([5], res[3][1].tolist())
+
+    def test_deserialize_weights_invalid_input_fails(self):
+        # pass an invalid state with missing model weights
+        invalid_model_state = np.array([0,1,2], dtype=np.float32)
+        dummy_model_shape = [(2, 1, 1, 1), (1,)]
+
+        # we except keras failure(ValueError) because we cannot reshape model 
weights of size 0 into shape (2,2,3,1)
+        with self.assertRaises(ValueError):
+            self.subject.deserialize_weights(invalid_model_state.tostring(), 
dummy_model_shape)
+
+        invalid_model_state = np.array([0,1,2,3,4], dtype=np.float32)
+        dummy_model_shape = [(2, 2, 3, 1), (1,)]
+        # we except keras failure(ValueError) because we cannot reshape model 
weights of size 2 into shape (2,2,3,1)
+        with self.assertRaises(ValueError):
+            self.subject.deserialize_weights(invalid_model_state.tostring(), 
dummy_model_shape)
+
+    def test_deserialize_iteration_state_none_input_returns_none(self):
+        self.assertEqual(None, self.subject.deserialize_iteration_state(None))
+
+    def test_deserialize_iteration_state_returns_valid_output(self):
+        dummy_iteration_state = np.array([0,1,2,3,4,5], dtype=np.float32)
+        res = self.subject.deserialize_iteration_state(
+            dummy_iteration_state.tostring())
+        self.assertEqual(0, res[0])
+        self.assertEqual(1, res[1])
+        self.assertEqual(res[2],
+                         np.array([0,0,0,3,4,5], dtype=np.float32).tostring())
+
+    def test_serialize_weights_none_weights_returns_none(self):
+        res = self.subject.serialize_weights(0,1,2,None)
+        self.assertEqual(None , res)
+
+    def test_serialize_weights_valid_output(self):
+        res = self.subject.serialize_weights(0,1,2,[np.array([1,3]),
+                                                    np.array([4,5])])
+        self.assertEqual(np.array([0,1,2,1,3,4,5], 
dtype=np.float32).tostring(),
+                         res)
+
+    def test_serialize_weights_merge_none_weights_returns_none(self):
+        res = self.subject.serialize_weights_merge(0,1,2,None)
+        self.assertEqual(None , res)
+
+    def test_serialize_weights_merge_valid_output(self):
+        res = self.subject.serialize_weights_merge(0,1,2,np.array([1,3,4,5]))
+        self.assertEqual(np.array([0,1,2,1,3,4,5], 
dtype=np.float32).tostring(),
+                         res)
+
+    def test_get_data_as_np_array_one_image_per_row(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
+                                               {'x': [[5,6]], 'y': 1}]
+        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
+                                                         3)
+        self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
+                         x_res.tolist())
+        self.assertEqual(np.array([[1, 0, 0], [0, 1, 0]]).tolist(),
+                         y_res.tolist())
+
+    def test_get_data_as_np_array_multiple_images_per_row(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2], [3,4]], 'y': 
[0,2]},
+                                               {'x': [[5,6], [7,8]], 'y': 
[1,0]}]
+        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x', 
[1,1,2],
+                                                         3)
+        self.assertEqual(np.array([[[[1,2]]], [[[3,4]]],
+                                   [[[5,6]]], [[[7,8]]]]).tolist(),
+                         x_res.tolist())
+        self.assertEqual(np.array([[1,0,0], [0,0,1] ,
+                                   [0,1,0], [1,0,0]]).tolist(),
+                         y_res.tolist())
+
+    def test_get_data_as_np_array_float_input_shape(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
+                                               {'x': [[5,6]], 'y': 1}]
+        x_res, y_res = self.subject.get_data_as_np_array('foo','y','x',
+                                                         [1.5,1.9,2.3], 3)
+        self.assertEqual(np.array([[[[1, 2]]], [[[5, 6]]]]).tolist(),
+                         x_res.tolist())
+        self.assertEqual(np.array([[1, 0, 0], [0, 1, 0]]).tolist(),
+                         y_res.tolist())
+
+    def test_get_data_as_np_array_invalid_input_shape(self):
+        self.plpy_mock_execute.return_value = [{'x': [[1,2]], 'y': 0},
+                                               {'x': [[5,6]], 'y': 1}]
+        # we except keras failure(ValueError) because we cannot reshape
+        # the input which is of size 2 to input shape of 1,1,3
+        with self.assertRaises(ValueError):
+            self.subject.get_data_as_np_array('foo','y','x', [1,1,3], 3)
+
+    def test_get_device_name_for_keras(self):
+        import os
+        self.assertEqual('/gpu:0', 
self.subject.get_device_name_for_keras(True, 1, 3))
+        self.assertEqual('/cpu:0', 
self.subject.get_device_name_for_keras(False, 1, 3))
+        self.assertEqual('-1', os.environ["CUDA_VISIBLE_DEVICES"] )
+
+    def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
+        k = {}
+        self.assertEqual('dummy_state',
+            self.subject.fit_transition('dummy_state', None , [0], 1, 2,
+            [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
+            'dummy_prev_state', **k))
+        self.assertEqual('dummy_state',
+            self.subject.fit_transition('dummy_state', [[0.5]], None, 1, 2,
+            [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
+            'dummy_prev_state', **k))
+        self.assertEqual('dummy_state',
+            self.subject.fit_transition('dummy_state', None, None, 1, 2,
+            [0,1,2], [3,3,3], 'dummy_model_json', "foo", "bar", False,
+            'dummy_prev_state', **k))
+
+    def test_validate_input_shapes_shapes_do_not_match(self):
+        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
+        with self.assertRaises(plpy.PLPYException):
+            self.subject._validate_input_shapes('foo', 'bar', [32,32,3])
+
+        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 
32}]
+        with self.assertRaises(plpy.PLPYException):
+            self.subject._validate_input_shapes('foo', 'bar', [32,32,3])
+
+        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': 
None}]
+        with self.assertRaises(plpy.PLPYException):
+            self.subject._validate_input_shapes('foo', 'bar', [3,32])
+
+    def test_validate_input_shapes_shapes_match(self):
+        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 
3}]
+        self.subject._validate_input_shapes('foo', 'bar', [32,32,3])
+
+if __name__ == '__main__':
+    unittest.main()
+# ---------------------------------------------------------------------

Reply via email to