This is an automated email from the ASF dual-hosted git repository. arnabp20 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new 6dd66b8 [SYSTEMDS-2784] Enable lineage-based reuse in federated workers 6dd66b8 is described below commit 6dd66b8257b43146fe3cd31bab61f3595184928a Author: arnabp <arnab.ph...@tugraz.at> AuthorDate: Sat Jan 2 22:55:37 2021 +0100 [SYSTEMDS-2784] Enable lineage-based reuse in federated workers This patch builds the initial infrastructure for lineage based reuse in federated workers. Changes include: - Lineage tracing InitFEDInstruction - Lineage trace READ and PUT requests. For PUT, lineageitem hash is sent with the request, which will be replaced by Adler32 in future commits. - Disable compiler assisted optimizations for lineage-based reuse (e.g. mark for caching) for the workers. - Testing infrastructure. --- .../controlprogram/federated/FederatedRequest.java | 13 +++ .../federated/FederatedWorkerHandler.java | 18 +++ .../fed/AggregateBinaryFEDInstruction.java | 6 + .../instructions/fed/InitFEDInstruction.java | 34 +++++- .../org/apache/sysds/test/AutomatedTestBase.java | 13 ++- .../test/functions/lineage/FedFullReuseTest.java | 128 +++++++++++++++++++++ .../scripts/functions/lineage/FedFullReuse1.dml | 30 +++++ .../functions/lineage/FedFullReuse1Reference.dml | 28 +++++ 8 files changed, 268 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java index 6c9be16..33dad44 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java @@ -23,8 +23,10 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.utils.Statistics; public class FederatedRequest implements Serializable { @@ -45,6 +47,7 @@ public class FederatedRequest implements Serializable { private long _tid; private List<Object> _data; private boolean _checkPrivacy; + private List<Integer> _lineageHash; public FederatedRequest(RequestType method) { @@ -117,6 +120,16 @@ public class FederatedRequest implements Serializable { return _checkPrivacy; } + public void setLineageHash(LineageItem[] liItems) { + // copy the hash of the corresponding lineage DAG + // TODO: copy both Adler32 checksum (on data) and hash (on lineage DAG) + _lineageHash = Arrays.stream(liItems).map(li -> li.hashCode()).collect(Collectors.toList()); + } + + public int getLineageHash(int i) { + return _lineageHash.get(i); + } + @Override public String toString() { StringBuilder sb = new StringBuilder("FederatedRequest["); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index e3ec403..5c0a0bc 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -48,6 +48,8 @@ import org.apache.sysds.runtime.instructions.cp.ListObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.lineage.LineageCacheConfig; +import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.runtime.privacy.DMLPrivacyException; @@ -232,6 +234,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { cd.enableCleanup(false); // guard against deletion _ecm.get(tid).setVariable(String.valueOf(id), cd); + if (DMLScript.LINEAGE) + // create a literal type lineage item with the file name + _ecm.get(tid).getLineage().set(String.valueOf(id), new LineageItem(filename)); + if(dataType == Types.DataType.FRAME) { FrameObject frameObject = (FrameObject) cd; frameObject.acquireRead(); @@ -264,6 +270,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { // set variable and construct empty response ec.setVariable(varname, data); + if (DMLScript.LINEAGE) + // TODO: Identify MO uniquely. Use Adler32 checksum. + ec.getLineage().set(varname, new LineageItem(String.valueOf(request.getLineageHash(0)))); + return new FederatedResponse(ResponseType.SUCCESS_EMPTY); } @@ -299,6 +309,14 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { pb.getInstructions().clear(); Instruction receivedInstruction = InstructionParser.parseSingleInstruction((String) request.getParam(0)); pb.getInstructions().add(receivedInstruction); + + if (DMLScript.LINEAGE) + // Compiler assisted optimizations are not applicable for Fed workers. + // e.g. isMarkedForCaching fails as output operands are saved in the + // symbol table only after the instruction execution finishes. + // NOTE: In shared JVM, this will disable compiler assistance even for the coordinator + LineageCacheConfig.setCompAssRW(false); + try { pb.execute(ec); // execute single instruction } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java index 6ed642e..4a8194b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java @@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.fed; import java.util.concurrent.Future; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; @@ -31,6 +32,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -78,6 +80,10 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction { else if(mo1.isFederated(FType.ROW)) { // MV + MM //construct commands: broadcast rhs, fed mv, retrieve results FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2); + if (DMLScript.LINEAGE) + //also copy the hash of the lineage DAG + fr1.setLineageHash(LineageItemUtils.getLineage(ec, input1)); + //TODO: calculate Adler32 checksum on data, and move this code inside FederationMap. FederatedRequest fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()}); if( mo2.getNumColumns() == 1 ) { //MV diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java index 17e2855..bc16149 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java @@ -56,9 +56,11 @@ import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.ListObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageTraceable; import org.apache.sysds.runtime.meta.DataCharacteristics; -public class InitFEDInstruction extends FEDInstruction { +public class InitFEDInstruction extends FEDInstruction implements LineageTraceable { private static final Log LOG = LogFactory.getLog(InitFEDInstruction.class.getName()); @@ -342,4 +344,34 @@ public class InitFEDInstruction extends FEDInstruction { throw new DMLRuntimeException("Exception in frame response from federated worker.", e); } } + + @Override + public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) { + String type = ec.getScalarInput(_type).getStringValue(); + ListObject addresses = ec.getListObject(_addresses.getName()); + ListObject ranges = ec.getListObject(_ranges.getName()); + LineageItem[] liInputs = new LineageItem[addresses.getLength()]; + + for(int i = 0; i < addresses.getLength(); i++) { + Data addressData = addresses.getData().get(i); + if(addressData instanceof StringObject) { + String address = ((StringObject)addressData).getStringValue(); + // get beginning and end of data ranges + List<Data> rangesData = ranges.getData(); + List<Data> beginDimsData = ((ListObject) rangesData.get(i*2)).getData(); + List<Data> endDimsData = ((ListObject) rangesData.get(i*2+1)).getData(); + String rl = ((ScalarObject)beginDimsData.get(0)).getStringValue(); + String cl = ((ScalarObject)beginDimsData.get(1)).getStringValue(); + String ru = ((ScalarObject)endDimsData.get(0)).getStringValue(); + String cu = ((ScalarObject)endDimsData.get(1)).getStringValue(); + // form a string with all the information and create a lineage item + String data = InstructionUtils.concatOperands(type, address, rl, cl, ru, cu); + liInputs[i] = new LineageItem(data); + } + else { + throw new DMLRuntimeException("federated instruction only takes strings as addresses"); + } + } + return Pair.of(_output.getName(), new LineageItem(getOpcode(), liInputs)); + } } diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java index 0143fed..d51f05b 100644 --- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java +++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java @@ -36,6 +36,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.stream.Collectors; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; @@ -1530,7 +1531,11 @@ public abstract class AutomatedTestBase { * @return the thread associated with the worker. */ protected Thread startLocalFedWorkerThread(int port) { - return startLocalFedWorkerThread(port, FED_WORKER_WAIT); + return startLocalFedWorkerThread(port, null, FED_WORKER_WAIT); + } + + protected Thread startLocalFedWorkerThread(int port, String[] otherArgs) { + return startLocalFedWorkerThread(port, otherArgs, FED_WORKER_WAIT); } /** @@ -1543,11 +1548,17 @@ public abstract class AutomatedTestBase { * @return the thread associated with the worker. */ protected Thread startLocalFedWorkerThread(int port, int sleep) { + return startLocalFedWorkerThread(port, null, sleep); + } + protected Thread startLocalFedWorkerThread(int port, String[] otherArgs, int sleep) { Thread t = null; String[] fedWorkArgs = {"-w", Integer.toString(port)}; ArrayList<String> args = new ArrayList<>(); addProgramIndependentArguments(args); + + if (otherArgs != null) + args.addAll(Arrays.stream(otherArgs).collect(Collectors.toList())); for(int i = 0; i < fedWorkArgs.length; i++) args.add(fedWorkArgs[i]); diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java new file mode 100644 index 0000000..00c6d6f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/lineage/FedFullReuseTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.lineage; + +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.runtime.lineage.Lineage; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class FedFullReuseTest extends AutomatedTestBase { + + private final static String TEST_DIR = "functions/lineage/"; + private final static String TEST_NAME = "FedFullReuse1"; + private final static String TEST_CLASS_DIR = TEST_DIR + FedFullReuseTest.class.getSimpleName() + "/"; + + private final static int blocksize = 1024; + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"})); + } + + @Parameterized.Parameters + public static Collection<Object[]> data() { + // rows have to be even and > 1 + return Arrays.asList(new Object[][] { + // {2, 1000}, {10, 100}, + {100, 10}, + //{1000, 1}, + // {10, 2000}, {2000, 10} + }); + } + + @Test + public void federatedReuseMM() { //reuse inside federated workers + federatedReuse(); + } + + public void federatedReuse() { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + // write input matrices + int halfRows = rows / 2; + // Share two matrices between two federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + double[][] Y1 = getRandomMatrix(cols, halfRows, 0, 1, 1, 44); + double[][] Y2 = getRandomMatrix(cols, halfRows, 0, 1, 1, 21); + + writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("Y1", Y1, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols)); + writeInputMatrixWithMTD("Y2", Y2, false, new MatrixCharacteristics(cols, halfRows, blocksize, halfRows * cols)); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + String[] otherargs = new String[] {"-lineage", "reuse_full"}; + Lineage.resetInternalState(); + Thread t1 = startLocalFedWorkerThread(port1, otherargs, FED_WORKER_WAIT_S); + Thread t2 = startLocalFedWorkerThread(port2, otherargs); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix. Reuse of ba+*. + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-stats", "-lineage", "reuse_full", + "-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"), + "Y2=" + input("Y2"), "Z=" + expected("Z")}; + runTest(true, false, null, -1); + long mmCount = Statistics.getCPHeavyHitterCount("ba+*"); + + // Run actual dml script with federated matrix + // The fed workers reuse ba+* + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats","-lineage", "reuse_full", + "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), + "Y1=" + TestUtils.federatedAddress(port1, input("Y1")), + "Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")}; + runTest(true, false, null, -1); + long mmCount_fed = Statistics.getCPHeavyHitterCount("ba+*"); + + // compare results + compareResults(1e-9); + // compare matrix multiplication count + // #federated execution of ba+* = #threads times #non-federated execution of ba+* (after reuse) + Assert.assertTrue("Violated reuse count: "+mmCount_fed+" == "+mmCount*2, + mmCount_fed == mmCount * 2); // #threads = 2 + + TestUtils.shutdownThreads(t1, t2); + } + +} diff --git a/src/test/scripts/functions/lineage/FedFullReuse1.dml b/src/test/scripts/functions/lineage/FedFullReuse1.dml new file mode 100644 index 0000000..4597332 --- /dev/null +++ b/src/test/scripts/functions/lineage/FedFullReuse1.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(addresses=list($X1, $X2), + ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c))); +Y = federated(addresses=list($Y1, $Y2), + ranges=list(list(0, 0), list($c, $r / 2), list(0, $r / 2), list($c, $r))); + +for(i in 1:10) + Z = X %*% Y; + +write(Z, $Z); diff --git a/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml new file mode 100644 index 0000000..6049f5d --- /dev/null +++ b/src/test/scripts/functions/lineage/FedFullReuse1Reference.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($X1), read($X2)); +Y = cbind(read($Y1), read($Y2)); + +for(i in 1:10) + Z = X %*% Y; + +write(Z, $Z);