This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1329f3db21 [SYSTEMDS-3525] Binary Inplace Operations
1329f3db21 is described below

commit 1329f3db21b21cbf470ebdef63ca61af1e0a64a8
Author: baunsgaard <[email protected]>
AuthorDate: Fri Apr 21 12:28:41 2023 +0200

    [SYSTEMDS-3525] Binary Inplace Operations
    
    This commit initialize the inplace logic for Binary operations.
    Initially this is only used in a very specific case of division by a vector
    that does not contain NaN or zero and the input is not used by any other
    operator.
    
    Additionally this commit adds a parameterized test that verify equivalent
    behavior of the inplace operations and the normal operations.
    
    Closes #1808
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     |  45 ++--
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  93 ++++++--
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  |  12 +
 src/main/java/org/apache/sysds/lops/Binary.java    |  14 +-
 .../sysds/runtime/instructions/Instruction.java    |   2 +-
 .../instructions/cp/BinaryCPInstruction.java       |   2 +-
 .../cp/BinaryMatrixMatrixCPInstruction.java        |  62 +++--
 .../runtime/matrix/data/LibMatrixBincell.java      | 180 +++++++++++++--
 .../runtime/matrix/operators/BinaryOperator.java   | 159 ++++++++++++-
 .../matrix/BinaryOperationInPlaceTest.java         | 251 ++++++++++++++++++++-
 .../BinaryOperationInPlaceTestParameterized.java   | 190 ++++++++++++++++
 11 files changed, 921 insertions(+), 89 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 468caa707e..f60ee6dd15 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -337,6 +337,28 @@ public class AggUnaryOp extends MultiThreadedHop
        }
        
 
