[SYSTEMML-540] Support loading of batch normalization weights in .caffemodel 
file using Caffe2DML

- Also fixed scala formatting.

Closes #662.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f07b5a2d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f07b5a2d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f07b5a2d

Branch: refs/heads/master
Commit: f07b5a2d92f95f28bcdf141d700fc1be0887d735
Parents: ebb6ea6
Author: Niketan Pansare <[email protected]>
Authored: Fri Sep 15 11:00:06 2017 -0700
Committer: Niketan Pansare <[email protected]>
Committed: Fri Sep 15 11:01:49 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  732 ++++++-------
 .../apache/sysml/api/dl/Caffe2DMLLoader.scala   |   20 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 1002 ++++++++++--------
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |  216 ++--
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  193 ++--
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  566 +++++-----
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  484 ++++-----
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  264 +++--
 .../sysml/api/ml/BaseSystemMLRegressor.scala    |   68 +-
 .../apache/sysml/api/ml/LinearRegression.scala  |   93 +-
 .../sysml/api/ml/LogisticRegression.scala       |  117 +-
 .../org/apache/sysml/api/ml/NaiveBayes.scala    |   62 +-
 .../apache/sysml/api/ml/PredictionUtils.scala   |   32 +-
 .../scala/org/apache/sysml/api/ml/SVM.scala     |   81 +-
 .../org/apache/sysml/api/ml/ScriptsUtils.scala  |   18 +-
 .../scala/org/apache/sysml/api/ml/Utils.scala   |   49 +-
 16 files changed, 2100 insertions(+), 1897 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index a62fae2..6e3e1dc 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -35,10 +35,10 @@ import java.util.HashSet
 import org.apache.sysml.api.DMLScript
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -55,7 +55,7 @@ import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 DESIGN OF CAFFE2DML:
 
 1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the 
key methods that were to be implemented are:
-- `getTrainingScript` for the Estimator class. 
+- `getTrainingScript` for the Estimator class.
 - `getPredictionScript` for the Model class.
 
 These methods should be the starting point of any developer to understand the 
DML generated for training and prediction respectively.
@@ -74,7 +74,7 @@ caffe.proto ---> protoc ---> 
target/generated-sources/caffe/Caffe.java
 - Just like the classes generated by Dml.g4 are used to parse input DML file,
 the target/generated-sources/caffe/Caffe.java class is used to parse the input 
caffe network/deploy prototxt and solver files.
 
-- You can think of .caffemodel file as DML file with matrix values encoded in 
it (please see below example). 
+- You can think of .caffemodel file as DML file with matrix values encoded in 
it (please see below example).
 So it is possible to read .caffemodel file with the Caffe.java class. This is 
done in Utils.scala's readCaffeNet method.
 
 X = matrix("1.2 3.5 0.999 7.123", rows=2, cols=2)
@@ -91,7 +91,7 @@ trait CaffeLayer {
   def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
   def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
   ...
-} 
+}
 trait CaffeSolver {
   def sourceFileName:String;
   def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
@@ -114,67 +114,85 @@ To shield from network files that violates this 
restriction, Caffe2DML performs
 6. Caffe2DML also expects the layers to be in sorted order.
 
 
***************************************************************************************/
-
-object Caffe2DML  {
-  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName()) 
+object Caffe2DML {
+  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
   // ------------------------------------------------------------------------
   def layerDir = "nn/layers/"
   def optimDir = "nn/optim/"
-  
+
   // Naming conventions:
-  val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = 
"num_images"; val numValidationImages = "num_validation"
+  val X    = "X"; val y        = "y"; val batchSize = "BATCH_SIZE"; val 
numImages = "num_images"; val numValidationImages = "num_validation"
   val XVal = "X_val"; val yVal = "y_val"
-  
+
   val USE_NESTEROV_UDF = {
     // Developer environment variable flag 'USE_NESTEROV_UDF' until codegen 
starts working.
     // Then, we will remove this flag and also the class 
org.apache.sysml.udf.lib.SGDNesterovUpdate
     val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
     envFlagNesterovUDF != null && envFlagNesterovUDF.toBoolean
   }
-  
+
   def main(args: Array[String]): Unit = {
-       // Arguments: [train_script | predict_script] $OUTPUT_DML_FILE 
$SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
-       if(args.length < 6) throwUsageError
-       val outputDMLFile = args(1)
-       val solverFile = args(2)
-       val inputChannels = args(3)
-       val inputHeight = args(4)
-       val inputWidth = args(5)
-       val caffeObj = new Caffe2DML(new SparkContext(), solverFile, 
inputChannels, inputHeight, inputWidth)
-       if(args(0).equals("train_script")) {
-               
Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, 
outputDMLFile)
-       }
-       else if(args(0).equals("predict_script")) {
-               Utils.writeToFile(new 
Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, 
outputDMLFile)
-       }
-       else {
-               throwUsageError
-       }
-  }
-  def throwUsageError():Unit = {
-       throw new RuntimeException("Incorrect usage: train_script 
OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH"); 
+    // Arguments: [train_script | predict_script] $OUTPUT_DML_FILE 
$SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
+    if (args.length < 6) throwUsageError
+    val outputDMLFile = args(1)
+    val solverFile    = args(2)
+    val inputChannels = args(3)
+    val inputHeight   = args(4)
+    val inputWidth    = args(5)
+    val caffeObj      = new Caffe2DML(new SparkContext(), solverFile, 
inputChannels, inputHeight, inputWidth)
+    if (args(0).equals("train_script")) {
+      Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, 
outputDMLFile)
+    } else if (args(0).equals("predict_script")) {
+      Utils.writeToFile(new 
Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, 
outputDMLFile)
+    } else {
+      throwUsageError
+    }
   }
+  def throwUsageError(): Unit =
+    throw new RuntimeException("Incorrect usage: train_script OUTPUT_DML_FILE 
SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH");
 }
 
-class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter, 
-    val solver:CaffeSolver, val net:CaffeNetwork, 
-    val lrPolicy:LearningRatePolicy, val numChannels:String, val 
height:String, val width:String) extends Estimator[Caffe2DMLModel] 
-  with BaseSystemMLClassifier with DMLGenerator {
+class Caffe2DML(val sc: SparkContext,
+                val solverParam: Caffe.SolverParameter,
+                val solver: CaffeSolver,
+                val net: CaffeNetwork,
+                val lrPolicy: LearningRatePolicy,
+                val numChannels: String,
+                val height: String,
+                val width: String)
+    extends Estimator[Caffe2DMLModel]
+    with BaseSystemMLClassifier
+    with DMLGenerator {
   // --------------------------------------------------------------
   // Invoked by Python, MLPipeline
-  def this(sc: SparkContext, solver1:Caffe.SolverParameter, 
networkPath:String, numChannels:String, height:String, width:String) {
-    this(sc, solver1, Utils.parseSolver(solver1), 
-        new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, 
height, width),
-        new LearningRatePolicy(solver1), numChannels, height, width)
+  def this(sc: SparkContext, solver1: Caffe.SolverParameter, networkPath: 
String, numChannels: String, height: String, width: String) {
+    this(
+      sc,
+      solver1,
+      Utils.parseSolver(solver1),
+      new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, 
height, width),
+      new LearningRatePolicy(solver1),
+      numChannels,
+      height,
+      width
+    )
   }
-  def this(sc: SparkContext, solver1:Caffe.SolverParameter, 
numChannels:String, height:String, width:String) {
-    this(sc, solver1, Utils.parseSolver(solver1), new 
CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, 
width), 
-        new LearningRatePolicy(solver1), numChannels, height, width)
+  def this(sc: SparkContext, solver1: Caffe.SolverParameter, numChannels: 
String, height: String, width: String) {
+    this(
+      sc,
+      solver1,
+      Utils.parseSolver(solver1),
+      new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, 
height, width),
+      new LearningRatePolicy(solver1),
+      numChannels,
+      height,
+      width
+    )
   }
