Repository: systemml Updated Branches: refs/heads/master c1ed79150 -> 75b93f261
http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/QuaternaryOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/QuaternaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/QuaternaryOperator.java index aa945c2..0107be0 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/QuaternaryOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/QuaternaryOperator.java @@ -32,26 +32,34 @@ import org.apache.sysml.runtime.functionobjects.ValueFunction; public class QuaternaryOperator extends Operator { - private static final long serialVersionUID = -1642908613016116069L; - public WeightsType wtype1 = null; - public WSigmoidType wtype2 = null; - public WDivMMType wtype3 = null; - public WCeMMType wtype4 = null; - public WUMMType wtype5 = null; - - public ValueFunction fn; + public final WeightsType wtype1; + public final WSigmoidType wtype2; + public final WDivMMType wtype3; + public final WCeMMType wtype4; + public final WUMMType wtype5; - private double eps = 0; + public final ValueFunction fn; + private final double eps; + private QuaternaryOperator( WeightsType wt1, WSigmoidType wt2, WDivMMType wt3, WCeMMType wt4, WUMMType wt5, ValueFunction fn, double eps ) { + wtype1 = wt1; + wtype2 = wt2; + wtype3 = wt3; + wtype4 = wt4; + wtype5 = wt5; + this.fn = fn; + this.eps = eps; + } + /** * wsloss * * @param wt Weights type */ public QuaternaryOperator( WeightsType wt ) { - wtype1 = wt; + this(wt, null, null, null, null, null, 0); } /** @@ -60,8 +68,7 @@ public class QuaternaryOperator extends Operator * @param wt WSigmoid type */ public QuaternaryOperator( WSigmoidType wt ) { - wtype2 = wt; - fn = Builtin.getBuiltinFnObject("sigmoid"); + this(null, wt, null, null, null, Builtin.getBuiltinFnObject("sigmoid"), 0); } /** @@ -70,7 +77,7 @@ public class QuaternaryOperator extends Operator * @param wt WDivMM type */ public QuaternaryOperator( WDivMMType wt ) { - wtype3 = wt; + this(null, null, wt, null, null, null, 0); } /** @@ -80,8 +87,7 @@ public class QuaternaryOperator extends Operator * @param epsilon the epsilon value */ public QuaternaryOperator( WDivMMType wt, double epsilon) { - wtype3 = wt; - eps = epsilon; + this(null, null, wt, null, null, null, epsilon); } /** @@ -90,7 +96,7 @@ public class QuaternaryOperator extends Operator * @param wt WCeMM type */ public QuaternaryOperator( WCeMMType wt ) { - wtype4 = wt; + this(null, null, null, wt, null, null, 0); } /** @@ -100,8 +106,7 @@ public class QuaternaryOperator extends Operator * @param epsilon the epsilon value */ public QuaternaryOperator( WCeMMType wt, double epsilon) { - wtype4 = wt; - eps = epsilon; + this(null, null, null, wt, null, null, epsilon); } /** @@ -111,14 +116,10 @@ public class QuaternaryOperator extends Operator * @param op operator type */ public QuaternaryOperator( WUMMType wt, String op ) { - wtype5 = wt; - - if( op.equals("^2") ) - fn = Power2.getPower2FnObject(); - else if( op.equals("*2") ) - fn = Multiply2.getMultiply2FnObject(); - else - fn = Builtin.getBuiltinFnObject(op); + this(null, null, null, null, wt, + op.equals("^2") ? Power2.getPower2FnObject() : + op.equals("*2") ? Multiply2.getMultiply2FnObject() : + Builtin.getBuiltinFnObject(op), 0); } public boolean hasFourInputs() { @@ -135,5 +136,4 @@ public class QuaternaryOperator extends Operator public double getScalar() { return eps; } - } http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/ReIndexOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/ReIndexOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/ReIndexOperator.java index 1376112..37df92f 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/ReIndexOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/ReIndexOperator.java @@ -22,11 +22,9 @@ package org.apache.sysml.runtime.matrix.operators; public class ReIndexOperator extends Operator { - private static final long serialVersionUID = 8603367674384408297L; - public ReIndexOperator() - { - sparseSafe=true; + public ReIndexOperator() { + super(true); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/ReorgOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/ReorgOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/ReorgOperator.java index 54c346f..fcdbe9d 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/ReorgOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/ReorgOperator.java @@ -28,8 +28,8 @@ public class ReorgOperator extends Operator implements Serializable { private static final long serialVersionUID = -5322516429026298404L; - public IndexFunction fn; - private int k; //num threads + public final IndexFunction fn; + private final int k; //num threads public ReorgOperator(IndexFunction p) { //default degree of parallelism is 1 @@ -38,12 +38,16 @@ public class ReorgOperator extends Operator implements Serializable } public ReorgOperator(IndexFunction p, int numThreads) { + super(true); fn = p; - sparseSafe = true; k = numThreads; } public int getNumThreads() { return k; } + + public ReorgOperator setFn(IndexFunction fn) { + return new ReorgOperator(fn, k); + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java index 5a75c32..9458270 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/RightScalarOperator.java @@ -38,22 +38,17 @@ public class RightScalarOperator extends ScalarOperator private static final long serialVersionUID = 5148300801904349919L; public RightScalarOperator(ValueFunction p, double cst) { - super(p, cst); + super(p, cst, (p instanceof GreaterThan && cst>=0) + || (p instanceof GreaterThanEquals && cst>0) + || (p instanceof LessThan && cst<=0) + || (p instanceof LessThanEquals && cst<0) + || (p instanceof Divide && cst!=0) + || (p instanceof Power && cst!=0)); } @Override - public void setConstant(double cst) - { - super.setConstant(cst); - - //enable conditionally sparse safe operations - sparseSafe |= (isSparseSafeStatic() - || (fn instanceof GreaterThan && _constant>=0) - || (fn instanceof GreaterThanEquals && _constant>0) - || (fn instanceof LessThan && _constant<=0) - || (fn instanceof LessThanEquals && _constant<0) - || (fn instanceof Divide && _constant!=0) - || (fn instanceof Power && _constant!=0)); + public ScalarOperator setConstant(double cst) { + return new RightScalarOperator(fn, cst); } @Override http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java index 5b28f7d..2c5885b 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/ScalarOperator.java @@ -42,28 +42,29 @@ public abstract class ScalarOperator extends Operator { private static final long serialVersionUID = 4547253761093455869L; - public ValueFunction fn; - protected double _constant; + public final ValueFunction fn; + protected final double _constant; public ScalarOperator(ValueFunction p, double cst) { + this(p, cst, false); + } + + protected ScalarOperator(ValueFunction p, double cst, boolean altSparseSafe) { + super( isSparseSafeStatic(p) || altSparseSafe + || (p instanceof NotEquals && cst==0) + || (p instanceof Equals && cst!=0) + || (p instanceof Minus && cst==0) + || (p instanceof Builtin && ((Builtin)p).getBuiltinCode()==BuiltinCode.MAX && cst<=0) + || (p instanceof Builtin && ((Builtin)p).getBuiltinCode()==BuiltinCode.MIN && cst>=0)); fn = p; - //set constant and sparse safe flag - setConstant(cst); + _constant = cst; } public double getConstant() { return _constant; } - public void setConstant(double cst) { - _constant = cst; - sparseSafe = (isSparseSafeStatic() - || (fn instanceof NotEquals && _constant==0) - || (fn instanceof Equals && _constant!=0) - || (fn instanceof Minus && _constant==0) - || (fn instanceof Builtin && ((Builtin)fn).getBuiltinCode()==BuiltinCode.MAX && _constant<=0) - || (fn instanceof Builtin && ((Builtin)fn).getBuiltinCode()==BuiltinCode.MIN && _constant>=0)); - } + public abstract ScalarOperator setConstant(double cst); /** * Apply the scalar operator over a given input value. @@ -81,7 +82,7 @@ public abstract class ScalarOperator extends Operator * * @return true if function statically sparse safe */ - protected boolean isSparseSafeStatic() { + protected static boolean isSparseSafeStatic(ValueFunction fn) { return ( fn instanceof Multiply || fn instanceof Multiply2 || fn instanceof Power2 || fn instanceof And || fn instanceof MinusNz || fn instanceof Builtin && ((Builtin)fn).getBuiltinCode()==BuiltinCode.LOG_NZ); http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/SimpleOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/SimpleOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/SimpleOperator.java index bbd7a89..a28bc8e 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/SimpleOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/SimpleOperator.java @@ -27,10 +27,9 @@ import org.apache.sysml.runtime.functionobjects.FunctionObject; */ public class SimpleOperator extends Operator { - private static final long serialVersionUID = 625147299273287379L; - public FunctionObject fn; + public final FunctionObject fn; public SimpleOperator ( FunctionObject f ) { fn = f; http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java index 743b1f3..a02c3d2 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/UnaryOperator.java @@ -27,29 +27,24 @@ public class UnaryOperator extends Operator { private static final long serialVersionUID = 2441990876648978637L; - public ValueFunction fn; - private int k; //num threads + public final ValueFunction fn; + private final int k; //num threads public UnaryOperator(ValueFunction p) { this(p, 1); //default single-threaded } - public UnaryOperator(ValueFunction p, int numThreads) - { + public UnaryOperator(ValueFunction p, int numThreads) { + super(p instanceof Builtin && + ((Builtin)p).bFunc==Builtin.BuiltinCode.SIN || ((Builtin)p).bFunc==Builtin.BuiltinCode.TAN + // sinh and tanh are zero only at zero, else they are nnz + || ((Builtin)p).bFunc==Builtin.BuiltinCode.SINH || ((Builtin)p).bFunc==Builtin.BuiltinCode.TANH + || ((Builtin)p).bFunc==Builtin.BuiltinCode.ROUND || ((Builtin)p).bFunc==Builtin.BuiltinCode.ABS + || ((Builtin)p).bFunc==Builtin.BuiltinCode.SQRT || ((Builtin)p).bFunc==Builtin.BuiltinCode.SPROP + || ((Builtin)p).bFunc==Builtin.BuiltinCode.SELP || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ + || ((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN ); fn = p; - sparseSafe = false; k = numThreads; - - if( fn instanceof Builtin ) { - Builtin f=(Builtin)fn; - sparseSafe = (f.bFunc==Builtin.BuiltinCode.SIN || f.bFunc==Builtin.BuiltinCode.TAN - // sinh and tanh are zero only at zero, else they are nnz - || f.bFunc==Builtin.BuiltinCode.SINH || f.bFunc==Builtin.BuiltinCode.TANH - || f.bFunc==Builtin.BuiltinCode.ROUND || f.bFunc==Builtin.BuiltinCode.ABS - || f.bFunc==Builtin.BuiltinCode.SQRT || f.bFunc==Builtin.BuiltinCode.SPROP - || f.bFunc==Builtin.BuiltinCode.SELP || f.bFunc==Builtin.BuiltinCode.LOG_NZ - || f.bFunc==Builtin.BuiltinCode.SIGN ); - } } public int getNumThreads() { http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/main/java/org/apache/sysml/runtime/matrix/operators/ZeroOutOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/ZeroOutOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/ZeroOutOperator.java index 18b70ad..109c293 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/ZeroOutOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/ZeroOutOperator.java @@ -22,11 +22,9 @@ package org.apache.sysml.runtime.matrix.operators; public class ZeroOutOperator extends Operator { - private static final long serialVersionUID = 8991309598821495444L; - public ZeroOutOperator() - { - sparseSafe=true; + public ZeroOutOperator() { + super(true); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/test/java/org/apache/sysml/test/integration/functions/codegen/CPlanVectorPrimitivesTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CPlanVectorPrimitivesTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CPlanVectorPrimitivesTest.java index 3b65281..36f9f58 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CPlanVectorPrimitivesTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CPlanVectorPrimitivesTest.java @@ -771,13 +771,13 @@ public class CPlanVectorPrimitivesTest extends AutomatedTestBase double[] ret2 = null; if( type1 == InputType.SCALAR ) { ScalarOperator bop = InstructionUtils.parseScalarBinaryOperator(opcode, true); - bop.setConstant(inA.max()); + bop = bop.setConstant(inA.max()); ret2 = DataConverter.convertToDoubleVector((MatrixBlock) in2.scalarOperations(bop, new MatrixBlock()), false); } else if( type2 == InputType.SCALAR ) { ScalarOperator bop = InstructionUtils.parseScalarBinaryOperator(opcode, false); - bop.setConstant(inB.max()); + bop = bop.setConstant(inB.max()); ret2 = DataConverter.convertToDoubleVector((MatrixBlock) in1.scalarOperations(bop, new MatrixBlock()), false); } http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParUnaryAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParUnaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParUnaryAggregateTest.java index a566682..c89356c 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParUnaryAggregateTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/LargeParUnaryAggregateTest.java @@ -1059,21 +1059,21 @@ public class LargeParUnaryAggregateTest extends AutomatedTestBase //prepare unary aggregate operator AggregateUnaryOperator auop = null; + int k = InfrastructureAnalyzer.getLocalParallelism(); switch (aggtype) { - case SUM: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); break; - case ROWSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+"); break; - case COLSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+"); break; - case SUMSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uasqk+"); break; - case ROWSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarsqk+"); break; - case COLSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacsqk+"); break; - case MAX: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax"); break; - case ROWMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax"); break; - case COLMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax"); break; - case MIN: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin"); break; - case ROWMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin"); break; - case COLMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin"); break; + case SUM: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+",k); break; + case ROWSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+",k); break; + case COLSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+",k); break; + case SUMSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uasqk+",k); break; + case ROWSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarsqk+",k); break; + case COLSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacsqk+",k); break; + case MAX: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax",k); break; + case ROWMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax",k); break; + case COLMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax",k); break; + case MIN: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin",k); break; + case ROWMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin",k); break; + case COLMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin",k); break; } - auop.setNumThreads(InfrastructureAnalyzer.getLocalParallelism()); //compress given matrix block CompressedMatrixBlock cmb = new CompressedMatrixBlock(mb); http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/test/java/org/apache/sysml/test/integration/functions/compress/ParUnaryAggregateTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/compress/ParUnaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/compress/ParUnaryAggregateTest.java index b883c21..7355645 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/compress/ParUnaryAggregateTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/compress/ParUnaryAggregateTest.java @@ -1058,21 +1058,21 @@ public class ParUnaryAggregateTest extends AutomatedTestBase //prepare unary aggregate operator AggregateUnaryOperator auop = null; + int k = InfrastructureAnalyzer.getLocalParallelism(); switch (aggtype) { - case SUM: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+"); break; - case ROWSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+"); break; - case COLSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+"); break; - case SUMSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uasqk+"); break; - case ROWSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarsqk+"); break; - case COLSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacsqk+"); break; - case MAX: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax"); break; - case ROWMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax"); break; - case COLMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax"); break; - case MIN: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin"); break; - case ROWMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin"); break; - case COLMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin"); break; + case SUM: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+",k); break; + case ROWSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+",k); break; + case COLSUMS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+",k); break; + case SUMSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uasqk+",k); break; + case ROWSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarsqk+",k); break; + case COLSUMSSQ: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacsqk+",k); break; + case MAX: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax",k); break; + case ROWMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax",k); break; + case COLMAXS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax",k); break; + case MIN: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin",k); break; + case ROWMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin",k); break; + case COLMINS: auop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin",k); break; } - auop.setNumThreads(InfrastructureAnalyzer.getLocalParallelism()); //compress given matrix block CompressedMatrixBlock cmb = new CompressedMatrixBlock(mb); http://git-wip-us.apache.org/repos/asf/systemml/blob/75b93f26/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCClonedPreparedScriptTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCClonedPreparedScriptTest.java b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCClonedPreparedScriptTest.java index d0667e0..ca80e47 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCClonedPreparedScriptTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCClonedPreparedScriptTest.java @@ -39,35 +39,68 @@ import org.apache.sysml.utils.Statistics; public class JMLCClonedPreparedScriptTest extends AutomatedTestBase { + //basic script with parfor loop + private static final String SCRIPT1 = + "X = matrix(7, 10, 10);" + + "R = matrix(0, 10, 1)" + + "parfor(i in 1:nrow(X))" + + " R[i,] = sum(X[i,])" + + "out = sum(R)" + + "write(out, 'tmp/out')"; + + //script with dml-bodied and external functions + private static final String SCRIPT2 = + "foo1 = externalFunction(int numInputs, boolean stretch, Matrix[double] A, Matrix[double] B, Matrix[double] C) " + + " return (Matrix[double] D)" + + " implemented in (classname='org.apache.sysml.udf.lib.MultiInputCbind', exectype='mem');" + + "foo2 = function(Matrix[double] A, Matrix[double] B, Matrix[double] C)" + + " return (Matrix[double] D) {" + + " while(FALSE){}" + + " D = cbind(A, B, C)" + + "}" + + "X = matrix(7, 10, 10);" + + "R = matrix(0, 10, 1)" + + "for(i in 1:nrow(X)) {" + + " D = foo1(3, FALSE, X[i,], X[i,], X[i,])" + + " E = foo2(D, D, D)" + + " R[i,] = sum(E)/9" + + "}" + + "out = sum(R)" + + "write(out, 'tmp/out')"; + + @Override public void setUp() { //do nothing } @Test - public void testSinglePreparedScript128() throws IOException { - runJMLCClonedTest(128, false); + public void testSinglePreparedScript1T128() throws IOException { + runJMLCClonedTest(SCRIPT1, 128, false); + } + + @Test + public void testClonedPreparedScript1T128() throws IOException { + runJMLCClonedTest(SCRIPT1, 128, true); + } + + @Test + public void testSinglePreparedScript2T128() throws IOException { + runJMLCClonedTest(SCRIPT2, 128, false); } @Test - public void testClonedPreparedScript128() throws IOException { - runJMLCClonedTest(128, true); + public void testClonedPreparedScript2T128() throws IOException { + runJMLCClonedTest(SCRIPT2, 128, true); } - private void runJMLCClonedTest(int num, boolean clone) + private void runJMLCClonedTest(String script, int num, boolean clone) throws IOException { int k = InfrastructureAnalyzer.getLocalParallelism(); boolean failed = false; try( Connection conn = new Connection() ) { - String script = - " X = matrix(7, 10, 10);" - + "R = matrix(0, 10, 1)" - + "parfor(i in 1:nrow(X))" - + " R[i,] = sum(X[i,])" - + "out = sum(R)" - + "write(out, 'tmp/out')"; DMLScript.STATISTICS = true; Statistics.reset(); PreparedScript pscript = conn.prepareScript(