+       private boolean inputAlreadySpark(){
+               return (!(getInput(0) instanceof DataOp)  //input is not 
checkpoint
+               && getInput(0).optFindExecType() == ExecType.SPARK);
+       }
+
+       private boolean inputOnlyRDD(){
+               return (getInput(0) instanceof DataOp && 
((DataOp)getInput(0)).hasOnlyRDD());
+       } 
+
+       private boolean onlyOneParent(){
+               return getInput(0).getParent().size()==1;
+       }
+
+       private boolean allParentsSpark(){
+               return getInput(0).getParent().stream().filter(h -> h != this)
+                                       .allMatch(h -> h.optFindExecType(false) 
== ExecType.SPARK);
+       }
+
+       private boolean inputDoesNotRequireAggregation(){
+               return !requiresAggregation(getInput(0), _direction);
+       }
+
        @Override
        protected ExecType optFindExecType(boolean transitive) {
                
@@ -351,17 +373,14 @@ public class AggUnaryOp extends MultiThreadedHop
                }
                else
                {
-                       if ( OptimizerUtils.isMemoryBasedOptLevel() ) 
-                       {
+                       if ( OptimizerUtils.isMemoryBasedOptLevel()) {
                                _etype = findExecTypeByMemEstimate();
                        }
                        // Choose CP, if the input dimensions are below 
threshold or if the input is a vector
-                       else if ( getInput().get(0).areDimsBelowThreshold() || 
getInput().get(0).isVector() )
-                       {
+                       else if(getInput().get(0).areDimsBelowThreshold() || 
getInput().get(0).isVector()) {
                                _etype = ExecType.CP;
                        }
-                       else 
-                       {
+                       else {
                                _etype = REMOTE;
                        }
                        
@@ -372,14 +391,12 @@ public class AggUnaryOp extends MultiThreadedHop
                //spark-specific decision refinement (execute unary aggregate 
w/ spark input and 
                //single parent also in spark because it's likely cheap and 
reduces data transfer)
                //we also allow multiple parents, if all other parents are 
already in Spark mode
-               if( transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP
-                       && ((!(getInput(0) instanceof DataOp)  //input is not 
checkpoint
-                               && getInput(0).optFindExecType() == 
ExecType.SPARK)
-                               || (getInput(0) instanceof DataOp && 
((DataOp)getInput(0)).hasOnlyRDD()))
-                       && (getInput(0).getParent().size()==1 //uagg is only 
parent, or 
-                               || getInput(0).getParent().stream().filter(h -> 
h != this)
-                                       .allMatch(h -> h.optFindExecType(false) 
== ExecType.SPARK)
-                               || !requiresAggregation(getInput(0), 
_direction)) ) //w/o agg
+
+               boolean shouldEvaluateIfSpark =  transitive && _etype == 
ExecType.CP && _etypeForced != ExecType.CP;
+
+               if( shouldEvaluateIfSpark
+                       && (inputAlreadySpark() || inputOnlyRDD())
+                       && (onlyOneParent() || allParentsSpark() || 
inputDoesNotRequireAggregation() ))
                {
                        //pull unary aggregate into spark 
                        _etype = ExecType.SPARK;
diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index 2346eeebfe..04585d7dc4 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.hops;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.AggOp;
 import org.apache.sysds.common.Types.DataType;
@@ -27,6 +29,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.OpOpDnn;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
@@ -52,21 +55,21 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
 
-/* Binary (cell operations): aij + bij
+/** Binary (cell operations): aij + bij
  *             Properties: 
  *                     Symbol: *, -, +, ...
  *                     2 Operands
  *             Semantic: align indices (sort), then perform operation
  */
-
 public class BinaryOp extends MultiThreadedHop {
-       // private static final Log LOG =  
LogFactory.getLog(BinaryOp.class.getName());
+       protected static final Log LOG =  
LogFactory.getLog(BinaryOp.class.getName());
 
        //we use the full remote memory budget (but reduced by sort buffer), 
        public static final double APPEND_MEM_MULTIPLIER = 1.0;
 
        private OpOp2 op;
        private boolean outer = false;
+       private boolean inplace = false;
        
        public static AppendMethod FORCED_APPEND_METHOD = null;
        public static MMBinaryMethod FORCED_BINARY_METHOD = null;
@@ -126,6 +129,10 @@ public class BinaryOp extends MultiThreadedHop {
        public boolean isOuter(){
                return outer;
        }
+
+       public boolean isInplace(){
+               return inplace;
+       }
        
        @Override
        public boolean isGPUEnabled() {
@@ -435,7 +442,7 @@ public class BinaryOp extends MultiThreadedHop {
                        else { //general case
                                tmp = new Binary(getInput(0).constructLops(), 
getInput(1).constructLops(),
                                        op, getDataType(), getValueType(), et,
-                                       
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
+                                       
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), inplace);
                        }
                        
                        setOutputDimensions(tmp);
@@ -477,7 +484,7 @@ public class BinaryOp extends MultiThreadedHop {
                                else
                                        binary = new 
Binary(getInput(0).constructLops(), getInput(1).constructLops(),
                                                op, getDataType(), 
getValueType(), et,
-                                               
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
+                                               
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads), inplace);
 
                                setOutputDimensions(binary);
                                setLineNumbers(binary);
@@ -700,6 +707,44 @@ public class BinaryOp extends MultiThreadedHop {
                return true;
        }
        
+       private static boolean isReplace(Hop h) {
+               return h instanceof ParameterizedBuiltinOp && //
+                       ((ParameterizedBuiltinOp) h).getOp() == 
ParamBuiltinOp.REPLACE;
+       }
+
+       private static boolean isReplaceWithPattern(ParameterizedBuiltinOp h, 
double pattern, double replace) {
+               Hop pat = h.getParameterHop("pattern");
+               Hop rep = h.getParameterHop("replacement");
+               if(pat instanceof LiteralOp && rep instanceof LiteralOp) {
+                       double patOb = ((LiteralOp) pat).getDoubleValue();
+                       double repOb = ((LiteralOp) rep).getDoubleValue();
+                       return ((Double.isNaN(pattern) && Double.isNaN(patOb)) 
// is both NaN
+                               || Double.compare(pattern, patOb) == 0) // Is 
equivalent pattern
+                               && Double.compare(replace, repOb) == 0; // is 
equivalent replace.
+               }
+               return false;
+       }
+
+       private static boolean doesNotContainNanAndInf(Hop p1) {
+               if(isReplace(p1)) {
+                       Hop p2 = p1.getInput().get(0);
+                       if(isReplace(p2)) {
+                               ParameterizedBuiltinOp pp1 = 
(ParameterizedBuiltinOp) p1;
+                               ParameterizedBuiltinOp pp2 = 
(ParameterizedBuiltinOp) p2;
+                               return (isReplaceWithPattern(pp1, Double.NaN, 
1) && isReplaceWithPattern(pp2, 0, 1)) ||
+                                       (isReplaceWithPattern(pp2, Double.NaN, 
1) && isReplaceWithPattern(pp1, 0, 1));
+                       }
+               }
+               return false;
+       }
+
+       private boolean memOfInputIsLessThanBudget() {
+               final double in1Memory = getInput().get(0).getMemEstimate();
+               final double in2Memory = getInput().get(1).getMemEstimate();
+               final double budget = OptimizerUtils.getLocalMemBudget();
+               return in1Memory + in2Memory < budget;
+       }
+
        @Override
        protected ExecType optFindExecType(boolean transitive) {
                
@@ -755,20 +800,34 @@ public class BinaryOp extends MultiThreadedHop {
                        checkAndSetInvalidCPDimsAndSize();
                }
 
-               //spark-specific decision refinement (execute unary scalar w/ 
spark input and 
-               //single parent also in spark because it's likely cheap and 
reduces intermediates)
-               if( transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP && _etypeForced != ExecType.FED
-                       && getDataType().isMatrix() && (dt1.isScalar() || 
dt2.isScalar()) 
-                       && supportsMatrixScalarOperations()                     
     //scalar operations
-                       && !(getInput().get(dt1.isScalar()?1:0) instanceof 
DataOp)   //input is not checkpoint
-                       && 
getInput().get(dt1.isScalar()?1:0).getParent().size()==1  //unary scalar is 
only parent
-                       && 
!HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar()?1:0)) //single 
block triggered exec
-                       && getInput().get(dt1.isScalar()?1:0).optFindExecType() 
== ExecType.SPARK )
-               {
-                       //pull unary scalar operation into spark 
+               //spark-specific decision refinement (execute unary scalar w/ 
spark input and
+               // single parent also in spark because it's likely cheap and 
reduces intermediates)
+               if(transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP && _etypeForced != ExecType.FED &&
+                       getDataType().isMatrix() // output should be a matrix
+                       && (dt1.isScalar() || dt2.isScalar()) // one side 
should be scalar
+                       && supportsMatrixScalarOperations() // scalar operations
+                       && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof 
DataOp) // input is not checkpoint
+                       && getInput().get(dt1.isScalar() ? 1 : 
0).getParent().size() == 1 // unary scalar is only parent
+                       && 
!HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // 
single block triggered exec
+                       && getInput().get(dt1.isScalar() ? 1 : 
0).optFindExecType() == ExecType.SPARK) {
+                       // pull unary scalar operation into spark
                        _etype = ExecType.SPARK;
                }
-               
+
+               if( transitive && _etypeForced != ExecType.SPARK && 
_etypeForced != ExecType.FED && //
+                       getDataType().isMatrix() // Output is a matrix
+                       && op == OpOp2.DIV // Operation is division
+                       && dt1.isMatrix() // Left hand side is a Matrix
+                       // right hand side is a scalar or a vector.
+                       && (dt2.isScalar() || (dt2.isMatrix() & 
getInput().get(1).isVector())) //
+                       && memOfInputIsLessThanBudget() //
+                       && getInput().get(0).getExecType() != ExecType.SPARK // 
Is not already a spark operation
+                       && doesNotContainNanAndInf(getInput().get(1)) // 
Guaranteed not to densify the operation
+               ) {
+                       inplace = true;
+                       _etype = ExecType.CP;
+               }
+
                //ensure cp exec type for single-node operations
                if ( op == OpOp2.SOLVE ) {
                        if (isGPUEnabled())
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 4404579894..59b957ac5b 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -734,6 +734,18 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                        _etype = ExecType.CP;
                }
 
+               // If previous instructions were in spark force aggregating
+               // parameterized operations to be executed in spark
+               if(transitive && _etype == ExecType.CP && _etypeForced != 
ExecType.CP) {
+                       switch(_op) {
+                               case CONTAINS:
+                                       if(getTargetHop().optFindExecType() == 
ExecType.SPARK)
+                                               _etype = ExecType.SPARK;
+                               default:
+                                       // Do not change execution type.
+                       }
+               }
+
                //mark for recompile (forever)
                setRequiresRecompileIfNecessary();
                
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java 
b/src/main/java/org/apache/sysds/lops/Binary.java
index 6949188564..c1c233e4f2 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -30,14 +30,14 @@ import org.apache.sysds.common.Types.ValueType;
 
 
 /**
- * Lop to perform binary operation. Both inputs must be matrices or vectors. 
- * Example - A = B + C, where B and C are matrices or vectors.
+ * Lop to perform binary operation. Both inputs must be matrices, vectors or 
scalars. 
+ * Example - A = B + C.
  */
-
 public class Binary extends Lop 
 {
        private OpOp2 operation;
        private final int _numThreads;
+       private final boolean inplace;
        
        /**
         * Constructor to perform a binary operation.
@@ -55,9 +55,14 @@ public class Binary extends Lop
        }
        
        public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType 
vt, ExecType et, int k) {
+               this(input1, input2, op, dt, vt, et, k, false);
+       }
+
+       public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType 
vt, ExecType et, int k, boolean inplace) {
                super(Lop.Type.Binary, dt, vt);
                init(input1, input2, op, dt, vt, et);
                _numThreads = k;
+               this.inplace = inplace; 
        }
        
        private void init(Lop input1, Lop input2, OpOp2 op, DataType dt, 
ValueType vt, ExecType et)  {
@@ -107,6 +112,9 @@ public class Binary extends Lop
                else if( getExecType() == ExecType.FED )
                        ret = InstructionUtils.concatOperands(ret, 
String.valueOf(_numThreads), _fedOutput.name());
 
+               if (getExecType() == ExecType.CP && inplace)
+                       ret = InstructionUtils.concatOperands(ret, "InPlace");
+
                return ret;
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index e0fbecaaea..3190bad650 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -38,7 +38,7 @@ public abstract class Instruction
                FEDERATED
        }
        
-       private static final Log LOG = 
LogFactory.getLog(Instruction.class.getName());
+       protected static final Log LOG = 
LogFactory.getLog(Instruction.class.getName());
        protected final Operator _optr;
 
        protected Instruction(Operator _optr){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index fddb8301a9..28b8775ebd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -67,7 +67,7 @@ public abstract class BinaryCPInstruction extends 
ComputationCPInstruction {
        
        private static String[] parseBinaryInstruction(String instr, CPOperand 
in1, CPOperand in2, CPOperand out) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
-               InstructionUtils.checkNumFields ( parts, 3, 4, 5 );
+               InstructionUtils.checkNumFields ( parts, 3, 4, 5, 6 );
                in1.split(parts[1]);
                in2.split(parts[2]);
                out.split(parts[3]);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 565210b585..20119ceacd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -23,19 +23,34 @@ import 
org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
+import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class BinaryMatrixMatrixCPInstruction extends BinaryCPInstruction {
 
+       private final boolean inplace;
+
        protected BinaryMatrixMatrixCPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, String opcode,
                String istr) {
                super(CPType.Binary, op, in1, in2, out, opcode, istr);
                if(op instanceof BinaryOperator) {
-                       String[] parts = 
InstructionUtils.getInstructionParts(istr);
-                       ((BinaryOperator) 
op).setNumThreads(Integer.parseInt(parts[parts.length - 1]));
+                       final String[] parts = 
InstructionUtils.getInstructionParts(istr);
+                       if(parts.length == 5) {
+                               ((BinaryOperator) 
op).setNumThreads(Integer.parseInt(parts[parts.length - 1]));
+                               inplace = false;
+                       }
+                       else {
+                               ((BinaryOperator) 
op).setNumThreads(Integer.parseInt(parts[parts.length - 2]));
+                               if(parts[parts.length - 1].equals("InPlace"))
+                                       inplace = true;
+                               else
+                                       inplace = false;
+                       }
                }
+               else
+                       inplace = false;
        }
 
        @Override
@@ -49,25 +64,36 @@ public class BinaryMatrixMatrixCPInstruction extends 
BinaryCPInstruction {
 
                MatrixBlock retBlock;
 
-               if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) 
&& !compressedLeft && !compressedRight)
-                       retBlock = 
LibCommonsMath.matrixMatrixOperations(inBlock1, inBlock2, getOpcode());
+               if(inplace && (compressedLeft || compressedRight))
+                       LOG.error("Not supporting inplace compressed binary 
operations yet");
+
+               if(inplace && !(compressedLeft || compressedRight)) {
+                       inBlock1 = LibMatrixBincell.bincellOpInPlace(inBlock1, 
inBlock2, (BinaryOperator) _optr);
+                       // Release the memory occupied by input matrices
+                       ec.releaseMatrixInput(input1.getName(), 
input2.getName());
+                       // Cleanup the inplace metadata input.
+                       ec.removeVariable(input1.getName());
+                       retBlock = inBlock1;
+               }
                else {
-                       // Perform computation using input matrices, and 
produce the result matrix
-                       BinaryOperator bop = (BinaryOperator) _optr;
-                       if(!compressedLeft && compressedRight)
-                               retBlock = ((CompressedMatrixBlock) 
inBlock2).binaryOperationsLeft(bop, inBlock1, new MatrixBlock());
-                       else
-                               retBlock = inBlock1.binaryOperations(bop, 
inBlock2, new MatrixBlock());
+                       
if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) && 
!compressedLeft && !compressedRight)
+                               retBlock = 
LibCommonsMath.matrixMatrixOperations(inBlock1, inBlock2, getOpcode());
+                       else {
+                               // Perform computation using input matrices, 
and produce the result matrix
+                               BinaryOperator bop = (BinaryOperator) _optr;
+                               if(!compressedLeft && compressedRight)
+                                       retBlock = ((CompressedMatrixBlock) 
inBlock2).binaryOperationsLeft(bop, inBlock1, new MatrixBlock());
+                               else
+                                       retBlock = 
inBlock1.binaryOperations(bop, inBlock2, new MatrixBlock());
+                       }
+                       // Release the memory occupied by input matrices
+                       ec.releaseMatrixInput(input1.getName(), 
input2.getName());
+                       // Ensure right dense/sparse output representation 
(guarded by released input memory)
+                       if(checkGuardedRepresentationChange(inBlock1, inBlock2, 
retBlock))
+                               retBlock.examSparsity();
                }
 
-               // Release the memory occupied by input matrices
-               ec.releaseMatrixInput(input1.getName(), input2.getName());
-
-               // Ensure right dense/sparse output representation (guarded by 
released input memory)
-               if(checkGuardedRepresentationChange(inBlock1, inBlock2, 
retBlock))
-                       retBlock.examSparsity();
-
                // Attach result matrix with MatrixObject associated with 
output_name
                ec.setMatrixOutput(output.getName(), retBlock);
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index 14f83e7de2..33d7bb2da5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -38,6 +38,7 @@ import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlockMCSR;
 import org.apache.sysds.runtime.data.SparseRow;
 import org.apache.sysds.runtime.data.SparseRowVector;
+import org.apache.sysds.runtime.functionobjects.And;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.functionobjects.Divide;
@@ -451,7 +452,7 @@ public class LibMatrixBincell {
                                                        || (clen1 == 1 && rlen2 
== 1 ) );              //VV
 
                if( !isValid ) {
-                       throw new RuntimeException("Block sizes are not matched 
for binary " +
+                       throw new DMLRuntimeException("Block sizes are not 
matched for binary " +
                                        "cell operations: " + rlen1 + "x" + 
clen1 + " vs " + rlen2 + "x" + clen2);
                }
        }
@@ -1589,27 +1590,69 @@ public class LibMatrixBincell {
        }
 
        private static void safeBinaryInPlace(MatrixBlock m1ret, MatrixBlock 
m2, BinaryOperator op) {
-               //early abort on skip and empty 
-               if( (m1ret.isEmpty() && m2.isEmpty() )
-                       || (op.fn instanceof Plus && m2.isEmpty())
-                       || (op.fn instanceof Minus && m2.isEmpty()))
+               // early abort on skip and empty
+               final boolean PoM = op.fn instanceof Plus || op.fn instanceof 
Minus;
+               if((m1ret.isEmpty() && m2.isEmpty()) || (PoM && m2.isEmpty())) {
+                       final boolean isEquals = op.fn instanceof Equals || 
op.fn instanceof LessThanEquals ||
+                               op.fn instanceof GreaterThanEquals;
+
+                       if(isEquals)
+                               m1ret.reset(m1ret.rlen, m1ret.clen, 1);
                        return; // skip entire empty block
-               //special case: start aggregation
-               else if( op.fn instanceof Plus && m1ret.isEmpty() ){
+               }
+               else if(m2.isEmpty() && // empty other side
+                       (op.fn instanceof Multiply || (op.fn instanceof And))) {
+                       m1ret.reset(m1ret.rlen, m1ret.clen, 0);
+                       return;
+               }
+
+               if(m1ret.getNumRows() > 1 && m2.getNumRows() == 1)
+                       safeBinaryInPlaceMatrixRowVector(m1ret, m2, op);
+               else
+                       safeBinaryInPlaceMatrixMatrix(m1ret, m2, op);
+       }
+
+       private static void safeBinaryInPlaceMatrixRowVector(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
+               if(m1ret.sparse) {
+                       if(m2.isInSparseFormat() && !op.isRowSafeLeft(m2))
+                               throw new DMLRuntimeException("Invalid row 
safety of inplace row operation: " + op);
+                       else if(m2.isEmpty())
+                               safeBinaryInPlaceSparseConst(m1ret, 0.0, op);
+                       else if(m2.sparse)
+                               throw new NotImplementedException("Not made 
sparse vector inplace to sparse " + op);
+                       else
+                               safeBinaryInPlaceSparseVector(m1ret, m2, op);
+               }
+               else {
+                       if(!m1ret.isAllocated()) {
+                               LOG.warn("Allocating inplace output block");
+                               m1ret.allocateBlock();
+                       }
+
+                       if(m2.isEmpty())
+                               safeBinaryInPlaceDenseConst(m1ret, 0.0, op);
+                       else if(m2.sparse)
+                               throw new NotImplementedException("Not made 
sparse vector inplace to dense " + op);
+                       else
+                               safeBinaryInPlaceDenseVector(m1ret, m2, op);
+               }
+       }
+
+       private static void safeBinaryInPlaceMatrixMatrix(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
+               if(op.fn instanceof Plus && m1ret.isEmpty()) {
                        m1ret.copy(m2);
-                       return; 
+                       return;
                }
-               
                if(m1ret.sparse && m2.sparse)
                        safeBinaryInPlaceSparse(m1ret, m2, op);
                else if(!m1ret.sparse && !m2.sparse)
                        safeBinaryInPlaceDense(m1ret, m2, op);
                else if(m2.sparse && (op.fn instanceof Plus || op.fn instanceof 
Minus))
                        safeBinaryInPlaceDenseSparseAdd(m1ret, m2, op);
-               else //GENERIC
+               else
                        safeBinaryInPlaceGeneric(m1ret, m2, op);
        }
-       
+
        private static void safeBinaryInPlaceSparse(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
                //allocation and preparation (note: for correctness and 
performance, this 
                //implementation requires the lhs in MCSR and hence we 
explicitly convert)
@@ -1625,6 +1668,9 @@ public class LibMatrixBincell {
                final int rlen = m1ret.rlen;
                final int clen = m1ret.clen;
                
+               final boolean compact = (op.fn instanceof Multiply || op.fn 
instanceof And );
+               final boolean mcsr = c instanceof SparseBlockMCSR;
+
                if( c!=null && b!=null ) {
                        for(int r=0; r<rlen; r++) {
                                if(c.isEmpty(r) && b.isEmpty(r))
@@ -1645,6 +1691,8 @@ public class LibMatrixBincell {
                                        mergeForSparseBinary(op, old.values(), 
old.indexes(), 0, 
                                                old.size(), b.values(r), 
b.indexes(r), b.pos(r), b.size(r), r, m1ret);
                                }
+                               if(compact && mcsr && !c.isEmpty(r))
+                                       c.get(r).compact();
                        }
                }
                else if( c == null ) { //lhs empty
@@ -1660,31 +1708,81 @@ public class LibMatrixBincell {
                                if( c.isEmpty(r) ) continue;
                                zeroRightForSparseBinary(op, r, m1ret);
                        }
+
                }
                
                m1ret.recomputeNonZeros();
        }
 
+       private static void safeBinaryInPlaceSparseConst(MatrixBlock m1ret, 
double m2, BinaryOperator op) {
+               if(m1ret.isEmpty()) // early termination... it is empty and 
safe... just stop.
+                       return;
+               final SparseBlock sb = m1ret.getSparseBlock();
+               final int rlen = m1ret.rlen;
+               for(int r = 0; r < rlen; r++) {
+                       if(sb.isEmpty(r))
+                               continue;
+                       final int apos = sb.pos(r);
+                       final int alen = sb.size(r) + apos;
+                       final double[] avals = sb.values(r);
+                       for(int k = apos; k < alen; k++)
+                               avals[k] = op.fn.execute(avals[k], m2);
+               }
+       }
+
+       private static void safeBinaryInPlaceSparseVector(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
+
+               if(m1ret.isEmpty()) // early termination... it is empty and 
safe... just stop.
+                       return;
+               final SparseBlock sb = m1ret.getSparseBlock();
+               final double[] b = m2.getDenseBlockValues();
+               final int rlen = m1ret.rlen;
+
+               final boolean compact = (op.fn instanceof Multiply || op.fn 
instanceof And) //
+                       && op.isIntroducingZerosRight(m2);
+               final boolean mcsr = sb instanceof SparseBlockMCSR;
+               for(int r = 0; r < rlen; r++) {
+                       if(sb.isEmpty(r))
+                               continue;
+                       final int apos = sb.pos(r);
+                       final int alen = sb.size(r) + apos;
+                       final double[] avals = sb.values(r);
+                       final int[] aix = sb.indexes(r);
+                       for(int k = apos; k < alen; k++)
+                               avals[k] = op.fn.execute(avals[k], b[aix[k]]);
+
+                       if(compact && mcsr) {
+                               SparseRow sr = sb.get(r);
+                               if(sr instanceof SparseRowVector)
+                                       ((SparseRowVector) 
sr).setSize(avals.length);
+                               sr.compact();
+                       }
+               }
+               if(compact && !mcsr) {
+                       ((SparseBlockCSR) sb).compact();
+               }
+       }
+
        private static void safeBinaryInPlaceDense(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
-               //prepare outputs
+               // prepare outputs
                m1ret.allocateDenseBlock();
                DenseBlock a = m1ret.getDenseBlock();
                DenseBlock b = m2.getDenseBlock();
                final int rlen = m1ret.rlen;
                final int clen = m1ret.clen;
-               
+
                long lnnz = 0;
-               if( m2.isEmptyBlock(false) ) {
-                       for(int r=0; r<rlen; r++) {
+               if(m2.isEmptyBlock(false)) {
+                       for(int r = 0; r < rlen; r++) {
                                double[] avals = a.values(r);
-                               for(int c=0, ix=a.pos(r); c<clen; c++, ix++) {
+                               for(int c = 0, ix = a.pos(r); c < clen; c++, 
ix++) {
                                        double tmp = op.fn.execute(avals[ix], 
0);
-                                       lnnz += (avals[ix] = tmp) != 0 ? 1: 0;
+                                       lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
                                }
                        }
                }
-               else if( op.fn instanceof Plus ) {
-                       for(int r=0; r<rlen; r++) {
+               else if(op.fn instanceof Plus) {
+                       for(int r = 0; r < rlen; r++) {
                                int aix = a.pos(r), bix = b.pos(r);
                                double[] avals = a.values(r), bvals = 
b.values(r);
                                LibMatrixMult.vectAdd(bvals, avals, bix, aix, 
clen);
@@ -1692,15 +1790,53 @@ public class LibMatrixBincell {
                        }
                }
                else {
-                       for(int r=0; r<rlen; r++) {
+                       for(int r = 0; r < rlen; r++) {
                                double[] avals = a.values(r), bvals = 
b.values(r);
-                               for(int c=0, ix=a.pos(r); c<clen; c++, ix++) {
+                               for(int c = 0, ix = a.pos(r); c < clen; c++, 
ix++) {
                                        double tmp = op.fn.execute(avals[ix], 
bvals[ix]);
                                        lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
                                }
                        }
                }
-               
+
+               m1ret.setNonZeros(lnnz);
+       }
+
+       private static void safeBinaryInPlaceDenseConst(MatrixBlock m1ret, 
double m2, BinaryOperator op) {
+               // prepare outputs
+               m1ret.allocateDenseBlock();
+               DenseBlock a = m1ret.getDenseBlock();
+               final int rlen = m1ret.rlen;
+               final int clen = m1ret.clen;
+
+               long lnnz = 0;
+               for(int r = 0; r < rlen; r++) {
+                       double[] avals = a.values(r);
+                       for(int c = 0, ix = a.pos(r); c < clen; c++, ix++) {
+                               double tmp = op.fn.execute(avals[ix], m2);
+                               lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
+                       }
+               }
+
+               m1ret.setNonZeros(lnnz);
+       }
+
+       private static void safeBinaryInPlaceDenseVector(MatrixBlock m1ret, 
MatrixBlock m2, BinaryOperator op) {
+               // prepare outputs
+               m1ret.allocateDenseBlock();
+               DenseBlock a = m1ret.getDenseBlock();
+               double[] b = m2.getDenseBlockValues();
+               final int rlen = m1ret.rlen;
+               final int clen = m1ret.clen;
+
+               long lnnz = 0;
+               for(int r = 0; r < rlen; r++) {
+                       double[] avals = a.values(r);
+                       for(int c = 0, ix = a.pos(r); c < clen; c++, ix++) {
+                               double tmp = op.fn.execute(avals[ix], b[ix % 
clen]);
+                               lnnz += (avals[ix] = tmp) != 0 ? 1 : 0;
+                       }
+               }
                m1ret.setNonZeros(lnnz);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index f1b98cdb8b..992ac5bee9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -21,6 +21,7 @@
 package org.apache.sysds.runtime.matrix.operators;
 
 import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.And;
 import org.apache.sysds.runtime.functionobjects.BitwAnd;
 import org.apache.sysds.runtime.functionobjects.BitwOr;
@@ -29,6 +30,7 @@ import org.apache.sysds.runtime.functionobjects.BitwShiftR;
 import org.apache.sysds.runtime.functionobjects.BitwXor;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.functionobjects.Divide;
 import org.apache.sysds.runtime.functionobjects.Equals;
 import org.apache.sysds.runtime.functionobjects.GreaterThan;
@@ -50,12 +52,22 @@ import org.apache.sysds.runtime.functionobjects.Power;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.functionobjects.Xor;
 
+/**
+ * BinaryOperator class for operations that have two inputs.
+ * 
+ * For instance
+ * 
+ * <pre>
+ *BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
+ *double r = op.execute(5.0, 8.2)
+ * </pre>
+ */
 public class BinaryOperator extends MultiThreadedOperator {
        private static final long serialVersionUID = -2547950181558989209L;
 
        public final ValueFunction fn;
        public final boolean commutative;
-       
+
        public BinaryOperator(ValueFunction p) {
                this(p, 1);
        }
@@ -120,6 +132,12 @@ public class BinaryOperator extends MultiThreadedOperator {
                return commutative;
        }
 
+       /**
+        * Check if the operation returns zeros if the input zero.
+        * 
+        * @param row The values to check
+        * @return If the output is always zero if other value is zero
+        */
        public boolean isRowSafeLeft(double[] row){
                for(double v : row)
                         if(0 !=  fn.execute(v, 0))
@@ -127,6 +145,78 @@ public class BinaryOperator extends MultiThreadedOperator {
                return true;
        }
 
+       /**
+        * Check if the operation returns zeros if the input zero.
+        * 
+        * @param row The values to check
+        * @return If the output is always zero if other value is zero
+        */
+       public boolean isRowSafeLeft(MatrixBlock row) {
+               if(row.isEmpty())
+                       return 0 == fn.execute(0.0, 0.0);
+               else if(row.isInSparseFormat()) {
+                       if(0 != fn.execute(0.0, 0.0))
+                               return false;
+                       SparseBlock sb = row.getSparseBlock();
+                       if(sb.isEmpty(0))
+                               return true;
+                       return isRowSafeLeft(sb.values(0));
+               }
+               else
+                       return isRowSafeLeft(row.getDenseBlockValues());
+       }
+       
+       /**
+        * Check if the operation returns zeros if the input is contained in 
row.
+        * 
+        * @param row The values to check
+        * @return If the output contains zeros
+        */
+       public boolean isIntroducingZerosLeft(MatrixBlock row) {
+               if(row.isEmpty())
+                       return introduceZeroLeft(0.0);
+               else if(row.isInSparseFormat()) {
+                       if(introduceZeroLeft(0.0))
+                               return true;
+                       SparseBlock sb = row.getSparseBlock();
+                       if(sb.isEmpty(0))
+                               return false;
+                       return isIntroducingZerosLeft(sb.values(0));
+               }
+               else
+                       return 
isIntroducingZerosLeft(row.getDenseBlockValues());
+       }
+
+       /**
+        * Check if the operation returns zeros if the input is contained in 
row.
+        * 
+        * @param row The values to check
+        * @return If the output contains zeros
+        */
+       public boolean isIntroducingZerosLeft(double[] row) {
+               for(double v : row)
+                       if(introduceZeroLeft(v))
+                               return true;
+               return false;
+       }
+
+       /**
+        * Check if zero is returned at arbitrary input. The verification is 
done via two different values that hopefully do
+        * not return 0 in both instances unless the operation really have a 
tendency to return zero.
+        * 
+        * @param v The value to check if returns zero
+        * @return if the evaluation return zero
+        */
+       private boolean introduceZeroLeft(double v) {
+               return 0 == fn.execute(v, 11.42) && 0 == fn.execute(v, -11.22);
+       }
+
+       /**
+        * Check if the operation returns zeros if the input zero.
+        * 
+        * @param row The values to check
+        * @return If the output is always zero if other value is zero
+        */
        public boolean isRowSafeRight(double[] row){
                for(double v : row)
                         if(0 !=  fn.execute(0, v))
@@ -134,6 +224,73 @@ public class BinaryOperator extends MultiThreadedOperator {
                return true;
        }
        
+       /**
+        * Check if the operation returns zeros if the input zero.
+        * 
+        * @param row The values to check
+        * @return If the output is always zero if other value is zero
+        */
+       public boolean isRowSafeRight(MatrixBlock row) {
+               if(row.isEmpty())
+                       return 0 == fn.execute(0.0, 0.0);
+               else if(row.isInSparseFormat()) {
+                       if(0 != fn.execute(0.0, 0.0))
+                               return false;
+                       SparseBlock sb = row.getSparseBlock();
+                       if(sb.isEmpty(0))
+                               return true;
+                       return isRowSafeRight(sb.values(0));
+               }
+               else
+                       return isRowSafeRight(row.getDenseBlockValues());
+       }
+
+       /**
+        * Check if the operation returns zeros if the input is contained in 
row.
+        * 
+        * @param row The values to check
+        * @return If the output contains zeros
+        */
+       public boolean isIntroducingZerosRight(MatrixBlock row){
+               if(row.isEmpty())
+                       return  introduceZeroRight(0.0);
+               else if(row.isInSparseFormat()){
+                       if (introduceZeroRight(0.0))
+                               return true;
+                       SparseBlock sb = row.getSparseBlock();
+                       if(sb.isEmpty(0))
+                               return false;
+                       return isIntroducingZerosRight(sb.values(0));   
+               }
+               else 
+                       return 
isIntroducingZerosRight(row.getDenseBlockValues());
+       }
+
+       /**
+        * Check if the operation returns zeros if the input is contained in 
row.
+        * 
+        * @param row The values to check
+        * @return If the output contains zeros
+        */
+       public boolean isIntroducingZerosRight(double[] row){
+               for(double v : row)
+                       if( introduceZeroRight(v))
+                               return true;
+
+               return false;
+       }
+
+       /**
+        * Check if zero is returned at arbitrary input. The verification is 
done via two different values that hopefully do
+        * not return 0 in both instances unless the operation really have a 
tendency to return zero.
+        * 
+        * @param v The value to check if returns zero
+        * @return if the evaluation return zero
+        */
+       private boolean introduceZeroRight(double v) {
+               return 0 == fn.execute(11.42, v) && 0 == fn.execute(-11.22, v);
+       }
+
        @Override
        public String toString() {
                return "BinaryOperator("+fn.getClass().getSimpleName()+")";
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
index 68fcd1be44..bf8258efc6 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTest.java
@@ -19,6 +19,14 @@
 
 package org.apache.sysds.test.component.matrix;
 
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.LessThan;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Or;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -30,36 +38,255 @@ public class BinaryOperationInPlaceTest {
        public void testPlus() {
                MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 1);
                MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 2);
-               execute(m1,m2);
+               executePlus(m1, m2);
        }
 
        @Test
        public void testPlus_emptyInplace() {
-               MatrixBlock m1 = new MatrixBlock(10,10,false);
+               MatrixBlock m1 = new MatrixBlock(10, 10, false);
                MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 2);
-               execute(m1,m2);
+               executePlus(m1, m2);
        }
 
-       @Test 
-       public void testPlus_emptyOther(){
+       @Test
+       public void testPlus_emptyOther() {
                MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 1);
-               MatrixBlock m2 = new MatrixBlock(10,10,false);
-               execute(m1,m2);
+               MatrixBlock m2 = new MatrixBlock(10, 10, false);
+               executePlus(m1, m2);
        }
 
-       @Test 
+       @Test
        public void testPlus_emptyInplace_butAllocatedDense() {
-               MatrixBlock m1 = new MatrixBlock(10,10,false);
+               MatrixBlock m1 = new MatrixBlock(10, 10, false);
                m1.allocateDenseBlock();
                MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 2);
-               execute(m1,m2);
+               executePlus(m1, m2);
+       }
+
+       @Test
+       public void testDivide() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 2);
+               executeDivide(m1, m2);
+       }
+
+       @Test
+       public void testDivide_matrixVector() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 10, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 10, 0, 
10, 1.0, 2);
+               executeDivide(m1, m2);
+       }
+
+       @Test(expected = DMLRuntimeException.class)
+       public void testDivide_Invalid_1() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 10, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 11, 0, 
10, 1.0, 2);
+               executeDivide(m1, m2);
+       }
+
+       @Test(expected = DMLRuntimeException.class)
+       public void testDivide_Invalid_2() {
+               try {
+
+                       MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 
10, 0, 10, 1.0, 1);
+                       MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 
9, 0, 10, 1.0, 2);
+                       executeDivide(m1, m2);
+               }
+               catch(DMLRuntimeException e) {
+                       throw e;
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testDivide_matrixVector_emptyVector() {
+               try {
+
+                       MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 
10, 0, 10, 1.0, 1);
+                       MatrixBlock m2 = new MatrixBlock(1, 10, 0.0);
+                       executeDivide(m1, m2);
+               }
+               catch(DMLRuntimeException e) {
+                       throw e;
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testDivide_matrixVector_sparseBoth() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 1000, 
0, 10, 0.2, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 1000, 0, 
10, 0.2, 2);
+               m1.examSparsity();
+               m2.examSparsity();
+               executeDivide(m1, m2);
+       }
+
+       @Test
+       public void testDivide_matrixVector_oneEmpty() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 0.2, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(1, 10, 0, 
10, 0.0, 2);
+               m1.examSparsity();
+               m2.examSparsity();
+               executeDivide(m1, m2);
+       }
+
+       @Test
+       public void testOr_matrixMatrix_denseDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 2);
+               executeOr(m1, m2);
+       }
+
+       @Test
+       public void testOr_matrixMatrix_denseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 2);
+               executeOr(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_denseDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(10, 10, 0, 
10, 1.0, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_denseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_denseEmpty() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.0, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_EmptyDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_EmptySparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, .1, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_EmptyEmpty() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.0, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testPlus_matrixMatrix_DenseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 2);
+               executePlus(m1, m2);
+       }
+
+       @Test
+       public void testMinus_matrixMatrix_DenseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 2);
+               executeMinus(m1, m2);
+       }
+
+       @Test
+       public void testMinus_matrixMatrix_DenseDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 2);
+               executeMinus(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_DenseDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 2);
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testLT_matrixMatrix_DenseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 2);
+               executeLT(m1, m2);
        }
 
