Repository: systemml
Updated Branches:
  refs/heads/master 45eec2d25 -> 9dc354ac2


[SYSTEMML-540] Added optimizer support in Keras2DML

- Also, updated the documentation.
- Added a controlled error when batch size is not multiple of training
  data points in lstm.
- Added perform_one_hot_encoding flag to deal with non-label data.
- Bug fix for EuclideanLoss layer in Caffe2DML.
- Added regularization support in Caffe2DML.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9dc354ac
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9dc354ac
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9dc354ac

Branch: refs/heads/master
Commit: 9dc354ac2f1ca0378a1ec76317d3df72cfb6f380
Parents: 45eec2d
Author: Niketan Pansare <[email protected]>
Authored: Thu Jan 11 15:14:25 2018 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Thu Jan 11 15:18:21 2018 -0800

----------------------------------------------------------------------
 docs/beginners-guide-keras2dml.md               |  82 ++++++++++++-
 docs/reference-guide-caffe2dml.md               |  29 ++++-
 scripts/nn/layers/lstm.dml                      |   9 +-
 src/main/python/systemml/mllearn/estimators.py  |  28 +++--
 src/main/python/systemml/mllearn/keras2caffe.py | 117 +++++++++++++++----
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  36 ++++--
 .../org/apache/sysml/api/dl/CaffeLayer.scala    |  96 +++++++++++----
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  95 +++++++++++++--
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |   9 +-
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  10 +-
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |   2 +-
 11 files changed, 428 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/docs/beginners-guide-keras2dml.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-keras2dml.md 
b/docs/beginners-guide-keras2dml.md
index fd2af87..c99334e 100644
--- a/docs/beginners-guide-keras2dml.md
+++ b/docs/beginners-guide-keras2dml.md
@@ -53,10 +53,84 @@ from systemml.mllearn import Keras2DML
 import keras
 from keras.applications.resnet50 import preprocess_input, decode_predictions, 
ResNet50
 
-model = 
ResNet50(weights='imagenet',include_top=True,pooling='None',input_shape=(224,224,3))
-model.compile(optimizer='sgd', loss= 'categorical_crossentropy')
+keras_model = 
ResNet50(weights='imagenet',include_top=True,pooling='None',input_shape=(224,224,3))
+keras_model.compile(optimizer='sgd', loss= 'categorical_crossentropy')
 
