This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a11235  [SYSTEMDS-2855] Fix missing federated col-partitioned matrix 
multiply
0a11235 is described below

commit 0a112356e059c20baf609cd6c1f06a232ddd2f4c
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Feb 9 17:30:47 2021 +0100

    [SYSTEMDS-2855] Fix missing federated col-partitioned matrix multiply
    
    This patch adds the missing support for federated matrix multiplication
    for column partitioned federated matrices. In addition, we changed the
    log level of federated request command from info to debug, for reduced
    default output in local tests.
---
 .../federated/FederatedWorkerHandler.java             |  8 ++++----
 .../controlprogram/federated/FederationMap.java       | 19 +++++++++++--------
 .../fed/AggregateBinaryFEDInstruction.java            | 13 +++++++++++++
 .../runtime/instructions/fed/FEDInstructionUtils.java |  2 +-
 4 files changed, 29 insertions(+), 13 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index a75c97a..57d5ba3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -96,10 +96,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
                for(int i = 0; i < requests.length; i++) {
                        FederatedRequest request = requests[i];
-                       if(log.isInfoEnabled()) {
-                               log.info("Executing command " + (i + 1) + "/" + 
requests.length + ": " + request.getType().name());
-                               if(log.isDebugEnabled()) {
-                                       log.debug("full command: " + 
request.toString());
+                       if(log.isDebugEnabled()) {
+                               log.debug("Executing command " + (i + 1) + "/" 
+ requests.length + ": " + request.getType().name());
+                               if(log.isTraceEnabled()) {
+                                       log.trace("full command: " + 
request.toString());
                                }
                        }
                        PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index e933979..4f70dd0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -143,18 +143,21 @@ public class FederationMap {
                // prepare broadcast id and pin input
                long id = FederationUtils.getNextFedDataID();
                CacheBlock cb = data.acquireReadAndRelease();
-
+               
                // prepare indexing ranges
                int[][] ix = new int[_fedMap.size()][];
                int pos = 0;
                for(Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet()) {
-                       int rl, ru, cl, cu;
-                       // TODO Handle different cases than ROW aligned 
Matrices.
-                       rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
-                       ru = transposed ? cb.getNumRows() - 1 : 
e.getKey().getEndDimsInt()[0] - 1;
-                       cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
-                       cu = transposed ? e.getKey().getEndDimsInt()[0] - 1 : 
cb.getNumColumns() - 1;
-                       ix[pos++] = new int[] {rl, ru, cl, cu};
+                       int beg = e.getKey().getBeginDimsInt()[(_type == 
FType.ROW ? 0 : 1)];
+                       int end = e.getKey().getEndDimsInt()[(_type == 
FType.ROW ? 0 : 1)];
+                       int nr = _type == FType.ROW ? cb.getNumRows() : 
cb.getNumColumns();
+                       int nc = _type == FType.ROW ? cb.getNumColumns() : 
cb.getNumRows();
+                       int rl = transposed ? 0 : beg;
+                       int ru = transposed ? nr - 1 : end - 1;
+                       int cl = transposed ? beg : 0;
+                       int cu = transposed ? end - 1 : nc - 1;
+                       ix[pos++] = _type == FType.ROW ?
+                               new int[] {rl, ru, cl, cu} : new int[] {cl, cu, 
rl, ru};
                }
 
                // multi-threaded block slicing and federation request creation
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6ed642e..12616ed 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -110,6 +110,19 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                        MatrixBlock ret = FederationUtils.aggAdd(tmp);
                        ec.setMatrixOutput(output.getName(), ret);
                }
+               //#3 col-federated matrix vector multiplication
+               else if (mo1.isFederated(FType.COL)) {// VM + MM
+                       //construct commands: broadcast rhs, fed mv, retrieve 
results
+                       FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, true);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
+                               new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
+                       FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                       FederatedRequest fr4 = 
mo1.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
+                       //execute federated operations and aggregate
+                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3, fr4);
+                       MatrixBlock ret = FederationUtils.aggAdd(tmp);
+                       ec.setMatrixOutput(output.getName(), ret);
+               }
                else { //other combinations
                        throw new DMLRuntimeException("Federated 
AggregateBinary not supported with the "
                                + "following federated objects: 
"+mo1.isFederated()+":"+mo1.getFedMapping()
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 845f8a4..6c0e3ba 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -78,7 +78,7 @@ public class FEDInstructionUtils {
                        if( instruction.input1.isMatrix() && 
instruction.input2.isMatrix() ) {
                                MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
                                MatrixObject mo2 = 
ec.getMatrixObject(instruction.input2);
-                               if (mo1.isFederated(FType.ROW) || 
mo2.isFederated(FType.ROW)) {
+                               if (mo1.isFederated(FType.ROW) || 
mo2.isFederated(FType.ROW) || mo1.isFederated(FType.COL)) {
                                        fedinst = 
AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
                                }
                        }

Reply via email to