Repository: systemml
Updated Branches:
  refs/heads/master ec0448850 -> 61925ab49


[SYSTEMML-2244] Fix handling of compressed blocks in few spark mm ops

This patch fixes the missing handling of compressed right-hand-side
blocks in spark cpmm, rmm, zipmm, and tsmm2 instructions. Similar to
mapmm, tsmm, mapmmchain, we now use a common primitive that internally
handles this case by calling binary operations on the compressed rhs.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5d149a0a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5d149a0a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5d149a0a

Branch: refs/heads/master
Commit: 5d149a0af2a0921581b702a0da62d79279b6aab8
Parents: ec04488
Author: Matthias Boehm <[email protected]>
Authored: Sat Apr 14 01:54:39 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Apr 14 01:54:39 2018 -0700

----------------------------------------------------------------------
 .../runtime/instructions/spark/CpmmSPInstruction.java   |  8 +++++---
 .../runtime/instructions/spark/MapmmSPInstruction.java  | 12 ++++++------
 .../runtime/instructions/spark/RmmSPInstruction.java    |  5 +++--
 .../runtime/instructions/spark/Tsmm2SPInstruction.java  |  2 +-
 .../runtime/instructions/spark/ZipmmSPInstruction.java  |  4 +++-
 .../runtime/matrix/data/OperationsOnMatrixValues.java   |  7 +++----
 6 files changed, 21 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
index 5c98964..de08d83 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
@@ -43,6 +43,7 @@ import 
org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
@@ -203,10 +204,10 @@ public class CpmmSPInstruction extends 
BinarySPInstruction {
                        MatrixBlock blkIn1 = 
(MatrixBlock)arg0._2()._1().getValue();
                        MatrixBlock blkIn2 = 
(MatrixBlock)arg0._2()._2().getValue();
                        MatrixIndexes ixOut = new MatrixIndexes();
-                       MatrixBlock blkOut = new MatrixBlock();
                        
                        //core block matrix multiplication 
-                       blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, 
blkOut, _op);
+                       MatrixBlock blkOut = OperationsOnMatrixValues
+                               .performAggregateBinaryIgnoreIndexes(blkIn1, 
blkIn2, new MatrixBlock(), _op);
                        
                        //return target block
                        