-resnet = Keras2DML(spark,model,input_shape=(3,224,224))
-resnet.summary()
+sysml_model = Keras2DML(spark, keras_model,input_shape=(3,224,224))
+sysml_model.summary()
 ```
 
+# Frequently asked questions
+
+#### What is the mapping between Keras' parameters and Caffe's solver 
specification ? 
+
+|                                                        | Specified via the 
given parameter in the Keras2DML constructor | From input Keras' model          
                                                       | Corresponding 
parameter in the Caffe solver file |
+|--------------------------------------------------------|----------------------------------------------------------------|-----------------------------------------------------------------------------------------|--------------------------------------------------|
+| Solver type                                            |                     
                                           | `type(keras_model.optimizer)`. 
Supported types: `keras.optimizers.{SGD, Adagrad, Adam}` | `type`               
                            |
+| Maximum number of iterations                           | `max_iter`          
                                           | The `epoch` parameter in the `fit` 
method is not supported.                             | `max_iter`               
                        |
+| Validation dataset                                     | `test_iter` 
(explained in the below section)                   | The `validation_data` 
parameter in the `fit` method is not supported.                   | `test_iter` 
                                     |
+| Monitoring the loss                                    | `display, 
test_interval` (explained in the below section)      | The `LossHistory` 
callback in the `fit` method is not supported.                        | 
`display, test_interval`                         |
+| Learning rate schedule                                 | `lr_policy`         
                                           | The `LearningRateScheduler` 
callback in the `fit` method is not supported.              | `lr_policy` 
(default: step)                      |
+| Base learning rate                                     |                     
                                           | `keras_model.optimizer.lr`         
                                                     | `base_lr`                
                        |
+| Learning rate decay over each update                   |                     
                                           | `keras_model.optimizer.decay`      
                                                     | `gamma`                  
                        |
+| Global regularizer to use for all layers               | 
`regularization_type,weight_decay`                             | The current 
version of Keras2DML doesnot support custom regularizers per layer.         | 
`regularization_type,weight_decay`               |
+| If type of the optimizer is `keras.optimizers.SGD`     |                     
                                           | `momentum, nesterov`               
                                                     | `momentum, type`         
                        |
+| If type of the optimizer is `keras.optimizers.Adam`    |                     
                                           | `beta_1, beta_2, epsilon`. The 
parameter `amsgrad` is not supported.                    | `momentum, 
momentum2, delta`                     |
+| If type of the optimizer is `keras.optimizers.Adagrad` |                     
                                           | `epsilon`                          
                                                     | `delta`                  
                        |
+
+#### How do I specify the batch size and the number of epochs ?
+
+Since Keras2DML is a mllearn API, it doesnot accept the batch size and number 
of epochs as the parameter in the `fit` method.
+Instead, these parameters are passed via `batch_size` and `max_iter` 
parameters in the Keras2DML constructor.
+For example, the equivalent Python code for `keras_model.fit(features, labels, 
epochs=10, batch_size=64)` is as follows:
+
+```python
+from systemml.mllearn import Keras2DML
+epochs = 10
+batch_size = 64
+num_samples = features.shape[0]
+max_iter = int(epochs*math.ceil(num_samples/batch_size))
+sysml_model = Keras2DML(spark, keras_model, batch_size=batch_size, 
max_iter=max_iter, ...)
+sysml_model.fit(features, labels)
+``` 
+
+#### What optimizer and loss does Keras2DML use by default if `keras_model` is 
not compiled ?
+
+If the user does not `compile` the keras model, then we use cross entropy loss 
and SGD optimizer with nesterov momentum:
+
+```python 
+keras_model.compile(loss='categorical_crossentropy', 
optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.95, decay=5e-4, 
nesterov=True))
+```
+
+#### What is the learning rate schedule used ?
+
+Keras2DML does not support the `LearningRateScheduler` callback. 
+Instead one can set the custom learning rate schedule to one of the following 
schedules by using the `lr_policy` parameter of the constructor:
+- `step`: return `base_lr * gamma ^ (floor(iter / step))` (default schedule)
+- `fixed`: always return `base_lr`.
+- `exp`: return `base_lr * gamma ^ iter`
+- `inv`: return `base_lr * (1 + gamma * iter) ^ (- power)`
+- `poly`: the effective learning rate follows a polynomial decay, to be zero 
by the max_iter. return `base_lr (1 - iter/max_iter) ^ (power)`
+- `sigmoid`: the effective learning rate follows a sigmod decay return 
b`ase_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))`
+
+#### How to set the size of the validation dataset ?
+
+The size of the validation dataset is determined by the parameters `test_iter` 
and the batch size. For example: If the batch size is 64 and 
+`test_iter` is set to 10 in the `Keras2DML`'s constructor, then the validation 
size is 640. This setting generates following DML code internally:
+
+```python
+num_images = nrow(y_full)
+BATCH_SIZE = 64
+num_validation = 10 * BATCH_SIZE
+X = X_full[(num_validation+1):num_images,]; y = 
y_full[(num_validation+1):num_images,]
+X_val = X_full[1:num_validation,]; y_val = y_full[1:num_validation,]
+num_images = nrow(y)
+``` 
+
+#### How to monitor loss via command-line ?
+
+To monitor loss, please set the parameters `display`, `test_iter` and 
`test_interval` in the `Keras2DML`'s constructor.  
+For example: for the expression `Keras2DML(..., display=100, test_iter=10, 
test_interval=500)`, we
+- display the training loss and accuracy every 100 iterations and
+- carry out validation every 500 training iterations and display validation 
loss and accuracy.
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/docs/reference-guide-caffe2dml.md
----------------------------------------------------------------------
diff --git a/docs/reference-guide-caffe2dml.md 
b/docs/reference-guide-caffe2dml.md
index be8c078..0e191dd 100644
--- a/docs/reference-guide-caffe2dml.md
+++ b/docs/reference-guide-caffe2dml.md
@@ -578,7 +578,34 @@ The parameter `lr_policy` specifies the learning rate 
decay policy. Caffe2DML su
 - `inv`: return `base_lr * (1 + gamma * iter) ^ (- power)`
 - `poly`: the effective learning rate follows a polynomial decay, to be zero 
by the max_iter. return `base_lr (1 - iter/max_iter) ^ (power)`
 - `sigmoid`: the effective learning rate follows a sigmod decay return 
b`ase_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))`
-      
+
+
+The parameters `base_lr` and  `lr_policy` are required and other parameters 
are optional:
+```
+lr_policy: "step" # learning rate policy: drop the learning rate in "steps"
+                  # by a factor of gamma every stepsize iterations (required)
+base_lr: 0.01     # begin training at a learning rate of 0.01 (required)
+gamma: 0.95       # drop the learning rate by the given factor (optional, 
default value: 0.95)
+stepsize: 100000  # drop the learning rate every 100K iterations (optional, 
default value: 100000)
+power: 0.75       # (optional, default value: 0.75)
+``` 
+
+#### How do I regularize weight matrices in the neural network ?
+
+The user can specify the type of regularization using the parameter 
`regularization_type` in the solver file.
+The valid values are `L2` (default) and `L1`.
+Caffe2DML then invokes the backward function of the layers 
`nn/layers/l2_reg.dml` and `nn/layers/l1_reg.dml` respectively.
+The regularation strength is set using the property `weight_decay` in the 
solver file:
+```
+regularization_type: "L2"
+weight_decay: 5e-4
+```
+
+Like learning rate, you can customize the regularation strength of a given 
layer by specifying the property `decay_mult` in the network file:
+```
+param { lr_mult: 1 decay_mult: 1 }
+```  
+
 #### How to set batch size ?
 
 Batch size is set in `data_param` of the Data layer:

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/scripts/nn/layers/lstm.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/lstm.dml b/scripts/nn/layers/lstm.dml
index 696ed88..664a1e2 100644
--- a/scripts/nn/layers/lstm.dml
+++ b/scripts/nn/layers/lstm.dml
@@ -168,10 +168,11 @@ backward = function(matrix[double] dout, matrix[double] 
dc,
   N = nrow(X)
   M = as.integer(ncol(W)/4)
   N1 = nrow(out0)
-  if(N < N1) {
-    # Allow for smaller out0 for last batch
-    out0 = out0[1:N,]
-    c0 = c0[1:N,]
+  if(N != N1) {
+    # Allow for smaller out0 for last batch 
+    # out0 = out0[1:N,]
+    # c0 = c0[1:N,]
+    stop("Unsupported operation: The batch size of previous iteration " + N1 + 
" is different than the batch size of current iteration " + N)
   }
   dX = matrix(0, rows=N, cols=T*D)
   dW = matrix(0, rows=D+M, cols=4*M)

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py 
b/src/main/python/systemml/mllearn/estimators.py
index f1e793b..72a6f55 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -213,11 +213,13 @@ class BaseSystemMLEstimator(Estimator):
         if y is None:
             return self._fit(X)
         elif y is not None and isinstance(X, SUPPORTED_TYPES) and 
isinstance(y, SUPPORTED_TYPES):
-            y = self.encode(y)
+            # Donot encode if y is a numpy matrix => useful for segmentation
+            skipEncodingY = len(y.shape) == 2 and y.shape[0] != 1 and 
y.shape[1] != 1
+            y = y if skipEncodingY else self.encode(y)
             if self.transferUsingDF:
                 pdfX = convertToPandasDF(X)
                 pdfY = convertToPandasDF(y)
-                if getNumCols(pdfY) != 1:
+                if getNumCols(pdfY) != 1 and not skipEncodingY:
                     raise Exception('y should be a column vector')
                 if pdfX.shape[0] != pdfY.shape[0]:
                     raise Exception('Number of rows of X and y should match')
@@ -227,7 +229,7 @@ class BaseSystemMLEstimator(Estimator):
                 self.fit_df(df)
             else:
                 numColsy = getNumCols(y)
-                if numColsy != 1:
+                if numColsy != 1 and not skipEncodingY:
                     raise Exception('Expected y to be a column vector')
                 self.fit_numpy(X, y)
             if self.setOutputRawPredictionsToFalse:
@@ -842,7 +844,7 @@ class Caffe2DML(BaseSystemMLClassifier):
         if ignore_weights is not None:
             self.estimator.setWeightsToIgnore(ignore_weights)
             
-    def set(self, debug=None, train_algo=None, test_algo=None, 
parallel_batches=None, output_activations=None):
+    def set(self, debug=None, train_algo=None, test_algo=None, 
parallel_batches=None, output_activations=None, perform_one_hot_encoding=None):
         """
         Set input to Caffe2DML
         
@@ -853,12 +855,14 @@ class Caffe2DML(BaseSystemMLClassifier):
         test_algo: can be minibatch, batch, allreduce_parallel_batches or 
allreduce (default: minibatch)
         parallel_batches: number of parallel batches
         output_activations: (developer flag) directory to output activations 
of each layer as csv while prediction. To be used only in batch mode (default: 
None)
+        perform_one_hot_encoding: should perform one-hot encoding in DML using 
table function (default: False)
         """
         if debug is not None: self.estimator.setInput("$debug", 
str(debug).upper())
         if train_algo is not None: self.estimator.setInput("$train_algo", 
str(train_algo).lower())
         if test_algo is not None: self.estimator.setInput("$test_algo", 
str(test_algo).lower())
         if parallel_batches is not None: 
self.estimator.setInput("$parallel_batches", str(parallel_batches))
         if output_activations is not None: 
self.estimator.setInput("$output_activations", str(output_activations))
+        if perform_one_hot_encoding is not None: 
self.estimator.setInput("$perform_one_hot_encoding", 
str(perform_one_hot_encoding).lower())
         return self
     
     def summary(self):
@@ -884,7 +888,7 @@ class Keras2DML(Caffe2DML):
 
     """
 
-    def __init__(self, sparkSession, keras_model, input_shape, 
transferUsingDF=False, weights=None, labels=None):
+    def __init__(self, sparkSession, keras_model, input_shape, 
transferUsingDF=False, weights=None, labels=None, batch_size=64, max_iter=2000, 
test_iter=10, test_interval=500, display=100, lr_policy="step", 
weight_decay=5e-4, regularization_type="L2"):
         """
         Performs training/prediction for a given keras model.
 
@@ -896,6 +900,14 @@ class Keras2DML(Caffe2DML):
         transferUsingDF: whether to pass the input dataset via PySpark 
DataFrame (default: False)
         weights: directory whether learned weights are stored (default: None)
         labels: file containing mapping between index and string labels 
(default: None)
+        batch_size: size of the input batch (default: 64)
+        max_iter: maximum number of iterations (default: 1)
+        test_iter: test_iter for caffe solver (default: 10)
+        test_interval: test_interval for caffe solver (default: 500)
+        display: display for caffe solver (default: 100)
+        lr_policy: learning rate policy for caffe solver (default: "step")
+        weight_decay: regularation strength (default: 5e-4)
+        regularization_type: regularization type (default: "L2")
         """
         from .keras2caffe import *
         import tempfile
@@ -906,8 +918,10 @@ class Keras2DML(Caffe2DML):
             keras_model = keras_model.model
         self.name = keras_model.name
         createJavaObject(sparkSession._sc, 'dummy')
-        convertKerasToCaffeNetwork(keras_model, self.name + ".proto")
-        convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name 
+ "_solver.proto")
+        if not hasattr(keras_model, 'optimizer'):
+                   keras_model.compile(loss='categorical_crossentropy', 
optimizer=keras.optimizers.SGD(lr=0.01, momentum=0.95, decay=5e-4, 
nesterov=True))
+        convertKerasToCaffeNetwork(keras_model, self.name + ".proto", 
int(batch_size))
+        convertKerasToCaffeSolver(keras_model, self.name + ".proto", self.name 
+ "_solver.proto", int(max_iter), int(test_iter), int(test_interval), 
int(display), lr_policy, weight_decay, regularization_type)
         self.weights = tempfile.mkdtemp() if weights is None else weights
         convertKerasToSystemMLModel(sparkSession, keras_model, self.weights)
         if labels is not None and (labels.startswith('https:') or 
labels.startswith('http:')):

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/python/systemml/mllearn/keras2caffe.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/keras2caffe.py 
b/src/main/python/systemml/mllearn/keras2caffe.py
index 3cf710c..81fb63a 100755
--- a/src/main/python/systemml/mllearn/keras2caffe.py
+++ b/src/main/python/systemml/mllearn/keras2caffe.py
@@ -28,6 +28,8 @@ from itertools import chain, imap
 from ..converters import *
 from ..classloader import *
 import keras
+from keras import backend as K
+from keras.layers import Activation
 
 try:
     import py4j.java_gateway
@@ -137,7 +139,7 @@ def _parseKerasLayer(layer):
        param = layerParamMapping[layerType](layer)
        paramName = param.keys()[0]
        if layerType == keras.layers.InputLayer:
-               ret = { 'layer': { 'name':layer.name, 'type':'Data', 
'top':layer.name, paramName:param[paramName] } }
+               ret = { 'layer': { 'name':layer.name, 'type':'Data', 
paramName:param[paramName], 'top':layer.name, 'top':'label' } }
        else:
                ret = { 'layer': { 'name':layer.name, 
'type':supportedLayers[layerType], 'bottom':_getBottomLayers(layer), 
'top':layer.name, paramName:param[paramName] } }
        return [ ret, _parseActivation(layer, layer.name + '_activation') ] if 
_shouldParseActivation(layer)  else [ ret ]
@@ -155,8 +157,6 @@ specialLayers = {
     keras.layers.BatchNormalization: _parseBatchNorm
     }
        
-batchSize = 64
-
 def getConvParam(layer):
        stride = (1, 1) if layer.strides is None else layer.strides
        padding = [layer.kernel_size[0] / 2, layer.kernel_size[1] / 2] if 
layer.padding == 'same' else [0, 0]
@@ -181,7 +181,7 @@ def getRecurrentParam(layer):
 # TODO: Update AveragePooling2D when we add maxpooling support 
 layerParamMapping = {
     keras.layers.InputLayer: lambda l: \
-        {'data_param': {'batch_size': batchSize}},
+        {'data_param': {'batch_size': l.batch_size}},
     keras.layers.Dense: lambda l: \
         {'inner_product_param': {'num_output': l.units}},
     keras.layers.Dropout: lambda l: \
@@ -210,16 +210,58 @@ def _checkIfValid(myList, fn, errorMessage):
        if len(unsupported_elems) != 0:
                raise ValueError(errorMessage + 
str(np.array(myList)[unsupported_elems]))
 
-def convertKerasToCaffeNetwork(kerasModel, outCaffeNetworkFilePath):
+def _transformLayer(layer, batch_size):
+       if type(layer) == keras.layers.InputLayer:
+               layer.batch_size = batch_size
+       return [ layer ]
+
+def _appendKerasLayers(fileHandle, kerasLayers, batch_size):
+       if len(kerasLayers) >= 1:
+               transformedLayers = list(chain.from_iterable(imap(lambda layer: 
_transformLayer(layer, batch_size), kerasLayers)))  
+               jsonLayers = list(chain.from_iterable(imap(lambda layer: 
_parseKerasLayer(layer), transformedLayers)))
+               parsedLayers = list(chain.from_iterable(imap(lambda layer: 
_parseJSONObject(layer), jsonLayers)))
+               fileHandle.write(''.join(parsedLayers))
+               fileHandle.write('\n')
+       
+def lossLayerStr(layerType, bottomLayer):
+       return 'layer {\n  name: "loss"\n  type: "' + layerType + '"\n  bottom: 
"' + bottomLayer + '"\n  bottom: "label"\n  top: "loss"\n}\n'
+       
+def _appendKerasLayerWithoutActivation(fileHandle, layer, batch_size):
+       if type(layer) != keras.layers.Activation:
+               lastLayerActivation = layer.activation
+               layer.activation = keras.activations.linear
+               _appendKerasLayers(fileHandle, [layer], batch_size)
+               layer.activation = lastLayerActivation
+
+def _getExactlyOneBottomLayer(layer):
+       bottomLayers = _getBottomLayers(layer)
+       if len(bottomLayers) != 1:
+               raise Exception('Expected only one bottom layer for ' + 
str(layer.name) + ', but found ' + str(bottomLayers))
+       return bottomLayers[0]
+
+def _isMeanSquaredError(loss):
+       return loss == 'mean_squared_error' or loss == 'mse' or loss == 'MSE' 
+       
+def convertKerasToCaffeNetwork(kerasModel, outCaffeNetworkFilePath, 
batch_size):
        _checkIfValid(kerasModel.layers, lambda layer: False if type(layer) in 
supportedLayers else True, 'Unsupported Layers:')
-       #unsupported_layers = np.array([False if type(layer) in supportedLayers 
else True for layer in kerasModel.layers])
-       #if len(np.where(unsupported_layers)[0]) != 0:
-       #       raise TypeError('Unsupported Layers:' + 
str(np.array(kerasModel.layers)[np.where(unsupported_layers)[0]]))
-       # Core logic: model.layers.flatMap(layer => 
_parseJSONObject(_parseKerasLayer(layer)))
-       jsonLayers = list(chain.from_iterable(imap(lambda layer: 
_parseKerasLayer(layer), kerasModel.layers)))
-       parsedLayers = list(chain.from_iterable(imap(lambda layer: 
_parseJSONObject(layer), jsonLayers)))
        with open(outCaffeNetworkFilePath, 'w') as f:
-               f.write(''.join(parsedLayers))
+               # Write the parsed layers for all but the last layer
+               _appendKerasLayers(f, kerasModel.layers[:-1], batch_size)
+               # Now process the last layer with loss
+               lastLayer = kerasModel.layers[-1]
+               if _isMeanSquaredError(kerasModel.loss):
+                       _appendKerasLayers(f, [ lastLayer ], batch_size)
+                       f.write(lossLayerStr('EuclideanLoss', lastLayer.name))
+               elif kerasModel.loss == 'categorical_crossentropy':
+                       _appendKerasLayerWithoutActivation(f, lastLayer, 
batch_size)
+                       bottomLayer = _getExactlyOneBottomLayer(lastLayer) if 
type(lastLayer) == keras.layers.Activation else lastLayer.name  
+                       lastLayerActivation = 
str(keras.activations.serialize(lastLayer.activation))
+                       if lastLayerActivation == 'softmax' and kerasModel.loss 
== 'categorical_crossentropy':
+                               f.write(lossLayerStr('SoftmaxWithLoss', 
bottomLayer))
+                       else:
+                               raise Exception('Unsupported loss layer ' + 
str(kerasModel.loss) + ' (where last layer activation ' + lastLayerActivation + 
').')
+               else:
+                       raise Exception('Unsupported loss layer ' + 
str(kerasModel.loss) + ' (where last layer activation ' + lastLayerActivation + 
').')
 
 
 def getNumPyMatrixFromKerasWeight(param):
@@ -234,23 +276,52 @@ def getNumPyMatrixFromKerasWeight(param):
 
 
 defaultSolver = """
-base_lr: 0.01
-momentum: 0.9
-weight_decay: 5e-4
-lr_policy: "exp"
-gamma: 0.95
-display: 100
 solver_mode: CPU
-type: "SGD"
-max_iter: 2000
-test_iter: 10
-test_interval: 500
 """
 
-def convertKerasToCaffeSolver(kerasModel, caffeNetworkFilePath, 
outCaffeSolverFilePath):
+def evaluateValue(val):
+       if type(val) == int or type(val) == float:
+               return float(val)
+       else:
+               return K.eval(val)
+       
+def convertKerasToCaffeSolver(kerasModel, caffeNetworkFilePath, 
outCaffeSolverFilePath, max_iter, test_iter, test_interval, display, lr_policy, 
weight_decay, regularization_type):
+       if type(kerasModel.optimizer) == keras.optimizers.SGD:
+               solver = 'type: "Nesterov"\n' if kerasModel.optimizer.nesterov 
else 'type: "SGD"\n'
+       elif type(kerasModel.optimizer) == keras.optimizers.Adagrad:
+               solver = 'type: "Adagrad"\n'
+       elif type(kerasModel.optimizer) == keras.optimizers.Adam:
+               solver = 'type: "Adam"\n'
+       else:
+               raise Exception('Only sgd (with/without momentum/nesterov), 
Adam and Adagrad supported.')
+       base_lr = evaluateValue(kerasModel.optimizer.lr) if 
hasattr(kerasModel.optimizer, 'lr') else 0.01
+       gamma = evaluateValue(kerasModel.optimizer.decay) if 
hasattr(kerasModel.optimizer, 'decay') else 0.0
        with open(outCaffeSolverFilePath, 'w') as f:
                f.write('net: "' + caffeNetworkFilePath + '"\n')
                f.write(defaultSolver)
+               f.write(solver)
+               f.write('lr_policy: "' + lr_policy + '"\n')
+               f.write('regularization_type: "' + str(regularization_type) + 
'"\n')
+               f.write('weight_decay: ' + str(weight_decay) + '\n')
+               f.write('max_iter: ' + str(max_iter) + '\ntest_iter: ' + 
str(test_iter) + '\ntest_interval: ' + str(test_interval) + '\n')
+               f.write('display: ' + str(display) + '\n')
+               f.write('base_lr: ' + str(base_lr) + '\n')
+               f.write('gamma: ' + str(gamma) + '\n')
+               if type(kerasModel.optimizer) == keras.optimizers.SGD:
+                       momentum = evaluateValue(kerasModel.optimizer.momentum) 
if hasattr(kerasModel.optimizer, 'momentum') else 0.0
+                       f.write('momentum: ' + str(momentum) + '\n')
+               elif type(kerasModel.optimizer) == keras.optimizers.Adam:
+                       momentum = evaluateValue(kerasModel.optimizer.beta_1) 
if hasattr(kerasModel.optimizer, 'beta_1') else 0.9
+                       momentum2 = evaluateValue(kerasModel.optimizer.beta_2) 
if hasattr(kerasModel.optimizer, 'beta_2') else 0.999
+                       delta = evaluateValue(kerasModel.optimizer.epsilon) if 
hasattr(kerasModel.optimizer, 'epsilon') else 1e-8
+                       f.write('momentum: ' + str(momentum) + '\n')
+                       f.write('momentum2: ' + str(momentum2) + '\n')
+                       f.write('delta: ' + str(delta) + '\n')
+               elif type(kerasModel.optimizer) == keras.optimizers.Adagrad:
+                       delta = evaluateValue(kerasModel.optimizer.epsilon) if 
hasattr(kerasModel.optimizer, 'epsilon') else 1e-8
+                       f.write('delta: ' + str(delta) + '\n')
+               else:
+                       raise Exception('Only sgd (with/without 
momentum/nesterov), Adam and Adagrad supported.')
 
 
 def getInputMatrices(layer):

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 789d08a..0a215b1 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -357,6 +357,10 @@ class Caffe2DML(val sc: SparkContext,
     
     System.out.println("* => memory in megabytes assuming the parameters 
(input, output activations, weights and backpropagation errors) are in double 
precision and in dense format.")
   }
+  
+  def setDebugFlags(isDebug:Boolean):Unit = {
+    net.getLayers.map(layer => {net.getCaffeLayer(layer).debugLayer = isDebug})
+  }
 
   // 
================================================================================================
   // The below method parses the provided network and solver file and 
generates DML script.
@@ -368,9 +372,11 @@ class Caffe2DML(val sc: SparkContext,
     // Flags passed by user
     val DEBUG_TRAINING = if (inputs.containsKey("$debug")) 
inputs.get("$debug").toLowerCase.toBoolean else false
     assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
+    setDebugFlags(DEBUG_TRAINING)
 
     appendHeaders(net, solver, true) // Appends DML corresponding to source 
and externalFunction statements.
-    readInputData(net, true)         // Read X_full and y_full
+    val performOneHotEncoding = 
!inputs.containsKey("$perform_one_hot_encoding") || 
inputs.get("$perform_one_hot_encoding").toBoolean
+    readInputData(net, true, performOneHotEncoding)         // Read X_full and 
y_full
     // Initialize the layers and solvers. Reads weights and bias if $weights 
is set.
     initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
 
@@ -389,7 +395,7 @@ class Caffe2DML(val sc: SparkContext,
     // 
----------------------------------------------------------------------------
     // Main logic
     forBlock("iter", "1", "max_iter") {
-      performTrainingIter(lossLayers, shouldValidate)
+      performTrainingIter(lossLayers, shouldValidate, performOneHotEncoding)
       if (getTrainAlgo.toLowerCase.equals("batch")) {
         assign(tabDMLScript, "e", "iter")
         tabDMLScript.append("# Learning rate\n")
@@ -414,11 +420,14 @@ class Caffe2DML(val sc: SparkContext,
     val script = dml(trainingScript).in(inputs)
     net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => 
script.out(l.weight))
     net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => 
script.out(l.bias))
+    
+    setDebugFlags(false)
+    
     (script, "X_full", "y_full")
   }
   // 
================================================================================================
 
-  private def performTrainingIter(lossLayers: List[IsLossLayer], 
shouldValidate: Boolean): Unit =
+  private def performTrainingIter(lossLayers: List[IsLossLayer], 
shouldValidate: Boolean, performOneHotEncoding:Boolean): Unit =
     getTrainAlgo.toLowerCase match {
       case "minibatch" =>
         getTrainingBatch(tabDMLScript)
@@ -426,14 +435,14 @@ class Caffe2DML(val sc: SparkContext,
         // Perform forward, backward and update on minibatch
         forward; backward; update
         // -------------------------------------------------------
-        displayLoss(lossLayers(0), shouldValidate)
+        displayLoss(lossLayers(0), shouldValidate, performOneHotEncoding)
         performSnapshot
       case "batch" => {
         // -------------------------------------------------------
         // Perform forward, backward and update on entire dataset
         forward; backward; update
         // -------------------------------------------------------
-        displayLoss(lossLayers(0), shouldValidate)
+        displayLoss(lossLayers(0), shouldValidate, performOneHotEncoding)
         performSnapshot
       }
       case "allreduce_parallel_batches" => {
@@ -469,7 +478,7 @@ class Caffe2DML(val sc: SparkContext,
           // -------------------------------------------------------
           assign(tabDMLScript, "Xb", "X_group_batch")
           assign(tabDMLScript, "yb", "y_group_batch")
-          displayLoss(lossLayers(0), shouldValidate)
+          displayLoss(lossLayers(0), shouldValidate, performOneHotEncoding)
           performSnapshot
         }
       }
@@ -496,7 +505,7 @@ class Caffe2DML(val sc: SparkContext,
         // -------------------------------------------------------
         assign(tabDMLScript, "Xb", "X_group_batch")
         assign(tabDMLScript, "yb", "y_group_batch")
-        displayLoss(lossLayers(0), shouldValidate)
+        displayLoss(lossLayers(0), shouldValidate, performOneHotEncoding)
         performSnapshot
       }
       case _ => throw new DMLRuntimeException("Unsupported train algo:" + 
getTrainAlgo)
@@ -537,7 +546,7 @@ class Caffe2DML(val sc: SparkContext,
     }
 
   // Append the DML to display training and validation loss
-  private def displayLoss(lossLayer: IsLossLayer, shouldValidate: Boolean): 
Unit = {
+  private def displayLoss(lossLayer: IsLossLayer, shouldValidate: Boolean, 
performOneHotEncoding:Boolean): Unit = {
     if (solverParam.getDisplay > 0) {
       // Append the DML to compute training loss
       if (!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
@@ -550,7 +559,9 @@ class Caffe2DML(val sc: SparkContext,
           tabDMLScript.append(
             print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", 
training loss:"), "training_loss", asDMLString(", training accuracy:"), 
"training_accuracy"))
           )
-          printClassificationReport
+          if(performOneHotEncoding) {
+            printClassificationReport
+          }
         }
       } else {
         Caffe2DML.LOG.info("Training loss is not printed for train_algo=" + 
getTrainAlgo)
@@ -743,9 +754,11 @@ class Caffe2DMLModel(val numClasses: String, val sc: 
SparkContext, val solver: C
 
     val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) 
estimator.inputs.get("$debug").toLowerCase.toBoolean else false
     assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
+    estimator.setDebugFlags(DEBUG_PREDICTION)
 
     appendHeaders(net, solver, false) // Appends DML corresponding to source 
and externalFunction statements.
-    readInputData(net, false)         // Read X_full and y_full
+    val performOneHotEncoding = 
!estimator.inputs.containsKey("$perform_one_hot_encoding") || 
estimator.inputs.get("$perform_one_hot_encoding").toBoolean
+    readInputData(net, false, performOneHotEncoding)         // Read X_full 
and y_full
     assign(tabDMLScript, "X", "X_full")
 
     // Initialize the layers and solvers. Reads weights and bias if 
readWeights is true.
@@ -837,6 +850,9 @@ class Caffe2DMLModel(val numClasses: String, val sc: 
SparkContext, val solver: C
       net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l 
=> script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
       net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => 
script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
     }
+    
+    estimator.setDebugFlags(false)
+    
     (script, "X_full")
   }
   // 
================================================================================================

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala 
b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index cdecdce..65a9921 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -42,6 +42,22 @@ trait CaffeLayer extends BaseDMLGenerator {
     computedOutputShape
   }
   // -------------------------------------------------
+  var debugLayer = false
+  def validateDimensions(dmlScript: StringBuilder, mat:String, 
expectedNumRows:String, expectedNumCols:String, optionalString:String=""):Unit 
= {
+    if(debugLayer) {
+      val msg = " in " + sourceFileName + "(" + optionalString + ") script."
+      if(expectedNumRows != null) {
+        dmlScript.append("\nif( " + expectedNumRows + " != nrow(" + mat + ")) 
{\n")
+        dmlScript.append("\tstop(\"Incorrect number of rows for " + mat + msg 
+ " Expected:\" + " + expectedNumRows + " + \" but found \" +  nrow(" + mat + 
") )") 
+        dmlScript.append("\n}\n")
+      }
+      if(expectedNumCols != null) {
+        dmlScript.append("\nif( " + expectedNumCols + " != ncol(" + mat + ")) 
{\n")
+        dmlScript.append("\tstop(\"Incorrect number of columns for " + mat + 
msg + " Expected:\" + " + expectedNumCols + " + \" but found \" +  ncol(" + mat 
+ ") )") 
+        dmlScript.append("\n}\n")
+      }
+    }
+  }
   var computedBottomLayerOutputShape: (String, String, String) = null
   def bottomLayerOutputShape: (String, String, String) = {
     if (computedBottomLayerOutputShape == null) {
@@ -532,27 +548,25 @@ class Concat(val param: LayerParameter, val id: Int, val 
net: CaffeNetwork) exte
 
 // L2 loss function.
 class EuclideanLoss(val param: LayerParameter, val id: Int, val net: 
CaffeNetwork) extends CaffeLayer with IsLossLayer {
-  override def sourceFileName: String = if (!isSegmentationProblem()) 
"l2_loss" else throw new DMLRuntimeException("Segmentation is not supported for 
EuclideanLoss in Caffe2DML yet")
+  override def sourceFileName: String = "l2_loss"
   override def weightShape(): Array[Int] = null
   override def biasShape(): Array[Int]   = null
   
-  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
-    invokeForward(dmlScript, List[String](out), scores, "yb")
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = 
+    assign(dmlScript, out, scores)
   
-  override def backward(dmlScript: StringBuilder,outSuffix: String): Unit = 
+  override def backward(dmlScript: StringBuilder,outSuffix: String): Unit =  {
+      invokeForward(dmlScript, List[String](out), scores, "yb")
       invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id + 
outSuffix), scores, "yb")
-  
-  override def computeLoss(dmlScript: StringBuilder,numTabs: Int): Unit =
-    if (!isSegmentationProblem()) {
-      val tabBuilder = new StringBuilder
-      for (i <- 0 until numTabs) tabBuilder.append("\t")
-      val tabs = tabBuilder.toString
-      dmlScript.append("tmp_loss = l2_loss::forward(" + commaSep(out, "yb") + 
")\n")
-      dmlScript.append(tabs).append("loss = loss + tmp_loss\n")
-      dmlScript.append(tabs).append("accuracy = -1\n")
-    } else {
-      throw new RuntimeException("Computation of loss for SoftmaxWithLoss is 
not implemented for segmentation problem")
-    }
+  }
+  override def computeLoss(dmlScript: StringBuilder,numTabs: Int): Unit = {
+    val tabBuilder = new StringBuilder
+    for (i <- 0 until numTabs) tabBuilder.append("\t")
+    val tabs = tabBuilder.toString
+    invokeForward(dmlScript, List[String]("tmp_loss"), scores, "yb")
+    dmlScript.append(tabs).append("loss = loss + tmp_loss\n")
+    dmlScript.append(tabs).append("accuracy = -1\n")
+  }
 }
 
 class SoftmaxWithLoss(val param: LayerParameter, val id: Int, val net: 
CaffeNetwork) extends CaffeLayer with IsLossLayer {
@@ -853,8 +867,14 @@ class InnerProduct(val param: LayerParameter, val id: Int, 
val net: CaffeNetwork
    * Outputs:
    *  - out: Outputs, of shape (N, M).
    */
-  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = {
+    val D = numFeatures
+    val M = numNeurons
+    validateDimensions(dmlScript, X, null, D)
+    validateDimensions(dmlScript, weight, D, M, "forward")
+    validateDimensions(dmlScript, bias, "1", M)
     invokeForward(dmlScript, List[String](out), X, weight, bias)
+  }
   /*
    * Computes the backward pass for a fully-connected (affine) layer
    * with M neurons.
@@ -870,8 +890,15 @@ class InnerProduct(val param: LayerParameter, val id: Int, 
val net: CaffeNetwork
    *  - dW: Gradient wrt `W`, of shape (D, M).
    *  - db: Gradient wrt `b`, of shape (1, M).
    */
-  override def backward(dmlScript: StringBuilder, outSuffix: String) =
+  override def backward(dmlScript: StringBuilder, outSuffix: String) = {
+    val D = numFeatures
+    val M = numNeurons
+    validateDimensions(dmlScript, dout, null, M)
+    validateDimensions(dmlScript, X, null, D)
+    validateDimensions(dmlScript, weight, D, M, "backward")
+    validateDimensions(dmlScript, bias, "1", M)
     invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, 
dBias), dout, X, weight, bias)
+  }
   // -------------------------------------------------
   // num_output (c_o): the number of filters
   def numNeurons  = param.getInnerProductParam.getNumOutput.toString
@@ -935,15 +962,44 @@ class LSTM(val param: LayerParameter, val id: Int, val 
net: CaffeNetwork) extend
   
   override def init(dmlScript: StringBuilder) = {
     invokeInit(dmlScript, List[String](weight, bias, out0, c0), 
Caffe2DML.batchSize, input_features, M)
+    // Also, initialize gradient wrt `c` to empty matrix 
+    dmlScript.append(dc0 + " = matrix(0, rows=" + Caffe2DML.batchSize + ", 
cols=" + M + ")\n")
   }
   
   override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = {
-    invokeForward(dmlScript, List[String](out, c, cache_out, cache_c, 
cache_ifog), X, weight, bias, timesteps, input_features, 
return_sequences.toString.toUpperCase, out0, c0)
+    val N:String = null // output_features.toString
+    val T = timesteps()
+    val D = input_features()
+    validateDimensions(dmlScript, X, N, T + "*" + D)
+    validateDimensions(dmlScript, out0, N, M)
+    validateDimensions(dmlScript, c0, N, M)
+    validateDimensions(dmlScript, weight, D + "+" + M, 4 + "*" + M)
+    validateDimensions(dmlScript, bias, "1", 4 + "*" + M)
+    invokeForward(dmlScript, List[String](out, c, cache_out, cache_c, 
cache_ifog), X, weight, bias, T, D, return_sequences.toString.toUpperCase, 
out0, c0)
+    // This validates whether the output is of correct dimensions
+    validateDimensions(dmlScript, out, null, int_mult(outputShape._1, 
outputShape._2, outputShape._3))
   }
   
   override def backward(dmlScript: StringBuilder, outSuffix: String) = {
+    val T = timesteps()
+    val D = input_features()
+    if(return_sequences) {
+      validateDimensions(dmlScript, dout, null, T + "*" + M)
+    }
+    else {
+      validateDimensions(dmlScript, dout, null, M)
+    }
+    validateDimensions(dmlScript, dc0, null, M)
+    validateDimensions(dmlScript, X, null, T + "*" + D)
+    validateDimensions(dmlScript, out0, null, M)
+    validateDimensions(dmlScript, c0, null, M)
+    validateDimensions(dmlScript, cache_out, T, null)
+    validateDimensions(dmlScript, cache_c, T, null)
+    validateDimensions(dmlScript, cache_ifog, T, null)
+    validateDimensions(dmlScript, weight, D + "+" + M, 4 + "*" + M)
+    validateDimensions(dmlScript, bias, "1", 4 + "*" + M)
     invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, 
dBias, dout0, dc0), dout, dc0, X, weight, bias,
-        timesteps, input_features, return_sequences.toString.toUpperCase, 
out0, c0, cache_out, cache_c, cache_ifog)
+        T, D, return_sequences.toString.toUpperCase, out0, c0, cache_out, 
cache_c, cache_ifog)
   }
   
   val cache_out = "cache_out_" + id

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala 
b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index a61ff10..d0d738e 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -52,12 +52,25 @@ trait CaffeSolver {
       ret
     }
 
-  def l2reg_update(lambda: Double, dmlScript: StringBuilder, layer: 
CaffeLayer): Unit =
+  def regularization_update(regularizationType:String, lambda: Double, 
dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     // val donotRegularizeLayers:Boolean = layer.isInstanceOf[BatchNorm] || 
layer.isInstanceOf[Scale];
+    val regularizationSource = 
+      if(regularizationType.toLowerCase.equals("l2")) "l2_reg"
+      else if(regularizationType.toLowerCase.equals("l1")) "l1_reg"
+      else null
+    if(regularizationSource == null) {
+      throw new DMLRuntimeException("Unsupported regularization_type:" + 
regularizationType + ". Please use either L2 or L1.")
+    }
+    
     if (lambda != 0 && layer.shouldUpdateWeight) {
-      dmlScript.append("\t").append(layer.dWeight + "_reg = l2_reg::backward(" 
+ layer.weight + ", " + lambda + ")\n")
+      // Use layer-specific decay multiplier, if param { lr_mult: 1 
decay_mult: 1 } is specified in the network file
+      val hasDecayMult = layer.param.getParamList != null && 
layer.param.getParamList.size >= 1 && 
layer.param.getParamList.get(0).hasDecayMult
+      val newLambda = if(hasDecayMult) 
layer.param.getParamList.get(0).getDecayMult * lambda else lambda
+      
+      dmlScript.append("\t").append(layer.dWeight + "_reg = " + 
regularizationSource + "::backward(" + layer.weight + ", " + newLambda + ")\n")
       dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " 
+ " + layer.dWeight + "_reg\n")
     }
+  }
 }
 
 class LearningRatePolicy(lr_policy: String = "exp", base_lr: Double = 0.01) {
@@ -87,7 +100,7 @@ class LearningRatePolicy(lr_policy: String = "exp", base_lr: 
Double = 0.01) {
   }
 }
 
-class SGD(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
+class SGD(regularizationType:String = "L2", lambda: Double = 5e-04, momentum: 
Double = 0.9) extends CaffeSolver {
   /*
    * Performs an SGD update with momentum.
    *
@@ -112,7 +125,7 @@ class SGD(lambda: Double = 5e-04, momentum: Double = 0.9) 
extends CaffeSolver {
    *      input `X`.
    */
   def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
-    l2reg_update(lambda, dmlScript, layer)
+    regularization_update(regularizationType, lambda, dmlScript, layer)
     if (momentum == 0) {
       // Use sgd
       if (layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight 
+ " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) 
+ ")\n")
@@ -143,7 +156,7 @@ class SGD(lambda: Double = 5e-04, momentum: Double = 0.9) 
extends CaffeSolver {
   def sourceFileName: String = if (momentum == 0) "sgd" else "sgd_momentum"
 }
 
-class AdaGrad(lambda: Double = 5e-04, epsilon: Double = 1e-6) extends 
CaffeSolver {
+class AdaGrad(regularizationType:String = "L2", lambda: Double = 5e-04, 
epsilon: Double = 1e-6) extends CaffeSolver {
   /*
    * Performs an Adagrad update.
    *
@@ -172,7 +185,7 @@ class AdaGrad(lambda: Double = 5e-04, epsilon: Double = 
1e-6) extends CaffeSolve
    *      gradients, of same shape as `X`.
    */
   def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
-    l2reg_update(lambda, dmlScript, layer)
+    regularization_update(regularizationType, lambda, dmlScript, layer)
     if (layer.shouldUpdateWeight)
       dmlScript
         .append("\t")
@@ -195,7 +208,73 @@ class AdaGrad(lambda: Double = 5e-04, epsilon: Double = 
1e-6) extends CaffeSolve
   def sourceFileName: String = "adagrad"
 }
 
-class Nesterov(lambda: Double = 5e-04, momentum: Double = 0.9) extends 
CaffeSolver {
+class Adam(regularizationType:String = "L2", lambda: Double = 5e-04, 
momentum:Double = 0.9, momentum2:Double = 0.999, delta:Double = 1e-8) extends 
CaffeSolver {
+  /*
+   * Performs an Adam update.
+   *
+   * Reference:
+   *  - Adam: A Method for Stochastic Optimization, Kingma, Ba.
+   *    - http://arxiv.org/abs/1412.6980
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient wrt `X` of a loss function being optimized, of
+   *      same shape as `X`.
+   *  - lr: Learning rate.  Recommended value is 0.001.
+   *  - beta1: Exponential decay rate for the 1st moment estimates.
+   *      Recommended value is 0.9.
+   *  - beta2: Exponential decay rate for the 2nd moment estimates.
+   *      Recommended value is 0.999.
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Recommended value is 1e-8.
+   *  - t: Timestep, starting at 0.
+   *  - m: State containing the 1st moment (mean) estimate by
+   *      maintaining exponential moving averages of the gradients, of
+   *      same shape as `X`.
+   *  - v: State containing the 2nd raw moment (uncentered variance)
+   *      estimate by maintaining exponential moving averages of the
+   *      squared gradients, of same shape as `X`.
+   *
+   * Outputs:
+   *  - X: Updated parameters `X`, of same shape as input `X`.
+   *  - m: Updated state containing the 1st moment (mean) estimate by
+   *      maintaining exponential moving averages of the gradients, of
+   *      same shape as `X`.
+   *  - v: Updated state containing the 2nd raw moment (uncentered
+   *      variance) estimate by maintaining exponential moving averages
+   *      of the squared gradients, of same shape as `X`.
+   */
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    regularization_update(regularizationType, lambda, dmlScript, layer)
+    val t:String = "iter - 1" // since iter starts with 0
+    // X, dX, double lr, double beta1, double beta2, epsilon, int t, 
matrix[double] m, matrix[double] v
+    if (layer.shouldUpdateWeight)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.weight, layer.weight + "_m", layer.weight + 
"_v") + "] " +
+          "= adam::update(" + commaSep(layer.weight, layer.dWeight, 
getWeightLr(layer), 
+              momentum.toString, momentum2.toString, delta.toString,  t,
+              layer.weight + "_m", layer.weight + "_v") + ")\n"
+        )
+    if (layer.shouldUpdateBias)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.bias, layer.bias + "_m", layer.bias + "_v") + 
"] " +
+          "= adam::update(" + commaSep(layer.bias, layer.dBias, 
getBiasLr(layer), 
+              momentum.toString, momentum2.toString, delta.toString,  t, 
+              layer.weight + "_m", layer.weight + "_v") + ")\n"
+        )
+  }
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    if (layer.shouldUpdateWeight) dmlScript.append("[ " + layer.weight + "_m, 
" + layer.weight + "_v ] = adam::init(" + layer.weight + ")\n")
+    if (layer.shouldUpdateBias) dmlScript.append("[ " + layer.bias + "_m, " + 
layer.bias + "_v ] = adam::init(" + layer.bias + ")\n")
+  }
+  def sourceFileName: String = "adam"
+}
+
+class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, 
momentum: Double = 0.9) extends CaffeSolver {
   /*
    * Performs an SGD update with Nesterov momentum.
    *
@@ -232,7 +311,7 @@ class Nesterov(lambda: Double = 5e-04, momentum: Double = 
0.9) extends CaffeSolv
     val fn            = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else 
"sgd_nesterov::update"
     val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
     if (!Caffe2DML.USE_NESTEROV_UDF) {
-      l2reg_update(lambda, dmlScript, layer)
+      regularization_update(regularizationType, lambda, dmlScript, layer)
     }
     if (layer.shouldUpdateWeight)
       dmlScript

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala 
b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 304f788..0231354 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -261,6 +261,7 @@ trait DMLGenerator extends SourceDMLGenerator with 
NextBatchGenerator {
   // Appends DML corresponding to source and externalFunction statements.
   def appendHeaders(net: CaffeNetwork, solver: CaffeSolver, isTraining: 
Boolean): Unit = {
     // Append source statements for layers as well as solver
+    source(net, solver, if (isTraining) Array[String]("l1_reg") else null)
     source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
 
     if (isTraining) {
@@ -284,14 +285,16 @@ trait DMLGenerator extends SourceDMLGenerator with 
NextBatchGenerator {
     assign(tabDMLScript, varName, "read(" + pathVar + ")")
   }
 
-  def readInputData(net: CaffeNetwork, isTraining: Boolean): Unit = {
+  def readInputData(net: CaffeNetwork, isTraining: Boolean, 
performOneHotEncoding:Boolean): Unit = {
     // Read and convert to one-hot encoding
     readMatrix("X_full", "$X")
     if (isTraining) {
       readMatrix("y_full", "$y")
       tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
-      tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based 
labels) \n")
-      tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + 
",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+      if(performOneHotEncoding) {
+        tabDMLScript.append("# Convert to one-hot encoding (Assumption: 
1-based labels) \n")
+        tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + 
",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+      }
     } else {
       tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
     }

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala 
b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 19596c3..5939cf1 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -71,12 +71,14 @@ object Utils {
     val momentum = if (solver.hasMomentum) solver.getMomentum else 0.0
     val lambda   = if (solver.hasWeightDecay) solver.getWeightDecay else 0.0
     val delta    = if (solver.hasDelta) solver.getDelta else 0.0
+    val regularizationType = if(solver.hasRegularizationType) 
solver.getRegularizationType else "L2"
 
     solver.getType.toLowerCase match {
-      case "sgd"      => new SGD(lambda, momentum)
-      case "adagrad"  => new AdaGrad(lambda, delta)
-      case "nesterov" => new Nesterov(lambda, momentum)
-      case _          => throw new DMLRuntimeException("The solver type is not 
supported: " + solver.getType + ". Try: SGD, AdaGrad or Nesterov.")
+      case "sgd"      => new SGD(regularizationType, lambda, momentum)
+      case "adagrad"  => new AdaGrad(regularizationType, lambda, delta)
+      case "nesterov" => new Nesterov(regularizationType, lambda, momentum)
+      case "adam"        => new Adam(regularizationType, lambda, momentum, 
if(solver.hasMomentum2) solver.getMomentum2 else 0.0, delta)
+      case _          => throw new DMLRuntimeException("The solver type is not 
supported: " + solver.getType + ". Try: SGD, AdaGrad or Nesterov or Adam.")
     }
 
   }

http://git-wip-us.apache.org/repos/asf/systemml/blob/9dc354ac/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala 
b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index ce92321..97abe9e 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -257,7 +257,7 @@ trait BaseSystemMLClassifierModel extends 
BaseSystemMLEstimatorModel {
     val freeMem = Runtime.getRuntime().freeMemory();
     if(freeMem < OptimizerUtils.getLocalMemBudget()) {
        val LOG = 
LogFactory.getLog(classOf[BaseSystemMLClassifierModel].getName())
-       LOG.warn("SystemML local memory budget:" + 
OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " mb. Approximate 
free memory abailable:" + OptimizerUtils.toMB(freeMem));
+       LOG.warn("SystemML local memory budget:" + 
OptimizerUtils.toMB(OptimizerUtils.getLocalMemBudget()) + " mb. Approximate 
free memory available:" + OptimizerUtils.toMB(freeMem));
     }
     val ret = (new 
MLContext(sc)).execute(script1).getMatrix("Prediction").toMatrixBlock
 

Reply via email to