http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java new file mode 100644 index 0000000..96fd310 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/TernarySPInstruction.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark; + +import java.io.Serializable; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.operators.TernaryOperator; + +import scala.Tuple2; + +public class TernarySPInstruction extends ComputationSPInstruction { + private TernarySPInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String opcode, String str) throws DMLRuntimeException { + super(SPType.Ternary, op, in1, in2, in3, out, opcode, str); + } + + public static TernarySPInstruction parseInstruction(String str) throws DMLRuntimeException + { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode=parts[0]; + CPOperand operand1 = new CPOperand(parts[1]); + CPOperand operand2 = new CPOperand(parts[2]); + CPOperand operand3 = new CPOperand(parts[3]); + CPOperand outOperand = new CPOperand(parts[4]); + TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode); + return new TernarySPInstruction(op,operand1, operand2, operand3, outOperand, opcode,str); + } + + @Override + public void processInstruction(ExecutionContext ec) + throws DMLRuntimeException + { + SparkExecutionContext sec = (SparkExecutionContext)ec; + JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = !input1.isMatrix() ? null : + sec.getBinaryBlockRDDHandleForVariable(input1.getName()); + JavaPairRDD<MatrixIndexes,MatrixBlock> in2 = !input2.isMatrix() ? null : + sec.getBinaryBlockRDDHandleForVariable(input2.getName()); + JavaPairRDD<MatrixIndexes,MatrixBlock> in3 = !input3.isMatrix() ? null : + sec.getBinaryBlockRDDHandleForVariable(input3.getName()); + MatrixBlock m1 = input1.isMatrix() ? null : + new MatrixBlock(ec.getScalarInput(input1).getDoubleValue()); + MatrixBlock m2 = input2.isMatrix() ? null : + new MatrixBlock(ec.getScalarInput(input2).getDoubleValue()); + MatrixBlock m3 = input3.isMatrix() ? null : + new MatrixBlock(ec.getScalarInput(input3).getDoubleValue()); + + TernaryOperator op = (TernaryOperator) _optr; + + JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; + if( input1.isMatrix() && !input2.isMatrix() && !input3.isMatrix() ) + out = in1.mapValues(new TernaryFunctionMSS(op, m1, m2, m3)); + else if( !input1.isMatrix() && input2.isMatrix() && !input3.isMatrix() ) + out = in2.mapValues(new TernaryFunctionSMS(op, m1, m2, m3)); + else if( !input1.isMatrix() && !input2.isMatrix() && input3.isMatrix() ) + out = in3.mapValues(new TernaryFunctionSSM(op, m1, m2, m3)); + else if( input1.isMatrix() && input2.isMatrix() && !input3.isMatrix() ) + out = in1.join(in2).mapValues(new TernaryFunctionMMS(op, m1, m2, m3)); + else if( input1.isMatrix() && !input2.isMatrix() && input3.isMatrix() ) + out = in1.join(in3).mapValues(new TernaryFunctionMSM(op, m1, m2, m3)); + else if( !input1.isMatrix() && input2.isMatrix() && input3.isMatrix() ) + out = in2.join(in3).mapValues(new TernaryFunctionSMM(op, m1, m2, m3)); + else // all matrices + out = in1.join(in2).join(in3).mapValues(new TernaryFunctionMMM(op, m1, m2, m3)); + + //set output RDD + updateTernaryOutputMatrixCharacteristics(sec); + sec.setRDDHandleForVariable(output.getName(), out); + if( input1.isMatrix() ) + sec.addLineageRDD(output.getName(), input1.getName()); + if( input2.isMatrix() ) + sec.addLineageRDD(output.getName(), input2.getName()); + if( input3.isMatrix() ) + sec.addLineageRDD(output.getName(), input3.getName()); + } + + protected void updateTernaryOutputMatrixCharacteristics(SparkExecutionContext sec) + throws DMLRuntimeException + { + MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName()); + for(CPOperand input : new CPOperand[]{input1, input2, input3}) + if( input.isMatrix() ) { + MatrixCharacteristics mc = sec.getMatrixCharacteristics(input.getName()); + if( mc.dimsKnown() ) + mcOut.set(mc); + } + } + + private static abstract class TernaryFunction implements Serializable { + private static final long serialVersionUID = 8345737737972434426L; + protected final TernaryOperator _op; + protected final MatrixBlock _m1, _m2, _m3; + public TernaryFunction(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + _op = op; _m1 = m1; _m2 = m2; _m3 = m3; + } + } + + private static class TernaryFunctionMSS extends TernaryFunction implements Function<MatrixBlock, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionMSS(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(MatrixBlock v1) throws Exception { + return v1.ternaryOperations(_op, _m2, _m3, new MatrixBlock()); + } + } + + private static class TernaryFunctionSMS extends TernaryFunction implements Function<MatrixBlock, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionSMS(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(MatrixBlock v1) throws Exception { + return _m1.ternaryOperations(_op, v1, _m3, new MatrixBlock()); + } + } + + private static class TernaryFunctionSSM extends TernaryFunction implements Function<MatrixBlock, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionSSM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(MatrixBlock v1) throws Exception { + return _m1.ternaryOperations(_op, _m2, v1, new MatrixBlock()); + } + } + + private static class TernaryFunctionMMS extends TernaryFunction implements Function<Tuple2<MatrixBlock,MatrixBlock>, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionMMS(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception { + return v1._1().ternaryOperations(_op, v1._2(), _m3, new MatrixBlock()); + } + } + + private static class TernaryFunctionMSM extends TernaryFunction implements Function<Tuple2<MatrixBlock,MatrixBlock>, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionMSM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception { + return v1._1().ternaryOperations(_op, _m2, v1._2(), new MatrixBlock()); + } + } + + private static class TernaryFunctionSMM extends TernaryFunction implements Function<Tuple2<MatrixBlock,MatrixBlock>, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionSMM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception { + return _m1.ternaryOperations(_op, v1._1(), v1._2(), new MatrixBlock()); + } + } + + private static class TernaryFunctionMMM extends TernaryFunction implements Function<Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>, MatrixBlock> { + private static final long serialVersionUID = 1L; + public TernaryFunctionMMM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) { + super(op, m1, m2, m3); + } + @Override + public MatrixBlock call(Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock> v1) throws Exception { + return v1._1()._1().ternaryOperations(_op, v1._1()._2(), v1._2(), new MatrixBlock()); + } + } +}
http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java index b447b23..0c06424 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/MatrixCharacteristics.java @@ -56,6 +56,7 @@ import org.apache.sysml.runtime.instructions.mr.ReorgInstruction; import org.apache.sysml.runtime.instructions.mr.ReplicateInstruction; import org.apache.sysml.runtime.instructions.mr.ScalarInstruction; import org.apache.sysml.runtime.instructions.mr.SeqInstruction; +import org.apache.sysml.runtime.instructions.mr.TernaryInstruction; import org.apache.sysml.runtime.instructions.mr.CtableInstruction; import org.apache.sysml.runtime.instructions.mr.UaggOuterChainInstruction; import org.apache.sysml.runtime.instructions.mr.UnaryInstruction; @@ -376,6 +377,9 @@ public class MatrixCharacteristics implements Serializable dimOut.set(mc1); } } + else if( ins instanceof TernaryInstruction ) { + dimOut.set(dims.get(ins.getInputIndexes()[0])); + } else if (ins instanceof CombineTernaryInstruction ) { CtableInstruction realIns=(CtableInstruction)ins; dimOut.set(dims.get(realIns.input1)); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 098427f..81cd3a9 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -78,6 +78,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator; import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator; import org.apache.sysml.runtime.matrix.operators.ReorgOperator; import org.apache.sysml.runtime.matrix.operators.ScalarOperator; +import org.apache.sysml.runtime.matrix.operators.TernaryOperator; import org.apache.sysml.runtime.matrix.operators.UnaryOperator; import org.apache.sysml.runtime.util.DataConverter; import org.apache.sysml.runtime.util.FastBufferedDataInputStream; @@ -157,6 +158,11 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab copy(that); } + public MatrixBlock(double val) { + reset(1, 1, false, 1, val); + nonZeros = (val != 0) ? 1 : 0; + } + /** * Constructs a sparse {@link MatrixBlock} with a given instance of a {@link SparseBlock} * @param rl number of rows @@ -2777,6 +2783,49 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab LibMatrixBincell.bincellOpInPlace(this, that, op); } + public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret) + throws DMLRuntimeException + { + //TODO perf for special cases like ifelse + + final int m = Math.max(Math.max(rlen, m2.rlen), m3.rlen); + final int n = Math.max(Math.max(clen, m2.clen), m3.clen); + + //error handling + if( (rlen != 1 && rlen != m) || (clen != 1 && clen != n) + || (m2.rlen != 1 && m2.rlen != m) || (m2.clen != 1 && m2.clen != n) + || (m3.rlen != 1 && m3.rlen != m) || (m3.clen != 1 && m3.clen != n) ) { + throw new DMLRuntimeException("Block sizes are not matched for ternary cell operations: " + + rlen + "x" + clen + " vs " + m2.rlen + "x" + m2.clen + " vs " + m3.rlen + "x" + m3.clen); + } + + //prepare inputs + final boolean s1 = (rlen==1 && clen==1); + final boolean s2 = (m2.rlen==1 && m2.clen==1); + final boolean s3 = (m3.rlen==1 && m3.clen==1); + final double d1 = s1 ? quickGetValue(0, 0) : Double.NaN; + final double d2 = s2 ? m2.quickGetValue(0, 0) : Double.NaN; + final double d3 = s3 ? m3.quickGetValue(0, 0) : Double.NaN; + + //prepare result + ret.reset(m, n, false); + ret.allocateDenseBlock(); + + //basic ternary operations + for( int i=0; i<m; i++ ) + for( int j=0; j<n; j++ ) { + double in1 = s1 ? d1 : quickGetValue(i, j); + double in2 = s2 ? d2 : m2.quickGetValue(i, j); + double in3 = s3 ? d3 : m3.quickGetValue(i, j); + ret.appendValue(i, j, op.fn.execute(in1, in2, in3)); + } + + //ensure correct output representation + ret.examSparsity(); + + return ret; + } + @Override public void incrementalAggregate(AggregateOperator aggOp, MatrixValue correction, MatrixValue newWithCorrection) http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java index 48af5e1..5245db5 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java @@ -39,14 +39,12 @@ import org.apache.sysml.runtime.functionobjects.IntegerDivide; import org.apache.sysml.runtime.functionobjects.LessThan; import org.apache.sysml.runtime.functionobjects.LessThanEquals; import org.apache.sysml.runtime.functionobjects.Minus; -import org.apache.sysml.runtime.functionobjects.MinusMultiply; import org.apache.sysml.runtime.functionobjects.MinusNz; import org.apache.sysml.runtime.functionobjects.Modulus; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.NotEquals; import org.apache.sysml.runtime.functionobjects.Or; import org.apache.sysml.runtime.functionobjects.Plus; -import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.Power; import org.apache.sysml.runtime.functionobjects.ValueFunction; import org.apache.sysml.runtime.functionobjects.Xor; @@ -61,7 +59,6 @@ public class BinaryOperator extends Operator implements Serializable //binaryop is sparse-safe iff (0 op 0) == 0 super (p instanceof Plus || p instanceof Multiply || p instanceof Minus || p instanceof And || p instanceof Or || p instanceof Xor - || p instanceof PlusMultiply || p instanceof MinusMultiply || p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor || p instanceof BitwShiftL || p instanceof BitwShiftR); fn = p; http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/main/java/org/apache/sysml/runtime/matrix/operators/TernaryOperator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/TernaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/TernaryOperator.java new file mode 100644 index 0000000..45be887 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/TernaryOperator.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.sysml.runtime.matrix.operators; + +import java.io.Serializable; + +import org.apache.sysml.runtime.functionobjects.IfElse; +import org.apache.sysml.runtime.functionobjects.MinusMultiply; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; +import org.apache.sysml.runtime.functionobjects.TernaryValueFunction; + +public class TernaryOperator extends Operator implements Serializable +{ + private static final long serialVersionUID = 3456088891054083634L; + + public final TernaryValueFunction fn; + + public TernaryOperator(TernaryValueFunction p) { + //ternaryop is sparse-safe iff (op 0 0 0) == 0 + super (p instanceof PlusMultiply || p instanceof MinusMultiply || p instanceof IfElse); + fn = p; + } + + @Override + public String toString() { + return "TernaryOperator("+fn.getClass().getSimpleName()+")"; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/test/java/org/apache/sysml/test/integration/functions/ternary/FullIfElseTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/FullIfElseTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/FullIfElseTest.java new file mode 100644 index 0000000..581a6cd --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/FullIfElseTest.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.functions.ternary; + +import java.util.HashMap; + +import org.junit.Test; + +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.lops.LopProperties.ExecType; +import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.TestConfiguration; +import org.apache.sysml.test.utils.TestUtils; + +public class FullIfElseTest extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "TernaryIfElse"; + + private final static String TEST_DIR = "functions/ternary/"; + private final static String TEST_CLASS_DIR = TEST_DIR + FullIfElseTest.class.getSimpleName() + "/"; + + private final static int rows = 2111; + private final static int cols = 30; + + private final static double sparsity1 = 0.6; + private final static double sparsity2 = 0.1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + } + + @Test + public void testScalarScalarScalarDenseCP() { + runIfElseTest(false, false, false, false, ExecType.CP); + } + + @Test + public void testMatrixScalarScalarDenseCP() { + runIfElseTest(true, false, false, false, ExecType.CP); + } + + @Test + public void testScalarMatrixScalarDenseCP() { + runIfElseTest(false, true, false, false, ExecType.CP); + } + + @Test + public void testMatrixMatrixScalarDenseCP() { + runIfElseTest(true, true, false, false, ExecType.CP); + } + + @Test + public void testScalarScalarMatrixDenseCP() { + runIfElseTest(false, false, true, false, ExecType.CP); + } + + @Test + public void testMatrixScalarMatrixDenseCP() { + runIfElseTest(true, false, true, false, ExecType.CP); + } + + @Test + public void testScalarMatrixMatrixDenseCP() { + runIfElseTest(false, true, true, false, ExecType.CP); + } + + @Test + public void testMatrixMatrixMatrixDenseCP() { + runIfElseTest(true, true, true, false, ExecType.CP); + } + + @Test + public void testScalarScalarScalarSparseCP() { + runIfElseTest(false, false, false, true, ExecType.CP); + } + + @Test + public void testMatrixScalarScalarSparseCP() { + runIfElseTest(true, false, false, true, ExecType.CP); + } + + @Test + public void testScalarMatrixScalarSparseCP() { + runIfElseTest(false, true, false, true, ExecType.CP); + } + + @Test + public void testMatrixMatrixScalarSparseCP() { + runIfElseTest(true, true, false, true, ExecType.CP); + } + + @Test + public void testScalarScalarMatrixSparseCP() { + runIfElseTest(false, false, true, true, ExecType.CP); + } + + @Test + public void testMatrixScalarMatrixSparseCP() { + runIfElseTest(true, false, true, true, ExecType.CP); + } + + @Test + public void testScalarMatrixMatrixSparseCP() { + runIfElseTest(false, true, true, true, ExecType.CP); + } + + @Test + public void testMatrixMatrixMatrixSparseCP() { + runIfElseTest(true, true, true, true, ExecType.CP); + } + + //SPARK + + @Test + public void testScalarScalarScalarDenseSP() { + runIfElseTest(false, false, false, false, ExecType.SPARK); + } + + @Test + public void testMatrixScalarScalarDenseSP() { + runIfElseTest(true, false, false, false, ExecType.SPARK); + } + + @Test + public void testScalarMatrixScalarDenseSP() { + runIfElseTest(false, true, false, false, ExecType.SPARK); + } + + @Test + public void testMatrixMatrixScalarDenseSP() { + runIfElseTest(true, true, false, false, ExecType.SPARK); + } + + @Test + public void testScalarScalarMatrixDenseSP() { + runIfElseTest(false, false, true, false, ExecType.SPARK); + } + + @Test + public void testMatrixScalarMatrixDenseSP() { + runIfElseTest(true, false, true, false, ExecType.SPARK); + } + + @Test + public void testScalarMatrixMatrixDenseSP() { + runIfElseTest(false, true, true, false, ExecType.SPARK); + } + + @Test + public void testMatrixMatrixMatrixDenseSP() { + runIfElseTest(true, true, true, false, ExecType.SPARK); + } + + @Test + public void testScalarScalarScalarSparseSP() { + runIfElseTest(false, false, false, true, ExecType.SPARK); + } + + @Test + public void testMatrixScalarScalarSparseSP() { + runIfElseTest(true, false, false, true, ExecType.SPARK); + } + + @Test + public void testScalarMatrixScalarSparseSP() { + runIfElseTest(false, true, false, true, ExecType.SPARK); + } + + @Test + public void testMatrixMatrixScalarSparseSP() { + runIfElseTest(true, true, false, true, ExecType.SPARK); + } + + @Test + public void testScalarScalarMatrixSparseSP() { + runIfElseTest(false, false, true, true, ExecType.SPARK); + } + + @Test + public void testMatrixScalarMatrixSparseSP() { + runIfElseTest(true, false, true, true, ExecType.SPARK); + } + + @Test + public void testScalarMatrixMatrixSparseSP() { + runIfElseTest(false, true, true, true, ExecType.SPARK); + } + + @Test + public void testMatrixMatrixMatrixSparseSP() { + runIfElseTest(true, true, true, true, ExecType.SPARK); + } + + //MR + + @Test + public void testScalarScalarScalarDenseMR() { + runIfElseTest(false, false, false, false, ExecType.MR); + } + + @Test + public void testMatrixScalarScalarDenseMR() { + runIfElseTest(true, false, false, false, ExecType.MR); + } + + @Test + public void testScalarMatrixScalarDenseMR() { + runIfElseTest(false, true, false, false, ExecType.MR); + } + + @Test + public void testMatrixMatrixScalarDenseMR() { + runIfElseTest(true, true, false, false, ExecType.MR); + } + + @Test + public void testScalarScalarMatrixDenseMR() { + runIfElseTest(false, false, true, false, ExecType.MR); + } + + @Test + public void testMatrixScalarMatrixDenseMR() { + runIfElseTest(true, false, true, false, ExecType.MR); + } + + @Test + public void testScalarMatrixMatrixDenseMR() { + runIfElseTest(false, true, true, false, ExecType.MR); + } + + @Test + public void testMatrixMatrixMatrixDenseMR() { + runIfElseTest(true, true, true, false, ExecType.MR); + } + + @Test + public void testScalarScalarScalarSparseMR() { + runIfElseTest(false, false, false, true, ExecType.MR); + } + + @Test + public void testMatrixScalarScalarSparseMR() { + runIfElseTest(true, false, false, true, ExecType.MR); + } + + @Test + public void testScalarMatrixScalarSparseMR() { + runIfElseTest(false, true, false, true, ExecType.MR); + } + + @Test + public void testMatrixMatrixScalarSparseMR() { + runIfElseTest(true, true, false, true, ExecType.MR); + } + + @Test + public void testScalarScalarMatrixSparseMR() { + runIfElseTest(false, false, true, true, ExecType.MR); + } + + @Test + public void testMatrixScalarMatrixSparseMR() { + runIfElseTest(true, false, true, true, ExecType.MR); + } + + @Test + public void testScalarMatrixMatrixSparseMR() { + runIfElseTest(false, true, true, true, ExecType.MR); + } + + @Test + public void testMatrixMatrixMatrixSparseMR() { + runIfElseTest(true, true, true, true, ExecType.MR); + } + + private void runIfElseTest(boolean matrix1, boolean matrix2, boolean matrix3, boolean sparse, ExecType et) + { + //rtplatform for MR + RUNTIME_PLATFORM platformOld = rtplatform; + switch( et ){ + case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break; + case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break; + default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break; + } + + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if( rtplatform == RUNTIME_PLATFORM.SPARK || rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK ) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + + try + { + TestConfiguration config = getTestConfiguration(TEST_NAME1); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[]{"-explain","-args", input("A"), input("B"), input("C"), output("R")}; + fullRScriptName = HOME + TEST_NAME1 + ".R"; + rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir(); + + //generate actual datasets (matrices and scalars) + double sparsity = sparse ? sparsity2 : sparsity1; + double[][] A = matrix1 ? getRandomMatrix(rows, cols, 0, 1, sparsity, 1) : getScalar(1); + writeInputMatrixWithMTD("A", A, true); + double[][] B = matrix2 ? getRandomMatrix(rows, cols, 0, 1, sparsity, 2) : getScalar(2); + writeInputMatrixWithMTD("B", B, true); + double[][] C = matrix3 ? getRandomMatrix(rows, cols, 0, 1, sparsity, 3) : getScalar(3); + writeInputMatrixWithMTD("C", C, true); + + //run test cases + runTest(true, false, null, -1); + runRScript(true); + + //compare output matrices + HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R"); + HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R"); + TestUtils.compareMatrices(dmlfile, rfile, 0, "Stat-DML", "Stat-R"); + } + finally { + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } + + private double[][] getScalar(int input) { + return new double[][]{{7d*input}}; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/test/scripts/functions/ternary/TernaryIfElse.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/TernaryIfElse.R b/src/test/scripts/functions/ternary/TernaryIfElse.R new file mode 100644 index 0000000..a1e8a03 --- /dev/null +++ b/src/test/scripts/functions/ternary/TernaryIfElse.R @@ -0,0 +1,45 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = as.matrix(readMM(paste(args[1], "B.mtx", sep=""))) +C = as.matrix(readMM(paste(args[1], "C.mtx", sep=""))) +m = max(max(nrow(A), nrow(B)), nrow(C)) +n = max(max(ncol(A), ncol(B)), ncol(C)) + +if( nrow(A)==1 ) { + A = matrix(A, m, n); +} +if( nrow(B)==1 ) { + B = matrix(B, m, n); +} +if( nrow(C)==1 ) { + C = matrix(C, m, n); +} + +R = matrix(ifelse(as.vector(A), as.vector(B), as.vector(C)), m, n); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/test/scripts/functions/ternary/TernaryIfElse.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/ternary/TernaryIfElse.dml b/src/test/scripts/functions/ternary/TernaryIfElse.dml new file mode 100644 index 0000000..12a11cc --- /dev/null +++ b/src/test/scripts/functions/ternary/TernaryIfElse.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); +B = read($2); +C = read($3); + +if( nrow(A)==1 & nrow(B)==1 & nrow(C)==1 ) + R = as.matrix(ifelse(as.scalar(A), as.scalar(B), as.scalar(C))); +else if( nrow(A)>1 & nrow(B)==1 & nrow(C)==1 ) + R = ifelse(A, as.scalar(B), as.scalar(C)); +else if( nrow(A)==1 & nrow(B)>1 & nrow(C)==1 ) + R = ifelse(as.scalar(A), B, as.scalar(C)); +else if( nrow(A)>1 & nrow(B)>1 & nrow(C)==1 ) + R = ifelse(A, B, as.scalar(C)); +else if( nrow(A)==1 & nrow(B)==1 & nrow(C)>1 ) + R = ifelse(as.scalar(A), as.scalar(B), C); +else if( nrow(A)>1 & nrow(B)==1 & nrow(C)>1 ) + R = ifelse(A, as.scalar(B), C); +else if( nrow(A)==1 & nrow(B)>1 & nrow(C)>1 ) + R = ifelse(as.scalar(A), B, C); +else if( nrow(A)>1 & nrow(B)>1 & nrow(C)>1 ) + R = ifelse(A, B, C); + +write(R, $4); http://git-wip-us.apache.org/repos/asf/systemml/blob/ce9e42fe/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java index ee14359..4b83df3 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java @@ -31,6 +31,7 @@ import org.junit.runners.Suite; CovarianceWeightsTest.class, CTableMatrixIgnoreZerosTest.class, CTableSequenceTest.class, + FullIfElseTest.class, QuantileWeightsTest.class, TableOutputTest.class, TernaryAggregateTest.class,
