This is an automated email from the ASF dual-hosted git repository. niketanpansare pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 0b9a5fc [SYSTEMML-540] Added tests for comparing Keras2DML output with TF 0b9a5fc is described below commit 0b9a5fc3e44d649efeebee929f679f0188e57134 Author: Niketan Pansare <npan...@us.ibm.com> AuthorDate: Mon Feb 18 11:54:57 2019 -0800 [SYSTEMML-540] Added tests for comparing Keras2DML output with TF - The test framework is generalized to simplify testing of new layers. - The default values in Keras2DML has been updated to match default invocation of Keras. - Added Flatten layer in Caffe2DML. - If the user attempts to use dense layer for 3-D inputs, we now throw an error instead of silently giving a wrong answer. - Fixed a bug in conversion of Conv2D weights. - Also, fixed a bug when a neural network is invoked which has no weights. --- src/main/python/systemml/mllearn/estimators.py | 32 ++- src/main/python/systemml/mllearn/keras2caffe.py | 61 ++-- src/main/python/tests/test_nn_numpy.py | 307 ++++++++++++++------- .../scala/org/apache/sysml/api/dl/CaffeLayer.scala | 18 ++ .../org/apache/sysml/api/dl/CaffeNetwork.scala | 1 + 5 files changed, 285 insertions(+), 134 deletions(-) diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py index a2647f1..2c3b6a2 100644 --- a/src/main/python/systemml/mllearn/estimators.py +++ b/src/main/python/systemml/mllearn/estimators.py @@ -1009,8 +1009,8 @@ class Keras2DML(Caffe2DML): """ - def __init__(self, sparkSession, keras_model, input_shape, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, - batch_size=64, max_iter=2000, test_iter=10, test_interval=500, display=100, lr_policy="step", weight_decay=5e-4, regularization_type="L2"): + def __init__(self, sparkSession, keras_model, input_shape=None, transferUsingDF=False, load_keras_weights=True, weights=None, labels=None, + batch_size=64, max_iter=2000, test_iter=0, test_interval=500, display=100, lr_policy="step", weight_decay=0, regularization_type="L2"): """ Performs training/prediction for a given keras model. @@ -1018,37 +1018,43 @@ class Keras2DML(Caffe2DML): ---------- sparkSession: PySpark SparkSession keras_model: keras model - input_shape: 3-element list (number of channels, input height, input width) + input_shape: 3-element list (number of channels, input height, input width). If not provided, it is inferred from the input shape of the first layer. transferUsingDF: whether to pass the input dataset via PySpark DataFrame (default: False) load_keras_weights: whether to load weights from the keras_model. If False, the weights will be initialized to random value using NN libraries' init method (default: True) weights: directory whether learned weights are stored (default: None) labels: file containing mapping between index and string labels (default: None) batch_size: size of the input batch (default: 64) - max_iter: maximum number of iterations (default: 1) - test_iter: test_iter for caffe solver (default: 10) + max_iter: maximum number of iterations (default: 2000) + test_iter: test_iter for caffe solver (default: 0) test_interval: test_interval for caffe solver (default: 500) display: display for caffe solver (default: 100) lr_policy: learning rate policy for caffe solver (default: "step") - weight_decay: regularation strength (default: 5e-4) + weight_decay: regularation strength (default: 0, recommended: 5e-4) regularization_type: regularization type (default: "L2") """ from .keras2caffe import convertKerasToCaffeNetwork, convertKerasToCaffeSolver, convertKerasToSystemMLModel import tempfile, keras + if keras.backend.image_data_format() != 'channels_first': + raise Exception('The data format ' + str(keras.backend.image_data_format()) + + ' is not supported. Please use keras.backend.set_image_data_format("channels_first")') if isinstance(keras_model, keras.models.Sequential): # Convert the sequential model to functional model if keras_model.model is None: keras_model.build() keras_model = keras_model.model + if input_shape is None: + keras_shape = keras_model.layers[0].input_shape + input_shape = [1, 1, 1] + if len(keras_shape) > 4 or len(keras_shape) <= 1: + raise Exception('Input shape ' + str(keras_shape) + ' is not supported.') + for i in range(len(keras_shape)-1): + input_shape[i] = keras_shape[i+1] # Ignore batch size + elif len(input_shape) > 3 or len(input_shape) == 0: + raise Exception('Input shape ' + str(input_shape) + ' is not supported.') self.name = keras_model.name createJavaObject(sparkSession._sc, 'dummy') if not hasattr(keras_model, 'optimizer'): - keras_model.compile( - loss='categorical_crossentropy', - optimizer=keras.optimizers.SGD( - lr=0.01, - momentum=0.95, - decay=5e-4, - nesterov=True)) + raise Exception('Please compile the model before passing it to Keras2DML') convertKerasToCaffeNetwork( keras_model, self.name + ".proto", diff --git a/src/main/python/systemml/mllearn/keras2caffe.py b/src/main/python/systemml/mllearn/keras2caffe.py index f6d6440..ca0fe3c 100755 --- a/src/main/python/systemml/mllearn/keras2caffe.py +++ b/src/main/python/systemml/mllearn/keras2caffe.py @@ -56,7 +56,8 @@ except ImportError: supportedCaffeActivations = { 'relu': 'ReLU', 'softmax': 'Softmax', - 'sigmoid': 'Sigmoid'} + 'sigmoid': 'Sigmoid' +} supportedLayers = { keras.layers.InputLayer: 'Data', keras.layers.Dense: 'InnerProduct', @@ -70,7 +71,7 @@ supportedLayers = { keras.layers.AveragePooling2D: 'Pooling', keras.layers.SimpleRNN: 'RNN', keras.layers.LSTM: 'LSTM', - keras.layers.Flatten: 'None', + keras.layers.Flatten: 'Flatten', keras.layers.BatchNormalization: 'None', keras.layers.Activation: 'None' } @@ -84,9 +85,10 @@ def _getInboundLayers(layer): for node in inbound_nodes: node_list = node.inbound_layers # get layers pointing to this node in_names = in_names + node_list + return list(in_names) # For Caffe2DML to reroute any use of Flatten layers - return list(chain.from_iterable([_getInboundLayers(l) if isinstance( - l, keras.layers.Flatten) else [l] for l in in_names])) + #return list(chain.from_iterable([_getInboundLayers(l) if isinstance( + # l, keras.layers.Flatten) else [l] for l in in_names])) def _getCompensatedAxis(layer): @@ -160,23 +162,19 @@ def _parseKerasLayer(layer): elif layerType == keras.layers.Activation: return [_parseActivation(layer)] param = layerParamMapping[layerType](layer) - paramName = param.keys()[0] + layerArgs = {} + layerArgs['name'] = layer.name if layerType == keras.layers.InputLayer: - ret = { - 'layer': { - 'name': layer.name, - 'type': 'Data', - paramName: param[paramName], - 'top': layer.name, - 'top': 'label'}} + layerArgs['type'] = 'Data' + layerArgs['top'] = 'label' # layer.name: TODO else: - ret = { - 'layer': { - 'name': layer.name, - 'type': supportedLayers[layerType], - 'bottom': _getBottomLayers(layer), - 'top': layer.name, - paramName: param[paramName]}} + layerArgs['type'] = supportedLayers[layerType] + layerArgs['bottom'] = _getBottomLayers(layer) + layerArgs['top'] = layer.name + if len(param) > 0: + paramName = param.keys()[0] + layerArgs[paramName] = param[paramName] + ret = { 'layer': layerArgs } return [ret, _parseActivation( layer, layer.name + '_activation')] if _shouldParseActivation(layer) else [ret] @@ -193,7 +191,6 @@ def _parseBatchNorm(layer): # The special are redirected to their custom parse function in _parseKerasLayer specialLayers = { - keras.layers.Flatten: lambda x: [], keras.layers.BatchNormalization: _parseBatchNorm } @@ -241,12 +238,18 @@ def getRecurrentParam(layer): layer.return_sequences).lower()} +def getInnerProductParam(layer): + if len(layer.output_shape) != 2: + raise Exception('Only 2-D input is supported for the Dense layer in the current implementation, but found ' + + str(layer.input_shape) + '. Consider adding a Flatten before ' + str(layer.name)) + return {'num_output': layer.units} + # TODO: Update AveragePooling2D when we add maxpooling support layerParamMapping = { keras.layers.InputLayer: lambda l: {'data_param': {'batch_size': l.batch_size}}, keras.layers.Dense: lambda l: - {'inner_product_param': {'num_output': l.units}}, + {'inner_product_param': getInnerProductParam(l)}, keras.layers.Dropout: lambda l: {'dropout_param': {'dropout_ratio': l.rate}}, keras.layers.Add: lambda l: @@ -267,6 +270,7 @@ layerParamMapping = { {'recurrent_param': getRecurrentParam(l)}, keras.layers.LSTM: lambda l: {'recurrent_param': getRecurrentParam(l)}, + keras.layers.Flatten: lambda l: {}, } @@ -475,7 +479,6 @@ def convertKerasToCaffeSolver(kerasModel, caffeNetworkFilePath, outCaffeSolverFi raise Exception( 'Only sgd (with/without momentum/nesterov), Adam and Adagrad supported.') - def getInputMatrices(layer): if isinstance(layer, keras.layers.SimpleRNN): weights = layer.get_weights() @@ -501,6 +504,11 @@ def getInputMatrices(layer): b_c = b[units * 2: units * 3] b_o = b[units * 3:] return [np.vstack((np.hstack((W_i, W_f, W_o, W_c)), np.hstack((U_i, U_f, U_o, U_c)))).reshape((-1, 4*units)), np.hstack((b_i, b_f, b_o, b_c)).reshape((1, -1))] + elif isinstance(layer, keras.layers.Conv2D): + weights = layer.get_weights() + #filter = np.swapaxes(weights[0].T, 2, 3) # convert RSCK => KCRS format + filter = np.swapaxes(np.swapaxes(np.swapaxes(weights[0], 1, 3), 0, 1), 1, 2) + return [ filter.reshape((filter.shape[0], -1)) , getNumPyMatrixFromKerasWeight(weights[1])] else: return [getNumPyMatrixFromKerasWeight( param) for param in layer.get_weights()] @@ -535,6 +543,9 @@ def convertKerasToSystemMLModel(spark, kerasModel, outDirectory): i == 1 and type(layer) in biasToTranspose) else inputMatrices[i] py4j.java_gateway.get_method(script_java, "in")( potentialVar[i], convertToMatrixBlock(sc, mat)) - script_java.setScriptString(''.join(dmlLines)) - ml = sc._jvm.org.apache.sysml.api.mlcontext.MLContext(sc._jsc) - ml.execute(script_java) + script_str = ''.join(dmlLines) + if script_str.strip() != '': + # Only execute if the script is not empty + script_java.setScriptString(script_str) + ml = sc._jvm.org.apache.sysml.api.mlcontext.MLContext(sc._jsc) + ml.execute(script_java) diff --git a/src/main/python/tests/test_nn_numpy.py b/src/main/python/tests/test_nn_numpy.py index 6fd190c..76d9619 100644 --- a/src/main/python/tests/test_nn_numpy.py +++ b/src/main/python/tests/test_nn_numpy.py @@ -21,15 +21,17 @@ #------------------------------------------------------------- # Assumption: pip install keras -# +# # This test validates SystemML's deep learning APIs (Keras2DML, Caffe2DML and nn layer) by comparing the results with that of keras. # # To run: -# - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-memory 10g --driver-class-path SystemML.jar,systemml-*-extra.jar test_nn_numpy.py` +# - Python 2: `PYSPARK_PYTHON=python2 spark-submit --master local[*] --driver-memory 10g --driver-class-path ../../../../target/SystemML.jar,../../../../target/systemml-*-extra.jar test_nn_numpy.py` # - Python 3: `PYSPARK_PYTHON=python3 spark-submit --master local[*] --driver-memory 10g --driver-class-path SystemML.jar,systemml-*-extra.jar test_nn_numpy.py` # Make the `systemml` package importable import os +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ['CUDA_VISIBLE_DEVICES'] = '' import sys path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../") sys.path.insert(0, path) @@ -38,114 +40,227 @@ import unittest import numpy as np from keras.models import Sequential -from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout, Flatten, LSTM, UpSampling2D, SimpleRNN +from keras.layers import Input, Dense, Conv2D, MaxPooling2D, Dropout, Flatten, LSTM, UpSampling2D, SimpleRNN, Activation +from keras.optimizers import SGD from keras import backend as K from keras.models import Model from systemml.mllearn import Keras2DML from pyspark.sql import SparkSession +from pyspark import SparkContext +from keras.utils import np_utils +from scipy import stats +from sklearn.preprocessing import normalize +from operator import mul batch_size = 32 -input_shape = (3,64,64) -K.set_image_data_format("channels_first") +K.set_image_data_format('channels_first') # K.set_image_dim_ordering("th") -keras_tensor = np.random.rand(batch_size,input_shape[0], input_shape[1], input_shape[2]) -sysml_matrix = keras_tensor.reshape((batch_size, -1)) + +def get_tensor(shape, random=True): + if shape[0] is None: + # Use the first dimension is None, use batch size: + shape = list(shape) + shape[0] = batch_size + return (np.random.randint(100, size=shape) + 1) / 100 + tmp_dir = 'tmp_dir' +sc = SparkContext() spark = SparkSession.builder.getOrCreate() +sc.setLogLevel('ERROR') -def are_predictions_all_close(keras_model, rtol=1e-05, atol=1e-08): - sysml_model = Keras2DML(spark, keras_model, input_shape=input_shape, weights=tmp_dir) - keras_preds = keras_model.predict(keras_tensor).flatten() - sysml_preds = sysml_model.predict_proba(sysml_matrix).flatten() - #print(str(keras_preds)) - #print(str(sysml_preds)) - return np.allclose(keras_preds, sysml_preds, rtol=rtol, atol=atol) +def initialize_weights(model): + for l in range(len(model.layers)): + if model.layers[l].get_weights() is not None or len(model.layers[l].get_weights()) > 0: + model.layers[l].set_weights([get_tensor(elem.shape) for elem in model.layers[l].get_weights()]) + return model -class TestNNLibrary(unittest.TestCase): - def test_1layer_upsample_predictions1(self): - keras_model = Sequential() - keras_model.add(UpSampling2D(size=(2, 2), input_shape=input_shape)) - keras_model.add(Flatten()) - keras_model.add(Dense(10, activation='softmax')) - self.failUnless(are_predictions_all_close(keras_model, atol=1e-06)) +def get_input_output_shape(layers): + tmp_keras_model = Sequential() + for layer in layers: + tmp_keras_model.add(layer) + return tmp_keras_model.layers[0].input_shape, tmp_keras_model.layers[-1].output_shape - def test_1layer_upsample_predictions2(self): - keras_model = Sequential() - keras_model.add(UpSampling2D(size=(2, 3), input_shape=input_shape)) - keras_model.add(Flatten()) - keras_model.add(Dense(10, activation='softmax')) - self.failUnless(are_predictions_all_close(keras_model, atol=1e-06)) - - def test_1layer_cnn_predictions(self): - keras_model = Sequential() - keras_model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape, padding='valid')) - keras_model.add(Flatten()) - keras_model.add(Dense(10, activation='softmax')) - self.failUnless(are_predictions_all_close(keras_model)) - - def test_multilayer_cnn_predictions(self): - keras_model = Sequential() - keras_model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape, padding='valid')) - keras_model.add(MaxPooling2D(pool_size=(2, 2))) - keras_model.add(Conv2D(64, (3, 3), activation='relu')) - keras_model.add(MaxPooling2D(pool_size=(2, 2))) +def get_one_hot_encoded_labels(output_shape): + output_cells = reduce(mul, list(output_shape[1:]), 1) + y = np.array(np.random.choice(output_cells, batch_size)) + y[0] = output_cells - 1 + one_hot_labels = np_utils.to_categorical(y, num_classes=output_cells) + return one_hot_labels + +def get_sysml_model(keras_model): + sysml_model = Keras2DML(spark, keras_model, weights=tmp_dir, max_iter=1, batch_size=batch_size) + # For apples-to-apples comparison of output probabilities: + # By performing one-hot encoding outside, we ensure that the ordering of the TF columns + # matches that of SystemML + sysml_model.set(train_algo='batch', perform_one_hot_encoding=False) + # print('Script:' + str(sysml_model.get_training_script())) + return sysml_model + +def base_test(layers, add_dense=False, test_backward=True): + layers = [layers] if not isinstance(layers, list) else layers + in_shape, output_shape = get_input_output_shape(layers) + # -------------------------------------- + # Create Keras model + keras_model = Sequential() + for layer in layers: + keras_model.add(layer) + if len(output_shape) > 2: + # Flatten the last layer activation before feeding it to the softmax loss keras_model.add(Flatten()) - keras_model.add(Dense(256, activation='softmax')) - keras_model.add(Dropout(0.25)) - keras_model.add(Dense(10, activation='softmax')) - self.failUnless(are_predictions_all_close(keras_model)) - - def test_simplernn_predictions1(self): - data_dim = 16 - timesteps = 8 - num_classes = 10 - batch_size = 64 - model = Sequential() - model.add(SimpleRNN(32, return_sequences=False, input_shape=(timesteps, data_dim))) - model.add(Dense(10, activation='softmax')) - x_train = np.random.random((batch_size, timesteps, data_dim)) - y_train = np.random.random((batch_size, num_classes)) - from systemml.mllearn import Keras2DML - sysml_model = Keras2DML(spark, model, input_shape=(timesteps,data_dim,1), weights='weights_dir').set(debug=True) - keras_preds = model.predict(x_train).flatten() - sysml_preds = sysml_model.predict_proba(x_train.reshape((batch_size, -1))).flatten() - self.failUnless(np.allclose(sysml_preds, keras_preds)) - - def test_simplernn_predictions2(self): - data_dim = 16 - timesteps = 8 - num_classes = 10 - batch_size = 100 - model = Sequential() - model.add(SimpleRNN(32, return_sequences=False, input_shape=(timesteps, data_dim))) - model.add(Dense(10, activation='softmax')) - x_train = np.random.random((batch_size, timesteps, data_dim)) - y_train = np.random.random((batch_size, num_classes)) - from systemml.mllearn import Keras2DML - sysml_model = Keras2DML(spark, model, input_shape=(timesteps,data_dim,1), weights='weights_dir').set(debug=True) - keras_preds = model.predict(x_train).flatten() - sysml_preds = sysml_model.predict_proba(x_train.reshape((batch_size, -1))).flatten() - self.failUnless(np.allclose(sysml_preds, keras_preds)) - - def test_lstm_predictions1(self): - data_dim = 32 - timesteps = 8 - num_classes = 10 - batch_size = 64 - w1 = np.random.random((data_dim, 4*data_dim)) - w2 = np.random.random((data_dim, 4*data_dim)) - b = np.zeros(128) - model = Sequential() - model.add(LSTM(32, return_sequences=False, recurrent_activation='sigmoid', input_shape=(timesteps, data_dim), weights=[w1, w2, b])) - model.add(Dense(10, activation='softmax')) - x_train = np.random.random((batch_size, timesteps, data_dim)) - y_train = np.random.random((batch_size, num_classes)) - from systemml.mllearn import Keras2DML - sysml_model = Keras2DML(spark, model, input_shape=(timesteps,data_dim,1), weights='weights_dir').set(debug=True) - keras_preds = model.predict(x_train) - sysml_preds = sysml_model.predict_proba(x_train.reshape((batch_size, -1))) - np.allclose(sysml_preds, keras_preds) + if add_dense: + keras_model.add(Dense(num_labels, activation='softmax')) + else: + keras_model.add(Activation('softmax')) + keras_model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.1, decay=0, momentum=0, nesterov=False)) + # -------------------------------------- + keras_model = initialize_weights(keras_model) + sysml_model = get_sysml_model(keras_model) + keras_tensor = get_tensor(in_shape) + sysml_matrix = keras_tensor.reshape((batch_size, -1)) + #if len(keras_tensor.shape) == 4: + # keras_tensor = np.flip(keras_tensor, 1) + # -------------------------------------- + sysml_preds = sysml_model.predict_proba(sysml_matrix) + if test_backward: + one_hot_labels = get_one_hot_encoded_labels(keras_model.layers[-1].output_shape) + sysml_model.fit(sysml_matrix, one_hot_labels) + sysml_preds = sysml_model.predict_proba(sysml_matrix) + keras_preds = keras_model.predict(keras_tensor) + if test_backward: + keras_model.train_on_batch(keras_tensor, one_hot_labels) + keras_preds = keras_model.predict(keras_tensor) + # -------------------------------------- + if len(output_shape) == 4: + # Flatten doesnot respect channel_first, so reshuffle the dimensions: + keras_preds = keras_preds.reshape((batch_size, output_shape[2], output_shape[3], output_shape[1])) + keras_preds = np.swapaxes(keras_preds, 2, 3) # (h,w,c) -> (h,c,w) + keras_preds = np.swapaxes(keras_preds, 1, 2) # (h,c,w) -> (c,h,w) + elif len(output_shape) > 4: + raise Exception('Unsupported output shape:' + str(output_shape)) + # -------------------------------------- + return sysml_preds, keras_preds, keras_model, output_shape + +def debug_layout(sysml_preds, keras_preds): + for i in range(len(keras_preds.shape)): + print('After flipping along axis=' + str(i) + ' => ' + str(np.allclose(sysml_preds, np.flip(keras_preds, i).flatten()))) + +def test_forward(layers): + sysml_preds, keras_preds, keras_model, output_shape = base_test(layers, test_backward=False) + ret = np.allclose(sysml_preds.flatten(), keras_preds.flatten()) + if not ret: + print('The forward test failed for the model:' + str(keras_model.summary())) + print('SystemML output:' + str(sysml_preds)) + print('Keras output:' + str(keras_preds)) + #debug_layout(sysml_preds.flatten(), + # keras_preds.reshape((-1, output_shape[1], output_shape[2], output_shape[3]))) + return ret + +def test_backward(layers): + sysml_preds, keras_preds, keras_model, output_shape = base_test(layers, test_backward=True) + ret = np.allclose(sysml_preds.flatten(), keras_preds.flatten()) + if not ret: + print('The backward test failed for the model:' + str(keras_model.summary())) + print('SystemML output:' + str(sysml_preds)) + print('Keras output:' + str(keras_preds)) + # debug_layout(sysml_preds.flatten(), + # keras_preds.reshape((-1, output_shape[1], output_shape[2], output_shape[3]))) + return ret + + +class TestNNLibrary(unittest.TestCase): + + def test_dense_forward(self): + self.failUnless(test_forward(Dense(10, input_shape=[30]))) + + def test_dense_backward(self): + self.failUnless(test_backward(Dense(10, input_shape=[30]))) + + def test_lstm_forward1(self): + self.failUnless(test_forward(LSTM(2, return_sequences=True, activation='tanh', stateful=False, recurrent_activation='sigmoid', input_shape=(3, 4)))) + + #def test_lstm_backward1(self): + # self.failUnless(test_backward(LSTM(2, return_sequences=True, activation='tanh', stateful=False, recurrent_activation='sigmoid', input_shape=(3, 4)))) + + def test_lstm_forward2(self): + self.failUnless(test_forward(LSTM(10, return_sequences=False, activation='tanh', stateful=False, recurrent_activation='sigmoid', input_shape=(30, 20)))) + + def test_lstm_backward2(self): + self.failUnless(test_backward(LSTM(10, return_sequences=False, activation='tanh', stateful=False, recurrent_activation='sigmoid', input_shape=(30, 20)))) + + def test_dense_relu_forward(self): + self.failUnless(test_forward(Dense(10, activation='relu', input_shape=[30]))) + + def test_dense_relu_backward(self): + self.failUnless(test_backward(Dense(10, activation='relu', input_shape=[30]))) + + def test_dense_sigmoid_forward(self): + self.failUnless(test_forward(Dense(10, activation='sigmoid', input_shape=[30]))) + + def test_dense_sigmoid_backward(self): + self.failUnless(test_backward(Dense(10, activation='sigmoid', input_shape=[30]))) + + def test_dense_softmax_forward(self): + self.failUnless(test_forward(Dense(10, activation='softmax', input_shape=[30]))) + + def test_dense_softmax_backward(self): + self.failUnless(test_backward(Dense(10, activation='softmax', input_shape=[30]))) + + def test_maxpool2d_forward(self): + self.failUnless(test_forward(MaxPooling2D(pool_size=(2, 2), input_shape=(1, 64, 32)))) + + def test_maxpool2d_backward(self): + self.failUnless(test_backward(MaxPooling2D(pool_size=(2, 2), input_shape=(1, 64, 32)))) + + def test_maxpool2d_multi_channel_forward(self): + self.failUnless(test_forward(MaxPooling2D(pool_size=(2, 2), input_shape=(3, 64, 32)))) + + def test_maxpool2d_multi_channel_backward(self): + self.failUnless(test_backward(MaxPooling2D(pool_size=(2, 2), input_shape=(3, 64, 32)))) + + def test_conv2d_forward_single_channel_input_output(self): + # 1-channel input and output + self.failUnless( + test_forward(Conv2D(1, kernel_size=(3, 3), input_shape=(1, 64, 64), activation='relu', padding='valid'))) + + def test_conv2d_forward_single_channel_input(self): + # 1-channel input + self.failUnless( + test_forward(Conv2D(32, kernel_size=(3, 3), input_shape=(1, 64, 64), activation='relu', padding='valid'))) + + def test_conv2d_forward_single_channel_output(self): + # 1-channel output + self.failUnless( + test_forward(Conv2D(1, kernel_size=(3, 3), input_shape=(3, 64, 64), activation='relu', padding='valid'))) + + def test_conv2d_forward(self): + self.failUnless( + test_forward(Conv2D(32, kernel_size=(3, 3), input_shape=(3, 64, 32), activation='relu', padding='valid'))) + + def test_conv2d_backward_single_channel_input_output(self): + # 1-channel input and output + self.failUnless( + test_backward(Conv2D(1, kernel_size=(3, 3), input_shape=(1, 64, 64), activation='relu', padding='valid'))) + + def test_conv2d_backward_single_channel_input(self): + # 1-channel input + self.failUnless( + test_backward(Conv2D(32, kernel_size=(3, 3), input_shape=(1, 64, 64), activation='relu', padding='valid'))) + + def test_conv2d_backward_single_channel_output(self): + # 1-channel output + self.failUnless( + test_backward(Conv2D(1, kernel_size=(3, 3), input_shape=(3, 64, 64), activation='relu', padding='valid'))) + + def test_conv2d_backward(self): + self.failUnless( + test_backward(Conv2D(32, kernel_size=(3, 3), input_shape=(3, 64, 32), activation='relu', padding='valid'))) + + def test_upsampling_forward(self): + self.failUnless(test_forward(UpSampling2D(size=(2, 2), input_shape=(3, 64, 32)))) + + def test_upsampling_backward(self): + self.failUnless(test_backward(UpSampling2D(size=(2, 2), input_shape=(3, 64, 32)))) if __name__ == '__main__': unittest.main() diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala index b290983..f405fb2 100644 --- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala +++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala @@ -405,6 +405,24 @@ class Elementwise(val param: LayerParameter, val id: Int, val net: CaffeNetwork) override def biasShape(): Array[Int] = null } +class Flatten(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer { + override def sourceFileName = null + override def init(dmlScript: StringBuilder): Unit = {} + override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = assign(dmlScript, out, X) + override def backward(dmlScript: StringBuilder, outSuffix: String): Unit = assignDoutToDX(dmlScript, outSuffix) + override def weightShape(): Array[Int] = null + override def biasShape(): Array[Int] = null + var _childLayers: List[CaffeLayer] = null + var _out: (String, String, String) = null + override def outputShape = { + if (_out == null) { + if (_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList + _out = (int_mult(_childLayers(0).outputShape._1, _childLayers(0).outputShape._2, _childLayers(0).outputShape._3), "1", "1") + } + _out + } +} + class Concat(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer { override def sourceFileName = null override def init(dmlScript: StringBuilder): Unit = {} diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala index 297176f..d3449f3 100644 --- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala +++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala @@ -249,6 +249,7 @@ class CaffeNetwork(netFilePath: String, val currentPhase: Phase, var numChannels case "softmax" => new Softmax(param, id, this) case "rnn" => new RNN(param, id, this) case "lstm" => new LSTM(param, id, this) + case "flatten" => new Flatten(param, id, this) case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported") } }