-  def this(sc: SparkContext, solverPath:String, numChannels:String, 
height:String, width:String) {
+  def this(sc: SparkContext, solverPath: String, numChannels: String, height: 
String, width: String) {
     this(sc, Utils.readCaffeSolver(solverPath), numChannels, height, width)
   }
-  val uid:String = "caffe_classifier_" + (new Random).nextLong
+  val uid: String = "caffe_classifier_" + (new Random).nextLong
   override def copy(extra: org.apache.spark.ml.param.ParamMap): 
Estimator[Caffe2DMLModel] = {
     val that = new Caffe2DML(sc, solverParam, solver, net, lrPolicy, 
numChannels, height, width)
     copyValues(that, extra)
@@ -188,221 +206,223 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
     mloutput = baseFit(df, sc)
     new Caffe2DMLModel(this)
   }
-       // --------------------------------------------------------------
+  // --------------------------------------------------------------
   // Returns true if last 2 of 4 dimensions are 1.
   // The first dimension refers to number of input datapoints.
   // The second dimension refers to number of classes.
-  def isClassification():Boolean = {
+  def isClassification(): Boolean = {
     val outShape = getOutputShapeOfLastLayer
     return outShape._2 == 1 && outShape._3 == 1
   }
-  def getOutputShapeOfLastLayer():(Int, Int, Int) = {
+  def getOutputShapeOfLastLayer(): (Int, Int, Int) = {
     val out = net.getCaffeLayer(net.getLayers().last).outputShape
-    (out._1.toInt, out._2.toInt, out._3.toInt) 
+    (out._1.toInt, out._2.toInt, out._3.toInt)
   }
-  
+
   // Used for simplifying transfer learning
-  private val layersToIgnore:HashSet[String] = new HashSet[String]() 
-  def setWeightsToIgnore(layerName:String):Unit = layersToIgnore.add(layerName)
-  def setWeightsToIgnore(layerNames:ArrayList[String]):Unit = 
layersToIgnore.addAll(layerNames)
-         
+  private val layersToIgnore: HashSet[String]                 = new 
HashSet[String]()
+  def setWeightsToIgnore(layerName: String): Unit             = 
layersToIgnore.add(layerName)
+  def setWeightsToIgnore(layerNames: ArrayList[String]): Unit = 
layersToIgnore.addAll(layerNames)
+
   // Input parameters to prediction and scoring script
-  val inputs:java.util.HashMap[String, String] = new java.util.HashMap[String, 
String]()
-  def setInput(key: String, value:String):Unit = inputs.put(key, value)
+  val inputs: java.util.HashMap[String, String]  = new 
java.util.HashMap[String, String]()
+  def setInput(key: String, value: String): Unit = inputs.put(key, value)
   customAssert(solverParam.getTestIterCount <= 1, "Multiple test_iter 
variables are not supported")
   customAssert(solverParam.getMaxIter > 0, "Please set max_iter to a positive 
value")
   
customAssert(net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[IsLossLayer]).length
 == 1, "Expected exactly one loss layer")
-    
+
   // TODO: throw error or warning if user tries to set solver_mode == GPU 
instead of using setGPU method
-  
+
   // Method called by Python mllearn to visualize variable of certain layer
-  def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = 
visualizeLayer(net, layerName, varType, aggFn)
-  
-  def getTrainAlgo():String = if(inputs.containsKey("$train_algo")) 
inputs.get("$train_algo") else "minibatch"
-  def getTestAlgo():String = if(inputs.containsKey("$test_algo")) 
inputs.get("$test_algo") else "minibatch"
+  def visualizeLayer(layerName: String, varType: String, aggFn: String): Unit 
= visualizeLayer(net, layerName, varType, aggFn)
 
-  def summary(sparkSession:org.apache.spark.sql.SparkSession):Unit = {
+  def getTrainAlgo(): String = if (inputs.containsKey("$train_algo")) 
inputs.get("$train_algo") else "minibatch"
+  def getTestAlgo(): String  = if (inputs.containsKey("$test_algo")) 
inputs.get("$test_algo") else "minibatch"
+
+  def summary(sparkSession: org.apache.spark.sql.SparkSession): Unit = {
     val header = Seq("Name", "Type", "Output", "Weight", "Bias", "Top", 
"Bottom")
-    val entries = net.getLayers.map(l => (l, net.getCaffeLayer(l))).map(l => {
-      val layer = l._2
-      (l._1, layer.param.getType, 
-          "(, " + layer.outputShape._1 + ", " + layer.outputShape._2 + ", " + 
layer.outputShape._3 + ")",
-          if(layer.weightShape != null) "[" + layer.weightShape()(0) + " X " + 
layer.weightShape()(1) + "]" else "",
-          if(layer.biasShape != null) "[" + layer.biasShape()(0) + " X " + 
layer.biasShape()(1) + "]" else "",
-          layer.param.getTopList.mkString(","),
-          layer.param.getBottomList.mkString(",")
-      )
-    })
+    val entries = net.getLayers
+      .map(l => (l, net.getCaffeLayer(l)))
+      .map(l => {
+        val layer = l._2
+        (l._1,
+         layer.param.getType,
+         "(, " + layer.outputShape._1 + ", " + layer.outputShape._2 + ", " + 
layer.outputShape._3 + ")",
+         if (layer.weightShape != null) "[" + layer.weightShape()(0) + " X " + 
layer.weightShape()(1) + "]" else "",
+         if (layer.biasShape != null) "[" + layer.biasShape()(0) + " X " + 
layer.biasShape()(1) + "]" else "",
+         layer.param.getTopList.mkString(","),
+         layer.param.getBottomList.mkString(","))
+      })
     import sparkSession.implicits._
-    sc.parallelize(entries).toDF(header : _*).show(net.getLayers.size)
+    sc.parallelize(entries).toDF(header: _*).show(net.getLayers.size)
   }
-  
+
   // 
================================================================================================
   // The below method parses the provided network and solver file and 
generates DML script.
-       def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = 
{
-         val startTrainingTime = System.nanoTime()
-         
-    reset                                 // Reset the state of DML generator 
for training script.
-    
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
+    val startTrainingTime = System.nanoTime()
+
+    reset // Reset the state of DML generator for training script.
+
     // Flags passed by user
-         val DEBUG_TRAINING = if(inputs.containsKey("$debug")) 
inputs.get("$debug").toLowerCase.toBoolean else false
-         assign(tabDMLScript, "debug", if(DEBUG_TRAINING) "TRUE" else "FALSE")
-         
-         appendHeaders(net, solver, true)      // Appends DML corresponding to 
source and externalFunction statements.
-         readInputData(net, true)              // Read X_full and y_full
-         // Initialize the layers and solvers. Reads weights and bias if 
$weights is set.
-         initWeights(net, solver, inputs.containsKey("$weights"), 
layersToIgnore)
-         
-         // Split into training and validation set
-         // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, 
Caffe2DML.yVal and Caffe2DML.numImages
-         val shouldValidate = solverParam.getTestInterval > 0 && 
solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
-         trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
-         
-         // Set iteration-related variables such as num_iters_per_epoch, lr, 
etc.
-         ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, 
Caffe2DML.batchSize)
-         assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-         assign(tabDMLScript, "max_iter", ifdef("$max_iter", 
solverParam.getMaxIter.toString))
-         assign(tabDMLScript, "e", "0")
-         
-         val lossLayers = getLossLayers(net)
-         // 
----------------------------------------------------------------------------
-         // Main logic
-         forBlock("iter", "1", "max_iter") {
-               performTrainingIter(lossLayers, shouldValidate)
-               if(getTrainAlgo.toLowerCase.equals("batch")) {
-                       assign(tabDMLScript, "e", "iter")
-                       tabDMLScript.append("# Learning rate\n")
-                       lrPolicy.updateLearningRate(tabDMLScript)
-               }
-               else {
-                       ifBlock("iter %% num_iters_per_epoch == 0") {
-                               // After every epoch, update the learning rate
-                               assign(tabDMLScript, "e", "e + 1")
-                               tabDMLScript.append("# Learning rate\n")
-                               lrPolicy.updateLearningRate(tabDMLScript)
-                       }
-               }
-         }
-         // 
----------------------------------------------------------------------------
-         
-         // Check if this is necessary
-         if(doVisualize) tabDMLScript.append("print(" + 
asDMLString("Visualization counter:") + " + viz_counter)")
-         
-         val trainingScript = tabDMLScript.toString()
-         // Print script generation time and the DML script on stdout
-         System.out.println("Time taken to generate training script from Caffe 
proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
-         if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
-         
-         // Set input/output variables and execute the script
-         val script = dml(trainingScript).in(inputs)
-         net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != 
null).map(l => script.out(l.weight))
-         net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l 
=> script.out(l.bias))
-         (script, "X_full", "y_full")
-       }
-       // 
================================================================================================
-  
-  private def performTrainingIter(lossLayers:List[IsLossLayer], 
shouldValidate:Boolean):Unit = {
-       getTrainAlgo.toLowerCase match {
-      case "minibatch" => 
-          getTrainingBatch(tabDMLScript)
-          // -------------------------------------------------------
-          // Perform forward, backward and update on minibatch
-          forward; backward; update
-          // -------------------------------------------------------
-          displayLoss(lossLayers(0), shouldValidate)
-          performSnapshot
+    val DEBUG_TRAINING = if (inputs.containsKey("$debug")) 
inputs.get("$debug").toLowerCase.toBoolean else false
+    assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
+
+    appendHeaders(net, solver, true) // Appends DML corresponding to source 
and externalFunction statements.
+    readInputData(net, true)         // Read X_full and y_full
+    // Initialize the layers and solvers. Reads weights and bias if $weights 
is set.
+    initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
+
+    // Split into training and validation set
+    // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal 
and Caffe2DML.numImages
+    val shouldValidate = solverParam.getTestInterval > 0 && 
solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
+    trainTestSplit(if (shouldValidate) solverParam.getTestIter(0) else 0)
+
+    // Set iteration-related variables such as num_iters_per_epoch, lr, etc.
+    ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, 
Caffe2DML.batchSize)
+    assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
+    assign(tabDMLScript, "max_iter", ifdef("$max_iter", 
solverParam.getMaxIter.toString))
+    assign(tabDMLScript, "e", "0")
+
+    val lossLayers = getLossLayers(net)
+    // 
----------------------------------------------------------------------------
+    // Main logic
+    forBlock("iter", "1", "max_iter") {
+      performTrainingIter(lossLayers, shouldValidate)
+      if (getTrainAlgo.toLowerCase.equals("batch")) {
+        assign(tabDMLScript, "e", "iter")
+        tabDMLScript.append("# Learning rate\n")
+        lrPolicy.updateLearningRate(tabDMLScript)
+      } else {
+        ifBlock("iter %% num_iters_per_epoch == 0") {
+          // After every epoch, update the learning rate
+          assign(tabDMLScript, "e", "e + 1")
+          tabDMLScript.append("# Learning rate\n")
+          lrPolicy.updateLearningRate(tabDMLScript)
+        }
+      }
+    }
+    // 
----------------------------------------------------------------------------
+
+    // Check if this is necessary
+    if (doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization 
counter:") + " + viz_counter)")
+
+    val trainingScript = tabDMLScript.toString()
+    // Print script generation time and the DML script on stdout
+    System.out.println("Time taken to generate training script from Caffe 
proto: " + ((System.nanoTime() - startTrainingTime) * 1e-9) + " seconds.")
+    if (DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
+
+    // Set input/output variables and execute the script
+    val script = dml(trainingScript).in(inputs)
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => 
script.out(l.weight))
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => 
script.out(l.bias))
+    (script, "X_full", "y_full")
+  }
+  // 
================================================================================================
+
+  private def performTrainingIter(lossLayers: List[IsLossLayer], 
shouldValidate: Boolean): Unit =
+    getTrainAlgo.toLowerCase match {
+      case "minibatch" =>
+        getTrainingBatch(tabDMLScript)
+        // -------------------------------------------------------
+        // Perform forward, backward and update on minibatch
+        forward; backward; update
+        // -------------------------------------------------------
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       case "batch" => {
-             // -------------------------------------------------------
-             // Perform forward, backward and update on entire dataset
-             forward; backward; update
-             // -------------------------------------------------------
-             displayLoss(lossLayers(0), shouldValidate)
-             performSnapshot
+        // -------------------------------------------------------
+        // Perform forward, backward and update on entire dataset
+        forward; backward; update
+        // -------------------------------------------------------
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       }
       case "allreduce_parallel_batches" => {
-         // This setting uses the batch size provided by the user
-             if(!inputs.containsKey("$parallel_batches")) {
-               throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
-             }
-             // The user specifies the number of parallel_batches
-             // This ensures that the user of generated script remembers to 
provide the commandline parameter $parallel_batches
-             assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
-             assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
-             assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
-             // Grab groups of mini-batches
-             forBlock("g", "1", "groups") {
-               // Get next group of mini-batches
-               assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) 
%% " + Caffe2DML.numImages + " + 1")
-               assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages 
+ ", group_beg + group_batch_size - 1)")
-               assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[group_beg:group_end,]")
-               assign(tabDMLScript, "y_group_batch", Caffe2DML.y + 
"[group_beg:group_end,]")
-               initializeGradients("parallel_batches")
-               assign(tabDMLScript, "X_group_batch_size", 
nrow("X_group_batch"))
-               parForBlock("j", "1", "parallel_batches") {
-                 // Get a mini-batch in this group
-                 assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize 
+ ") %% nrow(X_group_batch) + 1")
-                 assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " 
+ Caffe2DML.batchSize + " - 1)")
-                 assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
-                 assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
-                 forward; backward
-                 flattenGradients
-               }
-               aggregateAggGradients    
-               update
-               // -------------------------------------------------------
-               assign(tabDMLScript, "Xb", "X_group_batch")
-               assign(tabDMLScript, "yb", "y_group_batch")
-               displayLoss(lossLayers(0), shouldValidate)
-               performSnapshot
-             }
-      }
-      case "allreduce" => {
-         // This is distributed synchronous gradient descent
-         // -------------------------------------------------------
-         // Perform forward, backward and update on minibatch in parallel
-         assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") 
%% " + Caffe2DML.numImages + " + 1")
-         assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " 
- 1, " + Caffe2DML.numImages + ")")
-         assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
-         assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
-         assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-          tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
-          val localBatchSize = "local_batch_size"
-          initializeGradients(localBatchSize)
-          parForBlock("j", "1", localBatchSize) {
-            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
-            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+        // This setting uses the batch size provided by the user
+        if (!inputs.containsKey("$parallel_batches")) {
+          throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
+        }
+        // The user specifies the number of parallel_batches
+        // This ensures that the user of generated script remembers to provide 
the commandline parameter $parallel_batches
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches")
+        assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
+        assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
+        // Grab groups of mini-batches
+        forBlock("g", "1", "groups") {
+          // Get next group of mini-batches
+          assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + 
Caffe2DML.numImages + " + 1")
+          assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", 
group_beg + group_batch_size - 1)")
+          assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[group_beg:group_end,]")
+          assign(tabDMLScript, "y_group_batch", Caffe2DML.y + 
"[group_beg:group_end,]")
+          initializeGradients("parallel_batches")
+          assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+          parForBlock("j", "1", "parallel_batches") {
+            // Get a mini-batch in this group
+            assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") 
%% nrow(X_group_batch) + 1")
+            assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + 
Caffe2DML.batchSize + " - 1)")
+            assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+            assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
             forward; backward
