Repository: systemml Updated Branches: refs/heads/master 735c4119c -> 0871f260e
[SYSTEMML-2381] Rework paramserv data partitioner API and tests Closes #783. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0871f260 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0871f260 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0871f260 Branch: refs/heads/master Commit: 0871f260e6fc6d6fed57d2cc249bf4d8beb0a31f Parents: 735c411 Author: EdgarLGB <[email protected]> Authored: Wed Jun 13 08:47:24 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Jun 13 08:57:36 2018 -0700 ---------------------------------------------------------------------- .../paramserv/DataPartitioner.java | 25 +- .../paramserv/DataPartitionerDC.java | 16 +- .../paramserv/DataPartitionerDR.java | 25 +- .../paramserv/DataPartitionerDRR.java | 19 +- .../paramserv/DataPartitionerOR.java | 26 +- .../cp/ParamservBuiltinCPInstruction.java | 15 +- .../paramserv/DataPartitionerTest.java | 238 +++++++++---------- 7 files changed, 187 insertions(+), 177 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java index b94d765..fce6d35 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java @@ -21,27 +21,20 @@ package org.apache.sysml.runtime.controlprogram.paramserv; import java.util.List; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; public abstract class DataPartitioner { - protected static final Log LOG = LogFactory.getLog(DataPartitioner.class.getName()); + public final class Result { + public final List<MatrixObject> pFeatures; + public final List<MatrixObject> pLabels; - public abstract void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels); - - protected void setPartitionedData(List<LocalPSWorker> workers, List<MatrixObject> pfs, List<MatrixObject> pls) { - if (pfs.size() < workers.size()) { - if (LOG.isWarnEnabled()) { - LOG.warn(String.format("There is only %d batches of data but has %d workers. " - + "Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size())); - } - workers = workers.subList(0, pfs.size()); - } - for (int i = 0; i < workers.size(); i++) { - workers.get(i).setFeatures(pfs.get(i)); - workers.get(i).setLabels(pls.get(i)); + public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels) { + this.pFeatures = pFeatures; + this.pLabels = pLabels; } } + + public abstract Result doPartitioning(int workersNum, MatrixObject features, MatrixObject labels); + } http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java index 0426855..4810541 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDC.java @@ -32,17 +32,10 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; * non-overlapping partitions of rows. */ public class DataPartitionerDC extends DataPartitioner { - @Override - public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { - int workerNum = workers.size(); - List<MatrixObject> pfs = doPartitioning(workerNum, features); - List<MatrixObject> pls = doPartitioning(workerNum, labels); - setPartitionedData(workers, pfs, pls); - } private List<MatrixObject> doPartitioning(int k, MatrixObject mo) { List<MatrixObject> list = new ArrayList<>(); - long stepSize = (long) Math.ceil(mo.getNumRows() / k); + long stepSize = (long) Math.ceil((double) mo.getNumRows() / k); long begin = 1; while (begin < mo.getNumRows()) { long end = Math.min(begin - 1 + stepSize, mo.getNumRows()); @@ -52,4 +45,11 @@ public class DataPartitionerDC extends DataPartitioner { } return list; } + + @Override + public Result doPartitioning(int workersNum, MatrixObject features, MatrixObject labels) { + List<MatrixObject> pfs = doPartitioning(workersNum, features); + List<MatrixObject> pls = doPartitioning(workersNum, labels); + return new Result(pfs, pls); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java index ce5f71d..adc6f60 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDR.java @@ -34,22 +34,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; * i.e., sampling without replacement to ensure disjointness. */ public class DataPartitionerDR extends DataPartitioner { - @Override - public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { - // Generate a single permutation matrix (workers use slices) - MatrixBlock permutation = ParamservUtils.generatePermutation((int)features.getNumRows()); - List<MatrixObject> pfs = doPartitioning(workers.size(), features, permutation); - List<MatrixObject> pls = doPartitioning(workers.size(), labels, permutation); - setPartitionedData(workers, pfs, pls); - } - + private List<MatrixObject> doPartitioning(int k, MatrixObject mo, MatrixBlock permutation) { MatrixBlock data = mo.acquireRead(); - int batchSize = (int) Math.ceil(mo.getNumRows() / k); + int batchSize = (int) Math.ceil((double) mo.getNumRows() / k); List<MatrixObject> pMatrices = IntStream.range(0, k).mapToObj(i -> { - int begin = i * batchSize + 1; + int begin = i * batchSize; int end = (int) Math.min((i + 1) * batchSize, mo.getNumRows()); - MatrixBlock slicedPerm = permutation.slice(begin - 1, end - 1); + MatrixBlock slicedPerm = permutation.slice(begin, end - 1); MatrixBlock output = slicedPerm.aggregateBinaryOperations(slicedPerm, data, new MatrixBlock(), InstructionUtils.getMatMultOperator(k)); MatrixObject result = ParamservUtils.newMatrixObject(); @@ -60,4 +52,13 @@ public class DataPartitionerDR extends DataPartitioner { mo.release(); return pMatrices; } + + @Override + public Result doPartitioning(int workersNum, MatrixObject features, MatrixObject labels) { + // Generate a single permutation matrix (workers use slices) + MatrixBlock permutation = ParamservUtils.generatePermutation((int)features.getNumRows()); + List<MatrixObject> pfs = doPartitioning(workersNum, features, permutation); + List<MatrixObject> pls = doPartitioning(workersNum, labels, permutation); + return new Result(pfs, pls); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java index 9d2f666..a2ff5f9 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerDRR.java @@ -35,19 +35,11 @@ import org.apache.sysml.runtime.util.DataConverter; * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id) */ public class DataPartitionerDRR extends DataPartitioner { - @Override - public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { - List<MatrixObject> pfs = IntStream.range(0, workers.size()) - .mapToObj(i -> removeEmpty(features, workers.size(), i)).collect(Collectors.toList()); - List<MatrixObject> pls = IntStream.range(0, workers.size()) - .mapToObj(i -> removeEmpty(labels, workers.size(), i)).collect(Collectors.toList()); - setPartitionedData(workers, pfs, pls); - } private MatrixObject removeEmpty(MatrixObject mo, int k, int workerId) { MatrixObject result = ParamservUtils.newMatrixObject(); MatrixBlock tmp = mo.acquireRead(); - double[] data = LongStream.range(0, mo.getNumRows()) + double[] data = LongStream.range(1, mo.getNumRows() + 1) .mapToDouble(l -> l % k == workerId ? 1 : 0).toArray(); MatrixBlock select = DataConverter.convertToMatrixBlock(data, true); MatrixBlock resultMB = tmp.removeEmptyOperations(new MatrixBlock(), true, true, select); @@ -57,4 +49,13 @@ public class DataPartitionerDRR extends DataPartitioner { result.enableCleanup(false); return result; } + + @Override + public Result doPartitioning(int workersNum, MatrixObject features, MatrixObject labels) { + List<MatrixObject> pfs = IntStream.range(0, workersNum) + .mapToObj(i -> removeEmpty(features, workersNum, i)).collect(Collectors.toList()); + List<MatrixObject> pls = IntStream.range(0, workersNum) + .mapToObj(i -> removeEmpty(labels, workersNum, i)).collect(Collectors.toList()); + return new Result(pfs, pls); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java index fabbd74..0bfb4b7 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionerOR.java @@ -33,20 +33,11 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock; * where P is constructed for example with P=table(seq(1,nrow(X),sample(nrow(X), nrow(X)))) */ public class DataPartitionerOR extends DataPartitioner { - @Override - public void doPartitioning(List<LocalPSWorker> workers, MatrixObject features, MatrixObject labels) { - // Generate a different permutation matrix for each worker - List<MatrixBlock> permutation = IntStream.range(0, workers.size()).mapToObj(i -> - ParamservUtils.generatePermutation((int)features.getNumRows())).collect(Collectors.toList()); - List<MatrixObject> pfs = doPartitioning(workers.size(), features, permutation); - List<MatrixObject> pls = doPartitioning(workers.size(), labels, permutation); - setPartitionedData(workers, pfs, pls); - } - - private List<MatrixObject> doPartitioning(int k, MatrixObject mo, List<MatrixBlock> lpermutation) { + + private List<MatrixObject> doPartitioning(int k, MatrixObject mo, List<MatrixBlock> permutations) { MatrixBlock data = mo.acquireRead(); List<MatrixObject> pMatrices = IntStream.range(0, k).mapToObj(i -> { - MatrixBlock permutation = lpermutation.get(i); + MatrixBlock permutation = permutations.get(i); MatrixBlock output = permutation.aggregateBinaryOperations(permutation, data, new MatrixBlock(), InstructionUtils.getMatMultOperator(k)); MatrixObject result = ParamservUtils.newMatrixObject(); @@ -57,4 +48,15 @@ public class DataPartitionerOR extends DataPartitioner { mo.release(); return pMatrices; } + + @Override + public Result doPartitioning(int workersNum, MatrixObject features, MatrixObject labels) { + // Generate a different permutation matrix for each worker + List<MatrixBlock> permutations = IntStream.range(0, workersNum) + .mapToObj(i -> ParamservUtils.generatePermutation((int)features.getNumRows())) + .collect(Collectors.toList()); + List<MatrixObject> pfs = doPartitioning(workersNum, features, permutations); + List<MatrixObject> pls = doPartitioning(workersNum, labels, permutations); + return new Result(pfs, pls); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 8fbc7cc..09caa94 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -413,6 +413,19 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc } private void doDataPartitioning(DataPartitioner dp, MatrixObject features, MatrixObject labels, List<LocalPSWorker> workers) { - dp.doPartitioning(workers, features, labels); + DataPartitioner.Result result = dp.doPartitioning(workers.size(), features, labels); + List<MatrixObject> pfs = result.pFeatures; + List<MatrixObject> pls = result.pLabels; + if (pfs.size() < workers.size()) { + if (LOG.isWarnEnabled()) { + LOG.warn(String.format("There is only %d batches of data but has %d workers. " + + "Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size())); + } + workers = workers.subList(0, pfs.size()); + } + for (int i = 0; i < workers.size(); i++) { + workers.get(i).setFeatures(pfs.get(i)); + workers.get(i).setLabels(pls.get(i)); + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/0871f260/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java index 8cc7ab7..7e0784e 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/DataPartitionerTest.java @@ -19,40 +19,27 @@ package org.apache.sysml.test.integration.functions.paramserv; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; -import java.util.stream.Collectors; +import java.util.Map; import java.util.stream.IntStream; -import org.apache.sysml.parser.DMLProgram; -import org.apache.sysml.parser.DataIdentifier; -import org.apache.sysml.parser.Expression; -import org.apache.sysml.parser.Statement; -import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock; -import org.apache.sysml.runtime.controlprogram.Program; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner; import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDC; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDR; import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerDRR; -import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker; +import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionerOR; import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils; import org.apache.sysml.runtime.util.DataConverter; import org.junit.Assert; import org.junit.Test; -//TODO test data partitioning on defined API not internal methods, -// potentially remove workers from API to make the data partitioner independent -//TODO test expected behavior not the internal implementation against itself -// (e.g., for DR check that each row is in at most one partition and all rows are distributed) - public class DataPartitionerTest { @Test public void testDataPartitionerDC() { DataPartitioner dp = new DataPartitionerDC(); - List<LocalPSWorker> workers = IntStream.range(0, 2).mapToObj(i -> new LocalPSWorker(i, "updFunc", Statement.PSFrequency.BATCH, 1, 64, null, null, createMockExecutionContext(), null)).collect(Collectors.toList()); double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; MatrixObject features = ParamservUtils.newMatrixObject(); features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); @@ -62,56 +49,80 @@ public class DataPartitionerTest { labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); labels.refreshMetaData(); labels.release(); - dp.doPartitioning(workers, features, labels); - - double[] expected1 = new double[] { 1, 2, 3, 4, 5 }; - double[] realValue1 = workers.get(0).getFeatures().acquireRead().getDenseBlockValues(); - double[] realValue2 = workers.get(0).getLabels().acquireRead().getDenseBlockValues(); - Assert.assertArrayEquals(expected1, realValue1, 0); - Assert.assertArrayEquals(expected1, realValue2, 0); - - double[] expected2 = new double[] { 6, 7, 8, 9, 10 }; - double[] realValue3 = workers.get(1).getFeatures().acquireRead().getDenseBlockValues(); - double[] realValue4 = workers.get(1).getLabels().acquireRead().getDenseBlockValues(); - Assert.assertArrayEquals(expected2, realValue3, 0); - Assert.assertArrayEquals(expected2, realValue4, 0); + DataPartitioner.Result result = dp.doPartitioning(3, features, labels); + + Assert.assertEquals(3, result.pFeatures.size()); + Assert.assertEquals(3, result.pLabels.size()); + + double[] expected1 = new double[] { 1, 2, 3, 4 }; + assertResult(result, 0, expected1); + + double[] expected2 = new double[] { 5, 6, 7, 8 }; + assertResult(result, 1, expected2); + + double[] expected3 = new double[] { 9, 10 }; + assertResult(result, 2, expected3); + } + + private void assertResult(DataPartitioner.Result result, int index, double[] expected) { + List<MatrixObject> pfs = result.pFeatures; + List<MatrixObject> pls = result.pLabels; + double[] realValue1 = pfs.get(index).acquireRead().getDenseBlockValues(); + double[] realValue2 = pls.get(index).acquireRead().getDenseBlockValues(); + Assert.assertArrayEquals(expected, realValue1, 0); + Assert.assertArrayEquals(expected, realValue2, 0); } @Test public void testDataPartitionerDR() { -// DataPartitionerDR dp = new DataPartitionerDR(); -// double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; -// MatrixObject features = ParamservUtils.newMatrixObject(); -// features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); -// features.refreshMetaData(); -// features.release(); -// MatrixObject labels = ParamservUtils.newMatrixObject(); -// labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); -// labels.refreshMetaData(); -// labels.release(); -// -// MatrixBlock permutation = ParamservUtils.generatePermutation(df.length, df.length); -// -// List<MatrixObject> pfs = dp.doPartitioning(2, features, permutation); -// List<MatrixObject> pls = dp.doPartitioning(2, labels, permutation); -// -// double[] expected1 = IntStream.range(0, 5).mapToDouble(i -> permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray(); -// double[] realValue1 = pfs.get(0).acquireRead().getDenseBlockValues(); -// double[] realValue2 = pls.get(0).acquireRead().getDenseBlockValues(); -// Assert.assertArrayEquals(expected1, realValue1, 0); -// Assert.assertArrayEquals(expected1, realValue2, 0); -// -// double[] expected2 = IntStream.range(5, 10).mapToDouble(i -> permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray(); -// double[] realValue3 = pfs.get(1).acquireRead().getDenseBlockValues(); -// double[] realValue4 = pls.get(1).acquireRead().getDenseBlockValues(); -// Assert.assertArrayEquals(expected2, realValue3, 0); -// Assert.assertArrayEquals(expected2, realValue4, 0); + DataPartitioner dp = new DataPartitionerDR(); + double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + MatrixObject features = ParamservUtils.newMatrixObject(); + features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + features.refreshMetaData(); + features.release(); + MatrixObject labels = ParamservUtils.newMatrixObject(); + labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + labels.refreshMetaData(); + labels.release(); + + DataPartitioner.Result result = dp.doPartitioning(4, features, labels); + + Assert.assertEquals(4, result.pFeatures.size()); + Assert.assertEquals(4, result.pLabels.size()); + + // Ensure that the index is accorded between features and labels + IntStream.range(0, result.pFeatures.size()).forEach(i -> { + double[] f = result.pFeatures.get(i).acquireRead().getDenseBlockValues(); + double[] l = result.pLabels.get(i).acquireRead().getDenseBlockValues(); + Assert.assertArrayEquals(f, l, 0); + }); + + assertPermutationDR(df, result.pFeatures); + assertPermutationDR(df, result.pLabels); + } + + private void assertPermutationDR(double[] df, List<MatrixObject> list) { + Map<Double, Integer> dict = new HashMap<>(); + for (double d : df) { + dict.put(d, 0); + } + IntStream.range(0, list.size()).forEach(i -> { + double[] f = list.get(i).acquireRead().getDenseBlockValues(); + for (double d : f) { + dict.compute(d, (k, v) -> v + 1); + } + }); + + // check if all the occurence is equivalent to one + for (Map.Entry<Double, Integer> e : dict.entrySet()) { + Assert.assertEquals(1, (int) e.getValue()); + } } @Test public void testDataPartitionerDRR() { DataPartitioner dp = new DataPartitionerDRR(); - List<LocalPSWorker> workers = IntStream.range(0, 2).mapToObj(i -> new LocalPSWorker(i, "updFunc", Statement.PSFrequency.BATCH, 1, 64, null, null, createMockExecutionContext(), null)).collect(Collectors.toList()); double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; MatrixObject features = ParamservUtils.newMatrixObject(); features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); @@ -121,72 +132,61 @@ public class DataPartitionerTest { labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); labels.refreshMetaData(); labels.release(); - dp.doPartitioning(workers, features, labels); - - //TODO test against four not two workers - double[] expected1 = new double[] { 1, 3, 5, 7, 9 }; - double[] realValue1 = workers.get(0).getFeatures().acquireRead().getDenseBlockValues(); - double[] realValue2 = workers.get(0).getLabels().acquireRead().getDenseBlockValues(); - Assert.assertArrayEquals(expected1, realValue1, 0); - Assert.assertArrayEquals(expected1, realValue2, 0); - - double[] expected2 = new double[] { 2, 4, 6, 8, 10 }; - double[] realValue3 = workers.get(1).getFeatures().acquireRead().getDenseBlockValues(); - double[] realValue4 = workers.get(1).getLabels().acquireRead().getDenseBlockValues(); - Assert.assertArrayEquals(expected2, realValue3, 0); - Assert.assertArrayEquals(expected2, realValue4, 0); + DataPartitioner.Result result = dp.doPartitioning(4, features, labels); + + Assert.assertEquals(4, result.pFeatures.size()); + Assert.assertEquals(4, result.pLabels.size()); + + double[] expected1 = new double[] { 4, 8 }; + assertResult(result, 0, expected1); + + double[] expected2 = new double[] { 1, 5, 9 }; + assertResult(result, 1, expected2); + + double[] expected3 = new double[] { 2, 6, 10 }; + assertResult(result, 2, expected3); + + double[] expected4 = new double[] { 3, 7 }; + assertResult(result, 3, expected4); } @Test public void testDataPartitionerOR() { -// DataPartitionerOR dp = new DataPartitionerOR(); -// double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; -// MatrixObject features = ParamservUtils.newMatrixObject(); -// features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); -// features.refreshMetaData(); -// features.release(); -// MatrixObject labels = ParamservUtils.newMatrixObject(); -// labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); -// labels.refreshMetaData(); -// labels.release(); -// -// MatrixBlock permutation = ParamservUtils.generatePermutation(df.length, df.length); -// -// List<MatrixObject> pfs = dp.doPartitioning(1, features, permutation); -// List<MatrixObject> pls = dp.doPartitioning(1, labels, permutation); -// -// double[] expected1 = IntStream.range(0, 10).mapToDouble(i -> permutation.getSparseBlock().get(i).indexes()[0] + 1).toArray(); -// double[] realValue1 = pfs.get(0).acquireRead().getDenseBlockValues(); -// double[] realValue2 = pls.get(0).acquireRead().getDenseBlockValues(); -// Assert.assertArrayEquals(expected1, realValue1, 0); -// Assert.assertArrayEquals(expected1, realValue2, 0); + DataPartitioner dp = new DataPartitionerOR(); + double[] df = new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + MatrixObject features = ParamservUtils.newMatrixObject(); + features.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + features.refreshMetaData(); + features.release(); + MatrixObject labels = ParamservUtils.newMatrixObject(); + labels.acquireModify(DataConverter.convertToMatrixBlock(df, true)); + labels.refreshMetaData(); + labels.release(); + + DataPartitioner.Result result = dp.doPartitioning(4, features, labels); + + Assert.assertEquals(4, result.pFeatures.size()); + Assert.assertEquals(4, result.pLabels.size()); + + assertPermutationOR(df, result.pFeatures); + assertPermutationOR(df, result.pLabels); } - private ExecutionContext createMockExecutionContext() { - Program prog = new Program(); - ArrayList<DataIdentifier> inputs = new ArrayList<>(); - DataIdentifier features = new DataIdentifier("features"); - features.setDataType(Expression.DataType.MATRIX); - features.setValueType(Expression.ValueType.DOUBLE); - inputs.add(features); - DataIdentifier labels = new DataIdentifier("labels"); - labels.setDataType(Expression.DataType.MATRIX); - labels.setValueType(Expression.ValueType.DOUBLE); - inputs.add(labels); - DataIdentifier model = new DataIdentifier("model"); - model.setDataType(Expression.DataType.LIST); - model.setValueType(Expression.ValueType.UNKNOWN); - inputs.add(model); - - ArrayList<DataIdentifier> outputs = new ArrayList<>(); - DataIdentifier gradients = new DataIdentifier("gradients"); - gradients.setDataType(Expression.DataType.LIST); - gradients.setValueType(Expression.ValueType.UNKNOWN); - outputs.add(gradients); - - FunctionProgramBlock fpb = new FunctionProgramBlock(prog, inputs, outputs); - prog.addProgramBlock(fpb); - prog.addFunctionProgramBlock(DMLProgram.DEFAULT_NAMESPACE, "updFunc", fpb); - return ExecutionContextFactory.createContext(prog); + private void assertPermutationOR(double[] df, List<MatrixObject> list) { + for (MatrixObject mo : list) { + Map<Double, Integer> dict = new HashMap<>(); + for (double d : df) { + dict.put(d, 0); + } + double[] f = mo.acquireRead().getDenseBlockValues(); + for (double d : f) { + dict.compute(d, (k, v) -> v + 1); + } + Assert.assertEquals(10, dict.size()); + // check if all the occurence is equivalent to one + for (Map.Entry<Double, Integer> e : dict.entrySet()) { + Assert.assertEquals(1, (int) e.getValue()); + } + } } }
