Repository: systemml
Updated Branches:
  refs/heads/master 4225fd8b9 -> b586d1691


[SYSTEMML-2469] Performance distributed paramserv (partition/serialize)

Closes #809.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b586d169
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b586d169
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b586d169

Branch: refs/heads/master
Commit: b586d16913196276d5bbd0c0828389aed7e4d9e3
Parents: 4225fd8
Author: EdgarLGB <[email protected]>
Authored: Sat Jul 28 20:26:53 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Jul 28 20:26:53 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/paramserv/LocalPSWorker.java |  4 +-
 .../paramserv/ParamservUtils.java               | 57 ++++++++------------
 .../paramserv/spark/rpc/PSRpcObject.java        | 25 +++++----
 .../functions/paramserv/RpcObjectTest.java      | 20 +++----
 4 files changed, 48 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index b8a416f..f76fddb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -191,8 +191,8 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                // Get the gradients
                ListObject gradients = (ListObject) 
_ec.getVariable(_output.getName());
 
-               ParamservUtils.cleanupData(_ec, bFeatures);
-               ParamservUtils.cleanupData(_ec, bLabels);
+               ParamservUtils.cleanupData(_ec, Statement.PS_FEATURES);
+               ParamservUtils.cleanupData(_ec, Statement.PS_LABELS);
                return gradients;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index 9624c55..e9292d1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -94,14 +94,12 @@ public class ParamservUtils {
                }
                List<Data> newData = IntStream.range(0, 
lo.getLength()).mapToObj(i -> {
                        Data oldData = lo.slice(i);
-                       if (oldData instanceof MatrixObject) {
-                               MatrixObject mo = (MatrixObject) oldData;
-                               return sliceMatrix(mo, 1, mo.getNumRows());
-                       } else if (oldData instanceof ListObject || oldData 
instanceof FrameObject) {
+                       if (oldData instanceof MatrixObject)
+                               return createShallowCopy((MatrixObject) 
oldData);
+                       else if (oldData instanceof ListObject || oldData 
instanceof FrameObject)
                                throw new DMLRuntimeException("Copy list: does 
not support list or frame.");
-                       } else {
+                       else
                                return oldData;
-                       }
                }).collect(Collectors.toList());
                return new ListObject(newData, lo.getNames());
        }
@@ -145,14 +143,14 @@ public class ParamservUtils {
                CacheableData<?> cd = (CacheableData<?>) data;
                cd.enableCleanup(true);
                ec.cleanupCacheableData(cd);
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("%s has been deleted.", 
cd.getFileName()));
-               }
        }
 