-          flattenGradients
+            flattenGradients
           }
-          aggregateAggGradients    
+          aggregateAggGradients
           update
           // -------------------------------------------------------
           assign(tabDMLScript, "Xb", "X_group_batch")
           assign(tabDMLScript, "yb", "y_group_batch")
           displayLoss(lossLayers(0), shouldValidate)
           performSnapshot
+        }
+      }
+      case "allreduce" => {
+        // This is distributed synchronous gradient descent
+        // -------------------------------------------------------
+        // Perform forward, backward and update on minibatch in parallel
+        assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") 
%% " + Caffe2DML.numImages + " + 1")
+        assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 
1, " + Caffe2DML.numImages + ")")
+        assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
+        assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
+        assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+        tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
+        val localBatchSize = "local_batch_size"
+        initializeGradients(localBatchSize)
+        parForBlock("j", "1", localBatchSize) {
+          assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+          assign(tabDMLScript, "yb", "y_group_batch[j,]")
+          forward; backward
+          flattenGradients
+        }
+        aggregateAggGradients
+        update
+        // -------------------------------------------------------
+        assign(tabDMLScript, "Xb", "X_group_batch")
+        assign(tabDMLScript, "yb", "y_group_batch")
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       }
       case _ => throw new DMLRuntimeException("Unsupported train algo:" + 
getTrainAlgo)
     }
