ywcb00 commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r700188882



##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
                        mo1 = ec.getMatrixObject(input3);
                }
 
-               long dim1 = Collections.max(Arrays.asList(dims1), 
Long::compare);
-               boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 
&& dims1.length == Arrays.stream(dims1).distinct().count();
+               // static non-partitioned output dimension (same for all 
federated partitions)
+               long staticDim = Collections.max(Arrays.asList(dims1), 
Long::compare);
+               boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-               processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, 
fedOutput, dims1, dims2);
+               processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, 
fedOutput, staticDim, dims2);
        }
 
+       /**
+        * Broadcast, execute, and finalize the federated instruction according 
to
+        * the specified inputs.
+        *
+        * @param ec execution context
+        * @param mo1 input matrix object 1
+        * @param mo2 input matrix object 2
+        * @param mo3 input matrix object 3 or null
+        * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+        * @param reversedWeights boolean indicating if inputs mo1 and mo3 are 
reversed
+        * @param fedOutput boolean indicating if output can be kept federated
+        * @param staticDim static non-partitioned dimension of the output
+        * @param dims2 dimensions of the partial outputs along the federated 
partitioning
+        */
        private void processRequest(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2, MatrixObject mo3,
-               boolean reversed, boolean reversedWeights, boolean fedOutput, 
Long[] dims1, Long[] dims2) {
-               Future<FederatedResponse>[] ffr;
+               boolean reversed, boolean reversedWeights, boolean fedOutput, 
long staticDim, Long[] dims2) {
+
+               FederationMap fedMap = mo1.getFedMapping();
+
+               FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+               FederatedRequest[] fr2 = null;
+               FederatedRequest fr3, fr4, fr5;
+               fr3 = fr4 = fr5 = null;
+               Future<FederatedResponse>[] ffr = null;
 
-               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
-               FederatedRequest fr2, fr3;
                if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-               && mo1.getFedMapping().isAligned(mo3.getFedMapping(), 
AlignType.FULL)) { // mo1 and mo3 federated and aligned
+                       && fedMap.isAligned(mo3.getFedMapping(), 
AlignType.FULL)) { // mo1 and mo3 federated and aligned
                        if(!reversed)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fedMap.getID(), 
fr1[0].getID(), mo3.getFedMapping().getID()});
                        else
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr2, fr3);
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fedMap.getID(), mo3.getFedMapping().getID()});
                }
                else if(mo3 == null) {
                        if(!reversed)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
-                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
+                                       new long[] {fedMap.getID(), 
fr1[0].getID()});
                        else
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
-                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID()});
-
-                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr2, fr3);
-
-               } else {
-                       FederatedRequest[] fr4 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
+                                       new long[] {fr1[0].getID(), 
fedMap.getID()});
+               }
+               else {
+                       fr2 = fedMap.broadcastSliced(mo3, false);
                        if(!reversed && !reversedWeights)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fedMap.getID(), 
fr1[0].getID(), fr2[0].getID()});
                        else if(reversed && !reversedWeights)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID(), fr4[0].getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fedMap.getID(), fr2[0].getID()});
                        else
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] {fr1[0].getID(), 
fr4[0].getID(), mo1.getFedMapping().getID()});
-
-                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr4, fr2, fr3);
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fr2[0].getID(), fedMap.getID()});
                }
 
-               if(fedOutput && isFedOutput(ffr, dims1)) {
+               if(fedOutput) {
+                       if(fr2 != null) // broadcasted mo3
+                               fedMap.execute(getTID(), true, fr1, fr2, fr3);
+                       else
+                               fedMap.execute(getTID(), true, fr1, fr3);
+
                        MatrixObject out = ec.getMatrixObject(output);
-                       FederationMap newFedMap = 
modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-                       setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+                       FederationMap newFedMap = 
modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+                               staticDim, dims2, reversed);
+                       setFedOutput(mo1, out, newFedMap, staticDim, dims2, 
reversed);
                } else {
+                       fr4 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+                       fr5 = fedMap.cleanup(getTID(), fr3.getID());
+                       if(fr2 != null) // broadcasted mo3
+                               ffr = fedMap.execute(getTID(), true, fr1, fr2, 
fr3, fr4, fr5);
+                       else
+                               ffr = fedMap.execute(getTID(), true, fr1, fr3, 
fr4, fr5);
+
                        ec.setMatrixOutput(output.getName(), aggResult(ffr));
                }
        }
 
-       boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-               boolean fedOutput = true;
-
-               long fedSize = Collections.max(Arrays.asList(dims1), 
Long::compare) / ffr.length;
-               try {
-                       MatrixBlock curr;
-                       MatrixBlock prev =(MatrixBlock) 
ffr[0].get().getData()[0];
-                       for(int i = 1; i < ffr.length && fedOutput; i++) {
-                               curr = (MatrixBlock) ffr[i].get().getData()[0];
-                               MatrixBlock sliced = curr.slice((int) 
(curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-                               if(curr.getNumColumns() != prev.getNumColumns())
-                                       return false;
-
-                               // no intersection
-                               if(curr.getNumRows() == (i+1) * 
prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-                                       && (curr.getNumRows() - 
sliced.getNumRows()) == i * prev.getNumRows()
-                                       && curr.getNonZeros() - 
sliced.getNonZeros() == 0)
-                                       continue;
-
-                               // check intersect with AND and compare number 
of nnz
-                               MatrixBlock prevExtend = new 
MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-                               prevExtend.copy(0, prev.getNumRows()-1, 0, 
prev.getNumColumns()-1, prev, true);
-
-                               MatrixBlock  intersect = 
curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), 
prevExtend);
-                               if(intersect.getNonZeros() != 0)
-                                       fedOutput = false;
-                               prev = sliced;
-                       }
-               }
-               catch(Exception e) {
-                       e.printStackTrace();
-               }
-               return fedOutput;
-       }
+       /**
+        * Evaluate if the output can be kept federated on the different 
federated
+        * sites or if the output needs to be aggregated on the coordinator, 
based
+        * on the output ranges of mo2.

Review comment:
       I tried to add a little more explanation there - not sure if it is best 
understandable.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to