This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 6dd66b8  [SYSTEMDS-2784] Enable lineage-based reuse in federated 
workers
6dd66b8 is described below

commit 6dd66b8257b43146fe3cd31bab61f3595184928a
Author: arnabp <arnab.ph...@tugraz.at>
AuthorDate: Sat Jan 2 22:55:37 2021 +0100

    [SYSTEMDS-2784] Enable lineage-based reuse in federated workers
    
    This patch builds the initial infrastructure for lineage based
    reuse in federated workers. Changes include:
     - Lineage tracing InitFEDInstruction
     - Lineage trace READ and PUT requests. For PUT, lineageitem hash
       is sent with the request, which will be replaced by Adler32
       in future commits.
     - Disable compiler assisted optimizations for lineage-based reuse
       (e.g. mark for caching) for the workers.
     - Testing infrastructure.
---
 .../controlprogram/federated/FederatedRequest.java |  13 +++
 .../federated/FederatedWorkerHandler.java          |  18 +++
 .../fed/AggregateBinaryFEDInstruction.java         |   6 +
 .../instructions/fed/InitFEDInstruction.java       |  34 +++++-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  13 ++-
 .../test/functions/lineage/FedFullReuseTest.java   | 128 +++++++++++++++++++++
 .../scripts/functions/lineage/FedFullReuse1.dml    |  30 +++++
 .../functions/lineage/FedFullReuse1Reference.dml   |  28 +++++
 8 files changed, 268 insertions(+), 2 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 6c9be16..33dad44 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -23,8 +23,10 @@ import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.stream.Collectors;
 
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.utils.Statistics;
 
 public class FederatedRequest implements Serializable {
@@ -45,6 +47,7 @@ public class FederatedRequest implements Serializable {
        private long _tid;
        private List<Object> _data;
        private boolean _checkPrivacy;
+       private List<Integer> _lineageHash;
        
        
        public FederatedRequest(RequestType method) {
@@ -117,6 +120,16 @@ public class FederatedRequest implements Serializable {
                return _checkPrivacy;
        }
        
+       public void setLineageHash(LineageItem[] liItems) {
+               // copy the hash of the corresponding lineage DAG
+               // TODO: copy both Adler32 checksum (on data) and hash (on 
lineage DAG)
+               _lineageHash = Arrays.stream(liItems).map(li -> 
li.hashCode()).collect(Collectors.toList());
+       }
+       
+       public int getLineageHash(int i) {
+               return _lineageHash.get(i);
+       }
+       
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index e3ec403..5c0a0bc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -48,6 +48,8 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -232,6 +234,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                cd.enableCleanup(false); // guard against deletion
                _ecm.get(tid).setVariable(String.valueOf(id), cd);
 
+               if (DMLScript.LINEAGE)
+                       // create a literal type lineage item with the file name
+                       _ecm.get(tid).getLineage().set(String.valueOf(id), new 
LineageItem(filename));
+
                if(dataType == Types.DataType.FRAME) {
                        FrameObject frameObject = (FrameObject) cd;
                        frameObject.acquireRead();
@@ -264,6 +270,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
                // set variable and construct empty response
                ec.setVariable(varname, data);
+               if (DMLScript.LINEAGE)
+                       // TODO: Identify MO uniquely. Use Adler32 checksum.
+                       ec.getLineage().set(varname, new 
LineageItem(String.valueOf(request.getLineageHash(0))));
+
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
 
@@ -299,6 +309,14 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                pb.getInstructions().clear();
                Instruction receivedInstruction = 
InstructionParser.parseSingleInstruction((String) request.getParam(0));
                pb.getInstructions().add(receivedInstruction);
+
+               if (DMLScript.LINEAGE)
+                       // Compiler assisted optimizations are not applicable 
for Fed workers.
+                       // e.g. isMarkedForCaching fails as output operands are 
saved in the 
+                       // symbol table only after the instruction execution 
finishes. 
+                       // NOTE: In shared JVM, this will disable compiler 
assistance even for the coordinator 
+                       LineageCacheConfig.setCompAssRW(false);
+
                try {
                        pb.execute(ec); // execute single instruction
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6ed642e..4a8194b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed;
 
 import java.util.concurrent.Future;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -31,6 +32,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
@@ -78,6 +80,10 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                else if(mo1.isFederated(FType.ROW)) { // MV + MM
                        //construct commands: broadcast rhs, fed mv, retrieve 
results
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                       if (DMLScript.LINEAGE)
+                               //also copy the hash of the lineage DAG
+                               
fr1.setLineageHash(LineageItemUtils.getLineage(ec, input1));
+                               //TODO: calculate Adler32 checksum on data, and 
move this code inside FederationMap.
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
                        if( mo2.getNumColumns() == 1 ) { //MV
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 17e2855..bc16149 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -56,9 +56,11 @@ import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
-public class InitFEDInstruction extends FEDInstruction {
+public class InitFEDInstruction extends FEDInstruction implements 
LineageTraceable {
 
        private static final Log LOG = 
LogFactory.getLog(InitFEDInstruction.class.getName());
 
@@ -342,4 +344,34 @@ public class InitFEDInstruction extends FEDInstruction {
                        throw new DMLRuntimeException("Exception in frame 
response from federated worker.", e);
                }
        }
+
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               String type = ec.getScalarInput(_type).getStringValue();
+               ListObject addresses = ec.getListObject(_addresses.getName());
+               ListObject ranges = ec.getListObject(_ranges.getName());
+               LineageItem[] liInputs = new LineageItem[addresses.getLength()];
+
+               for(int i = 0; i < addresses.getLength(); i++) {
+                       Data addressData = addresses.getData().get(i);
+                       if(addressData instanceof StringObject) {
+                               String address = 
((StringObject)addressData).getStringValue();
+                               // get beginning and end of data ranges
+                               List<Data> rangesData = ranges.getData();
+                               List<Data> beginDimsData = ((ListObject) 
rangesData.get(i*2)).getData();
+                               List<Data> endDimsData = ((ListObject) 
rangesData.get(i*2+1)).getData();
+                               String rl = 
((ScalarObject)beginDimsData.get(0)).getStringValue();
+                               String cl = 
((ScalarObject)beginDimsData.get(1)).getStringValue();
+                               String ru = 
((ScalarObject)endDimsData.get(0)).getStringValue();
+                               String cu = 
((ScalarObject)endDimsData.get(1)).getStringValue();
+                               // form a string with all the information and 
create a lineage item
+                               String data = 
InstructionUtils.concatOperands(type, address, rl, cl, ru, cu);
+                               liInputs[i] = new LineageItem(data);
+                       }
+                       else {
+                               throw new DMLRuntimeException("federated 
instruction only takes strings as addresses");
+                       }
+               }
+               return Pair.of(_output.getName(), new LineageItem(getOpcode(), 
liInputs));
+       }
 }
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 0143fed..d51f05b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -36,6 +36,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.stream.Collectors;
 
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.IOUtils;
@@ -1530,7 +1531,11 @@ public abstract class AutomatedTestBase {
         * @return the thread associated with the worker.
         */
        protected Thread startLocalFedWorkerThread(int port) {
-               return startLocalFedWorkerThread(port, FED_WORKER_WAIT);
+               return startLocalFedWorkerThread(port, null, FED_WORKER_WAIT);
+       }
+
+       protected Thread startLocalFedWorkerThread(int port, String[] 
otherArgs) {
+               return startLocalFedWorkerThread(port, otherArgs, 
FED_WORKER_WAIT);
        }
 
        /**
@@ -1543,11 +1548,17 @@ public abstract class AutomatedTestBase {
         * @return the thread associated with the worker.
         */
        protected Thread startLocalFedWorkerThread(int port, int sleep) {
+               return startLocalFedWorkerThread(port, null, sleep);
+       }
+       protected Thread startLocalFedWorkerThread(int port, String[] 
otherArgs, int sleep) {
                Thread t = null;
                String[] fedWorkArgs = {"-w", Integer.toString(port)};
                ArrayList<String> args = new ArrayList<>();
 
                addProgramIndependentArguments(args);
+               
+               if (otherArgs != null)
+                       
args.addAll(Arrays.stream(otherArgs).collect(Collectors.toList()));
 
                for(int i = 0; i < fedWorkArgs.length; i++)
                        args.add(fedWorkArgs[i]);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java 
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
new file mode 100644
index 0000000..00c6d6f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FedFullReuseTest extends AutomatedTestBase {
+
+       private final static String TEST_DIR = "functions/lineage/";
+       private final static String TEST_NAME = "FedFullReuse1";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FedFullReuseTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows have to be even and > 1
+               return Arrays.asList(new Object[][] {
+                       // {2, 1000}, {10, 100},
+                       {100, 10}, 
+                       //{1000, 1},
+                       // {10, 2000}, {2000, 10}
+               });
+       }
+
+       @Test
+       public void federatedReuseMM() {    //reuse inside federated workers
+               federatedReuse();
+       }
+       
+       public void federatedReuse() {
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               int halfRows = rows / 2;
+               // Share two matrices between two federated worker
+               double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+               double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+               double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44);
+               double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21);
+
+               writeInputMatrixWithMTD("X1", X1, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("X2", X2, false, new 
MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("Y1", Y1, false, new 
MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+               writeInputMatrixWithMTD("Y2", Y2, false, new 
MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols));
+
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               String[] otherargs = new String[] {"-lineage", "reuse_full"};
+               Lineage.resetInternalState();
+               Thread t1 = startLocalFedWorkerThread(port1, otherargs, 
FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, otherargs);
+
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix. Reuse of ba+*.
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "-lineage", "reuse_full",
+                       "-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), 
"Y1=" + input("Y1"),
+                       "Y2=" + input("Y2"), "Z=" + expected("Z")};
+               runTest(true, false, null, -1);
+               long mmCount = Statistics.getCPHeavyHitterCount("ba+*");
+
+               // Run actual dml script with federated matrix
+               // The fed workers reuse ba+*
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats","-lineage", "reuse_full",
+                       "-nvargs", "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                       "X2=" + TestUtils.federatedAddress(port2, input("X2")),
+                       "Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
+                       "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), 
"r=" + rows, "c=" + cols, "Z=" + output("Z")};
+               runTest(true, false, null, -1);
+               long mmCount_fed = Statistics.getCPHeavyHitterCount("ba+*");
+
+               // compare results 
+               compareResults(1e-9);
+               // compare matrix multiplication count
+               // #federated execution of ba+* = #threads times #non-federated 
execution of ba+* (after reuse) 
+               Assert.assertTrue("Violated reuse count: "+mmCount_fed+" == 
"+mmCount*2, 
+                               mmCount_fed == mmCount * 2); // #threads = 2
+
+               TestUtils.shutdownThreads(t1, t2);
+       }
+
+}
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1.dml 
b/src/test/scripts/functions/lineage/FedFullReuse1.dml
new file mode 100644
index 0000000..4597332
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)));
+Y = federated(addresses=list($Y1, $Y2),
+    ranges=list(list(0, 0), list($c, $r / 2), list(0, $r / 2), list($c, $r)));
+
+for(i in 1:10)
+  Z = X %*% Y;
+
+write(Z, $Z);
diff --git a/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml 
b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
new file mode 100644
index 0000000..6049f5d
--- /dev/null
+++ b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2));
+Y = cbind(read($Y1), read($Y2));
+
+for(i in 1:10)
+  Z = X %*% Y;
+
+write(Z, $Z);

Reply via email to