This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e00b69  [MINOR] Modified federated BinarySPInstruction handling
6e00b69 is described below

commit 6e00b6912df94d64461be668958833b1cbdcfcdc
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Fri Oct 1 17:15:42 2021 +0200

    [MINOR] Modified federated BinarySPInstruction handling
    
    This change slightly modifies the handling of BinarySPInstruction in 
FEDInstructionUtils.checkAndReplaceSP(...).
    Closes #1408.
---
 .../instructions/fed/FEDInstructionUtils.java       | 21 +++++++++++----------
 ...pmmFEDInstruction.java => MMFEDInstruction.java} | 12 +++++++-----
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 107edac..d410042 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -69,6 +69,7 @@ import 
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorBroadcastSP
 import 
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
 import 
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
@@ -80,6 +81,7 @@ import 
org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
@@ -279,14 +281,7 @@ public class FEDInstructionUtils {
 
        public static Instruction checkAndReplaceSP(Instruction inst, 
ExecutionContext ec) {
                FEDInstruction fedinst = null;
-               if (inst instanceof MapmmSPInstruction) {
-                       MapmmSPInstruction instruction = (MapmmSPInstruction) 
inst;
-                       Data data = ec.getVariable(instruction.input1);
-                       if (data instanceof MatrixObject && ((MatrixObject) 
data).isFederatedExcept(FType.BROADCAST)) {
-                               fedinst = 
MapmmFEDInstruction.parseInstruction(instruction.getInstructionString());
-                       }
-               }
-               else if(inst instanceof CastSPInstruction){
+               if(inst instanceof CastSPInstruction){
                        CastSPInstruction ins = (CastSPInstruction) inst;
                        
if((ins.getOpcode().equalsIgnoreCase(UnaryCP.CAST_AS_FRAME_OPCODE) || 
ins.getOpcode().equalsIgnoreCase(UnaryCP.CAST_AS_MATRIX_OPCODE))
                                && ins.input1.isMatrix() && 
ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)){
@@ -377,8 +372,14 @@ public class FEDInstructionUtils {
                }
                else if (inst instanceof BinarySPInstruction) {
                        BinarySPInstruction instruction = (BinarySPInstruction) 
inst;
-
-                       if(inst instanceof QuantilePickSPInstruction) {
+                       if (inst instanceof MapmmSPInstruction || inst 
instanceof CpmmSPInstruction || inst instanceof RmmSPInstruction) {
+                               Data data = ec.getVariable(instruction.input1);
+                               if (data instanceof MatrixObject && 
((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
+                                       fedinst = 
MMFEDInstruction.parseInstruction(instruction.getInstructionString());
+                               }
+                       }
+                       else
+                               if(inst instanceof QuantilePickSPInstruction) {
                                QuantilePickSPInstruction qinstruction = 
(QuantilePickSPInstruction) inst;
                                Data data = ec.getVariable(qinstruction.input1);
                                if(data instanceof MatrixObject && 
((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
 b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
similarity index 96%
rename from 
src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
rename to 
src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
index 0d09411..d680f51 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
@@ -21,9 +21,11 @@ package org.apache.sysds.runtime.instructions.fed;
 
 import java.util.concurrent.Future;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.lops.MapMult;
+import org.apache.sysds.lops.PMMJ;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -41,18 +43,18 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
-public class MapmmFEDInstruction extends BinaryFEDInstruction
+public class MMFEDInstruction extends BinaryFEDInstruction
 {
-       private MapmmFEDInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, MapMult.CacheType type,
+       private MMFEDInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, MapMult.CacheType type,
                boolean outputEmpty, AggBinaryOp.SparkAggType aggtype, String 
opcode, String istr) {
                super(FEDType.MAPMM, op, in1, in2, out, opcode, istr);
        }
 
-       public static MapmmFEDInstruction parseInstruction( String str ) {
+       public static MMFEDInstruction parseInstruction( String str ) {
                String parts[] = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
 
-               if(!opcode.equalsIgnoreCase(MapMult.OPCODE))
+               if(!ArrayUtils.contains(new String[] {MapMult.OPCODE, 
PMMJ.OPCODE, "cpmm", "rmm"}, opcode))
                        throw new 
DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + 
opcode);
 
                CPOperand in1 = new CPOperand(parts[1]);
@@ -63,7 +65,7 @@ public class MapmmFEDInstruction extends BinaryFEDInstruction
                AggBinaryOp.SparkAggType aggtype = 
AggBinaryOp.SparkAggType.valueOf(parts[6]);
 
                AggregateBinaryOperator aggbin = 
InstructionUtils.getMatMultOperator(1);
-               return new MapmmFEDInstruction(aggbin, in1, in2, out, type, 
outputEmpty, aggtype, opcode, str);
+               return new MMFEDInstruction(aggbin, in1, in2, out, type, 
outputEmpty, aggtype, opcode, str);
        }
 
        public void processInstruction(ExecutionContext ec) {

Reply via email to