Repository: systemml Updated Branches: refs/heads/master 4225fd8b9 -> b586d1691
[SYSTEMML-2469] Performance distributed paramserv (partition/serialize) Closes #809. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b586d169 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b586d169 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b586d169 Branch: refs/heads/master Commit: b586d16913196276d5bbd0c0828389aed7e4d9e3 Parents: 4225fd8 Author: EdgarLGB <[email protected]> Authored: Sat Jul 28 20:26:53 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 28 20:26:53 2018 -0700 ---------------------------------------------------------------------- .../controlprogram/paramserv/LocalPSWorker.java | 4 +- .../paramserv/ParamservUtils.java | 57 ++++++++------------ .../paramserv/spark/rpc/PSRpcObject.java | 25 +++++---- .../functions/paramserv/RpcObjectTest.java | 20 +++---- 4 files changed, 48 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java index b8a416f..f76fddb 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java @@ -191,8 +191,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> { // Get the gradients ListObject gradients = (ListObject) _ec.getVariable(_output.getName()); - ParamservUtils.cleanupData(_ec, bFeatures); - ParamservUtils.cleanupData(_ec, bLabels); + ParamservUtils.cleanupData(_ec, Statement.PS_FEATURES); + ParamservUtils.cleanupData(_ec, Statement.PS_LABELS); return gradients; } http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java index 9624c55..e9292d1 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java @@ -94,14 +94,12 @@ public class ParamservUtils { } List<Data> newData = IntStream.range(0, lo.getLength()).mapToObj(i -> { Data oldData = lo.slice(i); - if (oldData instanceof MatrixObject) { - MatrixObject mo = (MatrixObject) oldData; - return sliceMatrix(mo, 1, mo.getNumRows()); - } else if (oldData instanceof ListObject || oldData instanceof FrameObject) { + if (oldData instanceof MatrixObject) + return createShallowCopy((MatrixObject) oldData); + else if (oldData instanceof ListObject || oldData instanceof FrameObject) throw new DMLRuntimeException("Copy list: does not support list or frame."); - } else { + else return oldData; - } }).collect(Collectors.toList()); return new ListObject(newData, lo.getNames()); } @@ -145,14 +143,14 @@ public class ParamservUtils { CacheableData<?> cd = (CacheableData<?>) data; cd.enableCleanup(true); ec.cleanupCacheableData(cd); - if (LOG.isDebugEnabled()) { - LOG.debug(String.format("%s has been deleted.", cd.getFileName())); - } } - public static void cleanupMatrixObject(ExecutionContext ec, MatrixObject mo) { - mo.enableCleanup(true); - ec.cleanupCacheableData(mo); + public static void cleanupData(ExecutionContext ec, String varName) { + cleanupData(ec, ec.removeVariable(varName)); + } + + public static void cleanupListObject(ListObject lo) { + cleanupListObject(ExecutionContextFactory.createContext(), lo); } public static MatrixObject newMatrixObject(MatrixBlock mb) { @@ -168,6 +166,10 @@ public class ParamservUtils { result.enableCleanup(cleanup); return result; } + + public static MatrixObject createShallowCopy(MatrixObject mo) { + return newMatrixObject(mo.acquireReadAndRelease(), false); + } /** * Slice the matrix @@ -178,11 +180,8 @@ public class ParamservUtils { * @return new sliced matrix */ public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) { - MatrixBlock mb = mo.acquireRead(); - MatrixObject result = newMatrixObject(sliceMatrixBlock(mb, rl, rh)); - result.enableCleanup(false); - mo.release(); - return result; + MatrixBlock mb = mo.acquireReadAndRelease(); + return newMatrixObject(sliceMatrixBlock(mb, rl, rh), false); } /** @@ -335,35 +334,23 @@ public class ParamservUtils { /** * Assemble the matrix of features and labels according to the rowID * - * @param numRows row size of the data * @param featuresRDD indexed features matrix block * @param labelsRDD indexed labels matrix block * @return Assembled rdd with rowID as key while matrix of features and labels as value (rowID -> features, labels) */ - public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> assembleTrainingData(long numRows, JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) { - JavaPairRDD<Long, MatrixBlock> fRDD = groupMatrix(numRows, featuresRDD); - JavaPairRDD<Long, MatrixBlock> lRDD = groupMatrix(numRows, labelsRDD); + public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) { + JavaPairRDD<Long, MatrixBlock> fRDD = groupMatrix(featuresRDD); + JavaPairRDD<Long, MatrixBlock> lRDD = groupMatrix(labelsRDD); //TODO Add an additional physical operator which broadcasts the labels directly (broadcast join with features) if certain memory budgets are satisfied return fRDD.join(lRDD); } - private static JavaPairRDD<Long, MatrixBlock> groupMatrix(long numRows, JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) { + private static JavaPairRDD<Long, MatrixBlock> groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) { //TODO could use join and aggregation to avoid unnecessary shuffle introduced by reduceByKey return rdd.mapToPair(input -> new Tuple2<>(input._1.getRowIndex(), new Tuple2<>(input._1.getColumnIndex(), input._2))) .aggregateByKey(new LinkedList<Tuple2<Long, MatrixBlock>>(), - new Partitioner() { - private static final long serialVersionUID = -7032660778344579236L; - @Override - public int getPartition(Object rblkID) { - return Math.toIntExact((Long) rblkID); - } - @Override - public int numPartitions() { - return Math.toIntExact(numRows); - } - }, (list, input) -> { - list.add(input); + list.add(input); return list; }, (l1, l2) -> { @@ -392,7 +379,7 @@ public class ParamservUtils { DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows()); JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils - .assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels)) + .assembleTrainingData(featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels)) .flatMapToPair(mapper) // Do the data partitioning on spark (workerID => (rowBlockID, (single row features, single row labels)) // Aggregate the partitioned matrix according to rowID for each worker // i.e. (workerID => ordered list[(rowBlockID, (single row features, single row labels)] http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java index 411822f..816cefd 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java @@ -46,30 +46,33 @@ public abstract class PSRpcObject { /** * Deep serialize and write of a list object (currently only support list containing matrices) * @param lo a list object containing only matrices - * @param dos output data to write to + * @param output output data to write to */ - protected void serializeAndWriteListObject(ListObject lo, DataOutput dos) throws IOException { + protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException { validateListObject(lo); - dos.writeInt(lo.getLength()); //write list length - dos.writeBoolean(lo.isNamedList()); //write list named + output.writeInt(lo.getLength()); //write list length + output.writeBoolean(lo.isNamedList()); //write list named for (int i = 0; i < lo.getLength(); i++) { if (lo.isNamedList()) - dos.writeUTF(lo.getName(i)); //write name + output.writeUTF(lo.getName(i)); //write name ((MatrixObject) lo.getData().get(i)) - .acquireReadAndRelease().write(dos); //write matrix + .acquireReadAndRelease().write(output); //write matrix } + // Cleanup the list object + // because it is transferred to remote worker in binary format + ParamservUtils.cleanupListObject(lo); } - protected ListObject readAndDeserialize(DataInput dis) throws IOException { - int listLen = dis.readInt(); + protected ListObject readAndDeserialize(DataInput input) throws IOException { + int listLen = input.readInt(); List<Data> data = new ArrayList<>(); - List<String> names = dis.readBoolean() ? + List<String> names = input.readBoolean() ? new ArrayList<>() : null; for(int i=0; i<listLen; i++) { if( names != null ) - names.add(dis.readUTF()); + names.add(input.readUTF()); MatrixBlock mb = new MatrixBlock(); - mb.readFields(dis); + mb.readFields(input); data.add(ParamservUtils.newMatrixObject(mb, false)); } return new ListObject(data, names); http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java index f2df1e6..17bfa4c 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java @@ -32,27 +32,27 @@ import org.junit.Test; public class RpcObjectTest { - @Test - public void testPSRpcCall() throws IOException { + private ListObject generateData() { MatrixObject mo1 = SerializationTest.generateDummyMatrix(10); MatrixObject mo2 = SerializationTest.generateDummyMatrix(20); - ListObject lo = new ListObject(Arrays.asList(mo1, mo2)); - PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo); + return new ListObject(Arrays.asList(mo1, mo2)); + } + + @Test + public void testPSRpcCall() throws IOException { + PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, generateData()); PSRpcCall actual = new PSRpcCall(expected.serialize()); Assert.assertTrue(Arrays.equals( - expected.serialize().array(), + new PSRpcCall(PSRpcObject.PUSH, 1, generateData()).serialize().array(), actual.serialize().array())); } @Test public void testPSRpcResponse() throws IOException { - MatrixObject mo1 = SerializationTest.generateDummyMatrix(10); - MatrixObject mo2 = SerializationTest.generateDummyMatrix(20); - ListObject lo = new ListObject(Arrays.asList(mo1, mo2)); - PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.Type.SUCCESS, lo); + PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.Type.SUCCESS, generateData()); PSRpcResponse actual = new PSRpcResponse(expected.serialize()); Assert.assertTrue(Arrays.equals( - expected.serialize().array(), + new PSRpcResponse(PSRpcResponse.Type.SUCCESS, generateData()).serialize().array(), actual.serialize().array())); } }