-  }
   // 
-------------------------------------------------------------------------------------------
   // Helper functions to generate DML
   // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and 
Caffe2DML.numImages
-  private def trainTestSplit(numValidationBatches:Int):Unit = {
-    if(numValidationBatches > 0) {
-      if(solverParam.getDisplay <= 0) 
+  private def trainTestSplit(numValidationBatches: Int): Unit =
+    if (numValidationBatches > 0) {
+      if (solverParam.getDisplay <= 0)
         throw new DMLRuntimeException("Since test_iter and test_interval is 
greater than zero, you should set display to be greater than zero")
       tabDMLScript.append(Caffe2DML.numValidationImages).append(" = " + 
numValidationBatches + " * " + Caffe2DML.batchSize + "\n")
       tabDMLScript.append("# Sanity check to ensure that validation set is not 
too large\n")
       val maxValidationSize = "ceil(0.3 * " + Caffe2DML.numImages + ")"
-      ifBlock(Caffe2DML.numValidationImages  + " > " + maxValidationSize) {
+      ifBlock(Caffe2DML.numValidationImages + " > " + maxValidationSize) {
         assign(tabDMLScript, "max_test_iter", "floor(" + maxValidationSize + " 
/ " + Caffe2DML.batchSize + ")")
-        tabDMLScript.append("stop(" +
-            dmlConcat(asDMLString("Too large validation size. Please reduce 
test_iter to "), "max_test_iter") 
-            + ")\n")
+        tabDMLScript.append(
+          "stop(" +
+          dmlConcat(asDMLString("Too large validation size. Please reduce 
test_iter to "), "max_test_iter")
+          + ")\n"
+        )
       }
       val one = "1"
-      val rl = int_add(Caffe2DML.numValidationImages, one)
+      val rl  = int_add(Caffe2DML.numValidationImages, one)
       rightIndexing(tabDMLScript.append(Caffe2DML.X).append(" = "), "X_full", 
rl, Caffe2DML.numImages, null, null)
       tabDMLScript.append("; ")
       rightIndexing(tabDMLScript.append(Caffe2DML.y).append(" = "), "y_full", 
rl, Caffe2DML.numImages, null, null)
@@ -412,41 +432,39 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
       rightIndexing(tabDMLScript.append(Caffe2DML.yVal).append(" = "), 
"y_full", one, Caffe2DML.numValidationImages, null, null)
       tabDMLScript.append("; ")
       tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y)\n")
-    }
-    else {
+    } else {
       assign(tabDMLScript, Caffe2DML.X, "X_full")
-           assign(tabDMLScript, Caffe2DML.y, "y_full")
-           tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + 
Caffe2DML.y + ")\n")
+      assign(tabDMLScript, Caffe2DML.y, "y_full")
+      tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + Caffe2DML.y 
+ ")\n")
     }
-  }
-  
+
   // Append the DML to display training and validation loss
-  private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit 
= {
-    if(solverParam.getDisplay > 0) {
+  private def displayLoss(lossLayer: IsLossLayer, shouldValidate: Boolean): 
Unit = {
+    if (solverParam.getDisplay > 0) {
       // Append the DML to compute training loss
-      if(!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
+      if (!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
         // Compute training loss for allreduce
         tabDMLScript.append("# Compute training loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getDisplay + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", 
"0")
           lossLayer.computeLoss(dmlScript, numTabs)
           assign(tabDMLScript, "training_loss", "loss"); assign(tabDMLScript, 
"training_accuracy", "accuracy")
-          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-              asDMLString(", training loss:"), "training_loss", asDMLString(", 
training accuracy:"), "training_accuracy" )))
+          tabDMLScript.append(
+            print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", 
training loss:"), "training_loss", asDMLString(", training accuracy:"), 
"training_accuracy"))
+          )
           appendTrainingVisualizationBody(dmlScript, numTabs)
           printClassificationReport
         }
-      }
-      else {
+      } else {
         Caffe2DML.LOG.info("Training loss is not printed for train_algo=" + 
getTrainAlgo)
       }
-      if(shouldValidate) {
-        if(  getTrainAlgo.toLowerCase.startsWith("allreduce") &&
+      if (shouldValidate) {
+        if (getTrainAlgo.toLowerCase.startsWith("allreduce") &&
             getTestAlgo.toLowerCase.startsWith("allreduce")) {
           Caffe2DML.LOG.warn("The setting: train_algo=" + getTrainAlgo + " and 
test_algo=" + getTestAlgo + " is not recommended. Consider changing 
test_algo=minibatch")
         }
         // Append the DML to compute validation loss
-        val numValidationBatches = if(solverParam.getTestIterCount > 0) 
solverParam.getTestIter(0) else 0
+        val numValidationBatches = if (solverParam.getTestIterCount > 0) 
solverParam.getTestIter(0) else 0
         tabDMLScript.append("# Compute validation loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getTestInterval + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", 
"0")
@@ -455,11 +473,11 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", "0")
               assign(tabDMLScript, "validation_accuracy", "0")
               forBlock("iVal", "1", "num_iters_per_epoch") {
-                 getValidationBatch(tabDMLScript)
-                 forward;  lossLayer.computeLoss(dmlScript, numTabs)
+                getValidationBatch(tabDMLScript)
+                forward; lossLayer.computeLoss(dmlScript, numTabs)
                 tabDMLScript.append("validation_loss = validation_loss + 
loss\n")
                 tabDMLScript.append("validation_accuracy = validation_accuracy 
+ accuracy\n")
-               }
+              }
               tabDMLScript.append("validation_accuracy = validation_accuracy / 
num_iters_per_epoch\n")
             }
             case "batch" => {
@@ -467,16 +485,16 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
               net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, false))
               lossLayer.computeLoss(dmlScript, numTabs)
               assign(tabDMLScript, "validation_loss", "loss"); 
assign(tabDMLScript, "validation_accuracy", "accuracy")
-              
+
             }
             case "allreduce_parallel_batches" => {
               // This setting uses the batch size provided by the user
-              if(!inputs.containsKey("$parallel_batches")) {
+              if (!inputs.containsKey("$parallel_batches")) {
                 throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
               }
               // The user specifies the number of parallel_batches
               // This ensures that the user of generated script remembers to 
provide the commandline parameter $parallel_batches
-              assign(tabDMLScript, "parallel_batches_val", 
"$parallel_batches") 
+              assign(tabDMLScript, "parallel_batches_val", "$parallel_batches")
               assign(tabDMLScript, "group_batch_size_val", 
"parallel_batches_val*" + Caffe2DML.batchSize)
               assign(tabDMLScript, "groups_val", "as.integer(ceil(" + 
Caffe2DML.numValidationImages + "/group_batch_size_val))")
               assign(tabDMLScript, "validation_accuracy", "0")
@@ -511,8 +529,8 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "group_validation_loss", matrix("0", 
Caffe2DML.numValidationImages, "1"))
               assign(tabDMLScript, "group_validation_accuracy", matrix("0", 
Caffe2DML.numValidationImages, "1"))
               parForBlock("iVal", "1", Caffe2DML.numValidationImages) {
-                assign(tabDMLScript, "Xb",  Caffe2DML.XVal + "[iVal,]")
-                assign(tabDMLScript, "yb",  Caffe2DML.yVal + "[iVal,]")
+                assign(tabDMLScript, "Xb", Caffe2DML.XVal + "[iVal,]")
+                assign(tabDMLScript, "yb", Caffe2DML.yVal + "[iVal,]")
                 net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, false))
                 lossLayer.computeLoss(dmlScript, numTabs)
                 assign(tabDMLScript, "group_validation_loss[iVal,1]", "loss")
@@ -521,124 +539,132 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", 
"sum(group_validation_loss)")
               assign(tabDMLScript, "validation_accuracy", 
"mean(group_validation_accuracy)")
             }
-            
+
             case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
getTestAlgo)
           }
-          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-              asDMLString(", validation loss:"), "validation_loss", 
asDMLString(", validation accuracy:"), "validation_accuracy" )))
+          tabDMLScript.append(
+            print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", 
validation loss:"), "validation_loss", asDMLString(", validation accuracy:"), 
"validation_accuracy"))
+          )
           appendValidationVisualizationBody(dmlScript, numTabs)
         }
       }
     }
   }
