This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 71116ac  [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
71116ac is described below

commit 71116aca666633859284741878ef16248e1f32dd
Author: baunsgaard <[email protected]>
AuthorDate: Mon Dec 13 17:58:26 2021 +0100

    [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
    
    This PR cleans AggregateBinaryCPInstruction to isolate Compressed
    instructions, and transposed instruction.
    A future todo is still to add the rewrite inside the transposed part,
    to optimize the multiply if one side is cheap to transpose.
    
    Closes #1482
---
 .../cp/AggregateBinaryCPInstruction.java           | 87 ++++++++++++++++------
 1 file changed, 64 insertions(+), 23 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
index 6981877..934ca7f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
@@ -22,22 +22,23 @@ package org.apache.sysds.runtime.instructions.cp;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 
 public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
        // private static final Log LOG = 
LogFactory.getLog(AggregateBinaryCPInstruction.class.getName());
 
-       public boolean transposeLeft;
-       public boolean transposeRight;
+       final public boolean transposeLeft;
+       final public boolean transposeRight;
 
        private AggregateBinaryCPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, String opcode,
                String istr) {
                super(CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
+               transposeLeft = false;
+               transposeRight = false;
        }
 
        private AggregateBinaryCPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, String opcode,
@@ -61,17 +62,30 @@ public class AggregateBinaryCPInstruction extends 
BinaryCPInstruction {
                CPOperand out = new CPOperand(parts[3]);
                int k = Integer.parseInt(parts[4]);
                AggregateBinaryOperator aggbin = 
InstructionUtils.getMatMultOperator(k);
-               if ( numFields == 6 ){
+               if(numFields == 6) {
                        boolean isLeftTransposed = 
Boolean.parseBoolean(parts[5]);
                        boolean isRightTransposed = 
Boolean.parseBoolean(parts[6]);
                        return new AggregateBinaryCPInstruction(aggbin, in1, 
in2, out, opcode, str, isLeftTransposed,
                                isRightTransposed);
                }
-               else return new AggregateBinaryCPInstruction(aggbin, in1, in2, 
out, opcode, str);
+               else
+                       return new AggregateBinaryCPInstruction(aggbin, in1, 
in2, out, opcode, str);
        }
 
        @Override
        public void processInstruction(ExecutionContext ec) {
+               // check compressed inputs
+               final boolean comp1 = 
ec.getMatrixObject(input1.getName()).isCompressed();
+               final boolean comp2 = 
ec.getMatrixObject(input2.getName()).isCompressed();
+               if(comp1 || comp2)
+                       processCompressedAggregateBinary(ec, comp1, comp2);
+               else if(transposeLeft || transposeRight)
+                       processTransposedFusedAggregateBinary(ec);
+               else
+                       precessNormal(ec);
+       }
+
+       private void precessNormal(ExecutionContext ec) {
                // get inputs
                MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
                MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
@@ -80,26 +94,53 @@ public class AggregateBinaryCPInstruction extends 
BinaryCPInstruction {
                AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
                MatrixBlock ret;
 
-               if(matBlock1 instanceof CompressedMatrixBlock) {
-                       CompressedMatrixBlock main = (CompressedMatrixBlock) 
matBlock1;
-                       ret = main.aggregateBinaryOperations(matBlock1, 
matBlock2, new MatrixBlock(), ab_op, transposeLeft, transposeRight);
+               ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, 
new MatrixBlock(), ab_op);
+
+               // release inputs/outputs
+               ec.releaseMatrixInput(input1.getName());
+               ec.releaseMatrixInput(input2.getName());
+               ec.setMatrixOutput(output.getName(), ret);
+       }
+
+       private void processTransposedFusedAggregateBinary(ExecutionContext ec) 
{
+               MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+               MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+               // compute matrix multiplication
+               AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
+               MatrixBlock ret;
+
+               // TODO: Use rewrite rule here t(x) %*% y -> t(t(y) %*% x)
+               if(transposeLeft) {
+                       matBlock1 = LibMatrixReorg.transpose(matBlock1, 
ab_op.getNumThreads());
+                       ec.releaseMatrixInput(input1.getName());
                }
-               else if(matBlock2 instanceof CompressedMatrixBlock) {
-                       CompressedMatrixBlock main = (CompressedMatrixBlock) 
matBlock2;
-                       ret = main.aggregateBinaryOperations(matBlock1, 
matBlock2, new MatrixBlock(), ab_op, transposeLeft, transposeRight);
+               if(transposeRight) {
+                       matBlock2 = LibMatrixReorg.transpose(matBlock2, 
ab_op.getNumThreads());
+                       ec.releaseMatrixInput(input2.getName());
+               }
+
+               ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, 
new MatrixBlock(), ab_op);
+               ec.releaseMatrixInput(input1.getName());
+               ec.releaseMatrixInput(input2.getName());
+               ec.setMatrixOutput(output.getName(), ret);
+       }
+
+       private void processCompressedAggregateBinary(ExecutionContext ec, 
boolean c1, boolean c2) {
+               MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+               MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+               // compute matrix multiplication
+               AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
+               MatrixBlock ret;
+
+               if(c1) {
+                       CompressedMatrixBlock main = (CompressedMatrixBlock) 
matBlock1;
+                       ret = main.aggregateBinaryOperations(matBlock1, 
matBlock2, new MatrixBlock(), ab_op, transposeLeft,
+                               transposeRight);
                }
                else {
-                       // todo move rewrite rule here. to do 
-                       // t(x) %*% y -> t(t(y) %*% x)
-                       if(transposeLeft){
-                               ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
-                               matBlock1 = matBlock1.reorgOperations(r_op, new 
MatrixBlock(), 0, 0, 0);
-                       }
-                       if(transposeRight){
-                               ReorgOperator r_op = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
-                               matBlock2 = matBlock2.reorgOperations(r_op, new 
MatrixBlock(), 0, 0, 0);
-                       }
-                       ret = matBlock1.aggregateBinaryOperations(matBlock1, 
matBlock2, new MatrixBlock(), ab_op);
+                       CompressedMatrixBlock main = (CompressedMatrixBlock) 
matBlock2;
+                       ret = main.aggregateBinaryOperations(matBlock1, 
matBlock2, new MatrixBlock(), ab_op, transposeLeft,
+                               transposeRight);
                }
 
                // release inputs/outputs

Reply via email to