This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 6e00b69 [MINOR] Modified federated BinarySPInstruction handling
6e00b69 is described below
commit 6e00b6912df94d64461be668958833b1cbdcfcdc
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Fri Oct 1 17:15:42 2021 +0200
[MINOR] Modified federated BinarySPInstruction handling
This change slightly modifies the handling of BinarySPInstruction in
FEDInstructionUtils.checkAndReplaceSP(...).
Closes #1408.
---
.../instructions/fed/FEDInstructionUtils.java | 21 +++++++++++----------
...pmmFEDInstruction.java => MMFEDInstruction.java} | 12 +++++++-----
2 files changed, 18 insertions(+), 15 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 107edac..d410042 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -69,6 +69,7 @@ import
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorBroadcastSP
import
org.apache.sysds.runtime.instructions.spark.BinaryTensorTensorSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
import
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
@@ -80,6 +81,7 @@ import
org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
@@ -279,14 +281,7 @@ public class FEDInstructionUtils {
public static Instruction checkAndReplaceSP(Instruction inst,
ExecutionContext ec) {
FEDInstruction fedinst = null;
- if (inst instanceof MapmmSPInstruction) {
- MapmmSPInstruction instruction = (MapmmSPInstruction)
inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject)
data).isFederatedExcept(FType.BROADCAST)) {
- fedinst =
MapmmFEDInstruction.parseInstruction(instruction.getInstructionString());
- }
- }
- else if(inst instanceof CastSPInstruction){
+ if(inst instanceof CastSPInstruction){
CastSPInstruction ins = (CastSPInstruction) inst;
if((ins.getOpcode().equalsIgnoreCase(UnaryCP.CAST_AS_FRAME_OPCODE) ||
ins.getOpcode().equalsIgnoreCase(UnaryCP.CAST_AS_MATRIX_OPCODE))
&& ins.input1.isMatrix() &&
ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)){
@@ -377,8 +372,14 @@ public class FEDInstructionUtils {
}
else if (inst instanceof BinarySPInstruction) {
BinarySPInstruction instruction = (BinarySPInstruction)
inst;
-
- if(inst instanceof QuantilePickSPInstruction) {
+ if (inst instanceof MapmmSPInstruction || inst
instanceof CpmmSPInstruction || inst instanceof RmmSPInstruction) {
+ Data data = ec.getVariable(instruction.input1);
+ if (data instanceof MatrixObject &&
((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
+ fedinst =
MMFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ }
+ else
+ if(inst instanceof QuantilePickSPInstruction) {
QuantilePickSPInstruction qinstruction =
(QuantilePickSPInstruction) inst;
Data data = ec.getVariable(qinstruction.input1);
if(data instanceof MatrixObject &&
((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
similarity index 96%
rename from
src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
rename to
src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
index 0d09411..d680f51 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MapmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.java
@@ -21,9 +21,11 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.concurrent.Future;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.MapMult;
+import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -41,18 +43,18 @@ import
org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-public class MapmmFEDInstruction extends BinaryFEDInstruction
+public class MMFEDInstruction extends BinaryFEDInstruction
{
- private MapmmFEDInstruction(Operator op, CPOperand in1, CPOperand in2,
CPOperand out, MapMult.CacheType type,
+ private MMFEDInstruction(Operator op, CPOperand in1, CPOperand in2,
CPOperand out, MapMult.CacheType type,
boolean outputEmpty, AggBinaryOp.SparkAggType aggtype, String
opcode, String istr) {
super(FEDType.MAPMM, op, in1, in2, out, opcode, istr);
}
- public static MapmmFEDInstruction parseInstruction( String str ) {
+ public static MMFEDInstruction parseInstruction( String str ) {
String parts[] =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
- if(!opcode.equalsIgnoreCase(MapMult.OPCODE))
+ if(!ArrayUtils.contains(new String[] {MapMult.OPCODE,
PMMJ.OPCODE, "cpmm", "rmm"}, opcode))
throw new
DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " +
opcode);
CPOperand in1 = new CPOperand(parts[1]);
@@ -63,7 +65,7 @@ public class MapmmFEDInstruction extends BinaryFEDInstruction
AggBinaryOp.SparkAggType aggtype =
AggBinaryOp.SparkAggType.valueOf(parts[6]);
AggregateBinaryOperator aggbin =
InstructionUtils.getMatMultOperator(1);
- return new MapmmFEDInstruction(aggbin, in1, in2, out, type,
outputEmpty, aggtype, opcode, str);
+ return new MMFEDInstruction(aggbin, in1, in2, out, type,
outputEmpty, aggtype, opcode, str);
}
public void processInstruction(ExecutionContext ec) {