Repository: systemml Updated Branches: refs/heads/master ec0448850 -> 61925ab49
[SYSTEMML-2244] Fix handling of compressed blocks in few spark mm ops This patch fixes the missing handling of compressed right-hand-side blocks in spark cpmm, rmm, zipmm, and tsmm2 instructions. Similar to mapmm, tsmm, mapmmchain, we now use a common primitive that internally handles this case by calling binary operations on the compressed rhs. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5d149a0a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5d149a0a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5d149a0a Branch: refs/heads/master Commit: 5d149a0af2a0921581b702a0da62d79279b6aab8 Parents: ec04488 Author: Matthias Boehm <[email protected]> Authored: Sat Apr 14 01:54:39 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Apr 14 01:54:39 2018 -0700 ---------------------------------------------------------------------- .../runtime/instructions/spark/CpmmSPInstruction.java | 8 +++++--- .../runtime/instructions/spark/MapmmSPInstruction.java | 12 ++++++------ .../runtime/instructions/spark/RmmSPInstruction.java | 5 +++-- .../runtime/instructions/spark/Tsmm2SPInstruction.java | 2 +- .../runtime/instructions/spark/ZipmmSPInstruction.java | 4 +++- .../runtime/matrix/data/OperationsOnMatrixValues.java | 7 +++---- 6 files changed, 21 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java index 5c98964..de08d83 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java @@ -43,6 +43,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; @@ -203,10 +204,10 @@ public class CpmmSPInstruction extends BinarySPInstruction { MatrixBlock blkIn1 = (MatrixBlock)arg0._2()._1().getValue(); MatrixBlock blkIn2 = (MatrixBlock)arg0._2()._2().getValue(); MatrixIndexes ixOut = new MatrixIndexes(); - MatrixBlock blkOut = new MatrixBlock(); //core block matrix multiplication - blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, blkOut, _op); + MatrixBlock blkOut = OperationsOnMatrixValues + .performAggregateBinaryIgnoreIndexes(blkIn1, blkIn2, new MatrixBlock(), _op); //return target block ixOut.setIndexes(arg0._2()._1().getIndexes().getRowIndex(), @@ -234,7 +235,8 @@ public class CpmmSPInstruction extends BinarySPInstruction { MatrixBlock in2 = (MatrixBlock)arg0._2() .reorgOperations(_rop, new MatrixBlock(), 0, 0, 0); //core block matrix multiplication - return in1.aggregateBinaryOperations(in1, in2, new MatrixBlock(), _op); + return OperationsOnMatrixValues + .performAggregateBinaryIgnoreIndexes(in1, in2, new MatrixBlock(), _op); } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java index d43b6f8..d54ccf8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/MapmmSPInstruction.java @@ -327,8 +327,8 @@ public class MapmmSPInstruction extends BinarySPInstruction { MatrixBlock left = _pbc.getBlock(1, (int)ixIn.getRowIndex()); //execute matrix-vector mult - return (MatrixBlock) OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( - left, blkIn, new MatrixBlock(), _op); + return OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( + left, blkIn, new MatrixBlock(), _op); } else //if( _type == CacheType.RIGHT ) { @@ -336,8 +336,8 @@ public class MapmmSPInstruction extends BinarySPInstruction { MatrixBlock right = _pbc.getBlock((int)ixIn.getColumnIndex(), 1); //execute matrix-vector mult - return (MatrixBlock) OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( - blkIn, right, new MatrixBlock(), _op); + return OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( + blkIn, right, new MatrixBlock(), _op); } } } @@ -392,7 +392,7 @@ public class MapmmSPInstruction extends BinarySPInstruction { MatrixBlock left = _pbc.getBlock(1, (int)ixIn.getRowIndex()); //execute index preserving matrix multiplication - left.aggregateBinaryOperations(left, blkIn, blkOut, _op); + OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(left, blkIn, blkOut, _op); } else //if( _type == CacheType.RIGHT ) { @@ -400,7 +400,7 @@ public class MapmmSPInstruction extends BinarySPInstruction { MatrixBlock right = _pbc.getBlock((int)ixIn.getColumnIndex(), 1); //execute index preserving matrix multiplication - blkIn.aggregateBinaryOperations(blkIn, right, blkOut, _op); + OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(blkIn, right, blkOut, _op); } return new Tuple2<>(ixIn, blkOut); http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java index 05f3870..294c142 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/RmmSPInstruction.java @@ -43,6 +43,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.data.TripleIndexes; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; @@ -188,10 +189,10 @@ public class RmmSPInstruction extends BinarySPInstruction { MatrixIndexes ixOut = new MatrixIndexes(ixIn.getFirstIndex(), ixIn.getSecondIndex()); //i,j MatrixBlock blkIn1 = arg0._2()._1(); MatrixBlock blkIn2 = arg0._2()._2(); - MatrixBlock blkOut = new MatrixBlock(); //core block matrix multiplication - blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, blkOut, _op); + MatrixBlock blkOut = OperationsOnMatrixValues + .performAggregateBinaryIgnoreIndexes(blkIn1, blkIn2, new MatrixBlock(), _op); //output new tuple return new Tuple2<>(ixOut, blkOut); http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java index b5e8d87..cabc2c8 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/Tsmm2SPInstruction.java @@ -215,7 +215,7 @@ public class Tsmm2SPInstruction extends UnarySPInstruction { (int)(_type.isLeft()?1:ixin.getColumnIndex())); MatrixBlock mbin2t = transpose(mbin2, new MatrixBlock()); //prep for transpose rewrite mm - MatrixBlock out2 = (MatrixBlock) OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( //mm + MatrixBlock out2 = OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes( //mm _type.isLeft() ? mbin2t : mbin, _type.isLeft() ? mbin : mbin2t, new MatrixBlock(), _op); MatrixIndexes ixout2 = _type.isLeft() ? new MatrixIndexes(2,1) : new MatrixIndexes(1,2); http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java index ec0b300..4f168c1 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ZipmmSPInstruction.java @@ -36,6 +36,7 @@ import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.Operator; @@ -124,7 +125,8 @@ public class ZipmmSPInstruction extends BinarySPInstruction { MatrixBlock tmp = (MatrixBlock)in2.reorgOperations(_rop, new MatrixBlock(), 0, 0, 0); //core matrix multiplication (for t(y)%*%X or t(X)%*%y) - return tmp.aggregateBinaryOperations(tmp, in1, new MatrixBlock(), _abop); + return OperationsOnMatrixValues + .performAggregateBinaryIgnoreIndexes(tmp, in1, new MatrixBlock(), _abop); } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/5d149a0a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java index 6b5b280..3715404 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java @@ -228,14 +228,13 @@ public class OperationsOnMatrixValues value1.aggregateBinaryOperations(indexes1, value1, indexes2, value2, valueOut, op); } - public static MatrixValue performAggregateBinaryIgnoreIndexes(MatrixBlock value1, MatrixBlock value2, + public static MatrixBlock performAggregateBinaryIgnoreIndexes(MatrixBlock value1, MatrixBlock value2, MatrixBlock valueOut, AggregateBinaryOperator op) { //perform on the value if( value2 instanceof CompressedMatrixBlock ) - value2.aggregateBinaryOperations(value1, value2, valueOut, op); + return value2.aggregateBinaryOperations(value1, value2, valueOut, op); else - value1.aggregateBinaryOperations(value1, value2, valueOut, op); - return valueOut; + return value1.aggregateBinaryOperations(value1, value2, valueOut, op); } @SuppressWarnings("rawtypes")
