Repository: systemml
Updated Branches:
  refs/heads/master 735c4119c -> 0871f260e


[SYSTEMML-2381] Rework paramserv data partitioner API and tests

Closes #783.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0871f260
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0871f260
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0871f260

Branch: refs/heads/master
Commit: 0871f260e6fc6d6fed57d2cc249bf4d8beb0a31f
Parents: 735c411
Author: EdgarLGB <[email protected]>
Authored: Wed Jun 13 08:47:24 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jun 13 08:57:36 2018 -0700

----------------------------------------------------------------------
 .../paramserv/DataPartitioner.java              |  25 +-
 .../paramserv/DataPartitionerDC.java            |  16 +-
 .../paramserv/DataPartitionerDR.java            |  25 +-
 .../paramserv/DataPartitionerDRR.java           |  19 +-
 .../paramserv/DataPartitionerOR.java            |  26 +-
 .../cp/ParamservBuiltinCPInstruction.java       |  15 +-
 .../paramserv/DataPartitionerTest.java          | 238 +++++++++----------
 7 files changed, 187 insertions(+), 177 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
index b94d765..fce6d35 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
@@ -21,27 +21,20 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
 
 import java.util.List;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 
 public abstract class DataPartitioner {
 
-       protected static final Log LOG = 
LogFactory.getLog(DataPartitioner.class.getName());
+       public final class Result {
+               public final List<MatrixObject> pFeatures;
+               public final List<MatrixObject> pLabels;
 
-       public abstract void doPartitioning(List<LocalPSWorker> workers, 
MatrixObject features, MatrixObject labels);
-
-       protected void setPartitionedData(List<LocalPSWorker> workers, 
List<MatrixObject> pfs, List<MatrixObject> pls) {
-               if (pfs.size() < workers.size()) {
-                       if (LOG.isWarnEnabled()) {
-                               LOG.warn(String.format("There is only %d 
batches of data but has %d workers. "
-                                       + "Hence, reset the number of workers 
with %d.", pfs.size(), workers.size(), pfs.size()));
-                       }
-                       workers = workers.subList(0, pfs.size());
-               }
-               for (int i = 0; i < workers.size(); i++) {
-                       workers.get(i).setFeatures(pfs.get(i));
-                       workers.get(i).setLabels(pls.get(i));
+               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels) {
+                       this.pFeatures = pFeatures;
+                       this.pLabels = pLabels;
                }
        }
+
+       public abstract Result doPartitioning(int workersNum, MatrixObject 
features, MatrixObject labels);
+
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java
index 0426855..4810541 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java
@@ -32,17 +32,10 @@ import 
org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
  * non-overlapping partitions of rows.
  */
 public class DataPartitionerDC extends DataPartitioner {
-       @Override
-       public void doPartitioning(List<LocalPSWorker> workers, MatrixObject 
features, MatrixObject labels) {
-               int workerNum = workers.size();
-               List<MatrixObject> pfs = doPartitioning(workerNum, features);
-               List<MatrixObject> pls = doPartitioning(workerNum, labels);
-               setPartitionedData(workers, pfs, pls);
-       }
 
        private List<MatrixObject> doPartitioning(int k, MatrixObject mo) {
                List<MatrixObject> list = new ArrayList<>();
-               long stepSize = (long) Math.ceil(mo.getNumRows() / k);
+               long stepSize = (long) Math.ceil((double) mo.getNumRows() / k);
                long begin = 1;
                while (begin < mo.getNumRows()) {
                        long end = Math.min(begin - 1 + stepSize, 
mo.getNumRows());
@@ -52,4 +45,11 @@ public class DataPartitionerDC extends DataPartitioner {
                }
                return list;
        }
+
+       @Override
+       public Result doPartitioning(int workersNum, MatrixObject features, 
MatrixObject labels) {
+               List<MatrixObject> pfs = doPartitioning(workersNum, features);
+               List<MatrixObject> pls = doPartitioning(workersNum, labels);
+               return new Result(pfs, pls);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java
index ce5f71d..adc6f60 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java
@@ -34,22 +34,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
  * i.e., sampling without replacement to ensure disjointness.
  */
 public class DataPartitionerDR extends DataPartitioner {
-       @Override
-       public void doPartitioning(List<LocalPSWorker> workers, MatrixObject 
features, MatrixObject labels) {
-               // Generate a single permutation matrix (workers use slices)
-               MatrixBlock permutation = 
ParamservUtils.generatePermutation((int)features.getNumRows());
-               List<MatrixObject> pfs = doPartitioning(workers.size(), 
features, permutation);
-               List<MatrixObject> pls = doPartitioning(workers.size(), labels, 
permutation);
-               setPartitionedData(workers, pfs, pls);
-       }
-       
+
        private List<MatrixObject> doPartitioning(int k, MatrixObject mo, 
MatrixBlock permutation) {
                MatrixBlock data = mo.acquireRead();
-               int batchSize = (int) Math.ceil(mo.getNumRows() / k);
+               int batchSize = (int) Math.ceil((double) mo.getNumRows() / k);
                List<MatrixObject> pMatrices = IntStream.range(0, k).mapToObj(i 
-> {
-                       int begin = i * batchSize + 1;
+                       int begin = i * batchSize;
                        int end = (int) Math.min((i + 1) * batchSize, 
mo.getNumRows());
-                       MatrixBlock slicedPerm = permutation.slice(begin - 1, 
end - 1);
+                       MatrixBlock slicedPerm = permutation.slice(begin, end - 
1);
                        MatrixBlock output = 
slicedPerm.aggregateBinaryOperations(slicedPerm,
                                data, new MatrixBlock(), 
InstructionUtils.getMatMultOperator(k));
                        MatrixObject result = ParamservUtils.newMatrixObject();
@@ -60,4 +52,13 @@ public class DataPartitionerDR extends DataPartitioner {
                mo.release();
                return pMatrices;
        }
+
+       @Override
+       public Result doPartitioning(int workersNum, MatrixObject features, 
MatrixObject labels) {
+               // Generate a single permutation matrix (workers use slices)
+               MatrixBlock permutation = 
ParamservUtils.generatePermutation((int)features.getNumRows());
+               List<MatrixObject> pfs = doPartitioning(workersNum, features, 
permutation);
+               List<MatrixObject> pls = doPartitioning(workersNum, labels, 
permutation);
+               return new Result(pfs, pls);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java
index 9d2f666..a2ff5f9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java
@@ -35,19 +35,11 @@ import org.apache.sysml.runtime.util.DataConverter;
  * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id)
  */
 public class DataPartitionerDRR extends DataPartitioner {
-       @Override
-       public void doPartitioning(List<LocalPSWorker> workers, MatrixObject 
features, MatrixObject labels) {
-               List<MatrixObject> pfs = IntStream.range(0, workers.size())
-                       .mapToObj(i -> removeEmpty(features, workers.size(), 
i)).collect(Collectors.toList());
-               List<MatrixObject> pls = IntStream.range(0, workers.size())
-                       .mapToObj(i -> removeEmpty(labels, workers.size(), 
i)).collect(Collectors.toList());
-               setPartitionedData(workers, pfs, pls);
-       }
 
        private MatrixObject removeEmpty(MatrixObject mo, int k, int workerId) {
                MatrixObject result = ParamservUtils.newMatrixObject();
                MatrixBlock tmp = mo.acquireRead();
-               double[] data = LongStream.range(0, mo.getNumRows())
+               double[] data = LongStream.range(1, mo.getNumRows() + 1)
                        .mapToDouble(l -> l % k == workerId ? 1 : 0).toArray();
                MatrixBlock select = DataConverter.convertToMatrixBlock(data, 
true);
                MatrixBlock resultMB = tmp.removeEmptyOperations(new 
MatrixBlock(), true, true, select);
@@ -57,4 +49,13 @@ public class DataPartitionerDRR extends DataPartitioner {
                result.enableCleanup(false);
                return result;
        }
+
+       @Override
+       public Result doPartitioning(int workersNum, MatrixObject features, 
MatrixObject labels) {
+               List<MatrixObject> pfs = IntStream.range(0, workersNum)
+                       .mapToObj(i -> removeEmpty(features, workersNum, 
i)).collect(Collectors.toList());
+               List<MatrixObject> pls = IntStream.range(0, workersNum)
+                   .mapToObj(i -> removeEmpty(labels, workersNum, 
i)).collect(Collectors.toList());
+               return new Result(pfs, pls);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java
index fabbd74..0bfb4b7 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java
@@ -33,20 +33,11 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
  * where P is constructed for example with 
P=table(seq(1,nrow(X),sample(nrow(X), nrow(X))))
  */
 public class DataPartitionerOR extends DataPartitioner {
-       @Override
-       public void doPartitioning(List<LocalPSWorker> workers, MatrixObject 
features, MatrixObject labels) {
-               // Generate a different permutation matrix for each worker
-               List<MatrixBlock> permutation = IntStream.range(0, 
workers.size()).mapToObj(i -> 
-                       
ParamservUtils.generatePermutation((int)features.getNumRows())).collect(Collectors.toList());
-               List<MatrixObject> pfs = doPartitioning(workers.size(), 
features, permutation);
-               List<MatrixObject> pls = doPartitioning(workers.size(), labels, 
permutation);
-               setPartitionedData(workers, pfs, pls);
-       }
-       
-       private List<MatrixObject> doPartitioning(int k, MatrixObject mo, 
List<MatrixBlock> lpermutation) {
+
+       private List<MatrixObject> doPartitioning(int k, MatrixObject mo, 
List<MatrixBlock> permutations) {
                MatrixBlock data = mo.acquireRead();
                List<MatrixObject> pMatrices = IntStream.range(0, k).mapToObj(i 
-> {
-                       MatrixBlock permutation = lpermutation.get(i);
+                       MatrixBlock permutation = permutations.get(i);
                        MatrixBlock output = 
permutation.aggregateBinaryOperations(permutation,
                                data, new MatrixBlock(), 
InstructionUtils.getMatMultOperator(k));
                        MatrixObject result = ParamservUtils.newMatrixObject();
@@ -57,4 +48,15 @@ public class DataPartitionerOR extends DataPartitioner {
                mo.release();
                return pMatrices;
        }
+
+       @Override
+       public Result doPartitioning(int workersNum, MatrixObject features, 
MatrixObject labels) {
+               // Generate a different permutation matrix for each worker
+               List<MatrixBlock> permutations = IntStream.range(0, workersNum)
+                       .mapToObj(i -> 
ParamservUtils.generatePermutation((int)features.getNumRows()))
+               .collect(Collectors.toList());
+               List<MatrixObject> pfs = doPartitioning(workersNum, features, 
permutations);
+               List<MatrixObject> pls = doPartitioning(workersNum, labels, 
permutations);
+               return new Result(pfs, pls);
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 8fbc7cc..09caa94 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -413,6 +413,19 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private void doDataPartitioning(DataPartitioner dp, MatrixObject 
features, MatrixObject labels, List<LocalPSWorker> workers) {
-               dp.doPartitioning(workers, features, labels);
+               DataPartitioner.Result result = 
dp.doPartitioning(workers.size(), features, labels);
+               List<MatrixObject> pfs = result.pFeatures;
+               List<MatrixObject> pls = result.pLabels;
+               if (pfs.size() < workers.size()) {
+                       if (LOG.isWarnEnabled()) {
+                               LOG.warn(String.format("There is only %d 
batches of data but has %d workers. "
+                                       + "Hence, reset the number of workers 
with %d.", pfs.size(), workers.size(), pfs.size()));
+                       }
+                       workers = workers.subList(0, pfs.size());
+               }
+               for (int i = 0; i < workers.size(); i++) {
+                       workers.get(i).setFeatures(pfs.get(i));
+                       workers.get(i).setLabels(pls.get(i));
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java
index 8cc7ab7..7e0784e 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java
@@ -19,40 +19,27 @@
 
 package org.apache.sysml.test.integration.functions.paramserv;
 
-import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
-import java.util.stream.Collectors;
+import java.util.Map;
 import java.util.stream.IntStream;
 
-import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.DataIdentifier;
-import org.apache.sysml.parser.Expression;
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
-import org.apache.sysml.runtime.controlprogram.Program;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
 import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC;
+import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDR;
 import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDRR;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerOR;
 import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysml.runtime.util.DataConverter;
 import org.junit.Assert;
 import org.junit.Test;
 
-//TODO test data partitioning on defined API not internal methods,
-// potentially remove workers from API to make the data partitioner independent
-//TODO test expected behavior not the internal implementation against itself 
-// (e.g., for DR check that each row is in at most one partition and all rows 
are distributed)
-
 public class DataPartitionerTest {
 
        @Test
        public void testDataPartitionerDC() {
                DataPartitioner dp = new DataPartitionerDC();
-               List<LocalPSWorker> workers = IntStream.range(0, 2).mapToObj(i 
-> new LocalPSWorker(i, "updFunc", Statement.PSFrequency.BATCH, 1, 64, null, 
null, createMockExecutionContext(), null)).collect(Collectors.toList());
                double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
                MatrixObject features = ParamservUtils.newMatrixObject();
                features.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
@@ -62,56 +49,80 @@ public class DataPartitionerTest {
                labels.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
                labels.refreshMetaData();
                labels.release();
-               dp.doPartitioning(workers, features, labels);
-
-               double[] expected1 = new double[] { 1, 2, 3, 4, 5 };
-               double[] realValue1 = 
workers.get(0).getFeatures().acquireRead().getDenseBlockValues();
-               double[] realValue2 = 
workers.get(0).getLabels().acquireRead().getDenseBlockValues();
-               Assert.assertArrayEquals(expected1, realValue1, 0);
-               Assert.assertArrayEquals(expected1, realValue2, 0);
-
-               double[] expected2 = new double[] { 6, 7, 8, 9, 10 };
-               double[] realValue3 = 
workers.get(1).getFeatures().acquireRead().getDenseBlockValues();
-               double[] realValue4 = 
workers.get(1).getLabels().acquireRead().getDenseBlockValues();
-               Assert.assertArrayEquals(expected2, realValue3, 0);
-               Assert.assertArrayEquals(expected2, realValue4, 0);
+               DataPartitioner.Result result = dp.doPartitioning(3, features, 
labels);
+
+               Assert.assertEquals(3, result.pFeatures.size());
+               Assert.assertEquals(3, result.pLabels.size());
+
+               double[] expected1 = new double[] { 1, 2, 3, 4 };
+               assertResult(result, 0, expected1);
+
+               double[] expected2 = new double[] { 5, 6, 7, 8 };
+               assertResult(result, 1, expected2);
+
+               double[] expected3 = new double[] { 9, 10 };
+               assertResult(result, 2, expected3);
+       }
+
+       private void assertResult(DataPartitioner.Result result, int index, 
double[] expected) {
+               List<MatrixObject> pfs = result.pFeatures;
+               List<MatrixObject> pls = result.pLabels;
+               double[] realValue1 = 
pfs.get(index).acquireRead().getDenseBlockValues();
+               double[] realValue2 = 
pls.get(index).acquireRead().getDenseBlockValues();
+               Assert.assertArrayEquals(expected, realValue1, 0);
+               Assert.assertArrayEquals(expected, realValue2, 0);
        }
 
        @Test
        public void testDataPartitionerDR() {
-//             DataPartitionerDR dp = new DataPartitionerDR();
-//             double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
-//             MatrixObject features = ParamservUtils.newMatrixObject();
-//             features.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
-//             features.refreshMetaData();
-//             features.release();
-//             MatrixObject labels = ParamservUtils.newMatrixObject();
-//             labels.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
-//             labels.refreshMetaData();
-//             labels.release();
-//
-//             MatrixBlock permutation = 
ParamservUtils.generatePermutation(df.length, df.length);
-//
-//             List<MatrixObject> pfs = dp.doPartitioning(2, features, 
permutation);
-//             List<MatrixObject> pls = dp.doPartitioning(2, labels, 
permutation);
-//
-//             double[] expected1 = IntStream.range(0, 5).mapToDouble(i -> 
permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray();
-//             double[] realValue1 = 
pfs.get(0).acquireRead().getDenseBlockValues();
-//             double[] realValue2 = 
pls.get(0).acquireRead().getDenseBlockValues();
-//             Assert.assertArrayEquals(expected1, realValue1, 0);
-//             Assert.assertArrayEquals(expected1, realValue2, 0);
-//
-//             double[] expected2 = IntStream.range(5, 10).mapToDouble(i -> 
permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray();
-//             double[] realValue3 = 
pfs.get(1).acquireRead().getDenseBlockValues();
-//             double[] realValue4 = 
pls.get(1).acquireRead().getDenseBlockValues();
-//             Assert.assertArrayEquals(expected2, realValue3, 0);
-//             Assert.assertArrayEquals(expected2, realValue4, 0);
+               DataPartitioner dp = new DataPartitionerDR();
+               double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
+               MatrixObject features = ParamservUtils.newMatrixObject();
+               features.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
+               features.refreshMetaData();
+               features.release();
+               MatrixObject labels = ParamservUtils.newMatrixObject();
+               labels.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
+               labels.refreshMetaData();
+               labels.release();
+
+               DataPartitioner.Result result = dp.doPartitioning(4, features, 
labels);
+
+               Assert.assertEquals(4, result.pFeatures.size());
+               Assert.assertEquals(4, result.pLabels.size());
+
+               // Ensure that the index is accorded between features and labels
+               IntStream.range(0, result.pFeatures.size()).forEach(i -> {
+                       double[] f = 
result.pFeatures.get(i).acquireRead().getDenseBlockValues();
+                       double[] l = 
result.pLabels.get(i).acquireRead().getDenseBlockValues();
+                       Assert.assertArrayEquals(f, l, 0);
+               });
+
+               assertPermutationDR(df, result.pFeatures);
+               assertPermutationDR(df, result.pLabels);
+       }
+
+       private void assertPermutationDR(double[] df, List<MatrixObject> list) {
+               Map<Double, Integer> dict = new HashMap<>();
+               for (double d : df) {
+                       dict.put(d, 0);
+               }
+               IntStream.range(0, list.size()).forEach(i -> {
+                       double[] f = 
list.get(i).acquireRead().getDenseBlockValues();
+                       for (double d : f) {
+                               dict.compute(d, (k, v) -> v + 1);
+                       }
+               });
+
+               // check if all the occurence is equivalent to one
+               for (Map.Entry<Double, Integer> e : dict.entrySet()) {
+                       Assert.assertEquals(1, (int) e.getValue());
+               }
        }
 
        @Test
        public void testDataPartitionerDRR() {
                DataPartitioner dp = new DataPartitionerDRR();
-               List<LocalPSWorker> workers = IntStream.range(0, 2).mapToObj(i 
-> new LocalPSWorker(i, "updFunc", Statement.PSFrequency.BATCH, 1, 64, null, 
null, createMockExecutionContext(), null)).collect(Collectors.toList());
                double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
                MatrixObject features = ParamservUtils.newMatrixObject();
                features.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
@@ -121,72 +132,61 @@ public class DataPartitionerTest {
                labels.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
                labels.refreshMetaData();
                labels.release();
-               dp.doPartitioning(workers, features, labels);
-
-               //TODO test against four not two workers
-               double[] expected1 = new double[] { 1, 3, 5, 7, 9 };
-               double[] realValue1 = 
workers.get(0).getFeatures().acquireRead().getDenseBlockValues();
-               double[] realValue2 = 
workers.get(0).getLabels().acquireRead().getDenseBlockValues();
-               Assert.assertArrayEquals(expected1, realValue1, 0);
-               Assert.assertArrayEquals(expected1, realValue2, 0);
-
-               double[] expected2 = new double[] { 2, 4, 6, 8, 10 };
-               double[] realValue3 = 
workers.get(1).getFeatures().acquireRead().getDenseBlockValues();
-               double[] realValue4 = 
workers.get(1).getLabels().acquireRead().getDenseBlockValues();
-               Assert.assertArrayEquals(expected2, realValue3, 0);
-               Assert.assertArrayEquals(expected2, realValue4, 0);
+               DataPartitioner.Result result = dp.doPartitioning(4, features, 
labels);
+
+               Assert.assertEquals(4, result.pFeatures.size());
+               Assert.assertEquals(4, result.pLabels.size());
+
+               double[] expected1 = new double[] { 4, 8 };
+               assertResult(result, 0, expected1);
+
+               double[] expected2 = new double[] { 1, 5, 9 };
+               assertResult(result, 1, expected2);
+
+               double[] expected3 = new double[] { 2, 6, 10 };
+               assertResult(result, 2, expected3);
+
+               double[] expected4 = new double[] { 3, 7 };
+               assertResult(result, 3, expected4);
        }
 
        @Test
        public void testDataPartitionerOR() {
-//             DataPartitionerOR dp = new DataPartitionerOR();
-//             double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
-//             MatrixObject features = ParamservUtils.newMatrixObject();
-//             features.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
-//             features.refreshMetaData();
-//             features.release();
-//             MatrixObject labels = ParamservUtils.newMatrixObject();
-//             labels.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
-//             labels.refreshMetaData();
-//             labels.release();
-//
-//             MatrixBlock permutation = 
ParamservUtils.generatePermutation(df.length, df.length);
-//
-//             List<MatrixObject> pfs = dp.doPartitioning(1, features, 
permutation);
-//             List<MatrixObject> pls = dp.doPartitioning(1, labels, 
permutation);
-//
-//             double[] expected1 = IntStream.range(0, 10).mapToDouble(i -> 
permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray();
-//             double[] realValue1 = 
pfs.get(0).acquireRead().getDenseBlockValues();
-//             double[] realValue2 = 
pls.get(0).acquireRead().getDenseBlockValues();
-//             Assert.assertArrayEquals(expected1, realValue1, 0);
-//             Assert.assertArrayEquals(expected1, realValue2, 0);
+               DataPartitioner dp = new DataPartitionerOR();
+               double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
+               MatrixObject features = ParamservUtils.newMatrixObject();
+               features.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
+               features.refreshMetaData();
+               features.release();
+               MatrixObject labels = ParamservUtils.newMatrixObject();
+               labels.acquireModify(DataConverter.convertToMatrixBlock(df, 
true));
+               labels.refreshMetaData();
+               labels.release();
+
+               DataPartitioner.Result result = dp.doPartitioning(4, features, 
labels);
+
+               Assert.assertEquals(4, result.pFeatures.size());
+               Assert.assertEquals(4, result.pLabels.size());
+
+               assertPermutationOR(df, result.pFeatures);
+               assertPermutationOR(df, result.pLabels);
        }
 
-       private ExecutionContext createMockExecutionContext() {
-               Program prog = new Program();
-               ArrayList<DataIdentifier> inputs = new ArrayList<>();
-               DataIdentifier features = new DataIdentifier("features");
-               features.setDataType(Expression.DataType.MATRIX);
-               features.setValueType(Expression.ValueType.DOUBLE);
-               inputs.add(features);
-               DataIdentifier labels = new DataIdentifier("labels");
-               labels.setDataType(Expression.DataType.MATRIX);
-               labels.setValueType(Expression.ValueType.DOUBLE);
-               inputs.add(labels);
-               DataIdentifier model = new DataIdentifier("model");
-               model.setDataType(Expression.DataType.LIST);
-               model.setValueType(Expression.ValueType.UNKNOWN);
-               inputs.add(model);
-
-               ArrayList<DataIdentifier> outputs = new ArrayList<>();
-               DataIdentifier gradients = new DataIdentifier("gradients");
-               gradients.setDataType(Expression.DataType.LIST);
-               gradients.setValueType(Expression.ValueType.UNKNOWN);
-               outputs.add(gradients);
-
-               FunctionProgramBlock fpb = new FunctionProgramBlock(prog, 
inputs, outputs);
-               prog.addProgramBlock(fpb);
-               prog.addFunctionProgramBlock(DMLProgram.DEFAULT_NAMESPACE, 
"updFunc", fpb);
-               return ExecutionContextFactory.createContext(prog);
+       private void assertPermutationOR(double[] df, List<MatrixObject> list) {
+               for (MatrixObject mo : list) {
+                       Map<Double, Integer> dict = new HashMap<>();
+                       for (double d : df) {
+                               dict.put(d, 0);
+                       }
+                       double[] f = mo.acquireRead().getDenseBlockValues();
+                       for (double d : f) {
+                               dict.compute(d, (k, v) -> v + 1);
+                       }
+                       Assert.assertEquals(10, dict.size());
+                       // check if all the occurence is equivalent to one
+                       for (Map.Entry<Double, Integer> e : dict.entrySet()) {
+                               Assert.assertEquals(1, (int) e.getValue());
+                       }
+               }
        }
 }

Reply via email to