-  
-  private def performSnapshot():Unit = {
-    if(solverParam.getSnapshot > 0) {
+  private def appendSnapshotWrite(varName: String, fileName: String): Unit =
+    tabDMLScript.append(write(varName, "snapshot_dir + \"" + fileName + "\"", 
"binary"))
+  private def performSnapshot(): Unit =
+    if (solverParam.getSnapshot > 0) {
       ifBlock("iter %% " + solverParam.getSnapshot + " == 0") {
         tabDMLScript.append("snapshot_dir= \"" + solverParam.getSnapshotPrefix 
+ "\" + \"/iter_\" + iter + \"/\"\n")
-        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l 
=> tabDMLScript.append(
-               "write(" + l.weight + ", snapshot_dir + \"" + l.param.getName + 
"_weight.mtx\", format=\"binary\")\n"))
-               net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != 
null).map(l => tabDMLScript.append(
-                       "write(" + l.bias + ", snapshot_dir + \"" + 
l.param.getName + "_bias.mtx\", format=\"binary\")\n"))
+        val allLayers = net.getLayers.map(net.getCaffeLayer(_))
+        allLayers.filter(_.weight != null).map(l => 
appendSnapshotWrite(l.weight, l.param.getName + "_weight.mtx"))
+        allLayers.filter(_.bias != null).map(l => appendSnapshotWrite(l.bias, 
l.param.getName + "_bias.mtx"))
       }
-       }
-  }
-  
-  private def forward():Unit = {
+    }
+
+  private def forward(): Unit = {
     tabDMLScript.append("# Perform forward pass\n")
-         net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, false))
+    net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, 
false))
   }
