[MINOR] Refactoring agg-binary, ternary, and quaternary instructions This patch makes a number of refactorings to cleanup the MatrixValue abstraction and prepare the addition of a "real" ternary instruction framework. So far ternary instructions were only used for ctable and are renamed accordingly.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/57438103 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/57438103 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/57438103 Branch: refs/heads/master Commit: 57438103b36f5c5ac09a3b09fe8d97ae4955439a Parents: 5b0fb0c Author: Matthias Boehm <[email protected]> Authored: Tue Jan 16 17:27:07 2018 -0800 Committer: Matthias Boehm <[email protected]> Committed: Wed Jan 17 15:04:11 2018 -0800 ---------------------------------------------------------------------- A.csv | 3 + A.csv.mtd | 12 + B.csv | 3 + B.csv.mtd | 12 + functions/jmlc/temp/m.binary | Bin 0 -> 255 bytes functions/jmlc/temp/m.binary.mtd | 12 + functions/jmlc/temp/m.csv | 4 + functions/jmlc/temp/m.csv.mtd | 12 + functions/jmlc/temp/m.mm | 8 + functions/jmlc/temp/m.txt | 6 + functions/jmlc/temp/m.txt.mtd | 10 + functions/jmlc/temp/scoring-example.dml | 12 + functions/jmlc/temp/x.csv | 2 + functions/jmlc/temp/x.csv.mtd | 12 + .../java/org/apache/sysml/hops/TernaryOp.java | 42 +- .../hops/cost/CostEstimatorStaticRuntime.java | 10 +- src/main/java/org/apache/sysml/lops/Ctable.java | 354 +++++++++++++ .../java/org/apache/sysml/lops/Ternary.java | 354 ------------- .../runtime/compress/CompressedMatrixBlock.java | 49 +- .../instructions/CPInstructionParser.java | 10 +- .../instructions/MRInstructionParser.java | 16 +- .../instructions/SPInstructionParser.java | 10 +- .../cp/AggregateBinaryCPInstruction.java | 2 +- .../runtime/instructions/cp/CPInstruction.java | 2 +- .../instructions/cp/CtableCPInstruction.java | 186 +++++++ .../cp/QuaternaryCPInstruction.java | 7 +- .../instructions/cp/TernaryCPInstruction.java | 186 ------- .../mr/AggregateBinaryInstruction.java | 37 +- .../mr/CombineTernaryInstruction.java | 4 +- .../instructions/mr/CtableInstruction.java | 264 ++++++++++ .../runtime/instructions/mr/MRInstruction.java | 2 +- .../instructions/mr/QuaternaryInstruction.java | 8 +- .../instructions/mr/TernaryInstruction.java | 264 ---------- .../instructions/spark/CtableSPInstruction.java | 510 +++++++++++++++++++ .../instructions/spark/SPInstruction.java | 2 +- .../spark/TernarySPInstruction.java | 510 ------------------- .../instructions/spark/ZipmmSPInstruction.java | 2 +- .../runtime/matrix/MatrixCharacteristics.java | 8 +- .../sysml/runtime/matrix/data/CM_N_COVCell.java | 82 +-- .../sysml/runtime/matrix/data/MatrixBlock.java | 46 +- .../sysml/runtime/matrix/data/MatrixCell.java | 59 +-- .../sysml/runtime/matrix/data/MatrixValue.java | 25 +- .../matrix/data/OperationsOnMatrixValues.java | 31 +- .../sysml/runtime/matrix/mapred/GMRReducer.java | 8 +- .../mapred/MMCJMRReducerWithAggregator.java | 8 +- .../runtime/matrix/mapred/MMRJMRReducer.java | 9 +- .../sysml/runtime/matrix/mapred/ReduceBase.java | 10 +- .../compress/BasicMatrixVectorMultTest.java | 4 +- .../compress/BasicVectorMatrixMultTest.java | 4 +- .../compress/LargeMatrixVectorMultTest.java | 4 +- .../compress/LargeParMatrixVectorMultTest.java | 4 +- .../compress/LargeVectorMatrixMultTest.java | 4 +- .../compress/ParMatrixVectorMultTest.java | 4 +- .../compress/ParVectorMatrixMultTest.java | 4 +- 54 files changed, 1637 insertions(+), 1616 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/A.csv ---------------------------------------------------------------------- diff --git a/A.csv b/A.csv new file mode 100644 index 0000000..e214f8a --- /dev/null +++ b/A.csv @@ -0,0 +1,3 @@ +0.0146817556045713,0.5112049868172497 +1.8844105118944472,1.575721197916694 +0.31179904293595984,1.9060943669721677 http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/A.csv.mtd ---------------------------------------------------------------------- diff --git a/A.csv.mtd b/A.csv.mtd new file mode 100644 index 0000000..268160d --- /dev/null +++ b/A.csv.mtd @@ -0,0 +1,12 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 3, + "cols": 2, + "nnz": 6, + "format": "csv", + "header": false, + "sep": ",", + "author": "mboehm", + "created": "2018-01-16 16:09:58 PST" +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/B.csv ---------------------------------------------------------------------- diff --git a/B.csv b/B.csv new file mode 100644 index 0000000..58a7cea --- /dev/null +++ b/B.csv @@ -0,0 +1,3 @@ +0.0146817556045713,0.5112049868172497,1.0,0,1.0,2.5258867424218208 +1.8844105118944472,1.575721197916694,1.0,1.0,1.8844105118944472,7.344542221705589 +0.31179904293595984,1.9060943669721677,1.0,2.0,2.0,7.217893409908127 http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/B.csv.mtd ---------------------------------------------------------------------- diff --git a/B.csv.mtd b/B.csv.mtd new file mode 100644 index 0000000..961e590 --- /dev/null +++ b/B.csv.mtd @@ -0,0 +1,12 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 3, + "cols": 6, + "nnz": 17, + "format": "csv", + "header": false, + "sep": ",", + "author": "mboehm", + "created": "2018-01-16 16:09:58 PST" +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.binary ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.binary b/functions/jmlc/temp/m.binary new file mode 100644 index 0000000..4455f9e Binary files /dev/null and b/functions/jmlc/temp/m.binary differ http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.binary.mtd ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.binary.mtd b/functions/jmlc/temp/m.binary.mtd new file mode 100644 index 0000000..31f6d06 --- /dev/null +++ b/functions/jmlc/temp/m.binary.mtd @@ -0,0 +1,12 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 4, + "cols": 3, + "rows_in_block": 1000, + "cols_in_block": 1000, + "nnz": 6, + "format": "binary", + "author": "mboehm", + "created": "2018-01-16 16:10:56 PST" +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.csv ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.csv b/functions/jmlc/temp/m.csv new file mode 100644 index 0000000..37aea08 --- /dev/null +++ b/functions/jmlc/temp/m.csv @@ -0,0 +1,4 @@ +1.0,2.0,3.0 +0,0,0 +7.0,8.0,9.0 +0,0,0 http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.csv.mtd ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.csv.mtd b/functions/jmlc/temp/m.csv.mtd new file mode 100644 index 0000000..2dce672 --- /dev/null +++ b/functions/jmlc/temp/m.csv.mtd @@ -0,0 +1,12 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 4, + "cols": 3, + "nnz": 6, + "format": "csv", + "header": false, + "sep": ",", + "author": "mboehm", + "created": "2018-01-16 16:10:56 PST" +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.mm ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.mm b/functions/jmlc/temp/m.mm new file mode 100644 index 0000000..f7ae3bf --- /dev/null +++ b/functions/jmlc/temp/m.mm @@ -0,0 +1,8 @@ +%%MatrixMarket matrix coordinate real general +4 3 6 +1 1 1.0 +1 2 2.0 +1 3 3.0 +3 1 7.0 +3 2 8.0 +3 3 9.0 http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.txt ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.txt b/functions/jmlc/temp/m.txt new file mode 100644 index 0000000..e11b21c --- /dev/null +++ b/functions/jmlc/temp/m.txt @@ -0,0 +1,6 @@ +1 1 1.0 +1 2 2.0 +1 3 3.0 +3 1 7.0 +3 2 8.0 +3 3 9.0 http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/m.txt.mtd ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/m.txt.mtd b/functions/jmlc/temp/m.txt.mtd new file mode 100644 index 0000000..ed6376f --- /dev/null +++ b/functions/jmlc/temp/m.txt.mtd @@ -0,0 +1,10 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 4, + "cols": 3, + "nnz": 6, + "format": "text", + "author": "mboehm", + "created": "2018-01-16 16:10:56 PST" +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/scoring-example.dml ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/scoring-example.dml b/functions/jmlc/temp/scoring-example.dml new file mode 100644 index 0000000..d45d95d --- /dev/null +++ b/functions/jmlc/temp/scoring-example.dml @@ -0,0 +1,12 @@ +X = read("./tmp/X", rows=-1, cols=-1); +W = read("./tmp/W", rows=-1, cols=-1); + +numRows = nrow(X); +numCols = ncol(X); +b = W[numCols+1,] +scores = X %*% W[1:numCols,] + b; +predicted_y = rowIndexMax(scores); + +print('pred:' + toString(predicted_y)) + +write(predicted_y, "./tmp", format="text"); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/x.csv ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/x.csv b/functions/jmlc/temp/x.csv new file mode 100644 index 0000000..e055049 --- /dev/null +++ b/functions/jmlc/temp/x.csv @@ -0,0 +1,2 @@ +1.0,2.0 +3.0,4.0 http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/functions/jmlc/temp/x.csv.mtd ---------------------------------------------------------------------- diff --git a/functions/jmlc/temp/x.csv.mtd b/functions/jmlc/temp/x.csv.mtd new file mode 100644 index 0000000..a350176 --- /dev/null +++ b/functions/jmlc/temp/x.csv.mtd @@ -0,0 +1,12 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 2, + "cols": 2, + "nnz": 4, + "format": "csv", + "header": false, + "sep": ",", + "author": "mboehm", + "created": "2018-01-16 16:10:56 PST" +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/hops/TernaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java index 47b012e..dc40982 100644 --- a/src/main/java/org/apache/sysml/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java @@ -34,7 +34,7 @@ import org.apache.sysml.lops.PickByCount; import org.apache.sysml.lops.PlusMult; import org.apache.sysml.lops.RepMat; import org.apache.sysml.lops.SortKeys; -import org.apache.sysml.lops.Ternary; +import org.apache.sysml.lops.Ctable; import org.apache.sysml.lops.UnaryCP; import org.apache.sysml.lops.CombineBinary.OperationTypes; import org.apache.sysml.lops.LopProperties.ExecType; @@ -412,7 +412,7 @@ public class TernaryOp extends Hop DataType dt1 = getInput().get(0).getDataType(); DataType dt2 = getInput().get(1).getDataType(); DataType dt3 = getInput().get(2).getDataType(); - Ternary.OperationTypes ternaryOpOrig = Ternary.findCtableOperationByInputDataTypes(dt1, dt2, dt3); + Ctable.OperationTypes ternaryOpOrig = Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3); // Compute lops for all inputs Lop[] inputLops = new Lop[getInput().size()]; @@ -428,8 +428,8 @@ public class TernaryOp extends Hop if ( et == ExecType.CP || et == ExecType.SPARK) { //for CP we support only ctable expand left - Ternary.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? - Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig; + Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable(true) ? + Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig; boolean ignoreZeros = false; if( isMatrixIgnoreZeroRewriteApplicable() ) { @@ -438,7 +438,7 @@ public class TernaryOp extends Hop inputLops[1] = ((ParameterizedBuiltinOp)getInput().get(1)).getTargetHop().getInput().get(0).constructLops(); } - Ternary ternary = new Ternary(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et); + Ctable ternary = new Ctable(inputLops, ternaryOp, getDataType(), getValueType(), ignoreZeros, et); ternary.getOutputParameters().setDimensions(_dim1, _dim2, getRowsInBlock(), getColsInBlock(), -1); setLineNumbers(ternary); @@ -459,8 +459,8 @@ public class TernaryOp extends Hop else //MR { //for MR we support both ctable expand left and right - Ternary.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? - Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig; + Ctable.OperationTypes ternaryOp = isSequenceRewriteApplicable() ? + Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ternaryOpOrig; Group group1 = null, group2 = null, group3 = null, group4 = null; group1 = new Group(inputLops[0], Group.OperationTypes.Sort, getDataType(), getValueType()); @@ -468,7 +468,7 @@ public class TernaryOp extends Hop getDim2(), getRowsInBlock(), getColsInBlock(), getNnz()); setLineNumbers(group1); - Ternary ternary = null; + Ctable ternary = null; // create "group" lops for MATRIX inputs switch (ternaryOp) { @@ -493,12 +493,12 @@ public class TernaryOp extends Hop setLineNumbers(group3); if ( inputLops.length == 3 ) - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1, group2, group3}, ternaryOp, getDataType(), getValueType(), et); else // output dimensions are given - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1, group2, group3, inputLops[3], inputLops[4]}, ternaryOp, getDataType(), getValueType(), et); break; @@ -515,11 +515,11 @@ public class TernaryOp extends Hop setLineNumbers(group2); if ( inputLops.length == 3) - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1,group2,inputLops[2]}, ternaryOp, getDataType(), getValueType(), et); else - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1,group2,inputLops[2], inputLops[3], inputLops[4]}, ternaryOp, getDataType(), getValueType(), et); @@ -539,14 +539,14 @@ public class TernaryOp extends Hop //TODO remove group, whenever we push it into the map task if (inputLops.length == 3) - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group, //matrix getInput().get(2).constructLops(), //weight new LiteralOp(left).constructLops() //left }, ternaryOp, getDataType(), getValueType(), et); else - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group, //matrix getInput().get(2).constructLops(), //weight new LiteralOp(left).constructLops(), //left @@ -558,14 +558,14 @@ public class TernaryOp extends Hop case CTABLE_TRANSFORM_HISTOGRAM: // F=ctable(A,1) or F = ctable(A,1,1) if ( inputLops.length == 3 ) - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1, getInput().get(1).constructLops(), getInput().get(2).constructLops() }, ternaryOp, getDataType(), getValueType(), et); else - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1, getInput().get(1).constructLops(), getInput().get(2).constructLops(), @@ -587,13 +587,13 @@ public class TernaryOp extends Hop setLineNumbers(group3); if ( inputLops.length == 3) - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1, getInput().get(1).constructLops(), group3}, ternaryOp, getDataType(), getValueType(), et); else - ternary = new Ternary( + ternary = new Ctable( new Lop[] {group1, getInput().get(1).constructLops(), group3, inputLops[3], inputLops[4] }, @@ -611,7 +611,7 @@ public class TernaryOp extends Hop Lop lctable = ternary; - if( !(_disjointInputs || ternaryOp == Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT) ) + if( !(_disjointInputs || ternaryOp == Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT) ) { //no need for aggregation if (1) input indexed disjoint or one side is sequence w/ 1 increment @@ -885,9 +885,9 @@ public class TernaryOp extends Hop else if( isSequenceRewriteApplicable(false) ) setDim2( input2._dim1 ); //for ctable_histogram also one dimension is known - Ternary.OperationTypes ternaryOp = Ternary.findCtableOperationByInputDataTypes( + Ctable.OperationTypes ternaryOp = Ctable.findCtableOperationByInputDataTypes( input1.getDataType(), input2.getDataType(), input3.getDataType()); - if( ternaryOp==Ternary.OperationTypes.CTABLE_TRANSFORM_HISTOGRAM + if( ternaryOp==Ctable.OperationTypes.CTABLE_TRANSFORM_HISTOGRAM && input2 instanceof LiteralOp ) { setDim2( HopRewriteUtils.getIntValueSafe((LiteralOp)input2) ); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java index 364ff09..aebfa9d 100644 --- a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java +++ b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java @@ -54,7 +54,7 @@ import org.apache.sysml.runtime.instructions.mr.MRInstruction; import org.apache.sysml.runtime.instructions.mr.MapMultChainInstruction; import org.apache.sysml.runtime.instructions.mr.PickByCountInstruction; import org.apache.sysml.runtime.instructions.mr.RemoveEmptyMRInstruction; -import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; +import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryMRInstructionBase; import org.apache.sysml.runtime.instructions.mr.MRInstruction.MRType; import org.apache.sysml.runtime.matrix.data.MatrixBlock; @@ -408,9 +408,9 @@ public class CostEstimatorStaticRuntime extends CostEstimator attr = new String[]{rbinst.isRemoveRows()?"0":"1"}; } } - else if( mrinst instanceof TernaryInstruction ) + else if( mrinst instanceof CtableInstruction ) { - TernaryInstruction tinst = (TernaryInstruction) mrinst; + CtableInstruction tinst = (CtableInstruction) mrinst; vs[0] = stats[ tinst.input1 ]; vs[1] = stats[ tinst.input2 ]; vs[2] = stats[ tinst.input3 ]; @@ -884,7 +884,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator else return d3m*d3n; - case Ternary: //opcodes: ctable + case Ctable: //opcodes: ctable if( optype.equals("ctable") ){ if( leftSparse ) return d1m * d1n * d1s; //add @@ -1139,7 +1139,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator //note: covers scalar, matrix, matrix-scalar return d3m * d3n; - case Ternary: //opcodes: ctabletransform, ctabletransformscalarweight, ctabletransformhistogram, ctabletransformweightedhistogram + case Ctable: //opcodes: ctabletransform, ctabletransformscalarweight, ctabletransformhistogram, ctabletransformweightedhistogram //note: copy from cp if( leftSparse ) return d1m * d1n * d1s; //add http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/lops/Ctable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Ctable.java b/src/main/java/org/apache/sysml/lops/Ctable.java new file mode 100644 index 0000000..7eec1bb --- /dev/null +++ b/src/main/java/org/apache/sysml/lops/Ctable.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.lops; + +import org.apache.sysml.lops.LopProperties.ExecLocation; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.compile.JobType; +import org.apache.sysml.parser.Expression.*; + + +/** + * Lop to perform ternary operation. All inputs must be matrices or vectors. + * For example, this lop is used in evaluating A = ctable(B,C,W) + * + * Currently, this lop is used only in case of CTABLE functionality. + */ + +public class Ctable extends Lop +{ + private boolean _ignoreZeros = false; + + public enum OperationTypes { + CTABLE_TRANSFORM, + CTABLE_TRANSFORM_SCALAR_WEIGHT, + CTABLE_TRANSFORM_HISTOGRAM, + CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM, + CTABLE_EXPAND_SCALAR_WEIGHT, + INVALID + } + + OperationTypes operation; + + + public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, ExecType et) { + this(inputLops, op, dt, vt, false, et); + } + + public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, ExecType et) { + super(Lop.Type.Ternary, dt, vt); + init(inputLops, op, et); + _ignoreZeros = ignoreZeros; + } + + private void init(Lop[] inputLops, OperationTypes op, ExecType et) { + operation = op; + + for(int i=0; i < inputLops.length; i++) { + this.addInput(inputLops[i]); + inputLops[i].addOutput(this); + } + + boolean breaksAlignment = true; + boolean aligner = false; + boolean definesMRJob = false; + + if ( et == ExecType.MR ) { + lps.addCompatibility(JobType.GMR); + //lps.addCompatibility(JobType.DATAGEN); MB: disabled due to piggybacking issues + //lps.addCompatibility(JobType.REBLOCK); MB: disabled since no runtime support + + if( operation==OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT ) + this.lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); + //TODO create runtime for ctable in gmr mapper and switch to maporreduce. + //this.lps.setProperties( inputs, et, ExecLocation.MapOrReduce, breaksAlignment, aligner, definesMRJob ); + else + this.lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); + } + else { + lps.addCompatibility(JobType.INVALID); + this.lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + } + } + + @Override + public String toString() { + + return " Operation: " + operation; + + } + + public static OperationTypes findCtableOperationByInputDataTypes(DataType dt1, DataType dt2, DataType dt3) + { + if ( dt1 == DataType.MATRIX ) { + if (dt2 == DataType.MATRIX && dt3 == DataType.SCALAR) { + // F = ctable(A,B) or F = ctable(A,B,1) + return OperationTypes.CTABLE_TRANSFORM_SCALAR_WEIGHT; + } else if (dt2 == DataType.SCALAR && dt3 == DataType.SCALAR) { + // F=ctable(A,1) or F = ctable(A,1,1) + return OperationTypes.CTABLE_TRANSFORM_HISTOGRAM; + } else if (dt2 == DataType.SCALAR && dt3 == DataType.MATRIX) { + // F=ctable(A,1,W) + return OperationTypes.CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM; + } else { + // F=ctable(A,B,W) + return OperationTypes.CTABLE_TRANSFORM; + } + } + else { + return OperationTypes.INVALID; + } + } + + /** + * method to get operation type + * @return operation type + */ + + public OperationTypes getOperationType() + { + return operation; + } + + @Override + public String getInstructions(String input1, String input2, String input3, String output) throws LopsException + { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + sb.append( Lop.OPERAND_DELIMITOR ); + if( operation != Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT ) + sb.append( "ctable" ); + else + sb.append( "ctableexpand" ); + sb.append( OPERAND_DELIMITOR ); + + if ( getInputs().get(0).getDataType() == DataType.SCALAR ) { + sb.append ( getInputs().get(0).prepScalarInputOperand(getExecType()) ); + } + else { + sb.append( getInputs().get(0).prepInputOperand(input1)); + } + sb.append( OPERAND_DELIMITOR ); + + if ( getInputs().get(1).getDataType() == DataType.SCALAR ) { + sb.append ( getInputs().get(1).prepScalarInputOperand(getExecType()) ); + } + else { + sb.append( getInputs().get(1).prepInputOperand(input2)); + } + sb.append( OPERAND_DELIMITOR ); + + if ( getInputs().get(2).getDataType() == DataType.SCALAR ) { + sb.append ( getInputs().get(2).prepScalarInputOperand(getExecType()) ); + } + else { + sb.append( getInputs().get(2).prepInputOperand(input3)); + } + sb.append( OPERAND_DELIMITOR ); + + if ( this.getInputs().size() > 3 ) { + sb.append(getInputs().get(3).getOutputParameters().getLabel()); + sb.append(LITERAL_PREFIX); + sb.append((getInputs().get(3).getType() == Type.Data && ((Data)getInputs().get(3)).isLiteral()) ); + sb.append( OPERAND_DELIMITOR ); + + sb.append(getInputs().get(4).getOutputParameters().getLabel()); + sb.append(LITERAL_PREFIX); + sb.append((getInputs().get(4).getType() == Type.Data && ((Data)getInputs().get(4)).isLiteral()) ); + sb.append( OPERAND_DELIMITOR ); + } + else { + sb.append(-1); + sb.append(LITERAL_PREFIX); + sb.append(true); + sb.append( OPERAND_DELIMITOR ); + + sb.append(-1); + sb.append(LITERAL_PREFIX); + sb.append(true); + sb.append( OPERAND_DELIMITOR ); + } + sb.append( this.prepOutputOperand(output)); + + sb.append( OPERAND_DELIMITOR ); + sb.append( _ignoreZeros ); + + return sb.toString(); + } + + @Override + public String getInstructions(int input_index1, int input_index2, int input_index3, int output_index) throws LopsException + { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + sb.append( Lop.OPERAND_DELIMITOR ); + switch(operation) { + /* Arithmetic */ + case CTABLE_TRANSFORM: + // F = ctable(A,B,W) + sb.append( "ctabletransform" ); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(0).prepInputOperand(input_index1)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(1).prepInputOperand(input_index2)); + sb.append( OPERAND_DELIMITOR ); + sb.append( getInputs().get(2).prepInputOperand(input_index3)); + sb.append( OPERAND_DELIMITOR ); + + break; + + case CTABLE_TRANSFORM_SCALAR_WEIGHT: + // F = ctable(A,B) or F = ctable(A,B,1) + // third input must be a scalar, and hence input_index3 == -1 + if ( input_index3 != -1 ) { + throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation + " \n"); + } + + int scalarIndex = 2; // index of the scalar input + + sb.append( "ctabletransformscalarweight" ); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(0).prepInputOperand(input_index1)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(1).prepInputOperand(input_index2)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(scalarIndex).prepScalarInputOperand(getExecType())); + sb.append( OPERAND_DELIMITOR ); + + break; + + case CTABLE_EXPAND_SCALAR_WEIGHT: + // F = ctable(seq,B) or F = ctable(seq,B,1) + // second and third inputs must be scalars, and hence input_index2 == -1, input_index3 == -1 + if ( input_index3 != -1 ) { + throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation + " \n"); + } + + int scalarIndex2 = 1; // index of the scalar input + int scalarIndex3 = 2; // index of the scalar input + + sb.append( "ctableexpandscalarweight" ); + sb.append( OPERAND_DELIMITOR ); + //get(0) because input under group + sb.append( getInputs().get(0).prepInputOperand(input_index1)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(scalarIndex2).prepScalarInputOperand(getExecType())); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(scalarIndex3).prepScalarInputOperand(getExecType())); + sb.append( OPERAND_DELIMITOR ); + + break; + + case CTABLE_TRANSFORM_HISTOGRAM: + // F=ctable(A,1) or F = ctable(A,1,1) + if ( input_index2 != -1 || input_index3 != -1) + throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation); + + // 2nd and 3rd inputs are scalar inputs + + sb.append( "ctabletransformhistogram" ); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(0).prepInputOperand(input_index1)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(1).prepScalarInputOperand(getExecType()) ); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(2).prepScalarInputOperand(getExecType()) ); + sb.append( OPERAND_DELIMITOR ); + + break; + + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: + // F=ctable(A,1,W) + if ( input_index2 != -1 ) + throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation); + + // 2nd input is the scalar input + + sb.append( "ctabletransformweightedhistogram" ); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(0).prepInputOperand(input_index1)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(1).prepScalarInputOperand(getExecType())); + sb.append( OPERAND_DELIMITOR ); + + sb.append( getInputs().get(2).prepInputOperand(input_index3)); + sb.append( OPERAND_DELIMITOR ); + + break; + + default: + throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Tertiary operation: " + operation); + } + + long outputDim1=-1, outputDim2=-1; + if ( getInputs().size() > 3 ) { + sb.append(getInputs().get(3).prepScalarLabel()); + sb.append( OPERAND_DELIMITOR ); + + sb.append(getInputs().get(4).prepScalarLabel()); + sb.append( OPERAND_DELIMITOR ); + /*if ( input3 instanceof Data && ((Data)input3).isLiteral() + && input4 instanceof Data && ((Data)input4).isLiteral() ) { + outputDim1 = ((Data)input3).getLongValue(); + outputDim2 = ((Data)input4).getLongValue(); + }*/ + } + else { + sb.append( outputDim1 ); + sb.append( OPERAND_DELIMITOR ); + + sb.append( outputDim2 ); + sb.append( OPERAND_DELIMITOR ); + } + sb.append( this.prepOutputOperand(output_index)); + + return sb.toString(); + } + + public static OperationTypes getOperationType(String opcode) + { + OperationTypes op = null; + + if( opcode.equals("ctabletransform") ) + op = OperationTypes.CTABLE_TRANSFORM; + else if( opcode.equals("ctabletransformscalarweight") ) + op = OperationTypes.CTABLE_TRANSFORM_SCALAR_WEIGHT; + else if( opcode.equals("ctableexpandscalarweight") ) + op = OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT; + else if( opcode.equals("ctabletransformhistogram") ) + op = OperationTypes.CTABLE_TRANSFORM_HISTOGRAM; + else if( opcode.equals("ctabletransformweightedhistogram") ) + op = OperationTypes.CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM; + else + throw new UnsupportedOperationException("Tertiary operation code is not defined: " + opcode); + + return op; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/lops/Ternary.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Ternary.java b/src/main/java/org/apache/sysml/lops/Ternary.java deleted file mode 100644 index ba3f884..0000000 --- a/src/main/java/org/apache/sysml/lops/Ternary.java +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.lops; - -import org.apache.sysml.lops.LopProperties.ExecLocation; -import org.apache.sysml.lops.LopProperties.ExecType; -import org.apache.sysml.lops.compile.JobType; -import org.apache.sysml.parser.Expression.*; - - -/** - * Lop to perform ternary operation. All inputs must be matrices or vectors. - * For example, this lop is used in evaluating A = ctable(B,C,W) - * - * Currently, this lop is used only in case of CTABLE functionality. - */ - -public class Ternary extends Lop -{ - private boolean _ignoreZeros = false; - - public enum OperationTypes { - CTABLE_TRANSFORM, - CTABLE_TRANSFORM_SCALAR_WEIGHT, - CTABLE_TRANSFORM_HISTOGRAM, - CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM, - CTABLE_EXPAND_SCALAR_WEIGHT, - INVALID - } - - OperationTypes operation; - - - public Ternary(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, ExecType et) { - this(inputLops, op, dt, vt, false, et); - } - - public Ternary(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, ExecType et) { - super(Lop.Type.Ternary, dt, vt); - init(inputLops, op, et); - _ignoreZeros = ignoreZeros; - } - - private void init(Lop[] inputLops, OperationTypes op, ExecType et) { - operation = op; - - for(int i=0; i < inputLops.length; i++) { - this.addInput(inputLops[i]); - inputLops[i].addOutput(this); - } - - boolean breaksAlignment = true; - boolean aligner = false; - boolean definesMRJob = false; - - if ( et == ExecType.MR ) { - lps.addCompatibility(JobType.GMR); - //lps.addCompatibility(JobType.DATAGEN); MB: disabled due to piggybacking issues - //lps.addCompatibility(JobType.REBLOCK); MB: disabled since no runtime support - - if( operation==OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT ) - this.lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); - //TODO create runtime for ctable in gmr mapper and switch to maporreduce. - //this.lps.setProperties( inputs, et, ExecLocation.MapOrReduce, breaksAlignment, aligner, definesMRJob ); - else - this.lps.setProperties( inputs, et, ExecLocation.Reduce, breaksAlignment, aligner, definesMRJob ); - } - else { - lps.addCompatibility(JobType.INVALID); - this.lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); - } - } - - @Override - public String toString() { - - return " Operation: " + operation; - - } - - public static OperationTypes findCtableOperationByInputDataTypes(DataType dt1, DataType dt2, DataType dt3) - { - if ( dt1 == DataType.MATRIX ) { - if (dt2 == DataType.MATRIX && dt3 == DataType.SCALAR) { - // F = ctable(A,B) or F = ctable(A,B,1) - return OperationTypes.CTABLE_TRANSFORM_SCALAR_WEIGHT; - } else if (dt2 == DataType.SCALAR && dt3 == DataType.SCALAR) { - // F=ctable(A,1) or F = ctable(A,1,1) - return OperationTypes.CTABLE_TRANSFORM_HISTOGRAM; - } else if (dt2 == DataType.SCALAR && dt3 == DataType.MATRIX) { - // F=ctable(A,1,W) - return OperationTypes.CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM; - } else { - // F=ctable(A,B,W) - return OperationTypes.CTABLE_TRANSFORM; - } - } - else { - return OperationTypes.INVALID; - } - } - - /** - * method to get operation type - * @return operation type - */ - - public OperationTypes getOperationType() - { - return operation; - } - - @Override - public String getInstructions(String input1, String input2, String input3, String output) throws LopsException - { - StringBuilder sb = new StringBuilder(); - sb.append( getExecType() ); - sb.append( Lop.OPERAND_DELIMITOR ); - if( operation != Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT ) - sb.append( "ctable" ); - else - sb.append( "ctableexpand" ); - sb.append( OPERAND_DELIMITOR ); - - if ( getInputs().get(0).getDataType() == DataType.SCALAR ) { - sb.append ( getInputs().get(0).prepScalarInputOperand(getExecType()) ); - } - else { - sb.append( getInputs().get(0).prepInputOperand(input1)); - } - sb.append( OPERAND_DELIMITOR ); - - if ( getInputs().get(1).getDataType() == DataType.SCALAR ) { - sb.append ( getInputs().get(1).prepScalarInputOperand(getExecType()) ); - } - else { - sb.append( getInputs().get(1).prepInputOperand(input2)); - } - sb.append( OPERAND_DELIMITOR ); - - if ( getInputs().get(2).getDataType() == DataType.SCALAR ) { - sb.append ( getInputs().get(2).prepScalarInputOperand(getExecType()) ); - } - else { - sb.append( getInputs().get(2).prepInputOperand(input3)); - } - sb.append( OPERAND_DELIMITOR ); - - if ( this.getInputs().size() > 3 ) { - sb.append(getInputs().get(3).getOutputParameters().getLabel()); - sb.append(LITERAL_PREFIX); - sb.append((getInputs().get(3).getType() == Type.Data && ((Data)getInputs().get(3)).isLiteral()) ); - sb.append( OPERAND_DELIMITOR ); - - sb.append(getInputs().get(4).getOutputParameters().getLabel()); - sb.append(LITERAL_PREFIX); - sb.append((getInputs().get(4).getType() == Type.Data && ((Data)getInputs().get(4)).isLiteral()) ); - sb.append( OPERAND_DELIMITOR ); - } - else { - sb.append(-1); - sb.append(LITERAL_PREFIX); - sb.append(true); - sb.append( OPERAND_DELIMITOR ); - - sb.append(-1); - sb.append(LITERAL_PREFIX); - sb.append(true); - sb.append( OPERAND_DELIMITOR ); - } - sb.append( this.prepOutputOperand(output)); - - sb.append( OPERAND_DELIMITOR ); - sb.append( _ignoreZeros ); - - return sb.toString(); - } - - @Override - public String getInstructions(int input_index1, int input_index2, int input_index3, int output_index) throws LopsException - { - StringBuilder sb = new StringBuilder(); - sb.append( getExecType() ); - sb.append( Lop.OPERAND_DELIMITOR ); - switch(operation) { - /* Arithmetic */ - case CTABLE_TRANSFORM: - // F = ctable(A,B,W) - sb.append( "ctabletransform" ); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(0).prepInputOperand(input_index1)); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(1).prepInputOperand(input_index2)); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(2).prepInputOperand(input_index3)); - sb.append( OPERAND_DELIMITOR ); - - break; - - case CTABLE_TRANSFORM_SCALAR_WEIGHT: - // F = ctable(A,B) or F = ctable(A,B,1) - // third input must be a scalar, and hence input_index3 == -1 - if ( input_index3 != -1 ) { - throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation + " \n"); - } - - int scalarIndex = 2; // index of the scalar input - - sb.append( "ctabletransformscalarweight" ); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(0).prepInputOperand(input_index1)); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(1).prepInputOperand(input_index2)); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(scalarIndex).prepScalarInputOperand(getExecType())); - sb.append( OPERAND_DELIMITOR ); - - break; - - case CTABLE_EXPAND_SCALAR_WEIGHT: - // F = ctable(seq,B) or F = ctable(seq,B,1) - // second and third inputs must be scalars, and hence input_index2 == -1, input_index3 == -1 - if ( input_index3 != -1 ) { - throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation + " \n"); - } - - int scalarIndex2 = 1; // index of the scalar input - int scalarIndex3 = 2; // index of the scalar input - - sb.append( "ctableexpandscalarweight" ); - sb.append( OPERAND_DELIMITOR ); - //get(0) because input under group - sb.append( getInputs().get(0).prepInputOperand(input_index1)); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(scalarIndex2).prepScalarInputOperand(getExecType())); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(scalarIndex3).prepScalarInputOperand(getExecType())); - sb.append( OPERAND_DELIMITOR ); - - break; - - case CTABLE_TRANSFORM_HISTOGRAM: - // F=ctable(A,1) or F = ctable(A,1,1) - if ( input_index2 != -1 || input_index3 != -1) - throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation); - - // 2nd and 3rd inputs are scalar inputs - - sb.append( "ctabletransformhistogram" ); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(0).prepInputOperand(input_index1)); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(1).prepScalarInputOperand(getExecType()) ); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(2).prepScalarInputOperand(getExecType()) ); - sb.append( OPERAND_DELIMITOR ); - - break; - - case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: - // F=ctable(A,1,W) - if ( input_index2 != -1 ) - throw new LopsException(this.printErrorLocation() + "In Tertiary Lop, Unexpected input while computing the instructions for op: " + operation); - - // 2nd input is the scalar input - - sb.append( "ctabletransformweightedhistogram" ); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(0).prepInputOperand(input_index1)); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(1).prepScalarInputOperand(getExecType())); - sb.append( OPERAND_DELIMITOR ); - - sb.append( getInputs().get(2).prepInputOperand(input_index3)); - sb.append( OPERAND_DELIMITOR ); - - break; - - default: - throw new UnsupportedOperationException(this.printErrorLocation() + "Instruction is not defined for Tertiary operation: " + operation); - } - - long outputDim1=-1, outputDim2=-1; - if ( getInputs().size() > 3 ) { - sb.append(getInputs().get(3).prepScalarLabel()); - sb.append( OPERAND_DELIMITOR ); - - sb.append(getInputs().get(4).prepScalarLabel()); - sb.append( OPERAND_DELIMITOR ); - /*if ( input3 instanceof Data && ((Data)input3).isLiteral() - && input4 instanceof Data && ((Data)input4).isLiteral() ) { - outputDim1 = ((Data)input3).getLongValue(); - outputDim2 = ((Data)input4).getLongValue(); - }*/ - } - else { - sb.append( outputDim1 ); - sb.append( OPERAND_DELIMITOR ); - - sb.append( outputDim2 ); - sb.append( OPERAND_DELIMITOR ); - } - sb.append( this.prepOutputOperand(output_index)); - - return sb.toString(); - } - - public static OperationTypes getOperationType(String opcode) - { - OperationTypes op = null; - - if( opcode.equals("ctabletransform") ) - op = OperationTypes.CTABLE_TRANSFORM; - else if( opcode.equals("ctabletransformscalarweight") ) - op = OperationTypes.CTABLE_TRANSFORM_SCALAR_WEIGHT; - else if( opcode.equals("ctableexpandscalarweight") ) - op = OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT; - else if( opcode.equals("ctabletransformhistogram") ) - op = OperationTypes.CTABLE_TRANSFORM_HISTOGRAM; - else if( opcode.equals("ctabletransformweightedhistogram") ) - op = OperationTypes.CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM; - else - throw new UnsupportedOperationException("Tertiary operation code is not defined: " + opcode); - - return op; - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java index 52ed2b4..0ce3dc8 100644 --- a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java @@ -1086,18 +1086,18 @@ public class CompressedMatrixBlock extends MatrixBlock implements Externalizable } @Override - public MatrixValue aggregateBinaryOperations(MatrixValue mv1, MatrixValue mv2, MatrixValue result, AggregateBinaryOperator op) + public MatrixBlock aggregateBinaryOperations(MatrixBlock mv1, MatrixBlock mv2, MatrixBlock ret, AggregateBinaryOperator op) throws DMLRuntimeException { //call uncompressed matrix mult if necessary if( !isCompressed() ) { - return super.aggregateBinaryOperations(mv1, mv2, result, op); + return super.aggregateBinaryOperations(mv1, mv2, ret, op); } //multi-threaded mm of single uncompressed colgroup if( isSingleUncompressedGroup() ){ MatrixBlock tmp = ((ColGroupUncompressed)_colGroups.get(0)).getData(); - return tmp.aggregateBinaryOperations(this==mv1?tmp:mv1, this==mv2?tmp:mv2, result, op); + return tmp.aggregateBinaryOperations(this==mv1?tmp:mv1, this==mv2?tmp:mv2, ret, op); } Timing time = LOG.isDebugEnabled() ? new Timing(true) : null; @@ -1107,7 +1107,6 @@ public class CompressedMatrixBlock extends MatrixBlock implements Externalizable int cl = mv2.getNumColumns(); //create output matrix block - MatrixBlock ret = (MatrixBlock) result; if( ret==null ) ret = new MatrixBlock(rl, cl, false, rl*cl); else @@ -2119,9 +2118,9 @@ public class CompressedMatrixBlock extends MatrixBlock implements Externalizable } @Override - public MatrixValue aggregateBinaryOperations(MatrixIndexes m1Index, - MatrixValue m1Value, MatrixIndexes m2Index, MatrixValue m2Value, - MatrixValue result, AggregateBinaryOperator op) + public MatrixBlock aggregateBinaryOperations(MatrixIndexes m1Index, + MatrixBlock m1Value, MatrixIndexes m2Index, MatrixBlock m2Value, + MatrixBlock result, AggregateBinaryOperator op) throws DMLRuntimeException { printDecompressWarning("aggregateBinaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; @@ -2199,84 +2198,84 @@ public class CompressedMatrixBlock extends MatrixBlock implements Externalizable } @Override - public void ternaryOperations(Operator op, double scalar, + public void ctableOperations(Operator op, double scalar, MatrixValue that, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; MatrixBlock right = getUncompressed(that); - left.ternaryOperations(op, scalar, right, resultMap, resultBlock); + left.ctableOperations(op, scalar, right, resultMap, resultBlock); } @Override - public void ternaryOperations(Operator op, double scalar, + public void ctableOperations(Operator op, double scalar, double scalar2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock tmp = isCompressed() ? decompress() : this; - tmp.ternaryOperations(op, scalar, scalar2, resultMap, resultBlock); + tmp.ctableOperations(op, scalar, scalar2, resultMap, resultBlock); } @Override - public void ternaryOperations(Operator op, MatrixIndexes ix1, + public void ctableOperations(Operator op, MatrixIndexes ix1, double scalar, boolean left, int brlen, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock tmp = isCompressed() ? decompress() : this; - tmp.ternaryOperations(op, ix1, scalar, left, brlen, resultMap, resultBlock); + tmp.ctableOperations(op, ix1, scalar, left, brlen, resultMap, resultBlock); } @Override - public void ternaryOperations(Operator op, MatrixValue that, + public void ctableOperations(Operator op, MatrixValue that, double scalar, boolean ignoreZeros, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; MatrixBlock right = getUncompressed(that); - left.ternaryOperations(op, right, scalar, ignoreZeros, resultMap, resultBlock); + left.ctableOperations(op, right, scalar, ignoreZeros, resultMap, resultBlock); } @Override - public void ternaryOperations(Operator op, MatrixValue that, double scalar, MatrixBlock resultBlock) + public void ctableOperations(Operator op, MatrixValue that, double scalar, MatrixBlock resultBlock) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; MatrixBlock right = getUncompressed(that); - left.ternaryOperations(op, right, scalar, resultBlock); + left.ctableOperations(op, right, scalar, resultBlock); } @Override - public void ternaryOperations(Operator op, MatrixValue that, + public void ctableOperations(Operator op, MatrixValue that, MatrixValue that2, CTableMap resultMap) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; MatrixBlock right1 = getUncompressed(that); MatrixBlock right2 = getUncompressed(that2); - left.ternaryOperations(op, right1, right2, resultMap); + left.ctableOperations(op, right1, right2, resultMap); } @Override - public void ternaryOperations(Operator op, MatrixValue that, + public void ctableOperations(Operator op, MatrixValue that, MatrixValue that2, CTableMap resultMap, MatrixBlock resultBlock) throws DMLRuntimeException { printDecompressWarning("ternaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; MatrixBlock right1 = getUncompressed(that); MatrixBlock right2 = getUncompressed(that2); - left.ternaryOperations(op, right1, right2, resultMap, resultBlock); + left.ctableOperations(op, right1, right2, resultMap, resultBlock); } @Override - public MatrixValue quaternaryOperations(QuaternaryOperator qop, - MatrixValue um, MatrixValue vm, MatrixValue wm, MatrixValue out) + public MatrixBlock quaternaryOperations(QuaternaryOperator qop, + MatrixBlock um, MatrixBlock vm, MatrixBlock wm, MatrixBlock out) throws DMLRuntimeException { return quaternaryOperations(qop, um, vm, wm, out, 1); } @Override - public MatrixValue quaternaryOperations(QuaternaryOperator qop, MatrixValue um, - MatrixValue vm, MatrixValue wm, MatrixValue out, int k) + public MatrixBlock quaternaryOperations(QuaternaryOperator qop, MatrixBlock um, + MatrixBlock vm, MatrixBlock wm, MatrixBlock out, int k) throws DMLRuntimeException { printDecompressWarning("quaternaryOperations"); MatrixBlock left = isCompressed() ? decompress() : this; http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index e2bb20a..6b2e3d6 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -59,7 +59,7 @@ import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.ReorgCPInstruction; import org.apache.sysml.runtime.instructions.cp.SpoofCPInstruction; import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction; -import org.apache.sysml.runtime.instructions.cp.TernaryCPInstruction; +import org.apache.sysml.runtime.instructions.cp.CtableCPInstruction; import org.apache.sysml.runtime.instructions.cp.UaggOuterChainCPInstruction; import org.apache.sysml.runtime.instructions.cp.UnaryCPInstruction; import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; @@ -254,8 +254,8 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( DataGen.SINIT_OPCODE , CPType.StringInit); String2CPInstructionType.put( DataGen.SAMPLE_OPCODE , CPType.Rand); - String2CPInstructionType.put( "ctable", CPType.Ternary); - String2CPInstructionType.put( "ctableexpand", CPType.Ternary); + String2CPInstructionType.put( "ctable", CPType.Ctable); + String2CPInstructionType.put( "ctableexpand", CPType.Ctable); //central moment, covariance, quantiles (sort/pick) String2CPInstructionType.put( "cm" , CPType.CentralMoment); @@ -327,8 +327,8 @@ public class CPInstructionParser extends InstructionParser else return BinaryCPInstruction.parseInstruction(str); - case Ternary: - return TernaryCPInstruction.parseInstruction(str); + case Ctable: + return CtableCPInstruction.parseInstruction(str); case Quaternary: return QuaternaryCPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java index debafe1..fae93f3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java @@ -74,7 +74,7 @@ import org.apache.sysml.runtime.instructions.mr.ReorgInstruction; import org.apache.sysml.runtime.instructions.mr.ReplicateInstruction; import org.apache.sysml.runtime.instructions.mr.ScalarInstruction; import org.apache.sysml.runtime.instructions.mr.SeqInstruction; -import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; +import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryInstruction; import org.apache.sysml.runtime.instructions.mr.ZeroOutInstruction; @@ -239,11 +239,11 @@ public class MRInstructionParser extends InstructionParser String2MRInstructionType.put( "csvrblk", MRType.CSVReblock); // Ternary Reorg Instruction Opcodes - String2MRInstructionType.put( "ctabletransform", MRType.Ternary); - String2MRInstructionType.put( "ctabletransformscalarweight", MRType.Ternary); - String2MRInstructionType.put( "ctableexpandscalarweight", MRType.Ternary); - String2MRInstructionType.put( "ctabletransformhistogram", MRType.Ternary); - String2MRInstructionType.put( "ctabletransformweightedhistogram", MRType.Ternary); + String2MRInstructionType.put( "ctabletransform", MRType.Ctable); + String2MRInstructionType.put( "ctabletransformscalarweight", MRType.Ctable); + String2MRInstructionType.put( "ctableexpandscalarweight", MRType.Ctable); + String2MRInstructionType.put( "ctabletransformhistogram", MRType.Ctable); + String2MRInstructionType.put( "ctabletransformweightedhistogram", MRType.Ctable); // Quaternary Instruction Opcodes String2MRInstructionType.put( WeightedSquaredLoss.OPCODE, MRType.Quaternary); @@ -362,8 +362,8 @@ public class MRInstructionParser extends InstructionParser case AggregateUnary: return AggregateUnaryInstruction.parseInstruction(str); - case Ternary: - return TernaryInstruction.parseInstruction(str); + case Ctable: + return CtableInstruction.parseInstruction(str); case Quaternary: return QuaternaryInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index 1ba7776..3098c10 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -74,7 +74,7 @@ import org.apache.sysml.runtime.instructions.spark.RmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.SPInstruction; import org.apache.sysml.runtime.instructions.spark.SPInstruction.SPType; import org.apache.sysml.runtime.instructions.spark.SpoofSPInstruction; -import org.apache.sysml.runtime.instructions.spark.TernarySPInstruction; +import org.apache.sysml.runtime.instructions.spark.CtableSPInstruction; import org.apache.sysml.runtime.instructions.spark.Tsmm2SPInstruction; import org.apache.sysml.runtime.instructions.spark.TsmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.QuantileSortSPInstruction; @@ -268,8 +268,8 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( DataGen.SAMPLE_OPCODE, SPType.Rand); //ternary instruction opcodes - String2SPInstructionType.put( "ctable", SPType.Ternary); - String2SPInstructionType.put( "ctableexpand", SPType.Ternary); + String2SPInstructionType.put( "ctable", SPType.Ctable); + String2SPInstructionType.put( "ctableexpand", SPType.Ctable); //quaternary instruction opcodes String2SPInstructionType.put( WeightedSquaredLoss.OPCODE, SPType.Quaternary); @@ -381,8 +381,8 @@ public class SPInstructionParser extends InstructionParser return BinarySPInstruction.parseInstruction(str); //ternary instructions - case Ternary: - return TernarySPInstruction.parseInstruction(str); + case Ctable: + return CtableSPInstruction.parseInstruction(str); //quaternary instructions case Quaternary: http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateBinaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateBinaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateBinaryCPInstruction.java index 4f542d6..32732d3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateBinaryCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/AggregateBinaryCPInstruction.java @@ -74,7 +74,7 @@ public class AggregateBinaryCPInstruction extends BinaryCPInstruction { //compute matrix multiplication AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr; MatrixBlock main = (matBlock2 instanceof CompressedMatrixBlock) ? matBlock2 : matBlock1; - MatrixBlock ret = (MatrixBlock) main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op); + MatrixBlock ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op); //release inputs/outputs ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java index bebf2ca..6b86029 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java @@ -31,7 +31,7 @@ public abstract class CPInstruction extends Instruction { public enum CPType { AggregateUnary, AggregateBinary, AggregateTernary, - Unary, Binary, Ternary, Quaternary, BuiltinNary, + Unary, Binary, Ctable, Quaternary, BuiltinNary, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, SpoofFused, http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/cp/CtableCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CtableCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CtableCPInstruction.java new file mode 100644 index 0000000..0879b81 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CtableCPInstruction.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.cp; + +import org.apache.sysml.lops.Ctable; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.Instruction; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.CTableMap; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.runtime.matrix.operators.SimpleOperator; +import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.runtime.util.LongLongDoubleHashMap.EntryType; + +public class CtableCPInstruction extends ComputationCPInstruction { + private final String _outDim1; + private final String _outDim2; + private final boolean _dim1Literal; + private final boolean _dim2Literal; + private final boolean _isExpand; + private final boolean _ignoreZeros; + + private CtableCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, + boolean ignoreZeros, String opcode, String istr) { + super(CPType.Ctable, op, in1, in2, in3, out, opcode, istr); + _outDim1 = outputDim1; + _dim1Literal = dim1Literal; + _outDim2 = outputDim2; + _dim2Literal = dim2Literal; + _isExpand = isExpand; + _ignoreZeros = ignoreZeros; + } + + public static CtableCPInstruction parseInstruction(String inst) + throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst); + InstructionUtils.checkNumFields ( parts, 7 ); + + String opcode = parts[0]; + + //handle opcode + if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) { + throw new DMLRuntimeException("Unexpected opcode in TertiaryCPInstruction: " + inst); + } + boolean isExpand = opcode.equalsIgnoreCase("ctableexpand"); + + //handle operands + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + + //handle known dimension information + String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX); + String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX); + + CPOperand out = new CPOperand(parts[6]); + boolean ignoreZeros = Boolean.parseBoolean(parts[7]); + + // ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject + return new CtableCPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst); + } + + private Ctable.OperationTypes findCtableOperation() { + DataType dt1 = input1.getDataType(); + DataType dt2 = input2.getDataType(); + DataType dt3 = input3.getDataType(); + return Ctable.findCtableOperationByInputDataTypes(dt1, dt2, dt3); + } + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException { + + MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); + MatrixBlock matBlock2=null, wtBlock=null; + double cst1, cst2; + + CTableMap resultMap = new CTableMap(EntryType.INT); + MatrixBlock resultBlock = null; + Ctable.OperationTypes ctableOp = findCtableOperation(); + ctableOp = _isExpand ? Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp; + + long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue()); + long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue()); + + boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1); + if ( outputDimsKnown ) { + int inputRows = matBlock1.getNumRows(); + int inputCols = matBlock1.getNumColumns(); + boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols); + //only create result block if dense; it is important not to aggregate on sparse result + //blocks because it would implicitly turn the O(N) algorithm into O(N log N). + if( !sparse ) + resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false); + } + if( _isExpand ){ + resultBlock = new MatrixBlock( matBlock1.getNumRows(), Integer.MAX_VALUE, true ); + } + + switch(ctableOp) { + case CTABLE_TRANSFORM: //(VECTOR) + // F=ctable(A,B,W) + matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); + wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode()); + matBlock1.ctableOperations((SimpleOperator)_optr, matBlock2, wtBlock, resultMap, resultBlock); + break; + case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX) + // F = ctable(A,B) or F = ctable(A,B,1) + matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); + cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); + matBlock1.ctableOperations((SimpleOperator)_optr, matBlock2, cst1, _ignoreZeros, resultMap, resultBlock); + break; + case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR) + // F = ctable(seq,A) or F = ctable(seq,B,1) + matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); + cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); + // only resultBlock.rlen known, resultBlock.clen set in operation + matBlock1.ctableOperations((SimpleOperator)_optr, matBlock2, cst1, resultBlock); + break; + case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR) + // F=ctable(A,1) or F = ctable(A,1,1) + cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); + cst2 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); + matBlock1.ctableOperations((SimpleOperator)_optr, cst1, cst2, resultMap, resultBlock); + break; + case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR) + // F=ctable(A,1,W) + wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode()); + cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); + matBlock1.ctableOperations((SimpleOperator)_optr, cst1, wtBlock, resultMap, resultBlock); + break; + + default: + throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString()); + } + + if(input1.getDataType() == DataType.MATRIX) + ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); + if(input2.getDataType() == DataType.MATRIX) + ec.releaseMatrixInput(input2.getName(), getExtendedOpcode()); + if(input3.getDataType() == DataType.MATRIX) + ec.releaseMatrixInput(input3.getName(), getExtendedOpcode()); + + if ( resultBlock == null ){ + //we need to respect potentially specified output dimensions here, because we might have + //decided for hash-aggregation just to prevent inefficiency in case of sparse outputs. + if( outputDimsKnown ) + resultBlock = DataConverter.convertToMatrixBlock( resultMap, (int)outputDim1, (int)outputDim2 ); + else + resultBlock = DataConverter.convertToMatrixBlock( resultMap ); + } + else + resultBlock.examSparsity(); + + // Ensure right dense/sparse output representation for special cases + // such as ctable expand (guarded by released input memory) + if( checkGuardedRepresentationChange(matBlock1, matBlock2, resultBlock) ) { + resultBlock.examSparsity(); + } + + ec.setMatrixOutput(output.getName(), resultBlock, getExtendedOpcode()); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/cp/QuaternaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/QuaternaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/QuaternaryCPInstruction.java index b0bb392..ce6dd47 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/QuaternaryCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/QuaternaryCPInstruction.java @@ -29,7 +29,6 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator; @@ -122,7 +121,7 @@ public class QuaternaryCPInstruction extends ComputationCPInstruction { } //core execute - MatrixValue out = matBlock1.quaternaryOperations(qop, matBlock2, matBlock3, matBlock4, new MatrixBlock(), _numThreads); + MatrixBlock out = matBlock1.quaternaryOperations(qop, matBlock2, matBlock3, matBlock4, new MatrixBlock(), _numThreads); //release inputs and output ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); @@ -134,14 +133,14 @@ public class QuaternaryCPInstruction extends ComputationCPInstruction { if (input4.getDataType() == DataType.MATRIX) { ec.releaseMatrixInput(input4.getName(), getExtendedOpcode()); } - ec.setVariable(output.getName(), new DoubleObject(out.getValue(0, 0))); + ec.setVariable(output.getName(), new DoubleObject(out.quickGetValue(0, 0))); } else { //wsigmoid / wdivmm / wumm if( qop.wtype3 != null && qop.wtype3.hasFourInputs() ) if (input4.getDataType() == DataType.MATRIX) { ec.releaseMatrixInput(input4.getName(), getExtendedOpcode()); } - ec.setMatrixOutput(output.getName(), (MatrixBlock)out, getExtendedOpcode()); + ec.setMatrixOutput(output.getName(), out, getExtendedOpcode()); } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java deleted file mode 100644 index a927bba..0000000 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/TernaryCPInstruction.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.runtime.instructions.cp; - -import org.apache.sysml.lops.Ternary; -import org.apache.sysml.parser.Expression.DataType; -import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysml.runtime.instructions.Instruction; -import org.apache.sysml.runtime.instructions.InstructionUtils; -import org.apache.sysml.runtime.matrix.data.CTableMap; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.operators.Operator; -import org.apache.sysml.runtime.matrix.operators.SimpleOperator; -import org.apache.sysml.runtime.util.DataConverter; -import org.apache.sysml.runtime.util.LongLongDoubleHashMap.EntryType; - -public class TernaryCPInstruction extends ComputationCPInstruction { - private final String _outDim1; - private final String _outDim2; - private final boolean _dim1Literal; - private final boolean _dim2Literal; - private final boolean _isExpand; - private final boolean _ignoreZeros; - - private TernaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, - String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, - boolean ignoreZeros, String opcode, String istr) { - super(CPType.Ternary, op, in1, in2, in3, out, opcode, istr); - _outDim1 = outputDim1; - _dim1Literal = dim1Literal; - _outDim2 = outputDim2; - _dim2Literal = dim2Literal; - _isExpand = isExpand; - _ignoreZeros = ignoreZeros; - } - - public static TernaryCPInstruction parseInstruction(String inst) - throws DMLRuntimeException - { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst); - InstructionUtils.checkNumFields ( parts, 7 ); - - String opcode = parts[0]; - - //handle opcode - if ( !(opcode.equalsIgnoreCase("ctable") || opcode.equalsIgnoreCase("ctableexpand")) ) { - throw new DMLRuntimeException("Unexpected opcode in TertiaryCPInstruction: " + inst); - } - boolean isExpand = opcode.equalsIgnoreCase("ctableexpand"); - - //handle operands - CPOperand in1 = new CPOperand(parts[1]); - CPOperand in2 = new CPOperand(parts[2]); - CPOperand in3 = new CPOperand(parts[3]); - - //handle known dimension information - String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX); - String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX); - - CPOperand out = new CPOperand(parts[6]); - boolean ignoreZeros = Boolean.parseBoolean(parts[7]); - - // ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject - return new TernaryCPInstruction(new SimpleOperator(null), in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), isExpand, ignoreZeros, opcode, inst); - } - - private Ternary.OperationTypes findCtableOperation() { - DataType dt1 = input1.getDataType(); - DataType dt2 = input2.getDataType(); - DataType dt3 = input3.getDataType(); - return Ternary.findCtableOperationByInputDataTypes(dt1, dt2, dt3); - } - - @Override - public void processInstruction(ExecutionContext ec) - throws DMLRuntimeException { - - MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName(), getExtendedOpcode()); - MatrixBlock matBlock2=null, wtBlock=null; - double cst1, cst2; - - CTableMap resultMap = new CTableMap(EntryType.INT); - MatrixBlock resultBlock = null; - Ternary.OperationTypes ctableOp = findCtableOperation(); - ctableOp = _isExpand ? Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : ctableOp; - - long outputDim1 = (_dim1Literal ? (long) Double.parseDouble(_outDim1) : (ec.getScalarInput(_outDim1, ValueType.DOUBLE, false)).getLongValue()); - long outputDim2 = (_dim2Literal ? (long) Double.parseDouble(_outDim2) : (ec.getScalarInput(_outDim2, ValueType.DOUBLE, false)).getLongValue()); - - boolean outputDimsKnown = (outputDim1 != -1 && outputDim2 != -1); - if ( outputDimsKnown ) { - int inputRows = matBlock1.getNumRows(); - int inputCols = matBlock1.getNumColumns(); - boolean sparse = MatrixBlock.evalSparseFormatInMemory(outputDim1, outputDim2, inputRows*inputCols); - //only create result block if dense; it is important not to aggregate on sparse result - //blocks because it would implicitly turn the O(N) algorithm into O(N log N). - if( !sparse ) - resultBlock = new MatrixBlock((int)outputDim1, (int)outputDim2, false); - } - if( _isExpand ){ - resultBlock = new MatrixBlock( matBlock1.getNumRows(), Integer.MAX_VALUE, true ); - } - - switch(ctableOp) { - case CTABLE_TRANSFORM: //(VECTOR) - // F=ctable(A,B,W) - matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); - wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode()); - matBlock1.ternaryOperations((SimpleOperator)_optr, matBlock2, wtBlock, resultMap, resultBlock); - break; - case CTABLE_TRANSFORM_SCALAR_WEIGHT: //(VECTOR/MATRIX) - // F = ctable(A,B) or F = ctable(A,B,1) - matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); - cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); - matBlock1.ternaryOperations((SimpleOperator)_optr, matBlock2, cst1, _ignoreZeros, resultMap, resultBlock); - break; - case CTABLE_EXPAND_SCALAR_WEIGHT: //(VECTOR) - // F = ctable(seq,A) or F = ctable(seq,B,1) - matBlock2 = ec.getMatrixInput(input2.getName(), getExtendedOpcode()); - cst1 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); - // only resultBlock.rlen known, resultBlock.clen set in operation - matBlock1.ternaryOperations((SimpleOperator)_optr, matBlock2, cst1, resultBlock); - break; - case CTABLE_TRANSFORM_HISTOGRAM: //(VECTOR) - // F=ctable(A,1) or F = ctable(A,1,1) - cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); - cst2 = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()).getDoubleValue(); - matBlock1.ternaryOperations((SimpleOperator)_optr, cst1, cst2, resultMap, resultBlock); - break; - case CTABLE_TRANSFORM_WEIGHTED_HISTOGRAM: //(VECTOR) - // F=ctable(A,1,W) - wtBlock = ec.getMatrixInput(input3.getName(), getExtendedOpcode()); - cst1 = ec.getScalarInput(input2.getName(), input2.getValueType(), input2.isLiteral()).getDoubleValue(); - matBlock1.ternaryOperations((SimpleOperator)_optr, cst1, wtBlock, resultMap, resultBlock); - break; - - default: - throw new DMLRuntimeException("Encountered an invalid ctable operation ("+ctableOp+") while executing instruction: " + this.toString()); - } - - if(input1.getDataType() == DataType.MATRIX) - ec.releaseMatrixInput(input1.getName(), getExtendedOpcode()); - if(input2.getDataType() == DataType.MATRIX) - ec.releaseMatrixInput(input2.getName(), getExtendedOpcode()); - if(input3.getDataType() == DataType.MATRIX) - ec.releaseMatrixInput(input3.getName(), getExtendedOpcode()); - - if ( resultBlock == null ){ - //we need to respect potentially specified output dimensions here, because we might have - //decided for hash-aggregation just to prevent inefficiency in case of sparse outputs. - if( outputDimsKnown ) - resultBlock = DataConverter.convertToMatrixBlock( resultMap, (int)outputDim1, (int)outputDim2 ); - else - resultBlock = DataConverter.convertToMatrixBlock( resultMap ); - } - else - resultBlock.examSparsity(); - - // Ensure right dense/sparse output representation for special cases - // such as ctable expand (guarded by released input memory) - if( checkGuardedRepresentationChange(matBlock1, matBlock2, resultBlock) ) { - resultBlock.examSparsity(); - } - - ec.setMatrixOutput(output.getName(), resultBlock, getExtendedOpcode()); - } -} http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java index 2766062..1c3d8bc 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateBinaryInstruction.java @@ -28,6 +28,7 @@ import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; @@ -158,14 +159,14 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen //process instruction OperationsOnMatrixValues.performAggregateBinary( - in1.getIndexes(), in1.getValue(), - in2.getIndexes(), in2.getValue(), - out.getIndexes(), out.getValue(), - ((AggregateBinaryOperator)optr)); + in1.getIndexes(), (MatrixBlock) in1.getValue(), + in2.getIndexes(), (MatrixBlock) in2.getValue(), + out.getIndexes(), (MatrixBlock) out.getValue(), + ((AggregateBinaryOperator)optr)); //put the output value in the cache if(out==tempValue) - cachedValues.add(output, out); + cachedValues.add(output, out); } } @@ -192,13 +193,12 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen long in2_cols = dcInput.getNumCols(); long in2_colBlocks = (long)Math.ceil(((double)in2_cols)/dcInput.getNumColsPerBlock()); - for(int bidx=1; bidx <= in2_colBlocks; bidx++) - { + for(int bidx=1; bidx <= in2_colBlocks; bidx++) { // Matrix multiply A[i,k] %*% B[k,bid] // Setup input2 block IndexedMatrixValue in2Block = dcInput.getDataBlock((int)in1.getIndexes().getColumnIndex(), bidx); - + MatrixValue in2BlockValue = in2Block.getValue(); MatrixIndexes in2BlockIndex = in2Block.getIndexes(); @@ -206,10 +206,9 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass); //process instruction - OperationsOnMatrixValues.performAggregateBinary(in1.getIndexes(), in1.getValue(), - in2BlockIndex, in2BlockValue, out.getIndexes(), out.getValue(), - ((AggregateBinaryOperator)optr)); - + OperationsOnMatrixValues.performAggregateBinary(in1.getIndexes(), (MatrixBlock)in1.getValue(), + in2BlockIndex, (MatrixBlock) in2BlockValue, out.getIndexes(), (MatrixBlock)out.getValue(), + ((AggregateBinaryOperator)optr)); removeOutput &= ( !_outputEmptyBlocks && out.getValue().isEmpty() ); } } @@ -226,7 +225,7 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen // Setup input2 block IndexedMatrixValue in1Block = dcInput.getDataBlock(bidx, (int)in2.getIndexes().getRowIndex()); - + MatrixValue in1BlockValue = in1Block.getValue(); MatrixIndexes in1BlockIndex = in1Block.getIndexes(); @@ -234,14 +233,14 @@ public class AggregateBinaryInstruction extends BinaryMRInstructionBase implemen IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass); //process instruction - OperationsOnMatrixValues.performAggregateBinary(in1BlockIndex, in1BlockValue, - in2.getIndexes(), in2.getValue(), - out.getIndexes(), out.getValue(), - ((AggregateBinaryOperator)optr)); - + OperationsOnMatrixValues.performAggregateBinary( + in1BlockIndex, (MatrixBlock)in1BlockValue, + in2.getIndexes(), (MatrixBlock)in2.getValue(), + out.getIndexes(), (MatrixBlock)out.getValue(), + ((AggregateBinaryOperator)optr)); removeOutput &= ( !_outputEmptyBlocks && out.getValue().isEmpty() ); } - } + } //empty block output filter (enabled by compiler consumer operation is in CP) if( removeOutput ) http://git-wip-us.apache.org/repos/asf/systemml/blob/57438103/src/main/java/org/apache/sysml/runtime/instructions/mr/CombineTernaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/CombineTernaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/CombineTernaryInstruction.java index 4bfc94e..7db8616 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/CombineTernaryInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/CombineTernaryInstruction.java @@ -19,14 +19,14 @@ package org.apache.sysml.runtime.instructions.mr; -import org.apache.sysml.lops.Ternary.OperationTypes; +import org.apache.sysml.lops.Ctable.OperationTypes; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixValue; import org.apache.sysml.runtime.matrix.mapred.CachedValueMap; import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; -public class CombineTernaryInstruction extends TernaryInstruction { +public class CombineTernaryInstruction extends CtableInstruction { private CombineTernaryInstruction(OperationTypes op, byte in1, byte in2, byte in3, byte out, String istr) { super(MRType.CombineTernary, op, in1, in2, in3, out, -1, -1, istr);
