Repository: systemml Updated Branches: refs/heads/master 5b0fb0ccf -> ce9e42fef
http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java index b1190ad..a9b5b37 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixCell.java @@ -32,12 +32,10 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.CTable; import org.apache.sysml.runtime.functionobjects.ReduceDiag; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; -import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.runtime.matrix.operators.Operator; -import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.matrix.operators.ScalarOperator; import org.apache.sysml.runtime.matrix.operators.UnaryOperator; @@ -170,20 +168,6 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria } @Override - public MatrixValue aggregateBinaryOperations(MatrixValue value1, - MatrixValue value2, MatrixValue result, AggregateBinaryOperator op) - throws DMLRuntimeException { - - MatrixCell c1=checkType(value1); - MatrixCell c2=checkType(value2); - MatrixCell c3=checkType(result); - if(c3==null) - c3=new MatrixCell(); - c3.setValue(op.binaryFn.execute(c1.getValue(), c2.getValue())); - return c3; - } - - @Override public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int brlen, int bclen, MatrixIndexes indexesIn) throws DMLRuntimeException { @@ -314,7 +298,7 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria } @Override - public void ternaryOperations(Operator op, MatrixValue that, + public void ctableOperations(Operator op, MatrixValue that, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -328,32 +312,32 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria } @Override - public void ternaryOperations(Operator op, MatrixValue that, double scalarThat2, boolean ignoreZeros, + public void ctableOperations(Operator op, MatrixValue that, double scalarThat2, boolean ignoreZeros, CTableMap ctableResult, MatrixBlock ctableResultBlock) throws DMLRuntimeException { MatrixCell c2=checkType(that); CTable ctable = CTable.getCTableFnObject(); if ( ctableResult != null) - ctable.execute(this.value, c2.value, scalarThat2, ignoreZeros, ctableResult); + ctable.execute(this.value, c2.value, scalarThat2, ignoreZeros, ctableResult); else - ctable.execute(this.value, c2.value, scalarThat2, ignoreZeros, ctableResultBlock); + ctable.execute(this.value, c2.value, scalarThat2, ignoreZeros, ctableResultBlock); } @Override - public void ternaryOperations(Operator op, double scalarThat, + public void ctableOperations(Operator op, double scalarThat, double scalarThat2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { CTable ctable = CTable.getCTableFnObject(); if ( resultMap != null) - ctable.execute(this.value, scalarThat, scalarThat2, false, resultMap); + ctable.execute(this.value, scalarThat, scalarThat2, false, resultMap); else - ctable.execute(this.value, scalarThat, scalarThat2, false, resultBlock); + ctable.execute(this.value, scalarThat, scalarThat2, false, resultBlock); } @Override - public void ternaryOperations(Operator op, MatrixIndexes ix1, double scalarThat, boolean left, int brlen, + public void ctableOperations(Operator op, MatrixIndexes ix1, double scalarThat, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -361,20 +345,20 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria CTable ctable = CTable.getCTableFnObject(); if ( resultMap != null ) { if( left ) - ctable.execute(ix1.getRowIndex(), this.value, scalarThat, false, resultMap); + ctable.execute(ix1.getRowIndex(), this.value, scalarThat, false, resultMap); else - ctable.execute(this.value, ix1.getRowIndex(), scalarThat, false, resultMap); + ctable.execute(this.value, ix1.getRowIndex(), scalarThat, false, resultMap); } else { if( left ) - ctable.execute(ix1.getRowIndex(), this.value, scalarThat, false, resultBlock); + ctable.execute(ix1.getRowIndex(), this.value, scalarThat, false, resultBlock); else - ctable.execute(this.value, ix1.getRowIndex(), scalarThat, false, resultBlock); + ctable.execute(this.value, ix1.getRowIndex(), scalarThat, false, resultBlock); } } @Override - public void ternaryOperations(Operator op, double scalarThat, + public void ctableOperations(Operator op, double scalarThat, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { @@ -386,13 +370,6 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria ctable.execute(this.value, scalarThat, c3.value, false, resultBlock); } - - @Override - public MatrixValue quaternaryOperations(QuaternaryOperator qop, MatrixValue um, MatrixValue vm, MatrixValue wm, MatrixValue out) - throws DMLRuntimeException - { - throw new DMLRuntimeException("operation not supported fro MatrixCell"); - } @Override public void sliceOperations(ArrayList<IndexedMatrixValue> outlist, @@ -407,7 +384,7 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria public MatrixValue replaceOperations(MatrixValue result, double pattern, double replacement) throws DMLRuntimeException { - MatrixCell out = checkType(result); + MatrixCell out = checkType(result); if( value == pattern || (Double.isNaN(pattern) && Double.isNaN(value)) ) out.value = replacement; else @@ -424,14 +401,6 @@ public class MatrixCell extends MatrixValue implements WritableComparable, Seria } @Override - public MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, - MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value, - MatrixValue result, AggregateBinaryOperator op) - throws DMLRuntimeException { - throw new DMLRuntimeException("MatrixCell.aggregateBinaryOperations should never be called"); - } - - @Override public void appendOperations(MatrixValue valueIn2, ArrayList<IndexedMatrixValue> outlist, int blockRowFactor, int blockColFactor, boolean cbind, boolean m2IsLast, int nextNCol) throws DMLRuntimeException { http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java index c6361bc..4f225be 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixValue.java @@ -26,12 +26,10 @@ import org.apache.hadoop.io.WritableComparable; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; -import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; import org.apache.sysml.runtime.matrix.operators.Operator; -import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.matrix.operators.ScalarOperator; import org.apache.sysml.runtime.matrix.operators.UnaryOperator; @@ -116,24 +114,21 @@ public abstract class MatrixValue implements WritableComparable int startRow, int startColumn, int length) throws DMLRuntimeException; - public abstract void ternaryOperations(Operator op, MatrixValue that, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) + public abstract void ctableOperations(Operator op, MatrixValue that, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException; - public abstract void ternaryOperations(Operator op, MatrixValue that, double scalar_that2, boolean ignoreZeros, CTableMap resultMap, MatrixBlock resultBlock) + public abstract void ctableOperations(Operator op, MatrixValue that, double scalar_that2, boolean ignoreZeros, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException; - public abstract void ternaryOperations(Operator op, double scalar_that, double scalar_that2, CTableMap resultMap, MatrixBlock resultBlock) + public abstract void ctableOperations(Operator op, double scalar_that, double scalar_that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException; - public abstract void ternaryOperations(Operator op, MatrixIndexes ix1, double scalar_that, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock) + public abstract void ctableOperations(Operator op, MatrixIndexes ix1, double scalar_that, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException; - public abstract void ternaryOperations(Operator op, double scalarThat, MatrixValue that2, CTableMap ctableResult, MatrixBlock ctableResultBlock) + public abstract void ctableOperations(Operator op, double scalarThat, MatrixValue that2, CTableMap ctableResult, MatrixBlock ctableResultBlock) throws DMLRuntimeException; - - public abstract MatrixValue quaternaryOperations(QuaternaryOperator qop, MatrixValue um, MatrixValue vm, MatrixValue wm, MatrixValue out) - throws DMLRuntimeException; - + public abstract MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int brlen, int bclen, MatrixIndexes indexesIn) throws DMLRuntimeException; @@ -141,14 +136,6 @@ public abstract class MatrixValue implements WritableComparable int blockingFactorRow, int blockingFactorCol, MatrixIndexes indexesIn, boolean inCP) throws DMLRuntimeException; - public abstract MatrixValue aggregateBinaryOperations(MatrixValue m1Value, MatrixValue m2Value, - MatrixValue result, AggregateBinaryOperator op) - throws DMLRuntimeException; - - public abstract MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value, - MatrixValue result, AggregateBinaryOperator op) - throws DMLRuntimeException; - public abstract MatrixValue unaryOperations(UnaryOperator op, MatrixValue result) throws DMLRuntimeException; http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java index 90a8d15..c11688e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java @@ -71,44 +71,44 @@ public class OperationsOnMatrixValues } // ------------- Ternary Operations ------------- - public static void performTernary(MatrixIndexes indexesIn1, MatrixValue valueIn1, MatrixIndexes indexesIn2, MatrixValue valueIn2, + public static void performCtable(MatrixIndexes indexesIn1, MatrixValue valueIn1, MatrixIndexes indexesIn2, MatrixValue valueIn2, MatrixIndexes indexesIn3, MatrixValue valueIn3, CTableMap resultMap, MatrixBlock resultBlock, Operator op ) throws DMLRuntimeException { //operation on the cells inside the value - valueIn1.ternaryOperations(op, valueIn2, valueIn3, resultMap, resultBlock); + valueIn1.ctableOperations(op, valueIn2, valueIn3, resultMap, resultBlock); } - public static void performTernary(MatrixIndexes indexesIn1, MatrixValue valueIn1, MatrixIndexes indexesIn2, MatrixValue valueIn2, + public static void performCtable(MatrixIndexes indexesIn1, MatrixValue valueIn1, MatrixIndexes indexesIn2, MatrixValue valueIn2, double scalarIn3, CTableMap resultMap, MatrixBlock resultBlock, Operator op) throws DMLRuntimeException { //operation on the cells inside the value - valueIn1.ternaryOperations(op, valueIn2, scalarIn3, false, resultMap, resultBlock); + valueIn1.ctableOperations(op, valueIn2, scalarIn3, false, resultMap, resultBlock); } - public static void performTernary(MatrixIndexes indexesIn1, MatrixValue valueIn1, double scalarIn2, + public static void performCtable(MatrixIndexes indexesIn1, MatrixValue valueIn1, double scalarIn2, double scalarIn3, CTableMap resultMap, MatrixBlock resultBlock, Operator op ) throws DMLRuntimeException { //operation on the cells inside the value - valueIn1.ternaryOperations(op, scalarIn2, scalarIn3, resultMap, resultBlock); + valueIn1.ctableOperations(op, scalarIn2, scalarIn3, resultMap, resultBlock); } - public static void performTernary(MatrixIndexes indexesIn1, MatrixValue valueIn1, double scalarIn2, boolean left, + public static void performCtable(MatrixIndexes indexesIn1, MatrixValue valueIn1, double scalarIn2, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock, Operator op ) throws DMLRuntimeException { //operation on the cells inside the value - valueIn1.ternaryOperations(op, indexesIn1, scalarIn2, left, brlen, resultMap, resultBlock); + valueIn1.ctableOperations(op, indexesIn1, scalarIn2, left, brlen, resultMap, resultBlock); } - public static void performTernary(MatrixIndexes indexesIn1, MatrixValue valueIn1, double scalarIn2, + public static void performCtable(MatrixIndexes indexesIn1, MatrixValue valueIn1, double scalarIn2, MatrixIndexes indexesIn3, MatrixValue valueIn3, CTableMap resultMap, MatrixBlock resultBlock, Operator op ) throws DMLRuntimeException { //operation on the cells inside the value - valueIn1.ternaryOperations(op, scalarIn2, valueIn3, resultMap, resultBlock); + valueIn1.ctableOperations(op, scalarIn2, valueIn3, resultMap, resultBlock); } // ----------------------------------------------------- @@ -240,13 +240,12 @@ public class OperationsOnMatrixValues valueIn.aggregateUnaryOperations(op, valueOut, brlen, bclen, indexesIn); } - public static void performAggregateBinary(MatrixIndexes indexes1, MatrixValue value1, MatrixIndexes indexes2, MatrixValue value2, - MatrixIndexes indexesOut, MatrixValue valueOut, AggregateBinaryOperator op) + public static void performAggregateBinary(MatrixIndexes indexes1, MatrixBlock value1, MatrixIndexes indexes2, MatrixBlock value2, + MatrixIndexes indexesOut, MatrixBlock valueOut, AggregateBinaryOperator op) throws DMLRuntimeException { //compute output index indexesOut.setIndexes(indexes1.getRowIndex(), indexes2.getColumnIndex()); - //perform on the value if( value2 instanceof CompressedMatrixBlock ) value2.aggregateBinaryOperations(value1, value2, valueOut, op); @@ -254,11 +253,9 @@ public class OperationsOnMatrixValues value1.aggregateBinaryOperations(indexes1, value1, indexes2, value2, valueOut, op); } - public static MatrixValue performAggregateBinaryIgnoreIndexes( - MatrixValue value1, MatrixValue value2, - MatrixValue valueOut, AggregateBinaryOperator op) + public static MatrixValue performAggregateBinaryIgnoreIndexes(MatrixBlock value1, MatrixBlock value2, + MatrixBlock valueOut, AggregateBinaryOperator op) throws DMLRuntimeException { - //perform on the value if( value2 instanceof CompressedMatrixBlock ) value2.aggregateBinaryOperations(value1, value2, valueOut, op); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/mapred/GMRReducer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GMRReducer.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GMRReducer.java index 3859636..a8ed5d9 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GMRReducer.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GMRReducer.java @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.Reporter; import org.apache.sysml.runtime.instructions.mr.AppendRInstruction; import org.apache.sysml.runtime.instructions.mr.MRInstruction; -import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; +import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixCell; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; @@ -78,10 +78,10 @@ implements Reducer<MatrixIndexes, TaggedMatrixValue, MatrixIndexes, MatrixValue> { for(MRInstruction ins: mixed_instructions) { - if(ins instanceof TernaryInstruction) + if(ins instanceof CtableInstruction) { - MatrixCharacteristics dim = dimensions.get(((TernaryInstruction) ins).input1); - ((TernaryInstruction) ins).processInstruction(valueClass, cachedValues, zeroInput, _buff.getMapBuffer(), _buff.getBlockBuffer(), dim.getRowsPerBlock(), dim.getColsPerBlock()); + MatrixCharacteristics dim = dimensions.get(((CtableInstruction) ins).input1); + ((CtableInstruction) ins).processInstruction(valueClass, cachedValues, zeroInput, _buff.getMapBuffer(), _buff.getBlockBuffer(), dim.getRowsPerBlock(), dim.getColsPerBlock()); if( _buff.getBufferSize() > GMRCtableBuffer.MAX_BUFFER_SIZE ) _buff.flushBuffer(cachedReporter); //prevent oom for large/many ctables } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMCJMRReducerWithAggregator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMCJMRReducerWithAggregator.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMCJMRReducerWithAggregator.java index 85306a6..a9e1714 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMCJMRReducerWithAggregator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMCJMRReducerWithAggregator.java @@ -123,15 +123,15 @@ public class MMCJMRReducerWithAggregator extends MMCJMRCombinerReducerBase { //perform matrix multiplication indexesbuffer.setIndexes(tmp.getKey().getRowIndex(), inIndex); - OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(tmp.getValue(), inValue, valueBuffer, - (AggregateBinaryOperator)aggBinInstruction.getOperator()); + OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes((MatrixBlock)tmp.getValue(), + (MatrixBlock)inValue, (MatrixBlock)valueBuffer, (AggregateBinaryOperator)aggBinInstruction.getOperator()); } else //right cached { //perform matrix multiplication indexesbuffer.setIndexes(inIndex, tmp.getKey().getColumnIndex()); - OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(inValue, tmp.getValue(), valueBuffer, - (AggregateBinaryOperator)aggBinInstruction.getOperator()); + OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes((MatrixBlock)inValue, + (MatrixBlock)tmp.getValue(), (MatrixBlock)valueBuffer, (AggregateBinaryOperator)aggBinInstruction.getOperator()); } //aggregate block to output buffer or direct output http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMRJMRReducer.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMRJMRReducer.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMRJMRReducer.java index 4bcb9c6..1572719 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMRJMRReducer.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/MMRJMRReducer.java @@ -30,6 +30,7 @@ import org.apache.hadoop.mapred.Reporter; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.mr.AggregateBinaryInstruction; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; @@ -42,9 +43,8 @@ import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; public class MMRJMRReducer extends ReduceBase implements Reducer<TripleIndexes, TaggedMatrixValue, MatrixIndexes, MatrixValue> { - private Reporter cachedReporter=null; - private MatrixValue resultblock=null; + private MatrixBlock resultblock=null; private MatrixIndexes aggIndexes=new MatrixIndexes(); private TripleIndexes prevIndexes=new TripleIndexes(-1, -1, -1); //aggregate binary instruction for the mmrj @@ -106,14 +106,11 @@ implements Reducer<TripleIndexes, TaggedMatrixValue, MatrixIndexes, MatrixValue> { IndexedMatrixValue left = cachedValues.getFirst(aggBinInstruction.input1); IndexedMatrixValue right= cachedValues.getFirst(aggBinInstruction.input2); - // System.out.println("left: \n"+left.getValue()); - // System.out.println("right: \n"+right.getValue()); if(left!=null && right!=null) { try { - resultblock=left.getValue().aggregateBinaryOperations(left.getValue(), right.getValue(), + resultblock=((MatrixBlock)left.getValue()).aggregateBinaryOperations((MatrixBlock)left.getValue(), (MatrixBlock)right.getValue(), resultblock, (AggregateBinaryOperator) aggBinInstruction.getOperator()); - // System.out.println("resultblock: \n"+resultblock); IndexedMatrixValue out=cachedValues.getFirst(aggBinInstruction.output); if(out==null) { http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/matrix/mapred/ReduceBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/ReduceBase.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/ReduceBase.java index 80df55d..15441ec 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/ReduceBase.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/ReduceBase.java @@ -33,7 +33,7 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.mr.AggregateInstruction; import org.apache.sysml.runtime.instructions.mr.MRInstruction; -import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; +import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; @@ -367,7 +367,7 @@ public class ReduceBase extends MRBaseForCommonInstructions { if( mixed_instructions != null ) for(MRInstruction inst : mixed_instructions) - if( inst instanceof TernaryInstruction ) + if( inst instanceof CtableInstruction ) return true; return false; } @@ -375,7 +375,7 @@ public class ReduceBase extends MRBaseForCommonInstructions protected boolean dimsKnownForTernaryInstructions() { if( mixed_instructions != null ) for(MRInstruction inst : mixed_instructions) - if( inst instanceof TernaryInstruction && !((TernaryInstruction)inst).knownOutputDims() ) + if( inst instanceof CtableInstruction && !((CtableInstruction)inst).knownOutputDims() ) return false; return true; } @@ -384,9 +384,9 @@ public class ReduceBase extends MRBaseForCommonInstructions { if( mixed_instructions != null ) for(MRInstruction inst : mixed_instructions) - if( inst instanceof TernaryInstruction ) + if( inst instanceof CtableInstruction ) { - TernaryInstruction tinst = (TernaryInstruction) inst; + CtableInstruction tinst = (CtableInstruction) inst; if( tinst.input1!=-1 ) dimensions.put(tinst.input1, MRJobConfiguration.getMatrixCharacteristicsForInput(job, tinst.input1)); //extend as required, currently only ctableexpand needs blocksizes http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicMatrixVectorMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicMatrixVectorMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicMatrixVectorMultTest.java index 537bef6..67e464a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicMatrixVectorMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicMatrixVectorMultTest.java @@ -171,10 +171,10 @@ public class BasicMatrixVectorMultTest extends AutomatedTestBase //matrix-vector uncompressed AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop); - MatrixBlock ret1 = (MatrixBlock)mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); + MatrixBlock ret1 = mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicVectorMatrixMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicVectorMatrixMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicVectorMatrixMultTest.java index 29832f1..c8592af 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicVectorMatrixMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/BasicVectorMatrixMultTest.java @@ -171,10 +171,10 @@ public class BasicVectorMatrixMultTest extends AutomatedTestBase //matrix-vector uncompressed AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop); - MatrixBlock ret1 = (MatrixBlock)vector.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop); + MatrixBlock ret1 = vector.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(vector, cmb, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(vector, cmb, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeMatrixVectorMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeMatrixVectorMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeMatrixVectorMultTest.java index 990845b..3401c77 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeMatrixVectorMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeMatrixVectorMultTest.java @@ -171,10 +171,10 @@ public class LargeMatrixVectorMultTest extends AutomatedTestBase //matrix-vector uncompressed AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop); - MatrixBlock ret1 = (MatrixBlock)mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); + MatrixBlock ret1 = mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParMatrixVectorMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParMatrixVectorMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParMatrixVectorMultTest.java index c5f9560..19d971a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParMatrixVectorMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParMatrixVectorMultTest.java @@ -173,10 +173,10 @@ public class LargeParMatrixVectorMultTest extends AutomatedTestBase AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop, InfrastructureAnalyzer.getLocalParallelism()); - MatrixBlock ret1 = (MatrixBlock)mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); + MatrixBlock ret1 = mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeVectorMatrixMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeVectorMatrixMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeVectorMatrixMultTest.java index 982c0c1..d7b9094 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeVectorMatrixMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeVectorMatrixMultTest.java @@ -171,10 +171,10 @@ public class LargeVectorMatrixMultTest extends AutomatedTestBase //matrix-vector uncompressed AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop); - MatrixBlock ret1 = (MatrixBlock)vector.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop); + MatrixBlock ret1 = vector.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(vector, cmb, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(vector, cmb, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/ParMatrixVectorMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/ParMatrixVectorMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/ParMatrixVectorMultTest.java index c34216c..3fdeb19 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/ParMatrixVectorMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/ParMatrixVectorMultTest.java @@ -173,10 +173,10 @@ public class ParMatrixVectorMultTest extends AutomatedTestBase AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop, InfrastructureAnalyzer.getLocalParallelism()); - MatrixBlock ret1 = (MatrixBlock)mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); + MatrixBlock ret1 = mb.aggregateBinaryOperations(mb, vector, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(cmb, vector, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/test/java/org/apache/sysml/test/integration/functions/compress/ParVectorMatrixMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/ParVectorMatrixMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/ParVectorMatrixMultTest.java index 1fee256..6bbcde9 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/ParVectorMatrixMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/ParVectorMatrixMultTest.java @@ -172,10 +172,10 @@ public class ParVectorMatrixMultTest extends AutomatedTestBase AggregateOperator aop = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator abop = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), aop, InfrastructureAnalyzer.getLocalParallelism()); - MatrixBlock ret1 = (MatrixBlock)vector.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop); + MatrixBlock ret1 = vector.aggregateBinaryOperations(vector, mb, new MatrixBlock(), abop); //matrix-vector compressed - MatrixBlock ret2 = (MatrixBlock)cmb.aggregateBinaryOperations(vector, cmb, new MatrixBlock(), abop); + MatrixBlock ret2 = cmb.aggregateBinaryOperations(vector, cmb, new MatrixBlock(), abop); //compare result with input double[][] d1 = DataConverter.convertToDoubleMatrix(ret1);