-       private void execute(MatrixBlock m1, MatrixBlock m2){
+       @Test
+       public void testLT_matrixMatrix_SparseDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 2);
+               assertTrue(m1.isInSparseFormat());
+               executeLT(m1, m2);
+       }
+
+       @Test
+       public void testPlus_matrixMatrix_SparseDense() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 1.0, 2);
+               assertTrue(m1.isInSparseFormat());
+               executePlus(m1, m2);
+       }
+
+       @Test
+       public void testPlus_matrixMatrix_SparseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 2);
+               assertTrue(m1.isInSparseFormat());
+               executePlus(m1, m2);
+       }
+
+       @Test
+       public void testDiv_matrixMatrix_SparseSparse() {
+               MatrixBlock m1 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 1);
+               MatrixBlock m2 = TestUtils.generateTestMatrixBlock(100, 100, 0, 
10, 0.1, 2);
+               assertTrue(m1.isInSparseFormat());
+               executeDivide(m1, m2);
+       }
+
+       private void executeDivide(MatrixBlock m1, MatrixBlock m2) {
+               BinaryOperator op = new 
BinaryOperator(Divide.getDivideFnObject());
+               testInplace(m1, m2, op);
+       }
+
+       private void executePlus(MatrixBlock m1, MatrixBlock m2) {
                BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
+               testInplace(m1, m2, op);
+       }
+
+       private void executeMinus(MatrixBlock m1, MatrixBlock m2) {
+               BinaryOperator op = new 
BinaryOperator(Minus.getMinusFnObject());
+               testInplace(m1, m2, op);
+       }
+
+       private void executeOr(MatrixBlock m1, MatrixBlock m2) {
+               BinaryOperator op = new BinaryOperator(Or.getOrFnObject());
+               testInplace(m1, m2, op);
+       }
+
+       private void executeLT(MatrixBlock m1, MatrixBlock m2) {
+               BinaryOperator op = new 
BinaryOperator(LessThan.getLessThanFnObject());
+               testInplace(m1, m2, op);
+       }
+
+       private void testInplace(MatrixBlock m1, MatrixBlock m2, BinaryOperator 
op) {
                MatrixBlock ret1 = m1.binaryOperations(op, m2);
                m1.binaryOperationsInPlace(op, m2);
-
                TestUtils.compareMatricesBitAvgDistance(ret1, m1, 0, 0, "Result 
is incorrect for inplace op");
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java
 
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java
new file mode 100644
index 0000000000..05f29afcc5
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/BinaryOperationInPlaceTestParameterized.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.And;
+import org.apache.sysds.runtime.functionobjects.BitwAnd;
+import org.apache.sysds.runtime.functionobjects.BitwOr;
+import org.apache.sysds.runtime.functionobjects.BitwShiftL;
+import org.apache.sysds.runtime.functionobjects.BitwShiftR;
+import org.apache.sysds.runtime.functionobjects.BitwXor;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Equals;
+import org.apache.sysds.runtime.functionobjects.GreaterThan;
+import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
+import org.apache.sysds.runtime.functionobjects.IntegerDivide;
+import org.apache.sysds.runtime.functionobjects.LessThan;
+import org.apache.sysds.runtime.functionobjects.LessThanEquals;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.MinusNz;
+import org.apache.sysds.runtime.functionobjects.Modulus;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.NotEquals;
+import org.apache.sysds.runtime.functionobjects.Or;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.Power;
+import org.apache.sysds.runtime.functionobjects.Xor;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class BinaryOperationInPlaceTestParameterized {
+       protected static final Log LOG = 
LogFactory.getLog(BinaryOperationInPlaceTestParameterized.class.getName());
+
+       private final MatrixBlock left;
+       private final MatrixBlock right;
+       private final BinaryOperator op;
+
+       public BinaryOperationInPlaceTestParameterized(MatrixBlock left, 
MatrixBlock right, BinaryOperator op) {
+               this.left = new MatrixBlock();
+               this.right = right;
+               this.op = op;
+               this.left.copy(left);
+       }
+
+       @Parameters
+       public static Collection<Object[]> data() {
+               List<Object[]> tests = new ArrayList<>();
+
+               try {
+                       double[] sparsities = new double[] {0.0, 0.001, 0.1, 
0.5, 1.0};
+
+                       BinaryOperator[] operators = new BinaryOperator[] {//
+                               new BinaryOperator(Plus.getPlusFnObject()), //
+                               new BinaryOperator(Minus.getMinusFnObject()), //
+                               new BinaryOperator(Or.getOrFnObject()), //
+                               new 
BinaryOperator(LessThan.getLessThanFnObject()), //
+                               new 
BinaryOperator(LessThanEquals.getLessThanEqualsFnObject()), //
+                               new 
BinaryOperator(GreaterThan.getGreaterThanFnObject()), //
+                               new 
BinaryOperator(GreaterThanEquals.getGreaterThanEqualsFnObject()), //
+                               new 
BinaryOperator(Multiply.getMultiplyFnObject()), //
+                               new BinaryOperator(Modulus.getFnObject()), //
+                               new 
BinaryOperator(IntegerDivide.getFnObject()), //
+                               new BinaryOperator(Equals.getEqualsFnObject()), 
//
+                               new 
BinaryOperator(NotEquals.getNotEqualsFnObject()), //
+                               new BinaryOperator(And.getAndFnObject()), //
+                               new BinaryOperator(Xor.getXorFnObject()), //
+                               new 
BinaryOperator(BitwAnd.getBitwAndFnObject()), //
+                               new BinaryOperator(BitwOr.getBitwOrFnObject()), 
//
+                               new 
BinaryOperator(BitwXor.getBitwXorFnObject()), //
+                               new 
BinaryOperator(BitwShiftL.getBitwShiftLFnObject()), //
+                               new 
BinaryOperator(BitwShiftR.getBitwShiftRFnObject()), //
+                               new BinaryOperator(Power.getPowerFnObject()), //
+                               new 
BinaryOperator(MinusNz.getMinusNzFnObject()), //
+                               // Builtin
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MIN)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MAX)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.LOG)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.LOG_NZ)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MAXINDEX)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.MININDEX)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.CUMMAX)), //
+                               new 
BinaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.CUMMIN)),//
+                       };
+
+                       for(double rightSparsity : sparsities) {
+                               MatrixBlock right = 
TestUtils.generateTestMatrixBlock(100, 100, 0, 10, rightSparsity, 2);
+                               MatrixBlock rightV = 
TestUtils.generateTestMatrixBlock(1, 100, 0, 10, rightSparsity, 2);
+                               for(double leftSparsity : sparsities) {
+                                       MatrixBlock left = 
TestUtils.generateTestMatrixBlock(100, 100, 0, 10, leftSparsity, 2);
+                                       for(BinaryOperator op : operators) {
+                                               tests.add(new Object[] {left, 
right, op});
+                                               tests.add(new Object[] {left, 
rightV, op});
+                                       }
+                               }
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail("failed constructing tests");
+               }
+
+               return tests;
+       }
+
+       @Test
+       public void testInplace() {
+               try {
+                       final int lrb = left.getNumRows();
+                       final int lcb = left.getNumColumns();
+                       final int rrb = right.getNumRows();
+                       final int rcb = right.getNumColumns();
+
+                       final double lspb = left.getSparsity();
+                       final double rspb = right.getSparsity();
+
+                       final MatrixBlock ret1 = left.binaryOperations(op, 
right);
+
+                       assertEquals(lrb, left.getNumRows());
+                       assertEquals(lcb, left.getNumColumns());
+                       assertEquals(rrb, right.getNumRows());
+                       assertEquals(rcb, right.getNumColumns());
+
+                       left.binaryOperationsInPlace(op, right);
+
+                       assertEquals(lrb, left.getNumRows());
+                       assertEquals(lcb, left.getNumColumns());
+                       assertEquals(rrb, right.getNumRows());
+                       assertEquals(rcb, right.getNumColumns());
+                       TestUtils.compareMatricesBitAvgDistance(ret1, left, 0, 
0, "Result is incorrect for inplace \n" + op + "  "
+                               + lspb + " " + rspb + " (" + lrb + "," + lcb + 
")" + " (" + rrb + "," + rcb + ")");
+               }
+               catch(DMLRuntimeException e) {
+                       if(e.getMessage().contains("Invalid row safety of 
inplace row operation: ")) {
+                               if(op.fn instanceof Divide || //
+                                       op.fn instanceof Plus || //
+                                       op.fn instanceof Minus || //
+                                       op.fn instanceof Or)
+                                       return;
+                       }
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+               catch(NotImplementedException e) {
+                       // TODO fix the not implemented instances.
+                       if(e.getMessage().contains("Not made sparse vector 
inplace"))
+                               return;
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+}

Reply via email to