This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 19e656e09e [SYSTEMDS-3185] Lineage trace federated broadcast slices
19e656e09e is described below

commit 19e656e09eda8310cc9c31f67a73c6496f7ff27b
Author: ywcb00 <[email protected]>
AuthorDate: Fri Oct 22 18:02:04 2021 +0200

    [SYSTEMDS-3185] Lineage trace federated broadcast slices
    
    This patch addresses the problem of differentiating between different slices
    from the same broadcast data object. The current logic identifies broadcast
    slices by its original data object, which leads to incorrect reuse of two
    different slices from the same original file. This patch manually creates
    a rightindex lineage trace for each slice to uniquely identify each slice.
    
    Closes #1574
---
 .../controlprogram/context/ExecutionContext.java   |  20 +--
 .../controlprogram/federated/FederatedRequest.java |   6 +-
 .../controlprogram/federated/FederationMap.java    |  50 ++++---
 .../sysds/runtime/instructions/cp/CPOperand.java   |   7 +-
 .../instructions/fed/IndexingFEDInstruction.java   |   3 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    |  13 +-
 .../instructions/fed/SpoofFEDInstruction.java      |   2 +-
 ...acheTest.java => FederatedReuseSlicesTest.java} | 156 +++++++++++++++------
 .../multitenant/FederatedReadCacheTest.dml         |  45 ------
 .../multitenant/FederatedReuseSlicesTest.dml       |  85 +++++++++++
 10 files changed, 263 insertions(+), 124 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 7d30c5ff4e..dfad1e14cb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -245,10 +245,14 @@ public class ExecutionContext {
        }
 
        public MatrixLineagePair getMatrixLineagePair(CPOperand cpo) {
-               MatrixObject mo = getMatrixObject(cpo);
+               return getMatrixLineagePair(cpo.getName());
+       }
+
+       public MatrixLineagePair getMatrixLineagePair(String varname) {
+               MatrixObject mo = getMatrixObject(varname);
                if(mo == null)
                        return null;
-               return MatrixLineagePair.of(mo, DMLScript.LINEAGE ? 
getLineageItem(cpo) : null);
+               return MatrixLineagePair.of(mo, DMLScript.LINEAGE ? 
getLineageItem(varname) : null);
        }
 
        public TensorObject getTensorObject(String varname) {
@@ -849,18 +853,14 @@ public class ExecutionContext {
                LineageDebugger.maintainSpecialValueBits(_lineage, inst, this);
        }
 
-       public String getSingleLineageTrace(CPOperand cpo) {
-               if(!DMLScript.LINEAGE)
-                       return null;
-               if( _lineage == null )
-                       throw new DMLRuntimeException("Lineage Trace 
unavailable.");
-               return _lineage.serializeSingleTrace(cpo);
+       public LineageItem getLineageItem(CPOperand input) {
+               return getLineageItem(input.getName());
        }
 
-       public LineageItem getLineageItem(CPOperand input) {
+       public LineageItem getLineageItem(String varname) {
                if( _lineage == null )
                        throw new DMLRuntimeException("Lineage Trace 
unavailable.");
-               return _lineage.get(input);
+               return _lineage.get(varname);
        }
 
        public LineageItem getOrCreateLineageItem(CPOperand input) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 1566a65314..76bd5b3d93 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -35,6 +35,8 @@ import 
org.apache.sysds.runtime.controlprogram.caching.CacheDataOutput;
 import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageItem;
 
 public class FederatedRequest implements Serializable {
        private static final long serialVersionUID = 5946781306963870394L;
@@ -71,9 +73,9 @@ public class FederatedRequest implements Serializable {
                this(method, id, Arrays.asList(data));
        }
 
-       public FederatedRequest(RequestType method, String linTrace, long id, 
Object ... data) {
+       public FederatedRequest(RequestType method, LineageItem linItem, long 
id, Object ... data) {
                this(method, id, Arrays.asList(data));
-               _lineageTrace = linTrace;
+               _lineageTrace = (linItem != null) ? 
Lineage.serializeSingleTrace(linItem) : null;
        }
 
        public FederatedRequest(RequestType method, long id, List<Object> data) 
{
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index e00c1ee549..1642d86553 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -32,15 +32,20 @@ import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.sysds.lops.RightIndex;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
-import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -117,10 +122,10 @@ public class FederationMap {
        }
 
        public FederatedRequest broadcast(MatrixLineagePair moLin) {
-               return broadcast(moLin.getMO(), 
Lineage.serializeSingleTrace(moLin.getLI()));
+               return broadcast(moLin.getMO(), moLin.getLI());
        }
 
-       private FederatedRequest broadcast(CacheableData<?> data, final String 
lineageTrace) {
+       private FederatedRequest broadcast(CacheableData<?> data, LineageItem 
lineageItem) {
                // reuse existing broadcast variable
                if( data.isFederated(FType.BROADCAST) )
                        return new FederatedRequest(RequestType.NOOP, 
data.getFedMapping().getID());
@@ -132,7 +137,7 @@ public class FederationMap {
                // is fine, because with broadcast all data on all workers)
                data.setFedMapping(copyWithNewIDAndRange(
                        cb.getNumRows(), cb.getNumColumns(), id, 
FType.BROADCAST));
-               return new FederatedRequest(RequestType.PUT_VAR, lineageTrace, 
id, cb);
+               return new FederatedRequest(RequestType.PUT_VAR, lineageItem, 
id, cb);
        }
 
        public FederatedRequest broadcast(ScalarObject scalar) {
@@ -147,7 +152,7 @@ public class FederationMap {
 
        public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin,
                boolean transposed) {
-               return broadcastSliced(moLin.getMO(), 
Lineage.serializeSingleTrace(moLin.getLI()),
+               return broadcastSliced(moLin.getMO(), moLin.getLI(),
                        transposed);
        }
 
@@ -155,12 +160,12 @@ public class FederationMap {
         * Creates separate slices of an input data object according to the 
index ranges of federated data. These slices
         * are then wrapped in separate federated requests for broadcasting.
         *
-        * @param data         input data object (matrix, tensor, frame)
-        * @param lineageTrace the serialized lineage trace of the data
-        * @param transposed   false: slice according to federated data, true: 
slice according to transposed federated data
+        * @param data        input data object (matrix, tensor, frame)
+        * @param lineageItem the lineage item of the data
+        * @param transposed  false: slice according to federated data, true: 
slice according to transposed federated data
         * @return array of federated requests corresponding to federated data
         */
-       private FederatedRequest[] broadcastSliced(CacheableData<?> data, 
String lineageTrace,
+       private FederatedRequest[] broadcastSliced(CacheableData<?> data, 
LineageItem lineageItem,
                boolean transposed) {
                if( _type == FType.FULL )
                        return new FederatedRequest[]{broadcast(data)};
@@ -198,8 +203,7 @@ public class FederationMap {
                // multi-threaded block slicing and federation request creation
                else {
                        Arrays.parallelSetAll(ret,
-                               i -> new FederatedRequest(RequestType.PUT_VAR, 
lineageTrace, id,
-                                       cb.slice(ix[i][0], ix[i][1], ix[i][2], 
ix[i][3], new MatrixBlock())));
+                               i -> sliceBroadcastBlock(ix[i], id, cb, 
lineageItem, false));
                }
                return ret;
        }
@@ -210,11 +214,10 @@ public class FederationMap {
 
        public FederatedRequest[] broadcastSliced(MatrixLineagePair moLin,
                boolean isFrame, int[][] ix) {
-               return broadcastSliced(moLin.getMO(), 
Lineage.serializeSingleTrace(moLin.getLI()),
-                       isFrame, ix);
+               return broadcastSliced(moLin.getMO(), moLin.getLI(), isFrame, 
ix);
        }
 
-       public FederatedRequest[] broadcastSliced(CacheableData<?> data, String 
lineageTrace,
+       public FederatedRequest[] broadcastSliced(CacheableData<?> data, 
LineageItem lineageItem,
                boolean isFrame, int[][] ix) {
                if( _type == FType.FULL )
                        return new FederatedRequest[]{broadcast(data)};
@@ -226,11 +229,26 @@ public class FederationMap {
                // multi-threaded block slicing and federation request creation
                FederatedRequest[] ret = new FederatedRequest[ix.length];
                Arrays.setAll(ret,
-                       i -> new FederatedRequest(RequestType.PUT_VAR, 
lineageTrace, id,
-                               cb.slice(ix[i][0], ix[i][1], ix[i][2], 
ix[i][3], isFrame ? new FrameBlock() : new MatrixBlock())));
+                       i -> sliceBroadcastBlock(ix[i], id, cb, lineageItem, 
isFrame));
                return ret;
        }
 
+       private FederatedRequest sliceBroadcastBlock(int[] ix, long id, 
CacheBlock cb, LineageItem objLi, boolean isFrame) {
+               LineageItem li = null;
+               if(objLi != null) {
+                       // manually create a lineage item for indexing to 
complete the lineage trace for slicing
+                       CPOperand rl = new CPOperand(String.valueOf(ix[0] + 1), 
ValueType.INT64, DataType.SCALAR, true);
+                       CPOperand ru = new CPOperand(String.valueOf(ix[1] + 1), 
ValueType.INT64, DataType.SCALAR, true);
+                       CPOperand cl = new CPOperand(String.valueOf(ix[2] + 1), 
ValueType.INT64, DataType.SCALAR, true);
+                       CPOperand cu = new CPOperand(String.valueOf(ix[3] + 1), 
ValueType.INT64, DataType.SCALAR, true);
+                       li = new LineageItem(RightIndex.OPCODE, new 
LineageItem[]{objLi, rl.getLiteralLineageItem(),
+                               ru.getLiteralLineageItem(), 
cl.getLiteralLineageItem(), cu.getLiteralLineageItem()});
+               }
+               FederatedRequest fr = new FederatedRequest(RequestType.PUT_VAR, 
li, id,
+                       cb.slice(ix[0], ix[1], ix[2], ix[3], isFrame ? new 
FrameBlock() : new MatrixBlock()));
+               return fr;
+       }
+
        /**
         * helper function for checking multiple allowed alignment types
         * @param that FederationMap to check alignment with
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
index 657025994b..0f343cd7be 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPOperand.java
@@ -26,6 +26,7 @@ import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 
 public class CPOperand
@@ -179,7 +180,11 @@ public class CPOperand
                        getName(), getDataType().name(),
                        getValueType().name(), String.valueOf(isLiteral()));
        }
-       
+
+       public LineageItem getLiteralLineageItem() {
+               return new LineageItem(getLineageLiteral());
+       }
+
        public String getLineageLiteral(ScalarObject so) {
                return getLineageLiteral(so, isLiteral());
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index 3ac2a8726b..bc70b398f9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -25,6 +25,7 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
@@ -271,7 +272,7 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                FederatedRequest tmp = new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new 
MatrixCharacteristics(-1, -1), in1.getDataType());
                fedMap.execute(getTID(), true, tmp);
 
-               FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, 
ec.getSingleLineageTrace(input2),
+               FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, 
DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
                        input2.isFrame(), sliceIxs);
                FederatedRequest[] fr2 = 
FederationUtils.callInstruction(instStrings, output, id, new 
CPOperand[]{input1, input2},
                        new long[]{fedMap.getID(), fr1[0].getID()}, null);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index d1b773d87a..edae75ed39 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -56,6 +56,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.Respo
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
 import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -342,7 +343,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        throw new DMLRuntimeException("Unsupported margin 
identifier '" + margin + "'.");
 
                FrameObject mo = (FrameObject) getTarget(ec);
-               MatrixObject select = params.containsKey("select") ? 
ec.getMatrixObject(params.get("select")) : null;
+               MatrixLineagePair select = params.containsKey("select") ? 
ec.getMatrixLineagePair(params.get("select")) : null;
                FrameObject out = ec.getFrameObject(output);
 
                boolean marginRow = params.get("margin").equals("rows");
@@ -378,10 +379,10 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        for(int i = 1; i < colSums.size(); i++)
                                s = s.binaryOperationsInPlace(plus, 
colSums.get(i));
                        s = s.binaryOperationsInPlace(greater, new 
MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
-                       select = ExecutionContext.createMatrixObject(s);
+                       select = 
MatrixLineagePair.of(ExecutionContext.createMatrixObject(s), null);
 
                        long varID = FederationUtils.getNextFedDataID();
-                       ec.setVariable(String.valueOf(varID), select);
+                       ec.setVariable(String.valueOf(varID), select.getMO());
                        params.put("select", String.valueOf(varID));
                        // construct new string
                        String[] oldString = 
InstructionUtils.getInstructionParts(instString);
@@ -505,7 +506,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        throw new DMLRuntimeException("Unsupported margin 
identifier '" + margin + "'.");
 
                MatrixObject mo = (MatrixObject) getTarget(ec);
-               MatrixObject select = params.containsKey("select") ? 
ec.getMatrixObject(params.get("select")) : null;
+               MatrixLineagePair select = params.containsKey("select") ? 
ec.getMatrixLineagePair(params.get("select")) : null;
                MatrixObject out = ec.getMatrixObject(output);
 
                boolean marginRow = params.get("margin").equals("rows");
@@ -541,10 +542,10 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        for(int i = 1; i < colSums.size(); i++)
                                s = s.binaryOperationsInPlace(plus, 
colSums.get(i));
                        s = s.binaryOperationsInPlace(greater, new 
MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
-                       select = ExecutionContext.createMatrixObject(s);
+                       select = 
MatrixLineagePair.of(ExecutionContext.createMatrixObject(s), null);
 
                        long varID = FederationUtils.getNextFedDataID();
-                       ec.setVariable(String.valueOf(varID), select);
+                       ec.setVariable(String.valueOf(varID), select.getMO());
                        params.put("select", String.valueOf(varID));
                        // construct new string
                        String[] oldString = 
InstructionUtils.getInstructionParts(instString);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 6d44928c76..0edad44621 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -118,7 +118,7 @@ public class SpoofFEDInstruction extends FEDInstruction
                for(CPOperand cpo : _inputs) {
                        Data tmpData = ec.getVariable(cpo);
                        if(tmpData instanceof MatrixObject) {
-                               MatrixLineagePair mo = new 
MatrixLineagePair((MatrixObject) tmpData, ec.getLineageItem(cpo));
+                               MatrixLineagePair mo = 
MatrixLineagePair.of((MatrixObject) tmpData, ec.getLineageItem(cpo));
                                if(mo.isFederatedExcept(FType.BROADCAST)) {
                                        frIds[index++] = 
mo.getFedMapping().getID();
                                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReadCacheTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
similarity index 52%
rename from 
src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReadCacheTest.java
rename to 
src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
index 5ceb6f5688..9f15f49d22 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReadCacheTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedReuseSlicesTest.java
@@ -24,7 +24,6 @@ import java.util.Collection;
 import java.util.HashMap;
 
 import org.apache.commons.lang3.ArrayUtils;
-import org.apache.commons.lang3.StringUtils;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -40,11 +39,11 @@ import org.junit.runners.Parameterized;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-public class FederatedReadCacheTest extends MultiTenantTestBase {
-       private final static String TEST_NAME = "FederatedReadCacheTest";
+public class FederatedReuseSlicesTest extends MultiTenantTestBase {
+       private final static String TEST_NAME = "FederatedReuseSlicesTest";
 
        private final static String TEST_DIR = 
"functions/federated/multitenant/";
-       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedReadCacheTest.class.getSimpleName() + "/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedReuseSlicesTest.class.getSimpleName() + "/";
 
        private final static double TOLERANCE = 0;
 
@@ -62,16 +61,17 @@ public class FederatedReadCacheTest extends 
MultiTenantTestBase {
        public static Collection<Object[]> data() {
                return Arrays.asList(
                        new Object[][] {
-                               {100, 1000, 0.9, false},
-                               // {1000, 100, 0.9, true},
+                               {100, 200, 0.9, false},
+                               // {200, 100, 0.9, true},
                                // {100, 1000, 0.01, false},
                                // {1000, 100, 0.01, true},
                });
        }
 
        private enum OpType {
-               PLUS_SCALAR,
-               MODIFIED_VAL,
+               EW_MULT,
+               RM_EMPTY,
+               PARFOR_DIV,
        }
 
        @Override
@@ -80,28 +80,40 @@ public class FederatedReadCacheTest extends 
MultiTenantTestBase {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
        }
 
+       @Test
+       public void testElementWisePlusCP() {
+               runReuseSlicesTest(OpType.EW_MULT, 4, ExecMode.SINGLE_NODE);
+       }
+
        @Test
        @Ignore
-       public void testPlusScalarCP() {
-               runReadCacheTest(OpType.PLUS_SCALAR, 3, ExecMode.SINGLE_NODE);
+       public void testElementWisePlusSP() {
+               runReuseSlicesTest(OpType.EW_MULT, 4, ExecMode.SPARK);
+       }
+
+       @Test
+       public void testRemoveEmptyCP() {
+               runReuseSlicesTest(OpType.RM_EMPTY, 4, ExecMode.SINGLE_NODE);
        }
 
        @Test
-       public void testPlusScalarSP() {
-               runReadCacheTest(OpType.PLUS_SCALAR, 3, ExecMode.SPARK);
+       @Ignore // NOTE: federated removeEmpty not supported in spark execution 
mode yet
+       public void testRemoveEmptySP() {
+               runReuseSlicesTest(OpType.RM_EMPTY, 4, ExecMode.SPARK);
        }
 
        @Test
-       public void testModifiedValCP() {
-               runReadCacheTest(OpType.MODIFIED_VAL, 4, ExecMode.SINGLE_NODE);
+       @Ignore
+       public void testParforDivCP() {
+               runReuseSlicesTest(OpType.PARFOR_DIV, 4, ExecMode.SINGLE_NODE);
        }
 
        @Test
-       public void testModifiedValSP() {
-               runReadCacheTest(OpType.MODIFIED_VAL, 4, ExecMode.SPARK);
+       public void testParforDivSP() {
+               runReuseSlicesTest(OpType.PARFOR_DIV, 4, ExecMode.SPARK);
        }
 
-       private void runReadCacheTest(OpType opType, int numCoordinators, 
ExecMode execMode) {
+       private void runReuseSlicesTest(OpType opType, int numCoordinators, 
ExecMode execMode) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
@@ -133,7 +145,7 @@ public class FederatedReadCacheTest extends 
MultiTenantTestBase {
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
 
-               int[] workerPorts = startFedWorkers(4);
+               int[] workerPorts = startFedWorkers(4, new String[]{"-lineage", 
"reuse"});
 
                rtplatform = execMode;
                if(rtplatform == ExecMode.SPARK) {
@@ -144,21 +156,30 @@ public class FederatedReadCacheTest extends 
MultiTenantTestBase {
 
                // start the coordinator processes
                String scriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-stats", "100", "-fedStats", 
"100", "-nvargs",
+               programArgs = new String[] {"-config", CONFIG_DIR + 
"SystemDS-MultiTenant-config.xml",
+                       "-lineage", "reuse", "-stats", "100", "-fedStats", 
"100", "-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(workerPorts[0], 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(workerPorts[1], 
input("X2")),
                        "in_X3=" + TestUtils.federatedAddress(workerPorts[2], 
input("X3")),
                        "in_X4=" + TestUtils.federatedAddress(workerPorts[3], 
input("X4")),
                        "rows=" + rows, "cols=" + cols, "testnum=" + 
Integer.toString(opType.ordinal()),
                        "rP=" + Boolean.toString(rowPartitioned).toUpperCase()};
-               for(int counter = 0; counter < numCoordinators; counter++)
+               for(int counter = 0; counter < numCoordinators; counter++) {
+                       // start coordinators with alternating boolean 
mod_fedMap --> change order of fed partitions
                        startCoordinator(execMode, scriptName,
-                               ArrayUtils.addAll(programArgs, "out_S=" + 
output("S" + counter)));
+                               ArrayUtils.addAll(programArgs, "out_S=" + 
output("S" + counter),
+                                       "mod_fedMap=" + 
Boolean.toString(counter % 2 == 1).toUpperCase()));
+
+                       // wait for the coordinator processes to end and verify 
the results
+                       String coordinatorOutput = waitForCoordinators();
+
+                       if(counter <= 1) // instructions are only executed for 
the first two coordinators
+                               Assert.assertTrue(checkForHeavyHitter(opType, 
coordinatorOutput, execMode));
+                       // verify that the matrix object has been taken from 
cache
+                       Assert.assertTrue(checkForReuses(opType, 
coordinatorOutput, execMode, counter));
+               }
 
-               // wait for the coordinator processes to end and verify the 
results
-               String coordinatorOutput = waitForCoordinators();
-               System.out.println(coordinatorOutput);
-               verifyResults(opType, coordinatorOutput, execMode);
+               verifyResults();
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
@@ -172,34 +193,85 @@ public class FederatedReadCacheTest extends 
MultiTenantTestBase {
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
 
-       private void verifyResults(OpType opType, String outputLog, ExecMode 
execMode) {
-               Assert.assertTrue(checkForHeavyHitter(opType, outputLog, 
execMode));
-               // verify that the matrix object has been taken from cache
-               Assert.assertTrue(outputLog.contains("Fed ReadCache (Hits, 
Bytes):\t"
-                       + Integer.toString((coordinatorProcesses.size()-1) * 
workerProcesses.size()) + "/"));
-
+       private void verifyResults() {
                // compare the results via files
-               HashMap<CellIndex, Double> refResults   = 
readDMLMatrixFromOutputDir("S" + 0);
+               HashMap<CellIndex, Double> refResults0  = 
readDMLMatrixFromOutputDir("S" + 0);
+               HashMap<CellIndex, Double> refResults1  = 
readDMLMatrixFromOutputDir("S" + 1);
                Assert.assertFalse("The result of the first coordinator, which 
is taken as reference, is empty.",
-                       refResults.isEmpty());
-               for(int counter = 1; counter < coordinatorProcesses.size(); 
counter++) {
+                       refResults0.isEmpty());
+               Assert.assertFalse("The result of the second coordinator, which 
is taken as reference, is empty.",
+                       refResults1.isEmpty());
+
+               boolean compareEqual = true;
+               for(CellIndex index : refResults0.keySet()) {
+                       compareEqual &= 
refResults0.get(index).equals(refResults1.get(index));
+                       if(!compareEqual)
+                               break;
+               }
+               Assert.assertFalse("The result of the first coordinator should 
be different than the "
+                       + "result of the second coordinator (due to modified 
federated maps).", compareEqual);
+
+               for(int counter = 2; counter < coordinatorProcesses.size(); 
counter++) {
                        HashMap<CellIndex, Double> fedResults = 
readDMLMatrixFromOutputDir("S" + counter);
-                       TestUtils.compareMatrices(fedResults, refResults, 
TOLERANCE, "Fed" + counter, "FedRef");
+                       TestUtils.compareMatrices(fedResults, (counter % 2 == 
0) ? refResults0 : refResults1,
+                               TOLERANCE, "Fed" + counter, "FedRef");
                }
        }
 
        private boolean checkForHeavyHitter(OpType opType, String outputLog, 
ExecMode execMode) {
+               boolean retVal = false;
                switch(opType) {
-                       case PLUS_SCALAR:
-                               return checkForHeavyHitter(outputLog, "fed_+");
-                       case MODIFIED_VAL:
-                               return checkForHeavyHitter(outputLog, "fed_*") 
&& checkForHeavyHitter(outputLog, "fed_+");
+                       case EW_MULT:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_*");
+                               break;
+                       case RM_EMPTY:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_rmempty");
+                               retVal &= checkForHeavyHitter(outputLog, 
"fed_uak+");
+                               break;
+                       case PARFOR_DIV:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_/");
+                               retVal &= checkForHeavyHitter(outputLog, 
(execMode == ExecMode.SPARK) ? "fed_rblk" : "fed_uak+");
+                               break;
                }
-               return false;
+               return retVal;
        }
 
        private boolean checkForHeavyHitter(String outputLog, String hhString) {
-               int occurrences = StringUtils.countMatches(outputLog, hhString);
-               return (occurrences == coordinatorProcesses.size());
+               return outputLog.contains(hhString);
+       }
+
+       private boolean checkForReuses(OpType opType, String outputLog, 
ExecMode execMode, int coordIX) {
+               final String LINCACHE_MULTILVL = "LinCache MultiLvl 
(Ins/SB/Fn):\t";
+               final String LINCACHE_WRITES = "LinCache writes 
(Mem/FS/Del):\t";
+               final String FED_LINEAGEPUT = "Fed PutLineage (Count, 
Items):\t";
+               boolean retVal = false;
+               int multiplier = 1;
+               int numInst = -1;
+               switch(opType) {
+                       case EW_MULT:
+                               numInst = 1;
+                               break;
+                       case RM_EMPTY:
+                               numInst = 1;
+                               break;
+                       case PARFOR_DIV: // number of instructions times number 
of iterations of the parfor loop
+                               multiplier = 3;
+                               numInst = ((execMode == ExecMode.SPARK) ? 1 : 
2) * multiplier;
+                               break;
+               }
+               if(coordIX <= 1) {
+                       retVal = outputLog.contains(LINCACHE_MULTILVL + "0/");
+                       retVal &= outputLog.contains(LINCACHE_WRITES + 
Integer.toString(
+                               (((coordIX == 0) ? 1 : 0) + numInst) // read + 
instructions
+                               * workerProcesses.size()) + "/");
+               }
+               else {
+                       retVal = outputLog.contains(LINCACHE_MULTILVL
+                               + Integer.toString(numInst * 
workerProcesses.size()) + "/");
+                       retVal &= outputLog.contains(LINCACHE_WRITES + "0/");
+               }
+               retVal &= outputLog.contains(FED_LINEAGEPUT
+                       + Integer.toString(workerProcesses.size() * multiplier) 
+ "/");
+               return retVal;
        }
 }
diff --git 
a/src/test/scripts/functions/federated/multitenant/FederatedReadCacheTest.dml 
b/src/test/scripts/functions/federated/multitenant/FederatedReadCacheTest.dml
deleted file mode 100644
index 83da5febc1..0000000000
--- 
a/src/test/scripts/functions/federated/multitenant/FederatedReadCacheTest.dml
+++ /dev/null
@@ -1,45 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-if ($rP) {
-  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
-                 list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 
0), list($rows, $cols)));
-} else {
-  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
-        ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
-               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
-}
-
-testnum = $testnum;
-
-if(testnum == 0) { # PLUS_SCALAR
-  X = X + 1;
-  while(FALSE) { }
-  S = as.matrix(sum(X));
-}
-else if(testnum == 1) { # MODIFIED_VAL
-  X[nrow(X)/2, ncol(X)/2] = (X[nrow(X)/2, ncol(X)/2] + 1) * 10;
-  while(FALSE) { }
-  S = as.matrix(sum(X));
-}
-
-write(S, $out_S);
diff --git 
a/src/test/scripts/functions/federated/multitenant/FederatedReuseSlicesTest.dml 
b/src/test/scripts/functions/federated/multitenant/FederatedReuseSlicesTest.dml
new file mode 100644
index 0000000000..e3c1b89423
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/multitenant/FederatedReuseSlicesTest.dml
@@ -0,0 +1,85 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rowPart = $rP;
+modFedMap = $mod_fedMap;
+
+if(modFedMap) { # change order of federated partitions
+  addr=list($in_X2, $in_X3, $in_X4, $in_X1);
+} else {
+  addr=list($in_X1, $in_X2, $in_X3, $in_X4);
+}
+if (rowPart) {
+  X = federated(addresses=addr,
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+          list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+  X = federated(addresses=addr,
+        ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+          list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), 
list($rows, $cols)));
+}
+
+testnum = $testnum;
+
+while(FALSE) { }
+
+if(testnum == 0) { # EW_MULT
+  Y = rand(rows=$rows, cols=$cols, seed=1234);
+
+  S = X * Y;
+  while(FALSE) { }
+}
+else if(testnum == 1) { # RM_EMPTY
+  if(rowPart) {
+    margin="rows";
+    Y = matrix(0, nrow(X), 1);
+    Y[floor(nrow(X) / 2) + 1 : nrow(X), 1] = matrix(1, floor(nrow(X) / 2), 1);
+  } else {
+    margin="cols";
+    Y = matrix(0, ncol(X), 1);
+    Y[floor(ncol(X) / 2) + 1 : ncol(X), 1] = matrix(1, floor(ncol(X) / 2), 1);
+  }
+  while(FALSE) { }
+  Y = Y - 1;
+  while(FALSE) { }
+  Y = Y + 1;
+  while(FALSE) { }
+
+  Z = removeEmpty(target=X, margin=margin, select=Y);
+  while(FALSE) { }
+  S = as.matrix(sum(Z));
+}
+else if(testnum == 2) { # PARFOR_DIV
+  Y = rand(rows=$rows, cols=$cols, seed=1234);
+
+  numiter = 3;
+  Z = matrix(0, rows=numiter, cols=1);
+  parfor(i in 1:numiter) {
+    Y_tmp = Y + i;
+    while(FALSE) { }
+    Z_tmp = X / Y_tmp;
+    while(FALSE) { }
+    Z[i, 1] = sum(Z_tmp);
+  }
+  S = as.matrix(sum(Z));
+}
+
+write(S, $out_S);

Reply via email to