-  private def backward():Unit = {
+  private def backward(): Unit = {
     tabDMLScript.append("# Perform backward pass\n")
     net.getLayers.reverse.map(layer => 
net.getCaffeLayer(layer).backward(tabDMLScript, ""))
   }
-  private def update():Unit = {
+  private def update(): Unit = {
     tabDMLScript.append("# Update the parameters\n")
     net.getLayers.map(layer => solver.update(tabDMLScript, 
net.getCaffeLayer(layer)))
   }
-  private def initializeGradients(parallel_batches:String):Unit = {
+  private def initializeGradients(parallel_batches: String): Unit = {
     tabDMLScript.append("# Data structure to store gradients computed in 
parallel\n")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", 
matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
-      if(l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", 
matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias)))) 
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", 
matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
+        if (l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", 
matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias))))
+      })
   }
-  private def flattenGradients():Unit = {
+  private def flattenGradients(): Unit = {
     tabDMLScript.append("# Flatten and store gradients for this parallel 
execution\n")
     // Note: We multiply by a weighting to allow for proper gradient averaging 
during the
     // aggregation even with uneven batch sizes.
     assign(tabDMLScript, "weighting", "nrow(Xb)/X_group_batch_size")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", 
-          matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " 
* weighting") 
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", 
-          matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias)))  + " * 
weighting")
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", 
matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " * 
weighting")
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", 
matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias))) + " * weighting")
+      })
   }
