This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e38638  [SYSTEMDS-3086] Fix federated wdivmm basic (no result 
consolidation)
5e38638 is described below

commit 5e386384296a781cf6d3adf1ffe52105e4356407
Author: ywcb00 <[email protected]>
AuthorDate: Mon Aug 9 21:36:23 2021 +0200

    [SYSTEMDS-3086] Fix federated wdivmm basic (no result consolidation)
    
    Closes #1361.
---
 .../controlprogram/context/ExecutionContext.java   |  4 +
 .../fed/QuaternaryWCeMMFEDInstruction.java         |  2 +-
 .../fed/QuaternaryWDivMMFEDInstruction.java        | 89 +++++++++++++++++-----
 3 files changed, 77 insertions(+), 18 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 0bee2ef..75591b6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -305,6 +305,10 @@ public class ExecutionContext {
        public MatrixBlock getMatrixInput(String varName) {
                return getMatrixObject(varName).acquireRead();
        }
+       
+       public MatrixBlock getMatrixInput(CPOperand input) {
+               return getMatrixObject(input.getName()).acquireRead();
+       }
 
        /**
         * Pins a matrix variable into memory and returns the internal matrix 
block.
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
index 7ae87e2..d2aa182 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.java
@@ -67,7 +67,7 @@ public class QuaternaryWCeMMFEDInstruction extends 
QuaternaryFEDInstruction
                if(qop.hasFourInputs()) {
                        eps = (_input4.getDataType() == DataType.SCALAR) ?
                                ec.getScalarInput(_input4) :
-                               new 
DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+                               new 
DoubleObject(ec.getMatrixInput(_input4).quickGetValue(0, 0));
                }
 
                if(X.isFederated()) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
index ed0d2a8..e2d83d8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
@@ -37,11 +37,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
 
 import java.util.ArrayList;
 import java.util.concurrent.Future;
+import java.util.stream.IntStream;
 
 public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 {
@@ -60,32 +60,35 @@ public class QuaternaryWDivMMFEDInstruction extends 
QuaternaryFEDInstruction
         * @param out             The Federated Result Z
         * @param opcode          ...
         * @param instruction_str ...
-        */
-       protected QuaternaryWDivMMFEDInstruction(Operator operator,
+       */
+
+       private QuaternaryOperator _qop;
+
+       protected QuaternaryWDivMMFEDInstruction(QuaternaryOperator operator,
                CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, 
CPOperand out, String opcode, String instruction_str)
        {
                super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, 
opcode, instruction_str);
+               _qop = operator;
        }
 
        @Override
        public void processInstruction(ExecutionContext ec)
        {
-               QuaternaryOperator qop = (QuaternaryOperator) _optr;
-               final WDivMMType wdivmm_type = qop.wtype3;
+               final WDivMMType wdivmm_type = _qop.wtype3;
                MatrixObject X = ec.getMatrixObject(input1);
                MatrixObject U = ec.getMatrixObject(input2);
                MatrixObject V = ec.getMatrixObject(input3);
                ScalarObject eps = null;
                MatrixObject MX = null;
 
-               if(qop.hasFourInputs()) {
+               if(_qop.hasFourInputs()) {
                        if(wdivmm_type == WDivMMType.MULT_MINUS_4_LEFT || 
wdivmm_type == WDivMMType.MULT_MINUS_4_RIGHT) {
                                MX = ec.getMatrixObject(_input4);
                        }
                        else {
                                eps = (_input4.getDataType() == 
DataType.SCALAR) ?
                                        ec.getScalarInput(_input4) :
-                                       new 
DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+                                       new 
DoubleObject(ec.getMatrixInput(_input4).quickGetValue(0, 0));
                        }
                }
 
@@ -93,7 +96,7 @@ public class QuaternaryWDivMMFEDInstruction extends 
QuaternaryFEDInstruction
                        FederationMap fedMap = X.getFedMapping();
                        ArrayList<FederatedRequest[]> frSliced = new 
ArrayList<>();
                        ArrayList<FederatedRequest> frB = new ArrayList<>(); // 
FederatedRequests of broadcasts
-                       long[] varNewIn = new long[qop.hasFourInputs() ? 4 : 3];
+                       long[] varNewIn = new long[_qop.hasFourInputs() ? 4 : 
3];
                        varNewIn[0] = fedMap.getID();
 
                        if(X.isFederated(FType.ROW)) { // row partitioned X
@@ -151,21 +154,26 @@ public class QuaternaryWDivMMFEDInstruction extends 
QuaternaryFEDInstruction
                        }
 
                        FederatedRequest frComp = 
FederationUtils.callInstruction(instString, output,
-                               qop.hasFourInputs() ? new CPOperand[]{input1, 
input2, input3, _input4}
+                               _qop.hasFourInputs() ? new CPOperand[]{input1, 
input2, input3, _input4}
                                : new CPOperand[]{input1, input2, input3}, 
varNewIn);
 
                        // get partial results from federated workers
-                       FederatedRequest frGet = new 
FederatedRequest(RequestType.GET_VAR, frComp.getID());
+                       FederatedRequest frGet = null;
 
                        ArrayList<FederatedRequest> frC = new ArrayList<>();
-                       frC.add(fedMap.cleanup(getTID(), frComp.getID()));
+                       if((wdivmm_type.isLeft() && X.isFederated(FType.ROW))
+                               || (wdivmm_type.isRight() && 
X.isFederated(FType.COL))) { // output needs local aggregation
+                               frGet = new 
FederatedRequest(RequestType.GET_VAR, frComp.getID());
+                               frC.add(fedMap.cleanup(getTID(), 
frComp.getID()));
+                       }
                        for(FederatedRequest[] frS : frSliced)
                                frC.add(fedMap.cleanup(getTID(), 
frS[0].getID()));
                        for(FederatedRequest fr : frB)
                                frC.add(fedMap.cleanup(getTID(), fr.getID()));
 
-                       FederatedRequest[] frAll = 
ArrayUtils.addAll(ArrayUtils.addAll(
-                               frB.toArray(new FederatedRequest[0]), frComp, 
frGet),
+                       FederatedRequest[] frAll = ArrayUtils.addAll(frGet == 
null ?
+                               ArrayUtils.addAll(frB.toArray(new 
FederatedRequest[0]), frComp) :
+                               ArrayUtils.addAll(frB.toArray(new 
FederatedRequest[0]), frComp, frGet),
                                frC.toArray(new FederatedRequest[0]));
 
                        // execute federated instructions
@@ -174,14 +182,13 @@ public class QuaternaryWDivMMFEDInstruction extends 
QuaternaryFEDInstruction
                                        getTID(), true, frSliced.toArray(new 
FederatedRequest[0][]), frAll);
 
                        if((wdivmm_type.isLeft() && X.isFederated(FType.ROW))
-                               || (wdivmm_type.isRight() && 
X.isFederated(FType.COL))) {
+                               || (wdivmm_type.isRight() && 
X.isFederated(FType.COL))) { // local aggregation
                                // aggregate partial results from federated 
responses
                                AggregateUnaryOperator aop = 
InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                                ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, response, fedMap));
                        }
                        else if(wdivmm_type.isLeft() || wdivmm_type.isRight() 
|| wdivmm_type.isBasic()) {
-                               // bind partial results from federated responses
-                               ec.setMatrixOutput(output.getName(), 
FederationUtils.bind(response, false));
+                               setFederatedOutput(X, U, V, ec, frComp.getID());
                        }
                        else {
                                throw new DMLRuntimeException("Federated WDivMM 
only supported for BASIC, LEFT or RIGHT variants.");
@@ -192,5 +199,53 @@ public class QuaternaryWDivMMFEDInstruction extends 
QuaternaryFEDInstruction
                                + X.isFederated() + ", " + U.isFederated() + ", 
" + V.isFederated() + ")");
                }
        }