ixOut.setIndexes(arg0._2()._1().getIndexes().getRowIndex(),
@@ -234,7 +235,8 @@ public class CpmmSPInstruction extends BinarySPInstruction {
                        MatrixBlock in2 = (MatrixBlock)arg0._2()
                                .reorgOperations(_rop, new MatrixBlock(), 0, 0, 
0);
                        //core block matrix multiplication
-                       return in1.aggregateBinaryOperations(in1, in2, new 
MatrixBlock(), _op);
+                       return OperationsOnMatrixValues
+                               .performAggregateBinaryIgnoreIndexes(in1, in2, 
new MatrixBlock(), _op);
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
index d43b6f8..d54ccf8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java
@@ -327,8 +327,8 @@ public class MapmmSPInstruction extends BinarySPInstruction 
{
                                MatrixBlock left = _pbc.getBlock(1, 
(int)ixIn.getRowIndex());
                                
                                //execute matrix-vector mult
-                               return (MatrixBlock) 
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( 
-                                               left, blkIn, new MatrixBlock(), 
_op);                                           
+                               return 
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( 
+                                       left, blkIn, new MatrixBlock(), _op);
                        }
                        else //if( _type == CacheType.RIGHT )
                        {
@@ -336,8 +336,8 @@ public class MapmmSPInstruction extends BinarySPInstruction 
{
                                MatrixBlock right = 
_pbc.getBlock((int)ixIn.getColumnIndex(), 1);
                                
                                //execute matrix-vector mult
-                               return (MatrixBlock) 
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(
-                                               blkIn, right, new 
MatrixBlock(), _op);
+                               return 
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(
+                                       blkIn, right, new MatrixBlock(), _op);
                        }
                }
        }
@@ -392,7 +392,7 @@ public class MapmmSPInstruction extends BinarySPInstruction 
{
                                        MatrixBlock left = _pbc.getBlock(1, 
(int)ixIn.getRowIndex());
                                        
                                        //execute index preserving matrix 
multiplication
-                                       left.aggregateBinaryOperations(left, 
blkIn, blkOut, _op);
+                                       
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(left, blkIn, 
blkOut, _op);
                                }
                                else //if( _type == CacheType.RIGHT )
                                {
@@ -400,7 +400,7 @@ public class MapmmSPInstruction extends BinarySPInstruction 
{
                                        MatrixBlock right = 
_pbc.getBlock((int)ixIn.getColumnIndex(), 1);
 
                                        //execute index preserving matrix 
multiplication
-                                       blkIn.aggregateBinaryOperations(blkIn, 
right, blkOut, _op);     
+                                       
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(blkIn, right, 
blkOut, _op);
                                }
                        
                                return new Tuple2<>(ixIn, blkOut);

http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java
index 05f3870..294c142 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java
@@ -43,6 +43,7 @@ import 
org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysml.runtime.matrix.data.TripleIndexes;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
@@ -188,10 +189,10 @@ public class RmmSPInstruction extends BinarySPInstruction 
{
                        MatrixIndexes ixOut = new 
MatrixIndexes(ixIn.getFirstIndex(), ixIn.getSecondIndex()); //i,j
                        MatrixBlock blkIn1 = arg0._2()._1();
                        MatrixBlock blkIn2 = arg0._2()._2();
-                       MatrixBlock blkOut = new MatrixBlock();
                        
                        //core block matrix multiplication 
-                       blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, 
blkOut, _op);
+                       MatrixBlock blkOut = OperationsOnMatrixValues
+                               .performAggregateBinaryIgnoreIndexes(blkIn1, 
blkIn2, new MatrixBlock(), _op);
                        
                        //output new tuple
                        return new Tuple2<>(ixOut, blkOut);

http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java
index b5e8d87..cabc2c8 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java
@@ -215,7 +215,7 @@ public class Tsmm2SPInstruction extends UnarySPInstruction {
                                                
(int)(_type.isLeft()?1:ixin.getColumnIndex()));
                                MatrixBlock mbin2t = transpose(mbin2, new 
MatrixBlock()); //prep for transpose rewrite mm
                                
-                               MatrixBlock out2 = (MatrixBlock) 
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( //mm
+                               MatrixBlock out2 = 
OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( //mm
                                                _type.isLeft() ? mbin2t : mbin, 
_type.isLeft() ? mbin : mbin2t, new MatrixBlock(), _op);
                                
                                MatrixIndexes ixout2 = _type.isLeft() ? new 
MatrixIndexes(2,1) : new MatrixIndexes(1,2);

http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java
index ec0b300..4f168c1 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java
@@ -36,6 +36,7 @@ import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -124,7 +125,8 @@ public class ZipmmSPInstruction extends BinarySPInstruction 
{
                        MatrixBlock tmp = 
(MatrixBlock)in2.reorgOperations(_rop, new MatrixBlock(), 0, 0, 0);
                                
                        //core matrix multiplication (for t(y)%*%X or t(X)%*%y)
-                       return tmp.aggregateBinaryOperations(tmp, in1, new 
MatrixBlock(), _abop);
+                       return OperationsOnMatrixValues
+                               .performAggregateBinaryIgnoreIndexes(tmp, in1, 
new MatrixBlock(), _abop);
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
 
b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
index 6b5b280..3715404 100644
--- 
a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
+++ 
b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
@@ -228,14 +228,13 @@ public class OperationsOnMatrixValues
                        value1.aggregateBinaryOperations(indexes1, value1, 
indexes2, value2, valueOut, op);
        }
 
-       public static MatrixValue 
performAggregateBinaryIgnoreIndexes(MatrixBlock value1, MatrixBlock value2,
+       public static MatrixBlock 
performAggregateBinaryIgnoreIndexes(MatrixBlock value1, MatrixBlock value2,
                        MatrixBlock valueOut, AggregateBinaryOperator op) {
                //perform on the value
                if( value2 instanceof CompressedMatrixBlock )
-                       value2.aggregateBinaryOperations(value1, value2, 
valueOut, op);
+                       return value2.aggregateBinaryOperations(value1, value2, 
valueOut, op);
                else
-                       value1.aggregateBinaryOperations(value1, value2, 
valueOut, op);
-               return valueOut;
+                       return value1.aggregateBinaryOperations(value1, value2, 
valueOut, op);
        }
 
        @SuppressWarnings("rawtypes")

Reply via email to