OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602795376



##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+       private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+               super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, 
opcode, str);
+       }
+
+       public static TernaryFEDInstruction parseInstruction(String str)
+       {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode=parts[0];
+               CPOperand operand1 = new CPOperand(parts[1]);
+               CPOperand operand2 = new CPOperand(parts[2]);
+               CPOperand operand3 = new CPOperand(parts[3]);
+               CPOperand outOperand = new CPOperand(parts[4]);
+               TernaryOperator op = 
InstructionUtils.parseTernaryOperator(opcode);
+               return new TernaryFEDInstruction(op, operand1, operand2, 
operand3, outOperand, opcode, str);
+       }
+
+       @Override
+       public void processInstruction( ExecutionContext ec ) {
+               MatrixObject mo1 = input1.isMatrix() ? 
ec.getMatrixObject(input1.getName()) : null;
+               MatrixObject mo2 = input2.isMatrix() ? 
ec.getMatrixObject(input2.getName()) : null;
+               MatrixObject mo3 = input3.isMatrix() ? 
ec.getMatrixObject(input3.getName()) : null;
+
+               long matrixInputsCount = List.of(mo1, mo2, 
mo3).stream().filter(Objects::nonNull).count();
+
+               if(matrixInputsCount == 3)
+                       processMatrixInput(ec, mo1, mo2, mo3);
+               else if (matrixInputsCount == 1) {
+                       CPOperand in = mo1 == null ? mo2 == null ? input3 : 
input2 : input1;
+                       mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+                       processMatrixScalarInput(ec, mo1, in);
+               } else
+                       process2MatrixScalarInput(ec, mo1, mo2, mo3);
+       }
+
+       private void processMatrixScalarInput(ExecutionContext ec, MatrixObject 
mo1, CPOperand in) {
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[] {in}, new long[] 
{mo1.getFedMapping().getID()});
+               mo1.getFedMapping().execute(getTID(), true, fr1);
+
+               //derive new fed mapping for output
+               MatrixObject out = ec.getMatrixObject(output);
+               out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+       }
+
+       private void process2MatrixScalarInput(ExecutionContext ec, 
MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+               CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+               if(mo1 != null && mo1.isFederated() && mo2 == null) {
+                       mo2 = mo3;
+                       inputArgs = new CPOperand[] {input1, input3};
+               } else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+                       mo1 = mo2;
+                       mo2 = mo3;
+                       inputArgs = new CPOperand[] {input2, input3};
+               } else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+                       mo1 = mo2;
+                       mo2 = ec.getMatrixObject(input1);
+                       inputArgs = new CPOperand[] {input2, input1};
+               } else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+                       mo1 = mo3;
+                       inputArgs = new CPOperand[] {input3, input2};
+               } else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+                       mo1 = mo3;
+                       mo2 = ec.getMatrixObject(input1);
+
+                       inputArgs = new CPOperand[] {input3, input1};
+               }
+
+               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+
+               FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
+                       inputArgs, new long[] {mo1.getFedMapping().getID(), 
fr1[0].getID()});
+
+               FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), 
fr1[0].getID());
+               mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+               //derive new fed mapping for output
+               MatrixObject out = ec.getMatrixObject(output);
+               out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+       }
+
+
+       private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2, MatrixObject mo3) {
+               if(!mo1.isFederated())
+                       if(mo2.isFederated()) {
+                               mo1 = mo2;
+                               mo2 = ec.getMatrixObject(input1);
+                       } else {
+                               mo1 = mo3;
+                               mo3 = ec.getMatrixObject(input1);
+                       }
+
+               FederatedRequest fr3;
+               // all 3 inputs aligned on the one worker
+               if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() 
&& mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && 
mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       > Thanks for the PR!
   > It looks like a good start.
   > 
   > I think we could use some more tests that cover all branches of the 
processing of federated ternary instructions. One of the comments I added is 
regarding whether a part of the code can even be reached, so maybe we could 
think about a test that could cover this part (or if this is not possible, this 
part of the code could be removed).
   > I think it would also be interesting to look at other ternary operations, 
for instance the "+_" and "-_". This is relevant for L2SVM and I have already 
looked at this in a separate branch with a solution that is targeted for this 
single purpose, but my approach is still incomplete, so it is more relevant to 
build this in your TernaryFEDInstruction version.
   
   What do you mean by "+_" and "-_"? 
   Heavy hitter instructions:
     1  m_l2svm        0,137      1
     2  fed_ba+*       0,116     15
     3  fed_fedinit    0,061      1
     4  ba+*           0,022     29
     5  write          0,012      1
     6  rightIndex     0,011      1
     7  rmvar          0,003    459
     8  +*             0,002     35
     9  1-*            0,002     21
    10  createvar      0,002    227
    11  tak+*          0,002     28
    12  *              0,002     91
    13  list           0,001      4
    14  max            0,001     21
    15  tsmm           0,001     42
    16  r'             0,001     30
    17  >              0,001     14
    18  castdts        0,000     56
    19  -              0,000     35
    20  +              0,000     59




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to