-  private def aggregateAggGradients():Unit = {
+  private def aggregateAggGradients(): Unit = {
     tabDMLScript.append("# Aggregate the gradients\n")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, 
-          matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight))) 
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, 
-          matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, 
matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight)))
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, 
matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
+      })
   }
   // 
-------------------------------------------------------------------------------------------
 }
 
-class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val 
solver:CaffeSolver,
-    val net:CaffeNetwork, val lrPolicy:LearningRatePolicy,
-    val estimator:Caffe2DML) 
-  extends Model[Caffe2DMLModel] with HasMaxOuterIter with 
BaseSystemMLClassifierModel with DMLGenerator {
+class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: 
CaffeSolver, val net: CaffeNetwork, val lrPolicy: LearningRatePolicy, val 
estimator: Caffe2DML)
+    extends Model[Caffe2DMLModel]
+    with HasMaxOuterIter
+    with BaseSystemMLClassifierModel
+    with DMLGenerator {
   // --------------------------------------------------------------
   // Invoked by Python, MLPipeline
-  val uid:String = "caffe_model_" + (new Random).nextLong 
-  def this(estimator:Caffe2DML) =  {
-    this(Utils.numClasses(estimator.net), estimator.sc, estimator.solver,
-        estimator.net,
-        // new CaffeNetwork(estimator.solverParam.getNet, 
caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, 
estimator.width), 
-        estimator.lrPolicy, estimator) 
+  val uid: String = "caffe_model_" + (new Random).nextLong
+  def this(estimator: Caffe2DML) = {
+    this(
+      Utils.numClasses(estimator.net),
+      estimator.sc,
+      estimator.solver,
+      estimator.net,
+      // new CaffeNetwork(estimator.solverParam.getNet, 
caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, 
estimator.width),
+      estimator.lrPolicy,
+      estimator
+    )
   }
-      
+
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Caffe2DMLModel 
= {
     val that = new Caffe2DMLModel(numClasses, sc, solver, net, lrPolicy, 
estimator)
     copyValues(that, extra)
   }
   // --------------------------------------------------------------
-  
-  def modelVariables():List[String] = {
-    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != 
null).map(_.weight) ++
-    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(_.bias)
+
+  def modelVariables(): List[String] = {
+    val allLayers = net.getLayers.map(net.getCaffeLayer(_))
+    allLayers.filter(_.weight != null).map(_.weight) ++ 
allLayers.filter(_.bias != null).map(_.bias)
   }
-    
+
   // 
================================================================================================
   // The below method parses the provided network and solver file and 
generates DML script.
-  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) = {
     val startPredictionTime = System.nanoTime()
-    
-         reset                                  // Reset the state of DML 
generator for training script.
-         
-         val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) 
estimator.inputs.get("$debug").toLowerCase.toBoolean else false
-         assign(tabDMLScript, "debug", if(DEBUG_PREDICTION) "TRUE" else 
"FALSE")
-    
-    appendHeaders(net, solver, false)      // Appends DML corresponding to 
source and externalFunction statements.
-    readInputData(net, false)              // Read X_full and y_full
+
+    reset // Reset the state of DML generator for training script.
+
+    val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) 
estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+    assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
+
+    appendHeaders(net, solver, false) // Appends DML corresponding to source 
and externalFunction statements.
+    readInputData(net, false)         // Read X_full and y_full
     assign(tabDMLScript, "X", "X_full")
-    
+
     // Initialize the layers and solvers. Reads weights and bias if 