-       public static void cleanupMatrixObject(ExecutionContext ec, 
MatrixObject mo) {
-               mo.enableCleanup(true);
-               ec.cleanupCacheableData(mo);
+       public static void cleanupData(ExecutionContext ec, String varName) {
+               cleanupData(ec, ec.removeVariable(varName));
+       }
+
+       public static void cleanupListObject(ListObject lo) {
+               cleanupListObject(ExecutionContextFactory.createContext(), lo);
        }
 
        public static MatrixObject newMatrixObject(MatrixBlock mb) {
@@ -168,6 +166,10 @@ public class ParamservUtils {
                result.enableCleanup(cleanup);
                return result;
        }
+       
+       public static MatrixObject createShallowCopy(MatrixObject mo) {
+               return newMatrixObject(mo.acquireReadAndRelease(), false);
+       }
 
        /**
         * Slice the matrix
@@ -178,11 +180,8 @@ public class ParamservUtils {
         * @return new sliced matrix
         */
        public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long 
rh) {
-               MatrixBlock mb = mo.acquireRead();
-               MatrixObject result = newMatrixObject(sliceMatrixBlock(mb, rl, 
rh));
-               result.enableCleanup(false);
-               mo.release();
-               return result;
+               MatrixBlock mb = mo.acquireReadAndRelease();
+               return newMatrixObject(sliceMatrixBlock(mb, rl, rh), false);
        }
 
        /**
@@ -335,35 +334,23 @@ public class ParamservUtils {
        /**
         * Assemble the matrix of features and labels according to the rowID
         *
-        * @param numRows row size of the data
         * @param featuresRDD indexed features matrix block
         * @param labelsRDD indexed labels matrix block
         * @return Assembled rdd with rowID as key while matrix of features and 
labels as value (rowID -> features, labels)
         */
-       public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> 
assembleTrainingData(long numRows, JavaPairRDD<MatrixIndexes, MatrixBlock> 
featuresRDD, JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) {
-               JavaPairRDD<Long, MatrixBlock> fRDD = groupMatrix(numRows, 
featuresRDD);
-               JavaPairRDD<Long, MatrixBlock> lRDD = groupMatrix(numRows, 
labelsRDD);
+       public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> 
assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, 
JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) {
+               JavaPairRDD<Long, MatrixBlock> fRDD = groupMatrix(featuresRDD);
+               JavaPairRDD<Long, MatrixBlock> lRDD = groupMatrix(labelsRDD);
                //TODO Add an additional physical operator which broadcasts the 
labels directly (broadcast join with features) if certain memory budgets are 
satisfied
                return fRDD.join(lRDD);
        }
 
-       private static JavaPairRDD<Long, MatrixBlock> groupMatrix(long numRows, 
JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) {
+       private static JavaPairRDD<Long, MatrixBlock> 
groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) {
                //TODO could use join and aggregation to avoid unnecessary 
shuffle introduced by reduceByKey
                return rdd.mapToPair(input -> new 
Tuple2<>(input._1.getRowIndex(), new Tuple2<>(input._1.getColumnIndex(), 
input._2)))
                        .aggregateByKey(new LinkedList<Tuple2<Long, 
MatrixBlock>>(),
-                               new Partitioner() {
-                                       private static final long 
serialVersionUID = -7032660778344579236L;
-                                       @Override
-                                       public int getPartition(Object rblkID) {
-                                               return Math.toIntExact((Long) 
rblkID);
-                                       }
-                                       @Override
-                                       public int numPartitions() {
-                                               return Math.toIntExact(numRows);
-                                       }
-                               },
                                (list, input) -> {
-                                       list.add(input); 
+                                       list.add(input);
                                        return list;
                                }, 
                                (l1, l2) -> {
@@ -392,7 +379,7 @@ public class ParamservUtils {
 
                DataPartitionerSparkMapper mapper = new 
DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
                JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = 
ParamservUtils
-                       .assembleTrainingData(features.getNumRows(), 
featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID 
=> (features, labels))
+                       .assembleTrainingData(featuresRDD, labelsRDD) // 
Combine features and labels into a pair (rowBlockID => (features, labels))
                        .flatMapToPair(mapper) // Do the data partitioning on 
spark (workerID => (rowBlockID, (single row features, single row labels))
                        // Aggregate the partitioned matrix according to rowID 
for each worker
                        // i.e. (workerID => ordered list[(rowBlockID, (single 
row features, single row labels)]

http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
index 411822f..816cefd 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
@@ -46,30 +46,33 @@ public abstract class PSRpcObject {
        /**
         * Deep serialize and write of a list object (currently only support 
list containing matrices)
         * @param lo a list object containing only matrices
-        * @param dos output data to write to
+        * @param output output data to write to
         */
-       protected void serializeAndWriteListObject(ListObject lo, DataOutput 
dos) throws IOException {
+       protected void serializeAndWriteListObject(ListObject lo, DataOutput 
output) throws IOException {
                validateListObject(lo);
-               dos.writeInt(lo.getLength()); //write list length
-               dos.writeBoolean(lo.isNamedList()); //write list named
+               output.writeInt(lo.getLength()); //write list length
+               output.writeBoolean(lo.isNamedList()); //write list named
                for (int i = 0; i < lo.getLength(); i++) {
                        if (lo.isNamedList())
-                               dos.writeUTF(lo.getName(i)); //write name
+                               output.writeUTF(lo.getName(i)); //write name
                        ((MatrixObject) lo.getData().get(i))
-                               .acquireReadAndRelease().write(dos); //write 
matrix
+                               .acquireReadAndRelease().write(output); //write 
matrix
                }
+               // Cleanup the list object
+               // because it is transferred to remote worker in binary format
+               ParamservUtils.cleanupListObject(lo);
        }
        
-       protected ListObject readAndDeserialize(DataInput dis) throws 
IOException {
-               int listLen = dis.readInt();
+       protected ListObject readAndDeserialize(DataInput input) throws 
IOException {
+               int listLen = input.readInt();
                List<Data> data = new ArrayList<>();
-               List<String> names = dis.readBoolean() ?
+               List<String> names = input.readBoolean() ?
                        new ArrayList<>() : null;
                for(int i=0; i<listLen; i++) {
                        if( names != null )
-                               names.add(dis.readUTF());
+                               names.add(input.readUTF());
                        MatrixBlock mb = new MatrixBlock();
-                       mb.readFields(dis);
+                       mb.readFields(input);
                        data.add(ParamservUtils.newMatrixObject(mb, false));
                }
                return new ListObject(data, names);

http://git-wip-us.apache.org/repos/asf/systemml/blob/b586d169/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
index f2df1e6..17bfa4c 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -32,27 +32,27 @@ import org.junit.Test;
 
 public class RpcObjectTest {
 
-       @Test
-       public void testPSRpcCall() throws IOException {
+       private ListObject generateData() {
                MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
                MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
-               ListObject lo = new ListObject(Arrays.asList(mo1, mo2));
-               PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo);
+               return new ListObject(Arrays.asList(mo1, mo2));
+       }
+
+       @Test
+       public void testPSRpcCall() throws IOException {
+               PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, 
generateData());
                PSRpcCall actual = new PSRpcCall(expected.serialize());
                Assert.assertTrue(Arrays.equals(
-                       expected.serialize().array(),
+                       new PSRpcCall(PSRpcObject.PUSH, 1, 
generateData()).serialize().array(),
                        actual.serialize().array()));
        }
 
        @Test
        public void testPSRpcResponse() throws IOException {
-               MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
-               MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
-               ListObject lo = new ListObject(Arrays.asList(mo1, mo2));
-               PSRpcResponse expected = new 
PSRpcResponse(PSRpcResponse.Type.SUCCESS, lo);
+               PSRpcResponse expected = new 
PSRpcResponse(PSRpcResponse.Type.SUCCESS, generateData());
                PSRpcResponse actual = new PSRpcResponse(expected.serialize());
                Assert.assertTrue(Arrays.equals(
-                       expected.serialize().array(),
+                       new PSRpcResponse(PSRpcResponse.Type.SUCCESS, 
generateData()).serialize().array(),
                        actual.serialize().array()));
        }
 }

Reply via email to