Repository: systemml
Updated Branches:
  refs/heads/master e11ae6af3 -> 382f847de


http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
index 1e4538a..4733406 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
@@ -23,7 +23,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -36,7 +36,7 @@ public class LocalDataPartitionerTest extends 
BaseDataPartitionerTest {
 
        @Test
        public void testLocalDataPartitionerDC() {
-               DataPartitionScheme.Result result = 
launchLocalDataPartitionerDC();
+               DataPartitionLocalScheme.Result result = 
launchLocalDataPartitionerDC();
 
                Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
                Assert.assertEquals(WORKER_NUM, result.pLabels.size());
@@ -45,7 +45,7 @@ public class LocalDataPartitionerTest extends 
BaseDataPartitionerTest {
                }
        }
 
-       private void assertDCResult(DataPartitionScheme.Result result, int 
workerID) {
+       private void assertDCResult(DataPartitionLocalScheme.Result result, int 
workerID) {
                Assert.assertArrayEquals(generateExpectedData(workerID * 
(ROW_SIZE / WORKER_NUM) * COL_SIZE, (workerID + 1) * (ROW_SIZE / WORKER_NUM) * 
COL_SIZE), result.pFeatures.get(workerID).acquireRead().getDenseBlockValues(), 
0);
                Assert.assertArrayEquals(generateExpectedData(workerID * 
(ROW_SIZE / WORKER_NUM), (workerID + 1) * (ROW_SIZE / WORKER_NUM)), 
result.pLabels.get(workerID).acquireRead().getDenseBlockValues(), 0);
        }
@@ -53,7 +53,7 @@ public class LocalDataPartitionerTest extends 
BaseDataPartitionerTest {
        @Test
        public void testLocalDataPartitionerDR() {
                MatrixBlock[] mbs = generateData();
-               DataPartitionScheme.Result result = 
launchLocalDataPartitionerDR(mbs);
+               DataPartitionLocalScheme.Result result = 
launchLocalDataPartitionerDR(mbs);
 
                Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
                Assert.assertEquals(WORKER_NUM, result.pLabels.size());
@@ -82,7 +82,7 @@ public class LocalDataPartitionerTest extends 
BaseDataPartitionerTest {
 
        @Test
        public void testLocalDataPartitionerDRR() {
-               DataPartitionScheme.Result result = 
launchLocalDataPartitionerDRR();
+               DataPartitionLocalScheme.Result result = 
launchLocalDataPartitionerDRR();
 
                Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
                Assert.assertEquals(WORKER_NUM, result.pLabels.size());
@@ -91,7 +91,7 @@ public class LocalDataPartitionerTest extends 
BaseDataPartitionerTest {
                }
        }
 
-       private void assertDRRResult(DataPartitionScheme.Result result, int 
workerID) {
+       private void assertDRRResult(DataPartitionLocalScheme.Result result, 
int workerID) {
                Tuple2<double[], double[]> expected = 
generateExpectedData(workerID, WORKER_NUM, ROW_SIZE / WORKER_NUM);
                Assert.assertArrayEquals(expected._1, 
result.pFeatures.get(workerID).acquireRead().getDenseBlockValues(), 0);
                Assert.assertArrayEquals(expected._2, 
result.pLabels.get(workerID).acquireRead().getDenseBlockValues(), 0);
@@ -114,7 +114,7 @@ public class LocalDataPartitionerTest extends 
BaseDataPartitionerTest {
        @Test
        public void testLocalDataPartitionerOR() {
                ParamservUtils.SEED = System.nanoTime();
-               DataPartitionScheme.Result result = 
launchLocalDataPartitionerOR();
+               DataPartitionLocalScheme.Result result = 
launchLocalDataPartitionerOR();
 
                Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
                Assert.assertEquals(WORKER_NUM, result.pLabels.size());

http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
index 17bfa4c..464b0b1 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -23,9 +23,9 @@ import java.io.IOException;
 import java.util.Arrays;
 
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
-import 
org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.junit.Assert;
 import org.junit.Test;

http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
index b0e4a27..8cae4a4 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
@@ -26,7 +26,7 @@ import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.parser.Statement;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
+import 
org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.junit.Assert;
@@ -51,14 +51,14 @@ public class SparkDataPartitionerTest extends 
BaseDataPartitionerTest {
 
        @Test
        public void testSparkDataPartitionerDC() {
-               DataPartitionScheme.Result localResult = 
launchLocalDataPartitionerDC();
+               DataPartitionLocalScheme.Result localResult = 
launchLocalDataPartitionerDC();
                Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = 
doPartitioning(Statement.PSScheme.DISJOINT_CONTIGUOUS);
 
                // Compare the both
                assertResult(localResult, sparkResult);
        }
 
-       private void assertResult(DataPartitionScheme.Result local, 
Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> spark) {
+       private void assertResult(DataPartitionLocalScheme.Result local, 
Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> spark) {
                IntStream.range(0, WORKER_NUM).forEach(w -> {
                        
Assert.assertArrayEquals(local.pFeatures.get(w).acquireRead().getDenseBlockValues(),
 spark.get(w)._1.getDenseBlockValues(), 0);
                        
Assert.assertArrayEquals(local.pLabels.get(w).acquireRead().getDenseBlockValues(),
 spark.get(w)._2.getDenseBlockValues(), 0);
@@ -69,7 +69,7 @@ public class SparkDataPartitionerTest extends 
BaseDataPartitionerTest {
        public void testSparkDataPartitionerDR() {
                ParamservUtils.SEED = System.nanoTime();
                MatrixBlock[] mbs = generateData();
-               DataPartitionScheme.Result localResult = 
launchLocalDataPartitionerDR(mbs);
+               DataPartitionLocalScheme.Result localResult = 
launchLocalDataPartitionerDR(mbs);
                Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = 
doPartitioning(Statement.PSScheme.DISJOINT_RANDOM);
 
                // Compare the both
@@ -78,7 +78,7 @@ public class SparkDataPartitionerTest extends 
BaseDataPartitionerTest {
 
        @Test
        public void testSparkDataPartitionerDRR() {
-               DataPartitionScheme.Result localResult = 
launchLocalDataPartitionerDRR();
+               DataPartitionLocalScheme.Result localResult = 
launchLocalDataPartitionerDRR();
                Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = 
doPartitioning(Statement.PSScheme.DISJOINT_ROUND_ROBIN);
 
                // Compare the both
@@ -88,7 +88,7 @@ public class SparkDataPartitionerTest extends 
BaseDataPartitionerTest {
        @Test
        public void testSparkDataPartitionerOR() {
                ParamservUtils.SEED = System.nanoTime();
-               DataPartitionScheme.Result localResult = 
launchLocalDataPartitionerOR();
+               DataPartitionLocalScheme.Result localResult = 
launchLocalDataPartitionerOR();
                Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = 
doPartitioning(Statement.PSScheme.OVERLAP_RESHUFFLE);
 
                // Compare the both

http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index 35b0bd2..5ccda12 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -101,8 +101,8 @@ train = function(matrix[double] X, matrix[double] Y,
   # Regularization
   lambda = 5e-04
 
-  # Create the model object
-  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, 
vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+  # Create the model list
+  modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, 
vb2, vb3, vb4)
 
   # Create the hyper parameter list
   params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
@@ -110,15 +110,14 @@ train = function(matrix[double] X, matrix[double] Y,
   # Use paramserv function
   modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation",
 mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batchsize, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
 
-  W1 = as.matrix(modelList2["W1"])
-  b1 = as.matrix(modelList2["b1"])
-  W2 = as.matrix(modelList2["W2"])
-  b2 = as.matrix(modelList2["b2"])
-  W3 = as.matrix(modelList2["W3"])
-  b3 = as.matrix(modelList2["b3"])
-  W4 = as.matrix(modelList2["W4"])
-  b4 = as.matrix(modelList2["b4"])
-
+  W1 = as.matrix(modelList2[1])
+  W2 = as.matrix(modelList2[2])
+  W3 = as.matrix(modelList2[3])
+  W4 = as.matrix(modelList2[4])
+  b1 = as.matrix(modelList2[5])
+  b2 = as.matrix(modelList2[6])
+  b3 = as.matrix(modelList2[7])
+  b4 = as.matrix(modelList2[8])
 }
 
 # Should always use 'features' (batch features), 'labels' (batch labels),
@@ -130,27 +129,25 @@ gradients = function(matrix[double] features,
                      list[unknown] model)
           return (list[unknown] gradients) {
 
-# PB: not be able to get scalar from list
-
-  C = as.scalar(hyperparams["C"])
-  Hin = 28
-  Win = 28
-  Hf = 5
-  Wf = 5
-  stride = 1
-  pad = 2
-  lambda = 5e-04
-  F1 = 32
-  F2 = 64
-  N3 = 512
-  W1 = as.matrix(model["W1"])
-  b1 = as.matrix(model["b1"])
-  W2 = as.matrix(model["W2"])
-  b2 = as.matrix(model["b2"])
-  W3 = as.matrix(model["W3"])
-  b3 = as.matrix(model["b3"])
-  W4 = as.matrix(model["W4"])
-  b4 = as.matrix(model["b4"])
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
 
   # Compute forward pass
   ## layer 1: conv1 -> relu1 -> pool1
@@ -202,7 +199,7 @@ gradients = function(matrix[double] features,
   dW3 = dW3 + dW3_reg
   dW4 = dW4 + dW4_reg
 
-  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
 }
 
 # Should use the arguments named 'model', 'gradients', 'hyperparams'
@@ -211,33 +208,32 @@ aggregation = function(list[unknown] model,
                        list[unknown] gradients,
                        list[unknown] hyperparams)
    return (list[unknown] modelResult) {
-
-     W1 = as.matrix(model["W1"])
-     W2 = as.matrix(model["W2"])
-     W3 = as.matrix(model["W3"])
-     W4 = as.matrix(model["W4"])
-     b1 = as.matrix(model["b1"])
-     b2 = as.matrix(model["b2"])
-     b3 = as.matrix(model["b3"])
-     b4 = as.matrix(model["b4"])
-     dW1 = as.matrix(gradients["dW1"])
-     dW2 = as.matrix(gradients["dW2"])
-     dW3 = as.matrix(gradients["dW3"])
-     dW4 = as.matrix(gradients["dW4"])
-     db1 = as.matrix(gradients["db1"])
-     db2 = as.matrix(gradients["db2"])
-     db3 = as.matrix(gradients["db3"])
-     db4 = as.matrix(gradients["db4"])
-     vW1 = as.matrix(model["vW1"])
-     vW2 = as.matrix(model["vW2"])
-     vW3 = as.matrix(model["vW3"])
-     vW4 = as.matrix(model["vW4"])
-     vb1 = as.matrix(model["vb1"])
-     vb2 = as.matrix(model["vb2"])
-     vb3 = as.matrix(model["vb3"])
-     vb4 = as.matrix(model["vb4"])
-     lr = 0.01
-     mu = 0.9
+     W1 = as.matrix(model[1])
+     W2 = as.matrix(model[2])
+     W3 = as.matrix(model[3])
+     W4 = as.matrix(model[4])
+     b1 = as.matrix(model[5])
+     b2 = as.matrix(model[6])
+     b3 = as.matrix(model[7])
+     b4 = as.matrix(model[8])
+     dW1 = as.matrix(gradients[1])
+     dW2 = as.matrix(gradients[2])
+     dW3 = as.matrix(gradients[3])
+     dW4 = as.matrix(gradients[4])
+     db1 = as.matrix(gradients[5])
+     db2 = as.matrix(gradients[6])
+     db3 = as.matrix(gradients[7])
+     db4 = as.matrix(gradients[8])
+     vW1 = as.matrix(model[9])
+     vW2 = as.matrix(model[10])
+     vW3 = as.matrix(model[11])
+     vW4 = as.matrix(model[12])
+     vb1 = as.matrix(model[13])
+     vb2 = as.matrix(model[14])
+     vb3 = as.matrix(model[15])
+     vb4 = as.matrix(model[16])
+     lr = as.double(as.scalar(hyperparams["lr"]))
+     mu = as.double(as.scalar(hyperparams["mu"]))
 
      # Optimize with SGD w/ Nesterov momentum
      [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
@@ -249,7 +245,7 @@ aggregation = function(list[unknown] model,
      [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
      [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
 
-     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+     modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
    }
 
 predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,

http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index a3677aa..e7056f0 100644
--- 
a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -101,7 +101,7 @@ train = function(matrix[double] X, matrix[double] Y,
   lambda = 5e-04
 
   # Create the model object
-  modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, 
vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+  modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, 
vb2, vb3, vb4)
 
   # Create the hyper parameter list
   params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
@@ -109,14 +109,14 @@ train = function(matrix[double] X, matrix[double] Y,
   # Use paramserv function
   modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation",
 mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params)
 
-  W1 = as.matrix(modelList2["W1"])
-  b1 = as.matrix(modelList2["b1"])
-  W2 = as.matrix(modelList2["W2"])
-  b2 = as.matrix(modelList2["b2"])
-  W3 = as.matrix(modelList2["W3"])
-  b3 = as.matrix(modelList2["b3"])
-  W4 = as.matrix(modelList2["W4"])
-  b4 = as.matrix(modelList2["b4"])
+  W1 = as.matrix(modelList2[1])
+  W2 = as.matrix(modelList2[2])
+  W3 = as.matrix(modelList2[3])
+  W4 = as.matrix(modelList2[4])
+  b1 = as.matrix(modelList2[5])
+  b2 = as.matrix(modelList2[6])
+  b3 = as.matrix(modelList2[7])
+  b4 = as.matrix(modelList2[8])
 
 }
 
@@ -126,25 +126,25 @@ gradients = function(matrix[double] features,
                      list[unknown] model)
           return (list[unknown] gradients) {
 
-  C = 1
-  Hin = 28
-  Win = 28
-  Hf = 5
-  Wf = 5
-  stride = 1
-  pad = 2
-  lambda = 5e-04
-  F1 = 32
-  F2 = 64
-  N3 = 512
-  W1 = as.matrix(model["W1"])
-  b1 = as.matrix(model["b1"])
-  W2 = as.matrix(model["W2"])
-  b2 = as.matrix(model["b2"])
-  W3 = as.matrix(model["W3"])
-  b3 = as.matrix(model["b3"])
-  W4 = as.matrix(model["W4"])
-  b4 = as.matrix(model["b4"])
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
 
   # Compute forward pass
   ## layer 1: conv1 -> relu1 -> pool1
@@ -196,41 +196,39 @@ gradients = function(matrix[double] features,
   dW3 = dW3 + dW3_reg
   dW4 = dW4 + dW4_reg
 
-  gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, 
db3=db3, db4=db4)
-
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
 }
 
 aggregation = function(list[unknown] model,
                        list[unknown] gradients,
                        list[unknown] hyperparams)
    return (list[unknown] modelResult) {
-
-     W1 = as.matrix(model["W1"])
-     W2 = as.matrix(model["W2"])
-     W3 = as.matrix(model["W3"])
-     W4 = as.matrix(model["W4"])
-     b1 = as.matrix(model["b1"])
-     b2 = as.matrix(model["b2"])
-     b3 = as.matrix(model["b3"])
-     b4 = as.matrix(model["b4"])
-     dW1 = as.matrix(gradients["dW1"])
-     dW2 = as.matrix(gradients["dW2"])
-     dW3 = as.matrix(gradients["dW3"])
-     dW4 = as.matrix(gradients["dW4"])
-     db1 = as.matrix(gradients["db1"])
-     db2 = as.matrix(gradients["db2"])
-     db3 = as.matrix(gradients["db3"])
-     db4 = as.matrix(gradients["db4"])
-     vW1 = as.matrix(model["vW1"])
-     vW2 = as.matrix(model["vW2"])
-     vW3 = as.matrix(model["vW3"])
-     vW4 = as.matrix(model["vW4"])
-     vb1 = as.matrix(model["vb1"])
-     vb2 = as.matrix(model["vb2"])
-     vb3 = as.matrix(model["vb3"])
-     vb4 = as.matrix(model["vb4"])
-     lr = 0.01
-     mu = 0.9
+     W1 = as.matrix(model[1])
+     W2 = as.matrix(model[2])
+     W3 = as.matrix(model[3])
+     W4 = as.matrix(model[4])
+     b1 = as.matrix(model[5])
+     b2 = as.matrix(model[6])
+     b3 = as.matrix(model[7])
+     b4 = as.matrix(model[8])
+     dW1 = as.matrix(gradients[1])
+     dW2 = as.matrix(gradients[2])
+     dW3 = as.matrix(gradients[3])
+     dW4 = as.matrix(gradients[4])
+     db1 = as.matrix(gradients[5])
+     db2 = as.matrix(gradients[6])
+     db3 = as.matrix(gradients[7])
+     db4 = as.matrix(gradients[8])
+     vW1 = as.matrix(model[9])
+     vW2 = as.matrix(model[10])
+     vW3 = as.matrix(model[11])
+     vW4 = as.matrix(model[12])
+     vb1 = as.matrix(model[13])
+     vb2 = as.matrix(model[14])
+     vb3 = as.matrix(model[15])
+     vb4 = as.matrix(model[16])
+     lr = as.double(as.scalar(hyperparams["lr"]))
+     mu = as.double(as.scalar(hyperparams["mu"]))
 
      # Optimize with SGD w/ Nesterov momentum
      [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
@@ -242,7 +240,7 @@ aggregation = function(list[unknown] model,
      [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
      [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
 
-     modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, 
b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+     modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
    }
 
 predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,

Reply via email to