readWeights is true.
-    if(!estimator.inputs.containsKey("$weights") && estimator.mloutput == 
null) 
+    if (!estimator.inputs.containsKey("$weights") && estimator.mloutput == 
null)
       throw new DMLRuntimeException("Cannot call predict/score without calling 
either fit or by providing weights")
     val readWeights = estimator.inputs.containsKey("$weights") || 
estimator.mloutput != null
     initWeights(net, solver, readWeights)
-         
-         // Donot update mean and variance in batchnorm
-         updateMeanVarianceForBatchNorm(net, false)
-         
-         val lossLayers = getLossLayers(net)
-         val lastLayerShape = estimator.getOutputShapeOfLastLayer
-         assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, 
(lastLayerShape._1*lastLayerShape._2*lastLayerShape._3).toString))
-         estimator.getTestAlgo.toLowerCase match {
+
+    // Donot update mean and variance in batchnorm
+    updateMeanVarianceForBatchNorm(net, false)
+
+    val lossLayers     = getLossLayers(net)
+    val lastLayerShape = estimator.getOutputShapeOfLastLayer
+    assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, 
(lastLayerShape._1 * lastLayerShape._2 * lastLayerShape._3).toString))
+    estimator.getTestAlgo.toLowerCase match {
       case "minibatch" => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, 
Caffe2DML.batchSize)
         forBlock("iter", "1", "num_iters") {
@@ -654,12 +680,12 @@ class Caffe2DMLModel(val numClasses:String, val sc: 
SparkContext, val solver:Caf
       }
       case "allreduce_parallel_batches" => {
         // This setting uses the batch size provided by the user
-        if(!estimator.inputs.containsKey("$parallel_batches")) {
+        if (!estimator.inputs.containsKey("$parallel_batches")) {
           throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
         }
         // The user specifies the number of parallel_batches
         // This ensures that the user of generated script remembers to provide 
the commandline parameter $parallel_batches
-        assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches")
         assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
         assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
         // Grab groups of mini-batches
@@ -688,70 +714,66 @@ class Caffe2DMLModel(val numClasses:String, val sc: 
SparkContext, val solver:Caf
       }
       case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
estimator.getTestAlgo)
     }
-    
-    if(estimator.inputs.containsKey("$output_activations")) {
-      if(estimator.getTestAlgo.toLowerCase.equals("batch")) {
-        net.getLayers.map(layer => 
-          tabDMLScript.append(write(net.getCaffeLayer(layer).out, 
-              estimator.inputs.get("$output_activations") + "/" + 
net.getCaffeLayer(layer).param.getName + "_activations.mtx", "csv") + "\n")
-        )  
-      }
-      else {
+
+    if (estimator.inputs.containsKey("$output_activations")) {
+      if (estimator.getTestAlgo.toLowerCase.equals("batch")) {
+        net.getLayers.map(
+          layer =>
+            tabDMLScript.append(
+              write(net.getCaffeLayer(layer).out, 
estimator.inputs.get("$output_activations") + "/" + 
net.getCaffeLayer(layer).param.getName + "_activations.mtx", "csv") + "\n"
+          )
+        )
+      } else {
         throw new DMLRuntimeException("Incorrect usage of output_activations. 
It should be only used in batch mode.")
       }
     }
-               
-               val predictionScript = dmlScript.toString()
-               System.out.println("Time taken to generate prediction script 
from Caffe proto:" + ((System.nanoTime() - startPredictionTime)*1e-9) + "secs." 
)
-               if(DEBUG_PREDICTION) 
Utils.prettyPrintDMLScript(predictionScript)
-               
-               // Reset state of BatchNorm layer
-               updateMeanVarianceForBatchNorm(net, true)
-               
-         val script = dml(predictionScript).out("Prob").in(estimator.inputs)
-         if(estimator.mloutput != null) {
-           // fit was called
-         net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != 
null).map(l => script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
-         net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l 
=> script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
-         }
-         (script, "X_full")
+
+    val predictionScript = dmlScript.toString()
+    System.out.println("Time taken to generate prediction script from Caffe 
proto:" + ((System.nanoTime() - startPredictionTime) * 1e-9) + "secs.")
+    if (DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
+
+    // Reset state of BatchNorm layer
+    updateMeanVarianceForBatchNorm(net, true)
+
+    val script = dml(predictionScript).out("Prob").in(estimator.inputs)
+    if (estimator.mloutput != null) {
+      // fit was called
+      net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l 
=> script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
+      net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => 
script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
+    }
+    (script, "X_full")
   }
   // 
================================================================================================
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  
+
+  def baseEstimator(): BaseSystemMLEstimator = estimator
+
   // Prediction
-  def transform(X: MatrixBlock): MatrixBlock = {
-    if(estimator.isClassification) {
+  def transform(X: MatrixBlock): MatrixBlock =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction assuming classification")
       baseTransform(X, sc, "Prob")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransform(X, sc, "Prob", outShape._1.toInt, outShape._2.toInt, 
outShape._3.toInt)
     }
-  }
-  def transform_probability(X: MatrixBlock): MatrixBlock = {
-    if(estimator.isClassification) {
+  def transform_probability(X: MatrixBlock): MatrixBlock =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction of probability assuming classification")
       baseTransformProbability(X, sc, "Prob")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction of probability assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransformProbability(X, sc, "Prob", outShape._1.toInt, 
outShape._2.toInt, outShape._3.toInt)
     }
-  } 
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = {
-    if(estimator.isClassification) {
+
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction assuming classification")
       baseTransform(df, sc, "Prob", true)
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransform(df, sc, "Prob", true, outShape._1.toInt, 
outShape._2.toInt, outShape._3.toInt)
     }
-  }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
index 30d86fd..19aff63 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,16 +22,16 @@ package org.apache.sysml.api.dl
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.net.MalformedURLException;
-import java.net.URL; 
+import java.net.URL;
 import java.net.URLClassLoader;
 import java.io.File;
 
 class Caffe2DMLLoader {
-  def loadCaffe2DML(filePath:String):Unit = {
-    val url = new File(filePath).toURI().toURL();
-               val classLoader = 
ClassLoader.getSystemClassLoader().asInstanceOf[URLClassLoader];
-               val method = 
classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]);
-               method.setAccessible(true);
-         method.invoke(classLoader, url);
+  def loadCaffe2DML(filePath: String): Unit = {
+    val url         = new File(filePath).toURI().toURL();
+    val classLoader = 
ClassLoader.getSystemClassLoader().asInstanceOf[URLClassLoader];
+    val method      = classOf[URLClassLoader].getDeclaredMethod("addURL", 
classOf[URL]);
+    method.setAccessible(true);
+    method.invoke(classLoader, url);
   }
-}
\ No newline at end of file
+}

Reply via email to