-}
 
+       /**
+        * Set the federated output according to the output data 
charactersitics of
+        * the different wdivmm types
+        */
+       private void setFederatedOutput(MatrixObject X, MatrixObject U, 
MatrixObject V, ExecutionContext ec, long fedMapID) {
+               final WDivMMType wdivmm_type = _qop.wtype3;
+               MatrixObject out = ec.getMatrixObject(output);
+               FederationMap outFedMap = 
X.getFedMapping().copyWithNewID(fedMapID);
+
+               long rows = -1;
+               long cols = -1;
+               if(wdivmm_type.isBasic()) {
+                       // BASIC: preserve dimensions of X
+                       rows = X.getNumRows();
+                       cols = X.getNumColumns();
+               }
+               else if(wdivmm_type.isLeft()) {
+                       // LEFT: nrows of transposed X, ncols of U
+                       rows = X.getNumColumns();
+                       cols = U.getNumColumns();
+                       outFedMap = modifyFedRanges(outFedMap.transpose(), 
cols, 1);
+               }
+               else if(wdivmm_type.isRight()) {
+                       // RIGHT: nrows of X, ncols of V
+                       rows = X.getNumRows();
+                       cols = V.getNumColumns();
+                       outFedMap = modifyFedRanges(outFedMap, cols, 1);
+               }
+               out.setFedMapping(outFedMap);
+               out.getDataCharacteristics().set(rows, cols, (int) 
X.getBlocksize());
+       }
+
+       /**
+        * Takes the federated mapping and sets one dimension of all federated 
ranges
+        * to the specified value.
+        *
+        * @param fedMap     the original federated mapping
+        * @param value      long value for setting the dimension
+        * @param dim        indicates if the row (0) or column (1) dimension 
should be set to value
+        * @return FederationMap with the modified federated ranges
+        */
+       private static FederationMap modifyFedRanges(FederationMap fedMap, long 
value, int dim) {
+               IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
+                       fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+                       fedMap.getFederatedRanges()[i].setEndDim(dim, value);
+               });
+               return fedMap;
+       }
+}

Reply via email to