Repository: incubator-systemml Updated Branches: refs/heads/master 01d9fdb45 -> b0d3c6c85
[SYSTEMML-766] New simplification rewrite/runtime axpy (+*, -*) Closes #179. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b0d3c6c8 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b0d3c6c8 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b0d3c6c8 Branch: refs/heads/master Commit: b0d3c6c85135c51177dbe67c4f944b1bf7dcc498 Parents: 01d9fdb Author: tgamal <[email protected]> Authored: Sat Jul 16 17:08:50 2016 -0700 Committer: Matthias Boehm <[email protected]> Committed: Sat Jul 16 17:08:50 2016 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 4 +- .../java/org/apache/sysml/hops/TernaryOp.java | 36 +++- .../RewriteAlgebraicSimplificationStatic.java | 45 ++++- src/main/java/org/apache/sysml/lops/Lop.java | 1 + .../java/org/apache/sysml/lops/PlusMult.java | 107 ++++++++++++ .../runtime/functionobjects/MinusMultiply.java | 43 +++++ .../runtime/functionobjects/PlusMultiply.java | 43 +++++ .../ValueFunctionWithConstant.java | 38 +++++ .../instructions/CPInstructionParser.java | 12 +- .../instructions/SPInstructionParser.java | 16 +- .../instructions/cp/PlusMultCPInstruction.java | 64 ++++++++ .../spark/PlusMultSPInstruction.java | 87 ++++++++++ .../misc/RewriteFuseBinaryOpChainTest.java | 164 +++++++++++++++++++ .../RewriteSimplifyRowColSumMVMultTest.java | 8 +- .../misc/RewriteFuseBinaryOpChainTest1.R | 28 ++++ .../misc/RewriteFuseBinaryOpChainTest1.dml | 27 +++ .../misc/RewriteFuseBinaryOpChainTest2.R | 28 ++++ .../misc/RewriteFuseBinaryOpChainTest2.dml | 27 +++ .../functions/misc/ZPackageSuite.java | 2 + 19 files changed, 764 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 8c4999e..144ca20 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -1064,7 +1064,7 @@ public abstract class Hop // Operations that require 3 operands public enum OpOp3 { - QUANTILE, INTERQUANTILE, CTABLE, CENTRALMOMENT, COVARIANCE, INVALID + QUANTILE, INTERQUANTILE, CTABLE, CENTRALMOMENT, COVARIANCE, INVALID, PLUS_MULT, MINUS_MULT }; // Operations that require 4 operands @@ -1416,6 +1416,8 @@ public abstract class Hop HopsOpOp3String.put(OpOp3.CTABLE, "ctable"); HopsOpOp3String.put(OpOp3.CENTRALMOMENT, "cm"); HopsOpOp3String.put(OpOp3.COVARIANCE, "cov"); + HopsOpOp3String.put(OpOp3.PLUS_MULT, "+*"); + HopsOpOp3String.put(OpOp3.MINUS_MULT, "-*"); } protected static final HashMap<Hop.OpOp4, String> HopsOpOp4String; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/hops/TernaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java index e353273..72e7624 100644 --- a/src/main/java/org/apache/sysml/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java @@ -30,6 +30,7 @@ import org.apache.sysml.lops.Group; import org.apache.sysml.lops.Lop; import org.apache.sysml.lops.LopsException; import org.apache.sysml.lops.PickByCount; +import org.apache.sysml.lops.PlusMult; import org.apache.sysml.lops.SortKeys; import org.apache.sysml.lops.Ternary; import org.apache.sysml.lops.UnaryCP; @@ -138,6 +139,11 @@ public class TernaryOp extends Hop case CTABLE: constructLopsCtable(); break; + + case PLUS_MULT: + case MINUS_MULT: + constructLopsPlusMult(); + break; default: throw new HopsException(this.printErrorLocation() + "Unknown TernaryOp (" + _op + ") while constructing Lops \n"); @@ -621,7 +627,16 @@ public class TernaryOp extends Hop } } } - + private void constructLopsPlusMult() throws HopsException, LopsException { + if ( _op != OpOp3.PLUS_MULT && _op != OpOp3.MINUS_MULT ) + throw new HopsException("Unexpected operation: " + _op + ", expecting " + OpOp3.PLUS_MULT + " or" + OpOp3.MINUS_MULT); + + ExecType et = optFindExecType(); + PlusMult plusmult = new PlusMult(getInput().get(0).constructLops(),getInput().get(1).constructLops(),getInput().get(2).constructLops(), _op, getDataType(),getValueType(), et ); + setOutputDimensions(plusmult); + setLineNumbers(plusmult); + setLops(plusmult); + } @Override public String getOpString() { String s = new String(""); @@ -667,7 +682,10 @@ public class TernaryOp extends Hop // This part of the code is executed only when a vector of quantiles are computed // Output is a vector of length = #of quantiles to be computed, and it is likely to be dense. return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0); - + case PLUS_MULT: + case MINUS_MULT: + sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); default: throw new RuntimeException("Memory for operation (" + _op + ") can not be estimated."); } @@ -742,7 +760,12 @@ public class TernaryOp extends Hop if( mc[2].dimsKnown() ) return new long[]{mc[2].getRows(), 1, mc[2].getRows()}; break; - + case PLUS_MULT: + case MINUS_MULT: + //compute back NNz + double sp1 = OptimizerUtils.getSparsity(mc[0].getRows(), mc[0].getRows(), mc[0].getNonZeros()); + double sp2 = OptimizerUtils.getSparsity(mc[2].getRows(), mc[2].getRows(), mc[2].getNonZeros()); + return new long[]{mc[0].getRows(), mc[0].getCols(), (long) Math.min(sp1+sp2,1)}; default: throw new RuntimeException("Memory for operation (" + _op + ") can not be estimated."); } @@ -845,7 +868,12 @@ public class TernaryOp extends Hop // Output is a vector of length = #of quantiles to be computed, and it is likely to be dense. // TODO qx1 break; - + + case PLUS_MULT: + case MINUS_MULT: + setDim1( getInput().get(0)._dim1 ); + setDim2( getInput().get(0)._dim2 ); + break; default: throw new RuntimeException("Size information for operation (" + _op + ") can not be updated."); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index e903a03..43d5791 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -25,6 +25,8 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.hops.AggBinaryOp; import org.apache.sysml.hops.AggUnaryOp; import org.apache.sysml.hops.BinaryOp; @@ -32,6 +34,7 @@ import org.apache.sysml.hops.DataGenOp; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.Hop.OpOp1; import org.apache.sysml.hops.IndexingOp; +import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.TernaryOp; import org.apache.sysml.hops.UnaryOp; import org.apache.sysml.hops.Hop.AggOp; @@ -162,8 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5) hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false) hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) - - + hi = fuseBinaryOperationChain(hop, hi, i); //e.g., X + lamda*Y -> X +* lambda Y //hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X)) //process childs recursively after rewrites (to investigate pattern newly created by rewrites) @@ -174,7 +176,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hop.setVisited(Hop.VisitStatus.DONE); } - /** * * @param hi @@ -1908,4 +1909,42 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + + /** + * + * @param parent + * @param hi + * @param pos + * @return + * @throws HopsException + */ + private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) { + //pattern: X + lamda*Y -> X +* lambda Y + if( hi instanceof BinaryOp + && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) + && ((BinaryOp)hi).getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)hi).getInput().get(1) instanceof BinaryOp + && (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) ) + { + //Check that the inner binary Op is a product of Scalar times Matrix or viceversa + Hop innerBinaryOp = ((BinaryOp)hi).getInput().get(1); + if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) + || (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR)) + { + //check which operand is the Scalar and which is the matrix + Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); + Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); + + OpOp3 operator = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT; + TernaryOp ternOp=new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, operator, ((BinaryOp)hi).getInput().get(0), lamda, matrix); + + HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos); + HopRewriteUtils.addChildReference(parent, ternOp, pos); + + LOG.debug("Applied fuseBinaryOperationChain. (line " +hi.getBeginLine()+")"); + return ternOp; + } + } + return hi; + + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/lops/Lop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java index d930da0..412e6d4 100644 --- a/src/main/java/org/apache/sysml/lops/Lop.java +++ b/src/main/java/org/apache/sysml/lops/Lop.java @@ -59,6 +59,7 @@ public abstract class Lop WeightedSquaredLoss, WeightedSigmoid, WeightedDivMM, WeightedCeMM, WeightedUMM, SortKeys, PickValues, Checkpoint, //Spark persist into storage level + PlusMult, MinusMult, //CP }; /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/lops/PlusMult.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PlusMult.java b/src/main/java/org/apache/sysml/lops/PlusMult.java new file mode 100644 index 0000000..2dc16e9 --- /dev/null +++ b/src/main/java/org/apache/sysml/lops/PlusMult.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.lops; + +import org.apache.sysml.hops.Hop.OpOp3; +import org.apache.sysml.lops.LopProperties.ExecLocation; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.lops.compile.JobType; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.parser.Expression.*; + + +/** + * Lop to perform Sum of a matrix with another matrix multiplied by Scalar. + */ +public class PlusMult extends Lop +{ + + private void init(Lop input1, Lop input2, Lop input3, ExecType et) { + this.addInput(input1); + this.addInput(input2); + this.addInput(input3); + input1.addOutput(this); + input2.addOutput(this); + input3.addOutput(this); + + boolean breaksAlignment = false; + boolean aligner = false; + boolean definesMRJob = false; + + if ( et == ExecType.CP || et == ExecType.SPARK ){ + lps.addCompatibility(JobType.INVALID); + this.lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob ); + } + } + + public PlusMult(Lop input1, Lop input2, Lop input3, OpOp3 op, DataType dt, ValueType vt, ExecType et) { + super(Lop.Type.PlusMult, dt, vt); + if(op == OpOp3.MINUS_MULT) + type=Lop.Type.MinusMult; + init(input1, input2, input3, et); + } + + @Override + public String toString() { + + return "Operation = PlusMult"; + } + + + /** + * Function to generate CP Sum of a matrix with another matrix multiplied by Scalar. + * + * input1: matrix1 + * input2: Scalar + * input3: matrix2 + */ + @Override + public String getInstructions(String input1, String input2, String input3, String output) { + StringBuilder sb = new StringBuilder(); + sb.append( getExecType() ); + sb.append( OPERAND_DELIMITOR ); + if(type==Lop.Type.PlusMult) + sb.append( "+*" ); + else + sb.append( "-*" ); + sb.append( OPERAND_DELIMITOR ); + + // Matrix1 + sb.append( getInputs().get(0).prepInputOperand(input1) ); + sb.append( OPERAND_DELIMITOR ); + + // Matrix2 + sb.append( getInputs().get(1).prepScalarInputOperand(input2) ); + sb.append( OPERAND_DELIMITOR ); + + // Scalar + sb.append( getInputs().get(2).prepInputOperand(input3)); + sb.append( OPERAND_DELIMITOR ); + + sb.append( prepOutputOperand(output)); + + return sb.toString(); + } + + + + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java new file mode 100644 index 0000000..ee7a8fb --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.functionobjects; + +import java.io.Serializable; + +public class MinusMultiply extends ValueFunctionWithConstant implements Serializable +{ + + private static final long serialVersionUID = 2801982061205871665L; + + public MinusMultiply() { + // nothing to do here + } + public Object clone() throws CloneNotSupportedException { + // cloning is not supported for singleton classes + throw new CloneNotSupportedException(); + } + @Override + public double execute(double in1, double in2) + { + return in1 - _constant*in2; + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java new file mode 100644 index 0000000..87eb47b --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.functionobjects; + +import java.io.Serializable; + +public class PlusMultiply extends ValueFunctionWithConstant implements Serializable +{ + + private static final long serialVersionUID = 2801982061205871665L; + + public PlusMultiply() { + // nothing to do here + } + public Object clone() throws CloneNotSupportedException { + // cloning is not supported for singleton classes + throw new CloneNotSupportedException(); + } + @Override + public double execute(double in1, double in2) + { + return in1 + _constant*in2; + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java new file mode 100644 index 0000000..2820875 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/functionobjects/ValueFunctionWithConstant.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.functionobjects; + +import java.io.Serializable; + +public abstract class ValueFunctionWithConstant extends ValueFunction implements Serializable +{ + private static final long serialVersionUID = -4985988545393861058L; + protected double _constant; + + public void setConstant(double constant) + { + _constant = constant; + } + + public double getConstant() + { + return _constant; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java index f3a1621..c91ad8c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java @@ -51,6 +51,7 @@ import org.apache.sysml.runtime.instructions.cp.MultiReturnBuiltinCPInstruction; import org.apache.sysml.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction; import org.apache.sysml.runtime.instructions.cp.PMMJCPInstruction; import org.apache.sysml.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; +import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuantilePickCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuantileSortCPInstruction; import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction; @@ -63,6 +64,7 @@ import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction; import org.apache.sysml.runtime.instructions.cpfile.ParameterizedBuiltinCPFileInstruction; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; public class CPInstructionParser extends InstructionParser { @@ -120,7 +122,9 @@ public class CPInstructionParser extends InstructionParser String2CPInstructionType.put( "^2" , CPINSTRUCTION_TYPE.ArithmeticBinary); //special ^ case String2CPInstructionType.put( "*2" , CPINSTRUCTION_TYPE.ArithmeticBinary); //special * case String2CPInstructionType.put( "-nz" , CPINSTRUCTION_TYPE.ArithmeticBinary); //special - case - + String2CPInstructionType.put( "+*" , CPINSTRUCTION_TYPE.ArithmeticBinary); + String2CPInstructionType.put( "-*" , CPINSTRUCTION_TYPE.ArithmeticBinary); + // Boolean Instruction Opcodes String2CPInstructionType.put( "&&" , CPINSTRUCTION_TYPE.BooleanBinary); @@ -306,7 +310,11 @@ public class CPInstructionParser extends InstructionParser return AggregateTernaryCPInstruction.parseInstruction(str); case ArithmeticBinary: - return ArithmeticBinaryCPInstruction.parseInstruction(str); + String opcode = InstructionUtils.getOpCode(str); + if( opcode.equals("+*") || opcode.equals("-*") ) + return PlusMultCPInstruction.parseInstruction(str); + else + return ArithmeticBinaryCPInstruction.parseInstruction(str); case Ternary: return TernaryCPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index e0c4631..a9a34f5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -34,6 +34,9 @@ import org.apache.sysml.lops.WeightedSquaredLossR; import org.apache.sysml.lops.WeightedUnaryMM; import org.apache.sysml.lops.WeightedUnaryMMR; import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.instructions.cp.ArithmeticBinaryCPInstruction; +import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction; +import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE; import org.apache.sysml.runtime.instructions.spark.AggregateTernarySPInstruction; import org.apache.sysml.runtime.instructions.spark.AggregateUnarySPInstruction; import org.apache.sysml.runtime.instructions.spark.AppendGAlignedSPInstruction; @@ -59,6 +62,7 @@ import org.apache.sysml.runtime.instructions.spark.MatrixReshapeSPInstruction; import org.apache.sysml.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction; import org.apache.sysml.runtime.instructions.spark.PMapmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.ParameterizedBuiltinSPInstruction; +import org.apache.sysml.runtime.instructions.spark.PlusMultSPInstruction; import org.apache.sysml.runtime.instructions.spark.PmmSPInstruction; import org.apache.sysml.runtime.instructions.spark.QuantilePickSPInstruction; import org.apache.sysml.runtime.instructions.spark.QuaternarySPInstruction; @@ -148,7 +152,9 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "1-*" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "^" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "^2" , SPINSTRUCTION_TYPE.ArithmeticBinary); - String2SPInstructionType.put( "*2" , SPINSTRUCTION_TYPE.ArithmeticBinary); + String2SPInstructionType.put( "*2" , SPINSTRUCTION_TYPE.ArithmeticBinary); + String2SPInstructionType.put( "+*" , SPINSTRUCTION_TYPE.ArithmeticBinary); + String2SPInstructionType.put( "-*" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "map+" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "map-" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "map*" , SPINSTRUCTION_TYPE.ArithmeticBinary); @@ -157,6 +163,8 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "map%/%" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "map1-*" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "map^" , SPINSTRUCTION_TYPE.ArithmeticBinary); + String2SPInstructionType.put( "map+*" , SPINSTRUCTION_TYPE.ArithmeticBinary); + String2SPInstructionType.put( "map-*" , SPINSTRUCTION_TYPE.ArithmeticBinary); String2SPInstructionType.put( "map>" , SPINSTRUCTION_TYPE.RelationalBinary); String2SPInstructionType.put( "map>=" , SPINSTRUCTION_TYPE.RelationalBinary); String2SPInstructionType.put( "map<" , SPINSTRUCTION_TYPE.RelationalBinary); @@ -326,7 +334,11 @@ public class SPInstructionParser extends InstructionParser return ReorgSPInstruction.parseInstruction(str); case ArithmeticBinary: - return ArithmeticBinarySPInstruction.parseInstruction(str); + String opcode = InstructionUtils.getOpCode(str); + if( opcode.equals("+*") || opcode.equals("-*") ) + return PlusMultSPInstruction.parseInstruction(str); + else + return ArithmeticBinarySPInstruction.parseInstruction(str); case RelationalBinary: return RelationalBinarySPInstruction.parseInstruction(str); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java new file mode 100644 index 0000000..8b01cb7 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/PlusMultCPInstruction.java @@ -0,0 +1,64 @@ +package org.apache.sysml.runtime.instructions.cp; + +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.functionobjects.CM; +import org.apache.sysml.runtime.functionobjects.MinusMultiply; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; +import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.runtime.matrix.operators.CMOperator; +import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes; + +public class PlusMultCPInstruction extends ArithmeticBinaryCPInstruction { + public PlusMultCPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2, + CPOperand in3, CPOperand out, String opcode, String str) + { + super(op, in1, in2, out, opcode, str); + input3=in3; + } + public static PlusMultCPInstruction parseInstruction(String str) + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode=parts[0]; + CPOperand operand1 = new CPOperand(parts[1]); + CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier + CPOperand operand3 = new CPOperand(parts[2]); + CPOperand outOperand = new CPOperand(parts[4]); + BinaryOperator bOperator = null; + if(opcode.equals("+*")) + bOperator = new BinaryOperator(new PlusMultiply()); + else if (opcode.equals("-*")) + bOperator = new BinaryOperator(new MinusMultiply()); + return new PlusMultCPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str); + + } + @Override + public void processInstruction( ExecutionContext ec ) + throws DMLRuntimeException + { + + String output_name = output.getName(); + + //get all the inputs + MatrixBlock matrix1 = ec.getMatrixInput(input1.getName()); + MatrixBlock matrix2 = ec.getMatrixInput(input2.getName()); + ScalarObject lambda = ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); + + + //execution + ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(lambda.getDoubleValue()); + MatrixBlock out = (MatrixBlock) matrix1.binaryOperations((BinaryOperator) _optr, matrix2, new MatrixBlock()); + + //release the matrices + ec.releaseMatrixInput(input1.getName()); + ec.releaseMatrixInput(input2.getName()); + + ec.setMatrixOutput(output_name, out); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java new file mode 100644 index 0000000..89de821 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/PlusMultSPInstruction.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.MinusMultiply; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; +import org.apache.sysml.runtime.functionobjects.ValueFunctionWithConstant; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.PlusMultCPInstruction; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction; +import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.runtime.matrix.operators.Operator; + +public class PlusMultSPInstruction extends ArithmeticBinarySPInstruction +{ + public PlusMultSPInstruction(BinaryOperator op, CPOperand in1, CPOperand in2, + CPOperand in3, CPOperand out, String opcode, String str) throws DMLRuntimeException + { + super(op, in1, in2, out, opcode, str); + input3= in3; + + //sanity check opcodes + if ( !( opcode.equalsIgnoreCase("+*") || opcode.equalsIgnoreCase("-*") ) ) + { + throw new DMLRuntimeException("Unknown opcode in PlusMultSPInstruction: " + toString()); + } + } + public static PlusMultSPInstruction parseInstruction(String str) throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode=parts[0]; + CPOperand operand1 = new CPOperand(parts[1]); + CPOperand operand2 = new CPOperand(parts[3]); //put the second matrix (parts[3]) in Operand2 to make using Binary matrix operations easier + CPOperand operand3 = new CPOperand(parts[2]); + CPOperand outOperand = new CPOperand(parts[4]); + BinaryOperator bOperator = null; + if(opcode.equals("+*")) + bOperator = new BinaryOperator(new PlusMultiply()); + else if (opcode.equals("-*")) + bOperator = new BinaryOperator(new MinusMultiply()); + return new PlusMultSPInstruction(bOperator,operand1, operand2, operand3, outOperand, opcode,str); + } + + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + SparkExecutionContext sec = (SparkExecutionContext)ec; + + //pass the scalar + ScalarObject constant = (ScalarObject) ec.getScalarInput(input3.getName(), input3.getValueType(), input3.isLiteral()); + ((ValueFunctionWithConstant) ((BinaryOperator)_optr).fn).setConstant(constant.getDoubleValue()); + + super.processMatrixMatrixBinaryInstruction(sec); + + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java new file mode 100644 index 0000000..e010083 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.misc; + +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; +import org.apache.sysml.utils.Statistics; + +/** + * Regression test for function recompile-once issue with literal replacement. + * + */ +public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase +{ + + private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; + private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; + + private static final String TEST_DIR = "functions/misc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/"; + + //private static final int rows = 1234; + //private static final int cols = 567; + private static final double eps = Math.pow(10, -10); + + @Override + public void setUp() + { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + } + + @Test + public void testFuseBinaryPlusNoRewrite() + { + testFuseBinaryChain( TEST_NAME1, false, ExecType.CP ); + } + + + @Test + public void testFuseBinaryPlusRewrite() + { + testFuseBinaryChain( TEST_NAME1, true, ExecType.CP); + } + @Test + public void testFuseBinaryMinusNoRewrite() + { + testFuseBinaryChain( TEST_NAME2, false, ExecType.CP ); + } + + @Test + public void testFuseBinaryMinusRewrite() + { + testFuseBinaryChain( TEST_NAME2, true, ExecType.CP ); + } + + + + @Test + public void testSpFuseBinaryPlusNoRewrite() + { + testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK ); + } + + + @Test + public void testSpFuseBinaryPlusRewrite() + { + testFuseBinaryChain( TEST_NAME1, true, ExecType.SPARK ); + } + + + @Test + public void testSpFuseBinaryMinusNoRewrite() + { + testFuseBinaryChain( TEST_NAME2, false, ExecType.SPARK ); + } + + @Test + public void testSpFuseBinaryMinusRewrite() + { + testFuseBinaryChain( TEST_NAME2, true, ExecType.SPARK ); + } + + + /** + * + * @param condition + * @param branchRemoval + * @param IPA + */ + private void testFuseBinaryChain( String testname, boolean rewrites, ExecType instType ) + { + RUNTIME_PLATFORM platformOld = rtplatform; + switch( instType ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + boolean rewritesOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + try + { + + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{"-explain", "-stats","-args", output("S") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), expectedDir()); + + runTest(true, false, null, -1); + runRScript(true); + + //compare matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("S"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("S"); + Assert.assertTrue(TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R")); + } + finally + { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewritesOld; + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java index d68d3b7..2829bab 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteSimplifyRowColSumMVMultTest.java @@ -56,25 +56,25 @@ public class RewriteSimplifyRowColSumMVMultTest extends AutomatedTestBase } @Test - public void testRewriteRowSumsMVMultNoRewrite() + public void testMultiScalarToBinaryNoRewrite() { testRewriteRowColSumsMVMult( TEST_NAME1, false ); } @Test - public void testRewriteRowSumsMVMultRewrite() + public void testMultiScalarToBinaryRewrite() { testRewriteRowColSumsMVMult( TEST_NAME1, true ); } @Test - public void testRewriteColSumsMVMultNoRewrite() + public void testMultiBinaryToScalarNoRewrite() { testRewriteRowColSumsMVMult( TEST_NAME2, false ); } @Test - public void testRewriteColSumsMVMultRewrite() + public void testMultiBinaryToScalarRewrite() { testRewriteRowColSumsMVMult( TEST_NAME2, true ); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R new file mode 100644 index 0000000..c34948c --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.R @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X=matrix(1,10,10) +Y=matrix(1,10,10) +lamda=7 +S=X+lamda*Y +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml new file mode 100644 index 0000000..077b8a9 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest1.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1,rows=10,cols=10) +Y=matrix(1,rows=10,cols=10) +if(1==1){} +lamda=7 +S=X+lamda*Y +write(S,$1) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R new file mode 100644 index 0000000..1caff09 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.R @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X=matrix(1,10,10) +Y=matrix(1,10,10) +lamda=7 +S=X-lamda*Y +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml new file mode 100644 index 0000000..f3c6b9a --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest2.dml @@ -0,0 +1,27 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X=matrix(1,rows=10,cols=10) +Y=matrix(1,rows=10,cols=10) +if(1==1){} +lamda=7 +S=X-lamda*Y +write(S,$1) \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b0d3c6c8/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java index 4720595..6c40dd7 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java @@ -49,8 +49,10 @@ import org.junit.runners.Suite; RewriteFusedRandTest.class, RewritePushdownSumOnBinaryTest.class, RewritePushdownUaggTest.class, + RewritePushdownSumBinaryMult.class, RewriteSimplifyRowColSumMVMultTest.class, RewriteSlicedMatrixMultTest.class, + RewriteFuseBinaryOpChainTest.class, ScalarAssignmentTest.class, ScalarFunctionTest.class, ScalarMatrixUnaryBinaryTermTest.class,
